mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 17:45:19 +08:00
[Frontend] Add rerank support to run_batch endpoint (#16278)
Signed-off-by: Pooya Davoodi <pooya.davoodi@parasail.io>
This commit is contained in:
parent
7782464a17
commit
dff80b0e42
@ -4,6 +4,8 @@ import json
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.protocol import BatchRequestOutput
|
||||
|
||||
# ruff: noqa: E501
|
||||
@ -23,9 +25,13 @@ INPUT_EMBEDDING_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "
|
||||
{"custom_id": "request-3", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/multilingual-e5-small", "input": "Hello world!"}}
|
||||
{"custom_id": "request-4", "method": "POST", "url": "/v1/embeddings", "body": {"model": "NonExistModel", "input": "Hello world!"}}"""
|
||||
|
||||
INPUT_SCORE_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/v1/score", "body": {"model": "BAAI/bge-reranker-v2-m3", "text_1": "What is the capital of France?", "text_2": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}}
|
||||
INPUT_SCORE_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/score", "body": {"model": "BAAI/bge-reranker-v2-m3", "text_1": "What is the capital of France?", "text_2": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}}
|
||||
{"custom_id": "request-2", "method": "POST", "url": "/v1/score", "body": {"model": "BAAI/bge-reranker-v2-m3", "text_1": "What is the capital of France?", "text_2": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}}"""
|
||||
|
||||
INPUT_RERANK_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/rerank", "body": {"model": "BAAI/bge-reranker-v2-m3", "query": "What is the capital of France?", "documents": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}}
|
||||
{"custom_id": "request-2", "method": "POST", "url": "/v1/rerank", "body": {"model": "BAAI/bge-reranker-v2-m3", "query": "What is the capital of France?", "documents": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}}
|
||||
{"custom_id": "request-2", "method": "POST", "url": "/v2/rerank", "body": {"model": "BAAI/bge-reranker-v2-m3", "query": "What is the capital of France?", "documents": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}}"""
|
||||
|
||||
|
||||
def test_empty_file():
|
||||
with tempfile.NamedTemporaryFile(
|
||||
@ -105,11 +111,13 @@ def test_embeddings():
|
||||
BatchRequestOutput.model_validate_json(line)
|
||||
|
||||
|
||||
def test_score():
|
||||
@pytest.mark.parametrize("input_batch",
|
||||
[INPUT_SCORE_BATCH, INPUT_RERANK_BATCH])
|
||||
def test_score(input_batch):
|
||||
with tempfile.NamedTemporaryFile(
|
||||
"w") as input_file, tempfile.NamedTemporaryFile(
|
||||
"r") as output_file:
|
||||
input_file.write(INPUT_SCORE_BATCH)
|
||||
input_file.write(input_batch)
|
||||
input_file.flush()
|
||||
proc = subprocess.Popen([
|
||||
"vllm",
|
||||
|
||||
@ -1481,6 +1481,10 @@ class TranscriptionStreamResponse(OpenAIBaseModel):
|
||||
usage: Optional[UsageInfo] = Field(default=None)
|
||||
|
||||
|
||||
BatchRequestInputBody = Union[ChatCompletionRequest, EmbeddingRequest,
|
||||
ScoreRequest, RerankRequest]
|
||||
|
||||
|
||||
class BatchRequestInput(OpenAIBaseModel):
|
||||
"""
|
||||
The per-line object of the batch input file.
|
||||
@ -1501,21 +1505,22 @@ class BatchRequestInput(OpenAIBaseModel):
|
||||
url: str
|
||||
|
||||
# The parameters of the request.
|
||||
body: Union[ChatCompletionRequest, EmbeddingRequest, ScoreRequest]
|
||||
body: BatchRequestInputBody
|
||||
|
||||
@field_validator('body', mode='plain')
|
||||
@classmethod
|
||||
def check_type_for_url(cls, value: Any, info: ValidationInfo):
|
||||
# Use url to disambiguate models
|
||||
url = info.data['url']
|
||||
url: str = info.data["url"]
|
||||
if url == "/v1/chat/completions":
|
||||
return ChatCompletionRequest.model_validate(value)
|
||||
if url == "/v1/embeddings":
|
||||
return TypeAdapter(EmbeddingRequest).validate_python(value)
|
||||
if url == "/v1/score":
|
||||
if url.endswith("/score"):
|
||||
return ScoreRequest.model_validate(value)
|
||||
return TypeAdapter(Union[ChatCompletionRequest, EmbeddingRequest,
|
||||
ScoreRequest]).validate_python(value)
|
||||
if url.endswith("/rerank"):
|
||||
return RerankRequest.model_validate(value)
|
||||
return TypeAdapter(BatchRequestInputBody).validate_python(value)
|
||||
|
||||
|
||||
class BatchResponseData(OpenAIBaseModel):
|
||||
@ -1527,7 +1532,7 @@ class BatchResponseData(OpenAIBaseModel):
|
||||
|
||||
# The body of the response.
|
||||
body: Optional[Union[ChatCompletionResponse, EmbeddingResponse,
|
||||
ScoreResponse]] = None
|
||||
ScoreResponse, RerankResponse]] = None
|
||||
|
||||
|
||||
class BatchRequestOutput(OpenAIBaseModel):
|
||||
|
||||
@ -21,7 +21,7 @@ from vllm.entrypoints.openai.protocol import (BatchRequestInput,
|
||||
BatchResponseData,
|
||||
ChatCompletionResponse,
|
||||
EmbeddingResponse, ErrorResponse,
|
||||
ScoreResponse)
|
||||
RerankResponse, ScoreResponse)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
@ -274,8 +274,11 @@ async def run_request(serving_engine_func: Callable,
|
||||
tracker: BatchProgressTracker) -> BatchRequestOutput:
|
||||
response = await serving_engine_func(request.body)
|
||||
|
||||
if isinstance(response,
|
||||
(ChatCompletionResponse, EmbeddingResponse, ScoreResponse)):
|
||||
if isinstance(
|
||||
response,
|
||||
(ChatCompletionResponse, EmbeddingResponse, ScoreResponse,
|
||||
RerankResponse),
|
||||
):
|
||||
batch_output = BatchRequestOutput(
|
||||
id=f"vllm-{random_uuid()}",
|
||||
custom_id=request.custom_id,
|
||||
@ -397,7 +400,7 @@ async def main(args):
|
||||
response_futures.append(
|
||||
run_request(embed_handler_fn, request, tracker))
|
||||
tracker.submitted()
|
||||
elif request.url == "/v1/score":
|
||||
elif request.url.endswith("/score"):
|
||||
score_handler_fn = openai_serving_scores.create_score if \
|
||||
openai_serving_scores is not None else None
|
||||
if score_handler_fn is None:
|
||||
@ -411,13 +414,29 @@ async def main(args):
|
||||
response_futures.append(
|
||||
run_request(score_handler_fn, request, tracker))
|
||||
tracker.submitted()
|
||||
elif request.url.endswith("/rerank"):
|
||||
rerank_handler_fn = openai_serving_scores.do_rerank if \
|
||||
openai_serving_scores is not None else None
|
||||
if rerank_handler_fn is None:
|
||||
response_futures.append(
|
||||
make_async_error_request_output(
|
||||
request,
|
||||
error_msg="The model does not support Rerank API",
|
||||
))
|
||||
continue
|
||||
|
||||
response_futures.append(
|
||||
run_request(rerank_handler_fn, request, tracker))
|
||||
tracker.submitted()
|
||||
else:
|
||||
response_futures.append(
|
||||
make_async_error_request_output(
|
||||
request,
|
||||
error_msg=
|
||||
"Only /v1/chat/completions, /v1/embeddings, and /v1/score "
|
||||
"are supported in the batch endpoint.",
|
||||
error_msg=f"URL {request.url} was used. "
|
||||
"Supported endpoints: /v1/chat/completions, /v1/embeddings,"
|
||||
" /score, /rerank ."
|
||||
"See vllm/entrypoints/openai/api_server.py for supported "
|
||||
"score/rerank versions.",
|
||||
))
|
||||
|
||||
with tracker.pbar():
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user