[Frontend] Add rerank support to run_batch endpoint (#16278)

Signed-off-by: Pooya Davoodi <pooya.davoodi@parasail.io>
This commit is contained in:
Pooya Davoodi 2025-05-31 00:40:01 -07:00 committed by GitHub
parent 7782464a17
commit dff80b0e42
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 48 additions and 16 deletions

View File

@ -4,6 +4,8 @@ import json
import subprocess
import tempfile
import pytest
from vllm.entrypoints.openai.protocol import BatchRequestOutput
# ruff: noqa: E501
@ -23,9 +25,13 @@ INPUT_EMBEDDING_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "
{"custom_id": "request-3", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/multilingual-e5-small", "input": "Hello world!"}}
{"custom_id": "request-4", "method": "POST", "url": "/v1/embeddings", "body": {"model": "NonExistModel", "input": "Hello world!"}}"""
INPUT_SCORE_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/v1/score", "body": {"model": "BAAI/bge-reranker-v2-m3", "text_1": "What is the capital of France?", "text_2": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}}
INPUT_SCORE_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/score", "body": {"model": "BAAI/bge-reranker-v2-m3", "text_1": "What is the capital of France?", "text_2": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}}
{"custom_id": "request-2", "method": "POST", "url": "/v1/score", "body": {"model": "BAAI/bge-reranker-v2-m3", "text_1": "What is the capital of France?", "text_2": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}}"""
INPUT_RERANK_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/rerank", "body": {"model": "BAAI/bge-reranker-v2-m3", "query": "What is the capital of France?", "documents": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}}
{"custom_id": "request-2", "method": "POST", "url": "/v1/rerank", "body": {"model": "BAAI/bge-reranker-v2-m3", "query": "What is the capital of France?", "documents": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}}
{"custom_id": "request-2", "method": "POST", "url": "/v2/rerank", "body": {"model": "BAAI/bge-reranker-v2-m3", "query": "What is the capital of France?", "documents": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}}"""
def test_empty_file():
with tempfile.NamedTemporaryFile(
@ -105,11 +111,13 @@ def test_embeddings():
BatchRequestOutput.model_validate_json(line)
def test_score():
@pytest.mark.parametrize("input_batch",
[INPUT_SCORE_BATCH, INPUT_RERANK_BATCH])
def test_score(input_batch):
with tempfile.NamedTemporaryFile(
"w") as input_file, tempfile.NamedTemporaryFile(
"r") as output_file:
input_file.write(INPUT_SCORE_BATCH)
input_file.write(input_batch)
input_file.flush()
proc = subprocess.Popen([
"vllm",

View File

@ -1481,6 +1481,10 @@ class TranscriptionStreamResponse(OpenAIBaseModel):
usage: Optional[UsageInfo] = Field(default=None)
BatchRequestInputBody = Union[ChatCompletionRequest, EmbeddingRequest,
ScoreRequest, RerankRequest]
class BatchRequestInput(OpenAIBaseModel):
"""
The per-line object of the batch input file.
@ -1501,21 +1505,22 @@ class BatchRequestInput(OpenAIBaseModel):
url: str
# The parameters of the request.
body: Union[ChatCompletionRequest, EmbeddingRequest, ScoreRequest]
body: BatchRequestInputBody
@field_validator('body', mode='plain')
@classmethod
def check_type_for_url(cls, value: Any, info: ValidationInfo):
# Use url to disambiguate models
url = info.data['url']
url: str = info.data["url"]
if url == "/v1/chat/completions":
return ChatCompletionRequest.model_validate(value)
if url == "/v1/embeddings":
return TypeAdapter(EmbeddingRequest).validate_python(value)
if url == "/v1/score":
if url.endswith("/score"):
return ScoreRequest.model_validate(value)
return TypeAdapter(Union[ChatCompletionRequest, EmbeddingRequest,
ScoreRequest]).validate_python(value)
if url.endswith("/rerank"):
return RerankRequest.model_validate(value)
return TypeAdapter(BatchRequestInputBody).validate_python(value)
class BatchResponseData(OpenAIBaseModel):
@ -1527,7 +1532,7 @@ class BatchResponseData(OpenAIBaseModel):
# The body of the response.
body: Optional[Union[ChatCompletionResponse, EmbeddingResponse,
ScoreResponse]] = None
ScoreResponse, RerankResponse]] = None
class BatchRequestOutput(OpenAIBaseModel):

View File

@ -21,7 +21,7 @@ from vllm.entrypoints.openai.protocol import (BatchRequestInput,
BatchResponseData,
ChatCompletionResponse,
EmbeddingResponse, ErrorResponse,
ScoreResponse)
RerankResponse, ScoreResponse)
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
@ -274,8 +274,11 @@ async def run_request(serving_engine_func: Callable,
tracker: BatchProgressTracker) -> BatchRequestOutput:
response = await serving_engine_func(request.body)
if isinstance(response,
(ChatCompletionResponse, EmbeddingResponse, ScoreResponse)):
if isinstance(
response,
(ChatCompletionResponse, EmbeddingResponse, ScoreResponse,
RerankResponse),
):
batch_output = BatchRequestOutput(
id=f"vllm-{random_uuid()}",
custom_id=request.custom_id,
@ -397,7 +400,7 @@ async def main(args):
response_futures.append(
run_request(embed_handler_fn, request, tracker))
tracker.submitted()
elif request.url == "/v1/score":
elif request.url.endswith("/score"):
score_handler_fn = openai_serving_scores.create_score if \
openai_serving_scores is not None else None
if score_handler_fn is None:
@ -411,13 +414,29 @@ async def main(args):
response_futures.append(
run_request(score_handler_fn, request, tracker))
tracker.submitted()
elif request.url.endswith("/rerank"):
rerank_handler_fn = openai_serving_scores.do_rerank if \
openai_serving_scores is not None else None
if rerank_handler_fn is None:
response_futures.append(
make_async_error_request_output(
request,
error_msg="The model does not support Rerank API",
))
continue
response_futures.append(
run_request(rerank_handler_fn, request, tracker))
tracker.submitted()
else:
response_futures.append(
make_async_error_request_output(
request,
error_msg=
"Only /v1/chat/completions, /v1/embeddings, and /v1/score "
"are supported in the batch endpoint.",
error_msg=f"URL {request.url} was used. "
"Supported endpoints: /v1/chat/completions, /v1/embeddings,"
" /score, /rerank ."
"See vllm/entrypoints/openai/api_server.py for supported "
"score/rerank versions.",
))
with tracker.pbar():