[Bugfix] Refactor /invocations to be task-agnostic (#20764)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-07-11 18:20:54 +08:00 committed by GitHub
parent 7bd4c37ae7
commit cbd14ed561
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 352 additions and 75 deletions

View File

@ -1113,10 +1113,7 @@ async def test_http_chat_no_model_name_with_curl(server: RemoteOpenAIServer):
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME, ""])
async def test_http_chat_no_model_name_with_openai(server: RemoteOpenAIServer,
model_name: str):
async def test_http_chat_no_model_name_with_openai(server: RemoteOpenAIServer):
openai_api_key = "EMPTY"
openai_api_base = f"http://localhost:{server.port}/v1"
@ -1135,3 +1132,35 @@ async def test_http_chat_no_model_name_with_openai(server: RemoteOpenAIServer,
messages=messages,
)
assert response.model == MODEL_NAME
@pytest.mark.asyncio
async def test_invocations(server: RemoteOpenAIServer,
client: openai.AsyncOpenAI):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role": "user",
"content": "what is 1+1?"
}]
request_args = {
"model": MODEL_NAME,
"messages": messages,
"max_completion_tokens": 5,
"temperature": 0.0,
"logprobs": False,
}
chat_completion = await client.chat.completions.create(**request_args)
invocation_response = requests.post(server.url_for("invocations"),
json=request_args)
invocation_response.raise_for_status()
chat_output = chat_completion.model_dump()
invocation_output = invocation_response.json()
assert chat_output.keys() == invocation_output.keys()
assert chat_output["choices"] == invocation_output["choices"]

View File

@ -155,3 +155,25 @@ def test_batch_classification_empty_list(server: RemoteOpenAIServer,
assert output.object == "list"
assert isinstance(output.data, list)
assert len(output.data) == 0
@pytest.mark.asyncio
async def test_invocations(server: RemoteOpenAIServer):
request_args = {
"model": MODEL_NAME,
"input": "This product was excellent and exceeded my expectations"
}
classification_response = requests.post(server.url_for("classify"),
json=request_args)
classification_response.raise_for_status()
invocation_response = requests.post(server.url_for("invocations"),
json=request_args)
invocation_response.raise_for_status()
classification_output = classification_response.json()
invocation_output = invocation_response.json()
assert classification_output.keys() == invocation_output.keys()
assert classification_output["data"] == invocation_output["data"]

View File

@ -11,6 +11,7 @@ import openai # use the official client for correctness check
import pytest
import pytest_asyncio
import regex as re
import requests
# downloading lora to test lora requests
from huggingface_hub import snapshot_download
from openai import BadRequestError
@ -833,3 +834,27 @@ async def test_echo_stream_completion(client: openai.AsyncOpenAI,
assert content is not None and saying in content
else:
assert content is not None and saying not in content
@pytest.mark.asyncio
async def test_invocations(server: RemoteOpenAIServer,
client: openai.AsyncOpenAI):
request_args = {
"model": MODEL_NAME,
"prompt": "Hello, my name is",
"max_tokens": 5,
"temperature": 0.0,
"logprobs": None,
}
completion = await client.completions.create(**request_args)
invocation_response = requests.post(server.url_for("invocations"),
json=request_args)
invocation_response.raise_for_status()
completion_output = completion.model_dump()
invocation_output = invocation_response.json()
assert completion_output.keys() == invocation_output.keys()
assert completion_output["choices"] == invocation_output["choices"]

View File

@ -296,3 +296,63 @@ async def test_single_embedding_truncation_invalid(client: openai.AsyncOpenAI,
assert "error" in response.object
assert "truncate_prompt_tokens value is greater than max_model_len. "\
"Please, select a smaller truncation size." in response.message
@pytest.mark.asyncio
async def test_invocations(server: RemoteOpenAIServer,
client: openai.AsyncOpenAI):
input_texts = [
"The chef prepared a delicious meal.",
]
request_args = {
"model": MODEL_NAME,
"input": input_texts,
"encoding_format": "float",
}
completion_response = await client.embeddings.create(**request_args)
invocation_response = requests.post(server.url_for("invocations"),
json=request_args)
invocation_response.raise_for_status()
completion_output = completion_response.model_dump()
invocation_output = invocation_response.json()
assert completion_output.keys() == invocation_output.keys()
assert completion_output["data"] == invocation_output["data"]
@pytest.mark.asyncio
async def test_invocations_conversation(server: RemoteOpenAIServer):
messages = [{
"role": "user",
"content": "The cat sat on the mat.",
}, {
"role": "assistant",
"content": "A feline was resting on a rug.",
}, {
"role": "user",
"content": "Stars twinkle brightly in the night sky.",
}]
request_args = {
"model": MODEL_NAME,
"messages": messages,
"encoding_format": "float",
}
chat_response = requests.post(server.url_for("v1/embeddings"),
json=request_args)
chat_response.raise_for_status()
invocation_response = requests.post(server.url_for("invocations"),
json=request_args)
invocation_response.raise_for_status()
chat_output = chat_response.json()
invocation_output = invocation_response.json()
assert chat_output.keys() == invocation_output.keys()
assert chat_output["data"] == invocation_output["data"]

View File

@ -13,7 +13,7 @@ from vllm.transformers_utils.tokenizer import get_tokenizer
from ...utils import RemoteOpenAIServer
MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach"
MODEL_NAME = "internlm/internlm2-1_8b-reward"
DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501
@ -21,15 +21,16 @@ DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' +
def server():
args = [
"--task",
"classify",
"reward",
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--enforce-eager",
"--max-model-len",
"8192",
"512",
"--chat-template",
DUMMY_CHAT_TEMPLATE,
"--trust-remote-code",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
@ -57,10 +58,10 @@ async def test_single_pooling(server: RemoteOpenAIServer, model_name: str):
assert poolings.id is not None
assert len(poolings.data) == 1
assert len(poolings.data[0].data) == 2
assert len(poolings.data[0].data) == 8
assert poolings.usage.completion_tokens == 0
assert poolings.usage.prompt_tokens == 7
assert poolings.usage.total_tokens == 7
assert poolings.usage.prompt_tokens == 8
assert poolings.usage.total_tokens == 8
# test using token IDs
input_tokens = [1, 1, 1, 1, 1]
@ -77,7 +78,7 @@ async def test_single_pooling(server: RemoteOpenAIServer, model_name: str):
assert poolings.id is not None
assert len(poolings.data) == 1
assert len(poolings.data[0].data) == 2
assert len(poolings.data[0].data) == 5
assert poolings.usage.completion_tokens == 0
assert poolings.usage.prompt_tokens == 5
assert poolings.usage.total_tokens == 5
@ -104,10 +105,10 @@ async def test_batch_pooling(server: RemoteOpenAIServer, model_name: str):
assert poolings.id is not None
assert len(poolings.data) == 3
assert len(poolings.data[0].data) == 2
assert len(poolings.data[0].data) == 8
assert poolings.usage.completion_tokens == 0
assert poolings.usage.prompt_tokens == 25
assert poolings.usage.total_tokens == 25
assert poolings.usage.prompt_tokens == 29
assert poolings.usage.total_tokens == 29
# test list[list[int]]
input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24],
@ -125,7 +126,7 @@ async def test_batch_pooling(server: RemoteOpenAIServer, model_name: str):
assert poolings.id is not None
assert len(poolings.data) == 4
assert len(poolings.data[0].data) == 2
assert len(poolings.data[0].data) == 5
assert poolings.usage.completion_tokens == 0
assert poolings.usage.prompt_tokens == 17
assert poolings.usage.total_tokens == 17
@ -157,7 +158,11 @@ async def test_conversation_pooling(server: RemoteOpenAIServer,
chat_response.raise_for_status()
chat_poolings = PoolingResponse.model_validate(chat_response.json())
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
tokenizer = get_tokenizer(
tokenizer_name=model_name,
tokenizer_mode="fast",
trust_remote_code=True,
)
prompt = tokenizer.apply_chat_template(
messages,
chat_template=DUMMY_CHAT_TEMPLATE,
@ -206,6 +211,9 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer,
)
float_response.raise_for_status()
responses_float = PoolingResponse.model_validate(float_response.json())
float_data = [
np.array(d.data).squeeze(-1).tolist() for d in responses_float.data
]
base64_response = requests.post(
server.url_for("pooling"),
@ -224,8 +232,7 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer,
np.frombuffer(base64.b64decode(data.data),
dtype="float32").tolist())
check_embeddings_close(
embeddings_0_lst=[d.data for d in responses_float.data],
check_embeddings_close(embeddings_0_lst=float_data,
embeddings_1_lst=decoded_responses_base64_data,
name_0="float32",
name_1="base64")
@ -240,9 +247,71 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer,
)
default_response.raise_for_status()
responses_default = PoolingResponse.model_validate(default_response.json())
default_data = [
np.array(d.data).squeeze(-1).tolist() for d in responses_default.data
]
check_embeddings_close(
embeddings_0_lst=[d.data for d in responses_default.data],
embeddings_1_lst=[d.data for d in responses_default.data],
check_embeddings_close(embeddings_0_lst=float_data,
embeddings_1_lst=default_data,
name_0="float32",
name_1="base64")
name_1="default")
@pytest.mark.asyncio
async def test_invocations(server: RemoteOpenAIServer):
input_texts = [
"The chef prepared a delicious meal.",
]
request_args = {
"model": MODEL_NAME,
"input": input_texts,
"encoding_format": "float",
}
completion_response = requests.post(server.url_for("pooling"),
json=request_args)
completion_response.raise_for_status()
invocation_response = requests.post(server.url_for("invocations"),
json=request_args)
invocation_response.raise_for_status()
completion_output = completion_response.json()
invocation_output = invocation_response.json()
assert completion_output.keys() == invocation_output.keys()
assert completion_output["data"] == invocation_output["data"]
@pytest.mark.asyncio
async def test_invocations_conversation(server: RemoteOpenAIServer):
messages = [{
"role": "user",
"content": "The cat sat on the mat.",
}, {
"role": "assistant",
"content": "A feline was resting on a rug.",
}, {
"role": "user",
"content": "Stars twinkle brightly in the night sky.",
}]
request_args = {
"model": MODEL_NAME,
"messages": messages,
"encoding_format": "float",
}
chat_response = requests.post(server.url_for("pooling"), json=request_args)
chat_response.raise_for_status()
invocation_response = requests.post(server.url_for("invocations"),
json=request_args)
invocation_response.raise_for_status()
chat_output = chat_response.json()
invocation_output = invocation_response.json()
assert chat_output.keys() == invocation_output.keys()
assert chat_output["data"] == invocation_output["data"]

View File

@ -94,3 +94,30 @@ def test_rerank_max_model_len(server: RemoteOpenAIServer, model_name: str):
# Assert just a small fragments of the response
assert "Please reduce the length of the input." in \
rerank_response.text
def test_invocations(server: RemoteOpenAIServer):
query = "What is the capital of France?"
documents = [
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
]
request_args = {
"model": MODEL_NAME,
"query": query,
"documents": documents,
}
rerank_response = requests.post(server.url_for("rerank"),
json=request_args)
rerank_response.raise_for_status()
invocation_response = requests.post(server.url_for("invocations"),
json=request_args)
invocation_response.raise_for_status()
rerank_output = rerank_response.json()
invocation_output = invocation_response.json()
assert rerank_output.keys() == invocation_output.keys()
assert rerank_output["results"] == invocation_output["results"]

View File

@ -191,3 +191,28 @@ class TestModel:
assert score_response.status_code == 400
assert "Please, select a smaller truncation size." in \
score_response.text
def test_invocations(self, server: RemoteOpenAIServer, model: dict[str,
Any]):
text_1 = "What is the capital of France?"
text_2 = "The capital of France is Paris."
request_args = {
"model": model["name"],
"text_1": text_1,
"text_2": text_2,
}
score_response = requests.post(server.url_for("score"),
json=request_args)
score_response.raise_for_status()
invocation_response = requests.post(server.url_for("invocations"),
json=request_args)
invocation_response.raise_for_status()
score_output = score_response.json()
invocation_output = invocation_response.json()
assert score_output.keys() == invocation_output.keys()
assert score_output["data"] == invocation_output["data"]

View File

@ -18,7 +18,7 @@ from collections.abc import AsyncIterator, Awaitable
from contextlib import asynccontextmanager
from functools import partial
from http import HTTPStatus
from typing import Annotated, Any, Optional
from typing import Annotated, Any, Callable, Optional
import prometheus_client
import pydantic
@ -61,13 +61,9 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionResponse,
DetokenizeRequest,
DetokenizeResponse,
EmbeddingChatRequest,
EmbeddingCompletionRequest,
EmbeddingRequest,
EmbeddingResponse, ErrorResponse,
LoadLoRAAdapterRequest,
PoolingChatRequest,
PoolingCompletionRequest,
PoolingRequest, PoolingResponse,
RerankRequest, RerankResponse,
ResponsesRequest,
@ -434,6 +430,7 @@ async def get_server_load_metrics(request: Request):
# - /v1/chat/completions
# - /v1/completions
# - /v1/audio/transcriptions
# - /v1/audio/translations
# - /v1/embeddings
# - /pooling
# - /classify
@ -957,31 +954,6 @@ async def do_rerank_v2(request: RerankRequest, raw_request: Request):
return await do_rerank(request, raw_request)
TASK_HANDLERS: dict[str, dict[str, tuple]] = {
"generate": {
"messages": (ChatCompletionRequest, create_chat_completion),
"default": (CompletionRequest, create_completion),
},
"embed": {
"messages": (EmbeddingChatRequest, create_embedding),
"default": (EmbeddingCompletionRequest, create_embedding),
},
"score": {
"default": (RerankRequest, do_rerank)
},
"rerank": {
"default": (RerankRequest, do_rerank)
},
"reward": {
"messages": (PoolingChatRequest, create_pooling),
"default": (PoolingCompletionRequest, create_pooling),
},
"classify": {
"messages": (PoolingChatRequest, create_pooling),
"default": (PoolingCompletionRequest, create_pooling),
},
}
if envs.VLLM_SERVER_DEV_MODE:
logger.warning("SECURITY WARNING: Development endpoints are enabled! "
"This should NOT be used in production!")
@ -1033,6 +1005,30 @@ if envs.VLLM_SERVER_DEV_MODE:
return JSONResponse(content={"is_sleeping": is_sleeping})
# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers
# (requires typing_extensions >= 4.13)
RequestType = Any
GetHandlerFn = Callable[[Request], Optional[OpenAIServing]]
EndpointFn = Callable[[RequestType, Request], Awaitable[Any]]
# NOTE: Items defined earlier take higher priority
INVOCATION_TYPES: list[tuple[RequestType, tuple[GetHandlerFn, EndpointFn]]] = [
(ChatCompletionRequest, (chat, create_chat_completion)),
(CompletionRequest, (completion, create_completion)),
(EmbeddingRequest, (embedding, create_embedding)),
(ClassificationRequest, (classify, create_classify)),
(ScoreRequest, (score, create_score)),
(RerankRequest, (rerank, do_rerank)),
(PoolingRequest, (pooling, create_pooling)),
]
# NOTE: Construct the TypeAdapters only once
INVOCATION_VALIDATORS = [
(pydantic.TypeAdapter(request_type), (get_handler, endpoint))
for request_type, (get_handler, endpoint) in INVOCATION_TYPES
]
@router.post("/invocations",
dependencies=[Depends(validate_json_request)],
responses={
@ -1047,32 +1043,34 @@ if envs.VLLM_SERVER_DEV_MODE:
},
})
async def invocations(raw_request: Request):
"""
For SageMaker, routes requests to other handlers based on model `task`.
"""
"""For SageMaker, routes requests based on the request type."""
try:
body = await raw_request.json()
except json.JSONDecodeError as e:
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value,
detail=f"JSON decode error: {e}") from e
task = raw_request.app.state.task
valid_endpoints = [(validator, endpoint)
for validator, (get_handler,
endpoint) in INVOCATION_VALIDATORS
if get_handler(raw_request) is not None]
if task not in TASK_HANDLERS:
raise HTTPException(
status_code=400,
detail=f"Unsupported task: '{task}' for '/invocations'. "
f"Expected one of {set(TASK_HANDLERS.keys())}")
for request_validator, endpoint in valid_endpoints:
try:
request = request_validator.validate_python(body)
except pydantic.ValidationError:
continue
handler_config = TASK_HANDLERS[task]
if "messages" in body:
request_model, handler = handler_config["messages"]
else:
request_model, handler = handler_config["default"]
return await endpoint(request, raw_request)
# this is required since we lose the FastAPI automatic casting
request = request_model.model_validate(body)
return await handler(request, raw_request)
type_names = [
t.__name__ if isinstance(t := validator._type, type) else str(t)
for validator, _ in valid_endpoints
]
msg = ("Cannot find suitable handler for request. "
f"Expected one of: {type_names}")
res = base(raw_request).create_error_response(message=msg)
return JSONResponse(content=res.model_dump(), status_code=res.code)
if envs.VLLM_TORCH_PROFILER_DIR:

View File

@ -11,6 +11,12 @@ from typing import Annotated, Any, ClassVar, Literal, Optional, Union
import regex as re
import torch
from fastapi import HTTPException, UploadFile
# yapf: disable
from openai.types.chat.chat_completion_audio import (
ChatCompletionAudio as OpenAIChatCompletionAudio)
from openai.types.chat.chat_completion_message import (
Annotation as OpenAIAnnotation)
# yapf: enable
from openai.types.responses import (ResponseInputParam, ResponseOutputItem,
ResponseOutputMessage, ResponsePrompt,
ResponseStatus, ResponseTextConfig)
@ -1393,11 +1399,16 @@ class CompletionResponseChoice(OpenAIBaseModel):
class CompletionResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: str = "text_completion"
object: Literal["text_completion"] = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: list[CompletionResponseChoice]
service_tier: Optional[Literal["auto", "default", "flex", "scale",
"priority"]] = None
system_fingerprint: Optional[str] = None
usage: UsageInfo
# vLLM-specific fields that are not in OpenAI spec
kv_transfer_params: Optional[dict[str, Any]] = Field(
default=None, description="KVTransfer parameters.")
@ -1549,10 +1560,16 @@ class ExtractedToolCallInformation(BaseModel):
class ChatMessage(OpenAIBaseModel):
role: str
reasoning_content: Optional[str] = None
content: Optional[str] = None
refusal: Optional[str] = None
annotations: Optional[OpenAIAnnotation] = None
audio: Optional[OpenAIChatCompletionAudio] = None
function_call: Optional[FunctionCall] = None
tool_calls: list[ToolCall] = Field(default_factory=list)
# vLLM-specific fields that are not in OpenAI spec
reasoning_content: Optional[str] = None
class ChatCompletionLogProb(OpenAIBaseModel):
token: str
@ -1587,7 +1604,12 @@ class ChatCompletionResponse(OpenAIBaseModel):
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: list[ChatCompletionResponseChoice]
service_tier: Optional[Literal["auto", "default", "flex", "scale",
"priority"]] = None
system_fingerprint: Optional[str] = None
usage: UsageInfo
# vLLM-specific fields that are not in OpenAI spec
prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None
kv_transfer_params: Optional[dict[str, Any]] = Field(
default=None, description="KVTransfer parameters.")