mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 16:15:34 +08:00
[Frontend] Add rerank support to run_batch endpoint (#16278)
Signed-off-by: Pooya Davoodi <pooya.davoodi@parasail.io>
This commit is contained in:
parent
7782464a17
commit
dff80b0e42
@ -4,6 +4,8 @@ import json
|
|||||||
import subprocess
|
import subprocess
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from vllm.entrypoints.openai.protocol import BatchRequestOutput
|
from vllm.entrypoints.openai.protocol import BatchRequestOutput
|
||||||
|
|
||||||
# ruff: noqa: E501
|
# ruff: noqa: E501
|
||||||
@ -23,9 +25,13 @@ INPUT_EMBEDDING_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "
|
|||||||
{"custom_id": "request-3", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/multilingual-e5-small", "input": "Hello world!"}}
|
{"custom_id": "request-3", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/multilingual-e5-small", "input": "Hello world!"}}
|
||||||
{"custom_id": "request-4", "method": "POST", "url": "/v1/embeddings", "body": {"model": "NonExistModel", "input": "Hello world!"}}"""
|
{"custom_id": "request-4", "method": "POST", "url": "/v1/embeddings", "body": {"model": "NonExistModel", "input": "Hello world!"}}"""
|
||||||
|
|
||||||
INPUT_SCORE_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/v1/score", "body": {"model": "BAAI/bge-reranker-v2-m3", "text_1": "What is the capital of France?", "text_2": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}}
|
INPUT_SCORE_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/score", "body": {"model": "BAAI/bge-reranker-v2-m3", "text_1": "What is the capital of France?", "text_2": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}}
|
||||||
{"custom_id": "request-2", "method": "POST", "url": "/v1/score", "body": {"model": "BAAI/bge-reranker-v2-m3", "text_1": "What is the capital of France?", "text_2": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}}"""
|
{"custom_id": "request-2", "method": "POST", "url": "/v1/score", "body": {"model": "BAAI/bge-reranker-v2-m3", "text_1": "What is the capital of France?", "text_2": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}}"""
|
||||||
|
|
||||||
|
INPUT_RERANK_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/rerank", "body": {"model": "BAAI/bge-reranker-v2-m3", "query": "What is the capital of France?", "documents": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}}
|
||||||
|
{"custom_id": "request-2", "method": "POST", "url": "/v1/rerank", "body": {"model": "BAAI/bge-reranker-v2-m3", "query": "What is the capital of France?", "documents": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}}
|
||||||
|
{"custom_id": "request-2", "method": "POST", "url": "/v2/rerank", "body": {"model": "BAAI/bge-reranker-v2-m3", "query": "What is the capital of France?", "documents": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}}"""
|
||||||
|
|
||||||
|
|
||||||
def test_empty_file():
|
def test_empty_file():
|
||||||
with tempfile.NamedTemporaryFile(
|
with tempfile.NamedTemporaryFile(
|
||||||
@ -105,11 +111,13 @@ def test_embeddings():
|
|||||||
BatchRequestOutput.model_validate_json(line)
|
BatchRequestOutput.model_validate_json(line)
|
||||||
|
|
||||||
|
|
||||||
def test_score():
|
@pytest.mark.parametrize("input_batch",
|
||||||
|
[INPUT_SCORE_BATCH, INPUT_RERANK_BATCH])
|
||||||
|
def test_score(input_batch):
|
||||||
with tempfile.NamedTemporaryFile(
|
with tempfile.NamedTemporaryFile(
|
||||||
"w") as input_file, tempfile.NamedTemporaryFile(
|
"w") as input_file, tempfile.NamedTemporaryFile(
|
||||||
"r") as output_file:
|
"r") as output_file:
|
||||||
input_file.write(INPUT_SCORE_BATCH)
|
input_file.write(input_batch)
|
||||||
input_file.flush()
|
input_file.flush()
|
||||||
proc = subprocess.Popen([
|
proc = subprocess.Popen([
|
||||||
"vllm",
|
"vllm",
|
||||||
|
|||||||
@ -1481,6 +1481,10 @@ class TranscriptionStreamResponse(OpenAIBaseModel):
|
|||||||
usage: Optional[UsageInfo] = Field(default=None)
|
usage: Optional[UsageInfo] = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
BatchRequestInputBody = Union[ChatCompletionRequest, EmbeddingRequest,
|
||||||
|
ScoreRequest, RerankRequest]
|
||||||
|
|
||||||
|
|
||||||
class BatchRequestInput(OpenAIBaseModel):
|
class BatchRequestInput(OpenAIBaseModel):
|
||||||
"""
|
"""
|
||||||
The per-line object of the batch input file.
|
The per-line object of the batch input file.
|
||||||
@ -1501,21 +1505,22 @@ class BatchRequestInput(OpenAIBaseModel):
|
|||||||
url: str
|
url: str
|
||||||
|
|
||||||
# The parameters of the request.
|
# The parameters of the request.
|
||||||
body: Union[ChatCompletionRequest, EmbeddingRequest, ScoreRequest]
|
body: BatchRequestInputBody
|
||||||
|
|
||||||
@field_validator('body', mode='plain')
|
@field_validator('body', mode='plain')
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_type_for_url(cls, value: Any, info: ValidationInfo):
|
def check_type_for_url(cls, value: Any, info: ValidationInfo):
|
||||||
# Use url to disambiguate models
|
# Use url to disambiguate models
|
||||||
url = info.data['url']
|
url: str = info.data["url"]
|
||||||
if url == "/v1/chat/completions":
|
if url == "/v1/chat/completions":
|
||||||
return ChatCompletionRequest.model_validate(value)
|
return ChatCompletionRequest.model_validate(value)
|
||||||
if url == "/v1/embeddings":
|
if url == "/v1/embeddings":
|
||||||
return TypeAdapter(EmbeddingRequest).validate_python(value)
|
return TypeAdapter(EmbeddingRequest).validate_python(value)
|
||||||
if url == "/v1/score":
|
if url.endswith("/score"):
|
||||||
return ScoreRequest.model_validate(value)
|
return ScoreRequest.model_validate(value)
|
||||||
return TypeAdapter(Union[ChatCompletionRequest, EmbeddingRequest,
|
if url.endswith("/rerank"):
|
||||||
ScoreRequest]).validate_python(value)
|
return RerankRequest.model_validate(value)
|
||||||
|
return TypeAdapter(BatchRequestInputBody).validate_python(value)
|
||||||
|
|
||||||
|
|
||||||
class BatchResponseData(OpenAIBaseModel):
|
class BatchResponseData(OpenAIBaseModel):
|
||||||
@ -1527,7 +1532,7 @@ class BatchResponseData(OpenAIBaseModel):
|
|||||||
|
|
||||||
# The body of the response.
|
# The body of the response.
|
||||||
body: Optional[Union[ChatCompletionResponse, EmbeddingResponse,
|
body: Optional[Union[ChatCompletionResponse, EmbeddingResponse,
|
||||||
ScoreResponse]] = None
|
ScoreResponse, RerankResponse]] = None
|
||||||
|
|
||||||
|
|
||||||
class BatchRequestOutput(OpenAIBaseModel):
|
class BatchRequestOutput(OpenAIBaseModel):
|
||||||
|
|||||||
@ -21,7 +21,7 @@ from vllm.entrypoints.openai.protocol import (BatchRequestInput,
|
|||||||
BatchResponseData,
|
BatchResponseData,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
EmbeddingResponse, ErrorResponse,
|
EmbeddingResponse, ErrorResponse,
|
||||||
ScoreResponse)
|
RerankResponse, ScoreResponse)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||||
@ -274,8 +274,11 @@ async def run_request(serving_engine_func: Callable,
|
|||||||
tracker: BatchProgressTracker) -> BatchRequestOutput:
|
tracker: BatchProgressTracker) -> BatchRequestOutput:
|
||||||
response = await serving_engine_func(request.body)
|
response = await serving_engine_func(request.body)
|
||||||
|
|
||||||
if isinstance(response,
|
if isinstance(
|
||||||
(ChatCompletionResponse, EmbeddingResponse, ScoreResponse)):
|
response,
|
||||||
|
(ChatCompletionResponse, EmbeddingResponse, ScoreResponse,
|
||||||
|
RerankResponse),
|
||||||
|
):
|
||||||
batch_output = BatchRequestOutput(
|
batch_output = BatchRequestOutput(
|
||||||
id=f"vllm-{random_uuid()}",
|
id=f"vllm-{random_uuid()}",
|
||||||
custom_id=request.custom_id,
|
custom_id=request.custom_id,
|
||||||
@ -397,7 +400,7 @@ async def main(args):
|
|||||||
response_futures.append(
|
response_futures.append(
|
||||||
run_request(embed_handler_fn, request, tracker))
|
run_request(embed_handler_fn, request, tracker))
|
||||||
tracker.submitted()
|
tracker.submitted()
|
||||||
elif request.url == "/v1/score":
|
elif request.url.endswith("/score"):
|
||||||
score_handler_fn = openai_serving_scores.create_score if \
|
score_handler_fn = openai_serving_scores.create_score if \
|
||||||
openai_serving_scores is not None else None
|
openai_serving_scores is not None else None
|
||||||
if score_handler_fn is None:
|
if score_handler_fn is None:
|
||||||
@ -411,13 +414,29 @@ async def main(args):
|
|||||||
response_futures.append(
|
response_futures.append(
|
||||||
run_request(score_handler_fn, request, tracker))
|
run_request(score_handler_fn, request, tracker))
|
||||||
tracker.submitted()
|
tracker.submitted()
|
||||||
|
elif request.url.endswith("/rerank"):
|
||||||
|
rerank_handler_fn = openai_serving_scores.do_rerank if \
|
||||||
|
openai_serving_scores is not None else None
|
||||||
|
if rerank_handler_fn is None:
|
||||||
|
response_futures.append(
|
||||||
|
make_async_error_request_output(
|
||||||
|
request,
|
||||||
|
error_msg="The model does not support Rerank API",
|
||||||
|
))
|
||||||
|
continue
|
||||||
|
|
||||||
|
response_futures.append(
|
||||||
|
run_request(rerank_handler_fn, request, tracker))
|
||||||
|
tracker.submitted()
|
||||||
else:
|
else:
|
||||||
response_futures.append(
|
response_futures.append(
|
||||||
make_async_error_request_output(
|
make_async_error_request_output(
|
||||||
request,
|
request,
|
||||||
error_msg=
|
error_msg=f"URL {request.url} was used. "
|
||||||
"Only /v1/chat/completions, /v1/embeddings, and /v1/score "
|
"Supported endpoints: /v1/chat/completions, /v1/embeddings,"
|
||||||
"are supported in the batch endpoint.",
|
" /score, /rerank ."
|
||||||
|
"See vllm/entrypoints/openai/api_server.py for supported "
|
||||||
|
"score/rerank versions.",
|
||||||
))
|
))
|
||||||
|
|
||||||
with tracker.pbar():
|
with tracker.pbar():
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user