From cbd14ed5613c6c20b3225e81f32a9bfe3a0d32ac Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 11 Jul 2025 18:20:54 +0800 Subject: [PATCH] [Bugfix] Refactor `/invocations` to be task-agnostic (#20764) Signed-off-by: DarkLight1337 --- tests/entrypoints/openai/test_chat.py | 37 +++++- .../entrypoints/openai/test_classification.py | 22 ++++ tests/entrypoints/openai/test_completion.py | 25 ++++ tests/entrypoints/openai/test_embedding.py | 60 ++++++++++ tests/entrypoints/openai/test_pooling.py | 113 ++++++++++++++---- tests/entrypoints/openai/test_rerank.py | 27 +++++ tests/entrypoints/openai/test_score.py | 25 ++++ vllm/entrypoints/openai/api_server.py | 92 +++++++------- vllm/entrypoints/openai/protocol.py | 26 +++- 9 files changed, 352 insertions(+), 75 deletions(-) diff --git a/tests/entrypoints/openai/test_chat.py b/tests/entrypoints/openai/test_chat.py index dab947b21b28..e7c3ffaa6a9f 100644 --- a/tests/entrypoints/openai/test_chat.py +++ b/tests/entrypoints/openai/test_chat.py @@ -1113,10 +1113,7 @@ async def test_http_chat_no_model_name_with_curl(server: RemoteOpenAIServer): @pytest.mark.asyncio -@pytest.mark.parametrize("model_name", [MODEL_NAME, ""]) -async def test_http_chat_no_model_name_with_openai(server: RemoteOpenAIServer, - model_name: str): - +async def test_http_chat_no_model_name_with_openai(server: RemoteOpenAIServer): openai_api_key = "EMPTY" openai_api_base = f"http://localhost:{server.port}/v1" @@ -1135,3 +1132,35 @@ async def test_http_chat_no_model_name_with_openai(server: RemoteOpenAIServer, messages=messages, ) assert response.model == MODEL_NAME + + +@pytest.mark.asyncio +async def test_invocations(server: RemoteOpenAIServer, + client: openai.AsyncOpenAI): + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": "user", + "content": "what is 1+1?" + }] + + request_args = { + "model": MODEL_NAME, + "messages": messages, + "max_completion_tokens": 5, + "temperature": 0.0, + "logprobs": False, + } + + chat_completion = await client.chat.completions.create(**request_args) + + invocation_response = requests.post(server.url_for("invocations"), + json=request_args) + invocation_response.raise_for_status() + + chat_output = chat_completion.model_dump() + invocation_output = invocation_response.json() + + assert chat_output.keys() == invocation_output.keys() + assert chat_output["choices"] == invocation_output["choices"] diff --git a/tests/entrypoints/openai/test_classification.py b/tests/entrypoints/openai/test_classification.py index 6d5f925152c3..330c7ff5c92f 100644 --- a/tests/entrypoints/openai/test_classification.py +++ b/tests/entrypoints/openai/test_classification.py @@ -155,3 +155,25 @@ def test_batch_classification_empty_list(server: RemoteOpenAIServer, assert output.object == "list" assert isinstance(output.data, list) assert len(output.data) == 0 + + +@pytest.mark.asyncio +async def test_invocations(server: RemoteOpenAIServer): + request_args = { + "model": MODEL_NAME, + "input": "This product was excellent and exceeded my expectations" + } + + classification_response = requests.post(server.url_for("classify"), + json=request_args) + classification_response.raise_for_status() + + invocation_response = requests.post(server.url_for("invocations"), + json=request_args) + invocation_response.raise_for_status() + + classification_output = classification_response.json() + invocation_output = invocation_response.json() + + assert classification_output.keys() == invocation_output.keys() + assert classification_output["data"] == invocation_output["data"] diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index 7933ca5cd6c6..df9586ee84de 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -11,6 +11,7 @@ import openai # use the official client for correctness check import pytest import pytest_asyncio import regex as re +import requests # downloading lora to test lora requests from huggingface_hub import snapshot_download from openai import BadRequestError @@ -833,3 +834,27 @@ async def test_echo_stream_completion(client: openai.AsyncOpenAI, assert content is not None and saying in content else: assert content is not None and saying not in content + + +@pytest.mark.asyncio +async def test_invocations(server: RemoteOpenAIServer, + client: openai.AsyncOpenAI): + request_args = { + "model": MODEL_NAME, + "prompt": "Hello, my name is", + "max_tokens": 5, + "temperature": 0.0, + "logprobs": None, + } + + completion = await client.completions.create(**request_args) + + invocation_response = requests.post(server.url_for("invocations"), + json=request_args) + invocation_response.raise_for_status() + + completion_output = completion.model_dump() + invocation_output = invocation_response.json() + + assert completion_output.keys() == invocation_output.keys() + assert completion_output["choices"] == invocation_output["choices"] diff --git a/tests/entrypoints/openai/test_embedding.py b/tests/entrypoints/openai/test_embedding.py index adb094127e40..143999edeafa 100644 --- a/tests/entrypoints/openai/test_embedding.py +++ b/tests/entrypoints/openai/test_embedding.py @@ -296,3 +296,63 @@ async def test_single_embedding_truncation_invalid(client: openai.AsyncOpenAI, assert "error" in response.object assert "truncate_prompt_tokens value is greater than max_model_len. "\ "Please, select a smaller truncation size." in response.message + + +@pytest.mark.asyncio +async def test_invocations(server: RemoteOpenAIServer, + client: openai.AsyncOpenAI): + input_texts = [ + "The chef prepared a delicious meal.", + ] + + request_args = { + "model": MODEL_NAME, + "input": input_texts, + "encoding_format": "float", + } + + completion_response = await client.embeddings.create(**request_args) + + invocation_response = requests.post(server.url_for("invocations"), + json=request_args) + invocation_response.raise_for_status() + + completion_output = completion_response.model_dump() + invocation_output = invocation_response.json() + + assert completion_output.keys() == invocation_output.keys() + assert completion_output["data"] == invocation_output["data"] + + +@pytest.mark.asyncio +async def test_invocations_conversation(server: RemoteOpenAIServer): + messages = [{ + "role": "user", + "content": "The cat sat on the mat.", + }, { + "role": "assistant", + "content": "A feline was resting on a rug.", + }, { + "role": "user", + "content": "Stars twinkle brightly in the night sky.", + }] + + request_args = { + "model": MODEL_NAME, + "messages": messages, + "encoding_format": "float", + } + + chat_response = requests.post(server.url_for("v1/embeddings"), + json=request_args) + chat_response.raise_for_status() + + invocation_response = requests.post(server.url_for("invocations"), + json=request_args) + invocation_response.raise_for_status() + + chat_output = chat_response.json() + invocation_output = invocation_response.json() + + assert chat_output.keys() == invocation_output.keys() + assert chat_output["data"] == invocation_output["data"] diff --git a/tests/entrypoints/openai/test_pooling.py b/tests/entrypoints/openai/test_pooling.py index 41c30e71684b..8752b128d54c 100644 --- a/tests/entrypoints/openai/test_pooling.py +++ b/tests/entrypoints/openai/test_pooling.py @@ -13,7 +13,7 @@ from vllm.transformers_utils.tokenizer import get_tokenizer from ...utils import RemoteOpenAIServer -MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach" +MODEL_NAME = "internlm/internlm2-1_8b-reward" DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501 @@ -21,15 +21,16 @@ DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + def server(): args = [ "--task", - "classify", + "reward", # use half precision for speed and memory savings in CI environment "--dtype", "bfloat16", "--enforce-eager", "--max-model-len", - "8192", + "512", "--chat-template", DUMMY_CHAT_TEMPLATE, + "--trust-remote-code", ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: @@ -57,10 +58,10 @@ async def test_single_pooling(server: RemoteOpenAIServer, model_name: str): assert poolings.id is not None assert len(poolings.data) == 1 - assert len(poolings.data[0].data) == 2 + assert len(poolings.data[0].data) == 8 assert poolings.usage.completion_tokens == 0 - assert poolings.usage.prompt_tokens == 7 - assert poolings.usage.total_tokens == 7 + assert poolings.usage.prompt_tokens == 8 + assert poolings.usage.total_tokens == 8 # test using token IDs input_tokens = [1, 1, 1, 1, 1] @@ -77,7 +78,7 @@ async def test_single_pooling(server: RemoteOpenAIServer, model_name: str): assert poolings.id is not None assert len(poolings.data) == 1 - assert len(poolings.data[0].data) == 2 + assert len(poolings.data[0].data) == 5 assert poolings.usage.completion_tokens == 0 assert poolings.usage.prompt_tokens == 5 assert poolings.usage.total_tokens == 5 @@ -104,10 +105,10 @@ async def test_batch_pooling(server: RemoteOpenAIServer, model_name: str): assert poolings.id is not None assert len(poolings.data) == 3 - assert len(poolings.data[0].data) == 2 + assert len(poolings.data[0].data) == 8 assert poolings.usage.completion_tokens == 0 - assert poolings.usage.prompt_tokens == 25 - assert poolings.usage.total_tokens == 25 + assert poolings.usage.prompt_tokens == 29 + assert poolings.usage.total_tokens == 29 # test list[list[int]] input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24], @@ -125,7 +126,7 @@ async def test_batch_pooling(server: RemoteOpenAIServer, model_name: str): assert poolings.id is not None assert len(poolings.data) == 4 - assert len(poolings.data[0].data) == 2 + assert len(poolings.data[0].data) == 5 assert poolings.usage.completion_tokens == 0 assert poolings.usage.prompt_tokens == 17 assert poolings.usage.total_tokens == 17 @@ -157,7 +158,11 @@ async def test_conversation_pooling(server: RemoteOpenAIServer, chat_response.raise_for_status() chat_poolings = PoolingResponse.model_validate(chat_response.json()) - tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast") + tokenizer = get_tokenizer( + tokenizer_name=model_name, + tokenizer_mode="fast", + trust_remote_code=True, + ) prompt = tokenizer.apply_chat_template( messages, chat_template=DUMMY_CHAT_TEMPLATE, @@ -206,6 +211,9 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer, ) float_response.raise_for_status() responses_float = PoolingResponse.model_validate(float_response.json()) + float_data = [ + np.array(d.data).squeeze(-1).tolist() for d in responses_float.data + ] base64_response = requests.post( server.url_for("pooling"), @@ -224,11 +232,10 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer, np.frombuffer(base64.b64decode(data.data), dtype="float32").tolist()) - check_embeddings_close( - embeddings_0_lst=[d.data for d in responses_float.data], - embeddings_1_lst=decoded_responses_base64_data, - name_0="float32", - name_1="base64") + check_embeddings_close(embeddings_0_lst=float_data, + embeddings_1_lst=decoded_responses_base64_data, + name_0="float32", + name_1="base64") # Default response is float32 decoded from base64 by OpenAI Client default_response = requests.post( @@ -240,9 +247,71 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer, ) default_response.raise_for_status() responses_default = PoolingResponse.model_validate(default_response.json()) + default_data = [ + np.array(d.data).squeeze(-1).tolist() for d in responses_default.data + ] - check_embeddings_close( - embeddings_0_lst=[d.data for d in responses_default.data], - embeddings_1_lst=[d.data for d in responses_default.data], - name_0="float32", - name_1="base64") + check_embeddings_close(embeddings_0_lst=float_data, + embeddings_1_lst=default_data, + name_0="float32", + name_1="default") + + +@pytest.mark.asyncio +async def test_invocations(server: RemoteOpenAIServer): + input_texts = [ + "The chef prepared a delicious meal.", + ] + + request_args = { + "model": MODEL_NAME, + "input": input_texts, + "encoding_format": "float", + } + + completion_response = requests.post(server.url_for("pooling"), + json=request_args) + completion_response.raise_for_status() + + invocation_response = requests.post(server.url_for("invocations"), + json=request_args) + invocation_response.raise_for_status() + + completion_output = completion_response.json() + invocation_output = invocation_response.json() + + assert completion_output.keys() == invocation_output.keys() + assert completion_output["data"] == invocation_output["data"] + + +@pytest.mark.asyncio +async def test_invocations_conversation(server: RemoteOpenAIServer): + messages = [{ + "role": "user", + "content": "The cat sat on the mat.", + }, { + "role": "assistant", + "content": "A feline was resting on a rug.", + }, { + "role": "user", + "content": "Stars twinkle brightly in the night sky.", + }] + + request_args = { + "model": MODEL_NAME, + "messages": messages, + "encoding_format": "float", + } + + chat_response = requests.post(server.url_for("pooling"), json=request_args) + chat_response.raise_for_status() + + invocation_response = requests.post(server.url_for("invocations"), + json=request_args) + invocation_response.raise_for_status() + + chat_output = chat_response.json() + invocation_output = invocation_response.json() + + assert chat_output.keys() == invocation_output.keys() + assert chat_output["data"] == invocation_output["data"] diff --git a/tests/entrypoints/openai/test_rerank.py b/tests/entrypoints/openai/test_rerank.py index e40bbca9a8ad..16a947bc3fea 100644 --- a/tests/entrypoints/openai/test_rerank.py +++ b/tests/entrypoints/openai/test_rerank.py @@ -94,3 +94,30 @@ def test_rerank_max_model_len(server: RemoteOpenAIServer, model_name: str): # Assert just a small fragments of the response assert "Please reduce the length of the input." in \ rerank_response.text + + +def test_invocations(server: RemoteOpenAIServer): + query = "What is the capital of France?" + documents = [ + "The capital of Brazil is Brasilia.", "The capital of France is Paris." + ] + + request_args = { + "model": MODEL_NAME, + "query": query, + "documents": documents, + } + + rerank_response = requests.post(server.url_for("rerank"), + json=request_args) + rerank_response.raise_for_status() + + invocation_response = requests.post(server.url_for("invocations"), + json=request_args) + invocation_response.raise_for_status() + + rerank_output = rerank_response.json() + invocation_output = invocation_response.json() + + assert rerank_output.keys() == invocation_output.keys() + assert rerank_output["results"] == invocation_output["results"] diff --git a/tests/entrypoints/openai/test_score.py b/tests/entrypoints/openai/test_score.py index 8927fe771809..4d3bbd9decc0 100644 --- a/tests/entrypoints/openai/test_score.py +++ b/tests/entrypoints/openai/test_score.py @@ -191,3 +191,28 @@ class TestModel: assert score_response.status_code == 400 assert "Please, select a smaller truncation size." in \ score_response.text + + def test_invocations(self, server: RemoteOpenAIServer, model: dict[str, + Any]): + text_1 = "What is the capital of France?" + text_2 = "The capital of France is Paris." + + request_args = { + "model": model["name"], + "text_1": text_1, + "text_2": text_2, + } + + score_response = requests.post(server.url_for("score"), + json=request_args) + score_response.raise_for_status() + + invocation_response = requests.post(server.url_for("invocations"), + json=request_args) + invocation_response.raise_for_status() + + score_output = score_response.json() + invocation_output = invocation_response.json() + + assert score_output.keys() == invocation_output.keys() + assert score_output["data"] == invocation_output["data"] diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index f0c486317c23..2f53357e1d4c 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -18,7 +18,7 @@ from collections.abc import AsyncIterator, Awaitable from contextlib import asynccontextmanager from functools import partial from http import HTTPStatus -from typing import Annotated, Any, Optional +from typing import Annotated, Any, Callable, Optional import prometheus_client import pydantic @@ -61,13 +61,9 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, CompletionResponse, DetokenizeRequest, DetokenizeResponse, - EmbeddingChatRequest, - EmbeddingCompletionRequest, EmbeddingRequest, EmbeddingResponse, ErrorResponse, LoadLoRAAdapterRequest, - PoolingChatRequest, - PoolingCompletionRequest, PoolingRequest, PoolingResponse, RerankRequest, RerankResponse, ResponsesRequest, @@ -434,6 +430,7 @@ async def get_server_load_metrics(request: Request): # - /v1/chat/completions # - /v1/completions # - /v1/audio/transcriptions + # - /v1/audio/translations # - /v1/embeddings # - /pooling # - /classify @@ -957,31 +954,6 @@ async def do_rerank_v2(request: RerankRequest, raw_request: Request): return await do_rerank(request, raw_request) -TASK_HANDLERS: dict[str, dict[str, tuple]] = { - "generate": { - "messages": (ChatCompletionRequest, create_chat_completion), - "default": (CompletionRequest, create_completion), - }, - "embed": { - "messages": (EmbeddingChatRequest, create_embedding), - "default": (EmbeddingCompletionRequest, create_embedding), - }, - "score": { - "default": (RerankRequest, do_rerank) - }, - "rerank": { - "default": (RerankRequest, do_rerank) - }, - "reward": { - "messages": (PoolingChatRequest, create_pooling), - "default": (PoolingCompletionRequest, create_pooling), - }, - "classify": { - "messages": (PoolingChatRequest, create_pooling), - "default": (PoolingCompletionRequest, create_pooling), - }, -} - if envs.VLLM_SERVER_DEV_MODE: logger.warning("SECURITY WARNING: Development endpoints are enabled! " "This should NOT be used in production!") @@ -1033,6 +1005,30 @@ if envs.VLLM_SERVER_DEV_MODE: return JSONResponse(content={"is_sleeping": is_sleeping}) +# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers +# (requires typing_extensions >= 4.13) +RequestType = Any +GetHandlerFn = Callable[[Request], Optional[OpenAIServing]] +EndpointFn = Callable[[RequestType, Request], Awaitable[Any]] + +# NOTE: Items defined earlier take higher priority +INVOCATION_TYPES: list[tuple[RequestType, tuple[GetHandlerFn, EndpointFn]]] = [ + (ChatCompletionRequest, (chat, create_chat_completion)), + (CompletionRequest, (completion, create_completion)), + (EmbeddingRequest, (embedding, create_embedding)), + (ClassificationRequest, (classify, create_classify)), + (ScoreRequest, (score, create_score)), + (RerankRequest, (rerank, do_rerank)), + (PoolingRequest, (pooling, create_pooling)), +] + +# NOTE: Construct the TypeAdapters only once +INVOCATION_VALIDATORS = [ + (pydantic.TypeAdapter(request_type), (get_handler, endpoint)) + for request_type, (get_handler, endpoint) in INVOCATION_TYPES +] + + @router.post("/invocations", dependencies=[Depends(validate_json_request)], responses={ @@ -1047,32 +1043,34 @@ if envs.VLLM_SERVER_DEV_MODE: }, }) async def invocations(raw_request: Request): - """ - For SageMaker, routes requests to other handlers based on model `task`. - """ + """For SageMaker, routes requests based on the request type.""" try: body = await raw_request.json() except json.JSONDecodeError as e: raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value, detail=f"JSON decode error: {e}") from e - task = raw_request.app.state.task + valid_endpoints = [(validator, endpoint) + for validator, (get_handler, + endpoint) in INVOCATION_VALIDATORS + if get_handler(raw_request) is not None] - if task not in TASK_HANDLERS: - raise HTTPException( - status_code=400, - detail=f"Unsupported task: '{task}' for '/invocations'. " - f"Expected one of {set(TASK_HANDLERS.keys())}") + for request_validator, endpoint in valid_endpoints: + try: + request = request_validator.validate_python(body) + except pydantic.ValidationError: + continue - handler_config = TASK_HANDLERS[task] - if "messages" in body: - request_model, handler = handler_config["messages"] - else: - request_model, handler = handler_config["default"] + return await endpoint(request, raw_request) - # this is required since we lose the FastAPI automatic casting - request = request_model.model_validate(body) - return await handler(request, raw_request) + type_names = [ + t.__name__ if isinstance(t := validator._type, type) else str(t) + for validator, _ in valid_endpoints + ] + msg = ("Cannot find suitable handler for request. " + f"Expected one of: {type_names}") + res = base(raw_request).create_error_response(message=msg) + return JSONResponse(content=res.model_dump(), status_code=res.code) if envs.VLLM_TORCH_PROFILER_DIR: diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index bfebe0ec0dc9..26c23a48e1d8 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -11,6 +11,12 @@ from typing import Annotated, Any, ClassVar, Literal, Optional, Union import regex as re import torch from fastapi import HTTPException, UploadFile +# yapf: disable +from openai.types.chat.chat_completion_audio import ( + ChatCompletionAudio as OpenAIChatCompletionAudio) +from openai.types.chat.chat_completion_message import ( + Annotation as OpenAIAnnotation) +# yapf: enable from openai.types.responses import (ResponseInputParam, ResponseOutputItem, ResponseOutputMessage, ResponsePrompt, ResponseStatus, ResponseTextConfig) @@ -1393,11 +1399,16 @@ class CompletionResponseChoice(OpenAIBaseModel): class CompletionResponse(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") - object: str = "text_completion" + object: Literal["text_completion"] = "text_completion" created: int = Field(default_factory=lambda: int(time.time())) model: str choices: list[CompletionResponseChoice] + service_tier: Optional[Literal["auto", "default", "flex", "scale", + "priority"]] = None + system_fingerprint: Optional[str] = None usage: UsageInfo + + # vLLM-specific fields that are not in OpenAI spec kv_transfer_params: Optional[dict[str, Any]] = Field( default=None, description="KVTransfer parameters.") @@ -1549,10 +1560,16 @@ class ExtractedToolCallInformation(BaseModel): class ChatMessage(OpenAIBaseModel): role: str - reasoning_content: Optional[str] = None content: Optional[str] = None + refusal: Optional[str] = None + annotations: Optional[OpenAIAnnotation] = None + audio: Optional[OpenAIChatCompletionAudio] = None + function_call: Optional[FunctionCall] = None tool_calls: list[ToolCall] = Field(default_factory=list) + # vLLM-specific fields that are not in OpenAI spec + reasoning_content: Optional[str] = None + class ChatCompletionLogProb(OpenAIBaseModel): token: str @@ -1587,7 +1604,12 @@ class ChatCompletionResponse(OpenAIBaseModel): created: int = Field(default_factory=lambda: int(time.time())) model: str choices: list[ChatCompletionResponseChoice] + service_tier: Optional[Literal["auto", "default", "flex", "scale", + "priority"]] = None + system_fingerprint: Optional[str] = None usage: UsageInfo + + # vLLM-specific fields that are not in OpenAI spec prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None kv_transfer_params: Optional[dict[str, Any]] = Field( default=None, description="KVTransfer parameters.")