diff --git a/tests/entrypoints/openai/test_run_batch.py b/tests/entrypoints/openai/test_run_batch.py index 27802945a216..99639ce51aa7 100644 --- a/tests/entrypoints/openai/test_run_batch.py +++ b/tests/entrypoints/openai/test_run_batch.py @@ -4,6 +4,8 @@ import json import subprocess import tempfile +import pytest + from vllm.entrypoints.openai.protocol import BatchRequestOutput # ruff: noqa: E501 @@ -23,9 +25,13 @@ INPUT_EMBEDDING_BATCH = """{"custom_id": "request-1", "method": "POST", "url": " {"custom_id": "request-3", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/multilingual-e5-small", "input": "Hello world!"}} {"custom_id": "request-4", "method": "POST", "url": "/v1/embeddings", "body": {"model": "NonExistModel", "input": "Hello world!"}}""" -INPUT_SCORE_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/v1/score", "body": {"model": "BAAI/bge-reranker-v2-m3", "text_1": "What is the capital of France?", "text_2": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}} +INPUT_SCORE_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/score", "body": {"model": "BAAI/bge-reranker-v2-m3", "text_1": "What is the capital of France?", "text_2": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}} {"custom_id": "request-2", "method": "POST", "url": "/v1/score", "body": {"model": "BAAI/bge-reranker-v2-m3", "text_1": "What is the capital of France?", "text_2": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}}""" +INPUT_RERANK_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/rerank", "body": {"model": "BAAI/bge-reranker-v2-m3", "query": "What is the capital of France?", "documents": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}} +{"custom_id": "request-2", "method": "POST", "url": "/v1/rerank", "body": {"model": "BAAI/bge-reranker-v2-m3", "query": "What is the capital of France?", "documents": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}} +{"custom_id": "request-2", "method": "POST", "url": "/v2/rerank", "body": {"model": "BAAI/bge-reranker-v2-m3", "query": "What is the capital of France?", "documents": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}}""" + def test_empty_file(): with tempfile.NamedTemporaryFile( @@ -105,11 +111,13 @@ def test_embeddings(): BatchRequestOutput.model_validate_json(line) -def test_score(): +@pytest.mark.parametrize("input_batch", + [INPUT_SCORE_BATCH, INPUT_RERANK_BATCH]) +def test_score(input_batch): with tempfile.NamedTemporaryFile( "w") as input_file, tempfile.NamedTemporaryFile( "r") as output_file: - input_file.write(INPUT_SCORE_BATCH) + input_file.write(input_batch) input_file.flush() proc = subprocess.Popen([ "vllm", diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index a7f85e9eef39..2f641079e584 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1481,6 +1481,10 @@ class TranscriptionStreamResponse(OpenAIBaseModel): usage: Optional[UsageInfo] = Field(default=None) +BatchRequestInputBody = Union[ChatCompletionRequest, EmbeddingRequest, + ScoreRequest, RerankRequest] + + class BatchRequestInput(OpenAIBaseModel): """ The per-line object of the batch input file. @@ -1501,21 +1505,22 @@ class BatchRequestInput(OpenAIBaseModel): url: str # The parameters of the request. - body: Union[ChatCompletionRequest, EmbeddingRequest, ScoreRequest] + body: BatchRequestInputBody @field_validator('body', mode='plain') @classmethod def check_type_for_url(cls, value: Any, info: ValidationInfo): # Use url to disambiguate models - url = info.data['url'] + url: str = info.data["url"] if url == "/v1/chat/completions": return ChatCompletionRequest.model_validate(value) if url == "/v1/embeddings": return TypeAdapter(EmbeddingRequest).validate_python(value) - if url == "/v1/score": + if url.endswith("/score"): return ScoreRequest.model_validate(value) - return TypeAdapter(Union[ChatCompletionRequest, EmbeddingRequest, - ScoreRequest]).validate_python(value) + if url.endswith("/rerank"): + return RerankRequest.model_validate(value) + return TypeAdapter(BatchRequestInputBody).validate_python(value) class BatchResponseData(OpenAIBaseModel): @@ -1527,7 +1532,7 @@ class BatchResponseData(OpenAIBaseModel): # The body of the response. body: Optional[Union[ChatCompletionResponse, EmbeddingResponse, - ScoreResponse]] = None + ScoreResponse, RerankResponse]] = None class BatchRequestOutput(OpenAIBaseModel): diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index f38465b22bcc..ac250b3cb4fb 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -21,7 +21,7 @@ from vllm.entrypoints.openai.protocol import (BatchRequestInput, BatchResponseData, ChatCompletionResponse, EmbeddingResponse, ErrorResponse, - ScoreResponse) + RerankResponse, ScoreResponse) # yapf: enable from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding @@ -274,8 +274,11 @@ async def run_request(serving_engine_func: Callable, tracker: BatchProgressTracker) -> BatchRequestOutput: response = await serving_engine_func(request.body) - if isinstance(response, - (ChatCompletionResponse, EmbeddingResponse, ScoreResponse)): + if isinstance( + response, + (ChatCompletionResponse, EmbeddingResponse, ScoreResponse, + RerankResponse), + ): batch_output = BatchRequestOutput( id=f"vllm-{random_uuid()}", custom_id=request.custom_id, @@ -397,7 +400,7 @@ async def main(args): response_futures.append( run_request(embed_handler_fn, request, tracker)) tracker.submitted() - elif request.url == "/v1/score": + elif request.url.endswith("/score"): score_handler_fn = openai_serving_scores.create_score if \ openai_serving_scores is not None else None if score_handler_fn is None: @@ -411,13 +414,29 @@ async def main(args): response_futures.append( run_request(score_handler_fn, request, tracker)) tracker.submitted() + elif request.url.endswith("/rerank"): + rerank_handler_fn = openai_serving_scores.do_rerank if \ + openai_serving_scores is not None else None + if rerank_handler_fn is None: + response_futures.append( + make_async_error_request_output( + request, + error_msg="The model does not support Rerank API", + )) + continue + + response_futures.append( + run_request(rerank_handler_fn, request, tracker)) + tracker.submitted() else: response_futures.append( make_async_error_request_output( request, - error_msg= - "Only /v1/chat/completions, /v1/embeddings, and /v1/score " - "are supported in the batch endpoint.", + error_msg=f"URL {request.url} was used. " + "Supported endpoints: /v1/chat/completions, /v1/embeddings," + " /score, /rerank ." + "See vllm/entrypoints/openai/api_server.py for supported " + "score/rerank versions.", )) with tracker.pbar():