mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:05:01 +08:00
[Bugfix] Refactor /invocations to be task-agnostic (#20764)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
7bd4c37ae7
commit
cbd14ed561
@ -1113,10 +1113,7 @@ async def test_http_chat_no_model_name_with_curl(server: RemoteOpenAIServer):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME, ""])
|
||||
async def test_http_chat_no_model_name_with_openai(server: RemoteOpenAIServer,
|
||||
model_name: str):
|
||||
|
||||
async def test_http_chat_no_model_name_with_openai(server: RemoteOpenAIServer):
|
||||
openai_api_key = "EMPTY"
|
||||
openai_api_base = f"http://localhost:{server.port}/v1"
|
||||
|
||||
@ -1135,3 +1132,35 @@ async def test_http_chat_no_model_name_with_openai(server: RemoteOpenAIServer,
|
||||
messages=messages,
|
||||
)
|
||||
assert response.model == MODEL_NAME
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invocations(server: RemoteOpenAIServer,
|
||||
client: openai.AsyncOpenAI):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "what is 1+1?"
|
||||
}]
|
||||
|
||||
request_args = {
|
||||
"model": MODEL_NAME,
|
||||
"messages": messages,
|
||||
"max_completion_tokens": 5,
|
||||
"temperature": 0.0,
|
||||
"logprobs": False,
|
||||
}
|
||||
|
||||
chat_completion = await client.chat.completions.create(**request_args)
|
||||
|
||||
invocation_response = requests.post(server.url_for("invocations"),
|
||||
json=request_args)
|
||||
invocation_response.raise_for_status()
|
||||
|
||||
chat_output = chat_completion.model_dump()
|
||||
invocation_output = invocation_response.json()
|
||||
|
||||
assert chat_output.keys() == invocation_output.keys()
|
||||
assert chat_output["choices"] == invocation_output["choices"]
|
||||
|
||||
@ -155,3 +155,25 @@ def test_batch_classification_empty_list(server: RemoteOpenAIServer,
|
||||
assert output.object == "list"
|
||||
assert isinstance(output.data, list)
|
||||
assert len(output.data) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invocations(server: RemoteOpenAIServer):
|
||||
request_args = {
|
||||
"model": MODEL_NAME,
|
||||
"input": "This product was excellent and exceeded my expectations"
|
||||
}
|
||||
|
||||
classification_response = requests.post(server.url_for("classify"),
|
||||
json=request_args)
|
||||
classification_response.raise_for_status()
|
||||
|
||||
invocation_response = requests.post(server.url_for("invocations"),
|
||||
json=request_args)
|
||||
invocation_response.raise_for_status()
|
||||
|
||||
classification_output = classification_response.json()
|
||||
invocation_output = invocation_response.json()
|
||||
|
||||
assert classification_output.keys() == invocation_output.keys()
|
||||
assert classification_output["data"] == invocation_output["data"]
|
||||
|
||||
@ -11,6 +11,7 @@ import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import regex as re
|
||||
import requests
|
||||
# downloading lora to test lora requests
|
||||
from huggingface_hub import snapshot_download
|
||||
from openai import BadRequestError
|
||||
@ -833,3 +834,27 @@ async def test_echo_stream_completion(client: openai.AsyncOpenAI,
|
||||
assert content is not None and saying in content
|
||||
else:
|
||||
assert content is not None and saying not in content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invocations(server: RemoteOpenAIServer,
|
||||
client: openai.AsyncOpenAI):
|
||||
request_args = {
|
||||
"model": MODEL_NAME,
|
||||
"prompt": "Hello, my name is",
|
||||
"max_tokens": 5,
|
||||
"temperature": 0.0,
|
||||
"logprobs": None,
|
||||
}
|
||||
|
||||
completion = await client.completions.create(**request_args)
|
||||
|
||||
invocation_response = requests.post(server.url_for("invocations"),
|
||||
json=request_args)
|
||||
invocation_response.raise_for_status()
|
||||
|
||||
completion_output = completion.model_dump()
|
||||
invocation_output = invocation_response.json()
|
||||
|
||||
assert completion_output.keys() == invocation_output.keys()
|
||||
assert completion_output["choices"] == invocation_output["choices"]
|
||||
|
||||
@ -296,3 +296,63 @@ async def test_single_embedding_truncation_invalid(client: openai.AsyncOpenAI,
|
||||
assert "error" in response.object
|
||||
assert "truncate_prompt_tokens value is greater than max_model_len. "\
|
||||
"Please, select a smaller truncation size." in response.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invocations(server: RemoteOpenAIServer,
|
||||
client: openai.AsyncOpenAI):
|
||||
input_texts = [
|
||||
"The chef prepared a delicious meal.",
|
||||
]
|
||||
|
||||
request_args = {
|
||||
"model": MODEL_NAME,
|
||||
"input": input_texts,
|
||||
"encoding_format": "float",
|
||||
}
|
||||
|
||||
completion_response = await client.embeddings.create(**request_args)
|
||||
|
||||
invocation_response = requests.post(server.url_for("invocations"),
|
||||
json=request_args)
|
||||
invocation_response.raise_for_status()
|
||||
|
||||
completion_output = completion_response.model_dump()
|
||||
invocation_output = invocation_response.json()
|
||||
|
||||
assert completion_output.keys() == invocation_output.keys()
|
||||
assert completion_output["data"] == invocation_output["data"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invocations_conversation(server: RemoteOpenAIServer):
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": "The cat sat on the mat.",
|
||||
}, {
|
||||
"role": "assistant",
|
||||
"content": "A feline was resting on a rug.",
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Stars twinkle brightly in the night sky.",
|
||||
}]
|
||||
|
||||
request_args = {
|
||||
"model": MODEL_NAME,
|
||||
"messages": messages,
|
||||
"encoding_format": "float",
|
||||
}
|
||||
|
||||
chat_response = requests.post(server.url_for("v1/embeddings"),
|
||||
json=request_args)
|
||||
chat_response.raise_for_status()
|
||||
|
||||
invocation_response = requests.post(server.url_for("invocations"),
|
||||
json=request_args)
|
||||
invocation_response.raise_for_status()
|
||||
|
||||
chat_output = chat_response.json()
|
||||
invocation_output = invocation_response.json()
|
||||
|
||||
assert chat_output.keys() == invocation_output.keys()
|
||||
assert chat_output["data"] == invocation_output["data"]
|
||||
|
||||
@ -13,7 +13,7 @@ from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach"
|
||||
MODEL_NAME = "internlm/internlm2-1_8b-reward"
|
||||
DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501
|
||||
|
||||
|
||||
@ -21,15 +21,16 @@ DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' +
|
||||
def server():
|
||||
args = [
|
||||
"--task",
|
||||
"classify",
|
||||
"reward",
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--enforce-eager",
|
||||
"--max-model-len",
|
||||
"8192",
|
||||
"512",
|
||||
"--chat-template",
|
||||
DUMMY_CHAT_TEMPLATE,
|
||||
"--trust-remote-code",
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
@ -57,10 +58,10 @@ async def test_single_pooling(server: RemoteOpenAIServer, model_name: str):
|
||||
|
||||
assert poolings.id is not None
|
||||
assert len(poolings.data) == 1
|
||||
assert len(poolings.data[0].data) == 2
|
||||
assert len(poolings.data[0].data) == 8
|
||||
assert poolings.usage.completion_tokens == 0
|
||||
assert poolings.usage.prompt_tokens == 7
|
||||
assert poolings.usage.total_tokens == 7
|
||||
assert poolings.usage.prompt_tokens == 8
|
||||
assert poolings.usage.total_tokens == 8
|
||||
|
||||
# test using token IDs
|
||||
input_tokens = [1, 1, 1, 1, 1]
|
||||
@ -77,7 +78,7 @@ async def test_single_pooling(server: RemoteOpenAIServer, model_name: str):
|
||||
|
||||
assert poolings.id is not None
|
||||
assert len(poolings.data) == 1
|
||||
assert len(poolings.data[0].data) == 2
|
||||
assert len(poolings.data[0].data) == 5
|
||||
assert poolings.usage.completion_tokens == 0
|
||||
assert poolings.usage.prompt_tokens == 5
|
||||
assert poolings.usage.total_tokens == 5
|
||||
@ -104,10 +105,10 @@ async def test_batch_pooling(server: RemoteOpenAIServer, model_name: str):
|
||||
|
||||
assert poolings.id is not None
|
||||
assert len(poolings.data) == 3
|
||||
assert len(poolings.data[0].data) == 2
|
||||
assert len(poolings.data[0].data) == 8
|
||||
assert poolings.usage.completion_tokens == 0
|
||||
assert poolings.usage.prompt_tokens == 25
|
||||
assert poolings.usage.total_tokens == 25
|
||||
assert poolings.usage.prompt_tokens == 29
|
||||
assert poolings.usage.total_tokens == 29
|
||||
|
||||
# test list[list[int]]
|
||||
input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24],
|
||||
@ -125,7 +126,7 @@ async def test_batch_pooling(server: RemoteOpenAIServer, model_name: str):
|
||||
|
||||
assert poolings.id is not None
|
||||
assert len(poolings.data) == 4
|
||||
assert len(poolings.data[0].data) == 2
|
||||
assert len(poolings.data[0].data) == 5
|
||||
assert poolings.usage.completion_tokens == 0
|
||||
assert poolings.usage.prompt_tokens == 17
|
||||
assert poolings.usage.total_tokens == 17
|
||||
@ -157,7 +158,11 @@ async def test_conversation_pooling(server: RemoteOpenAIServer,
|
||||
chat_response.raise_for_status()
|
||||
chat_poolings = PoolingResponse.model_validate(chat_response.json())
|
||||
|
||||
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
|
||||
tokenizer = get_tokenizer(
|
||||
tokenizer_name=model_name,
|
||||
tokenizer_mode="fast",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
chat_template=DUMMY_CHAT_TEMPLATE,
|
||||
@ -206,6 +211,9 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer,
|
||||
)
|
||||
float_response.raise_for_status()
|
||||
responses_float = PoolingResponse.model_validate(float_response.json())
|
||||
float_data = [
|
||||
np.array(d.data).squeeze(-1).tolist() for d in responses_float.data
|
||||
]
|
||||
|
||||
base64_response = requests.post(
|
||||
server.url_for("pooling"),
|
||||
@ -224,11 +232,10 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer,
|
||||
np.frombuffer(base64.b64decode(data.data),
|
||||
dtype="float32").tolist())
|
||||
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=[d.data for d in responses_float.data],
|
||||
embeddings_1_lst=decoded_responses_base64_data,
|
||||
name_0="float32",
|
||||
name_1="base64")
|
||||
check_embeddings_close(embeddings_0_lst=float_data,
|
||||
embeddings_1_lst=decoded_responses_base64_data,
|
||||
name_0="float32",
|
||||
name_1="base64")
|
||||
|
||||
# Default response is float32 decoded from base64 by OpenAI Client
|
||||
default_response = requests.post(
|
||||
@ -240,9 +247,71 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer,
|
||||
)
|
||||
default_response.raise_for_status()
|
||||
responses_default = PoolingResponse.model_validate(default_response.json())
|
||||
default_data = [
|
||||
np.array(d.data).squeeze(-1).tolist() for d in responses_default.data
|
||||
]
|
||||
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=[d.data for d in responses_default.data],
|
||||
embeddings_1_lst=[d.data for d in responses_default.data],
|
||||
name_0="float32",
|
||||
name_1="base64")
|
||||
check_embeddings_close(embeddings_0_lst=float_data,
|
||||
embeddings_1_lst=default_data,
|
||||
name_0="float32",
|
||||
name_1="default")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invocations(server: RemoteOpenAIServer):
|
||||
input_texts = [
|
||||
"The chef prepared a delicious meal.",
|
||||
]
|
||||
|
||||
request_args = {
|
||||
"model": MODEL_NAME,
|
||||
"input": input_texts,
|
||||
"encoding_format": "float",
|
||||
}
|
||||
|
||||
completion_response = requests.post(server.url_for("pooling"),
|
||||
json=request_args)
|
||||
completion_response.raise_for_status()
|
||||
|
||||
invocation_response = requests.post(server.url_for("invocations"),
|
||||
json=request_args)
|
||||
invocation_response.raise_for_status()
|
||||
|
||||
completion_output = completion_response.json()
|
||||
invocation_output = invocation_response.json()
|
||||
|
||||
assert completion_output.keys() == invocation_output.keys()
|
||||
assert completion_output["data"] == invocation_output["data"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invocations_conversation(server: RemoteOpenAIServer):
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": "The cat sat on the mat.",
|
||||
}, {
|
||||
"role": "assistant",
|
||||
"content": "A feline was resting on a rug.",
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Stars twinkle brightly in the night sky.",
|
||||
}]
|
||||
|
||||
request_args = {
|
||||
"model": MODEL_NAME,
|
||||
"messages": messages,
|
||||
"encoding_format": "float",
|
||||
}
|
||||
|
||||
chat_response = requests.post(server.url_for("pooling"), json=request_args)
|
||||
chat_response.raise_for_status()
|
||||
|
||||
invocation_response = requests.post(server.url_for("invocations"),
|
||||
json=request_args)
|
||||
invocation_response.raise_for_status()
|
||||
|
||||
chat_output = chat_response.json()
|
||||
invocation_output = invocation_response.json()
|
||||
|
||||
assert chat_output.keys() == invocation_output.keys()
|
||||
assert chat_output["data"] == invocation_output["data"]
|
||||
|
||||
@ -94,3 +94,30 @@ def test_rerank_max_model_len(server: RemoteOpenAIServer, model_name: str):
|
||||
# Assert just a small fragments of the response
|
||||
assert "Please reduce the length of the input." in \
|
||||
rerank_response.text
|
||||
|
||||
|
||||
def test_invocations(server: RemoteOpenAIServer):
|
||||
query = "What is the capital of France?"
|
||||
documents = [
|
||||
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
|
||||
]
|
||||
|
||||
request_args = {
|
||||
"model": MODEL_NAME,
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
}
|
||||
|
||||
rerank_response = requests.post(server.url_for("rerank"),
|
||||
json=request_args)
|
||||
rerank_response.raise_for_status()
|
||||
|
||||
invocation_response = requests.post(server.url_for("invocations"),
|
||||
json=request_args)
|
||||
invocation_response.raise_for_status()
|
||||
|
||||
rerank_output = rerank_response.json()
|
||||
invocation_output = invocation_response.json()
|
||||
|
||||
assert rerank_output.keys() == invocation_output.keys()
|
||||
assert rerank_output["results"] == invocation_output["results"]
|
||||
|
||||
@ -191,3 +191,28 @@ class TestModel:
|
||||
assert score_response.status_code == 400
|
||||
assert "Please, select a smaller truncation size." in \
|
||||
score_response.text
|
||||
|
||||
def test_invocations(self, server: RemoteOpenAIServer, model: dict[str,
|
||||
Any]):
|
||||
text_1 = "What is the capital of France?"
|
||||
text_2 = "The capital of France is Paris."
|
||||
|
||||
request_args = {
|
||||
"model": model["name"],
|
||||
"text_1": text_1,
|
||||
"text_2": text_2,
|
||||
}
|
||||
|
||||
score_response = requests.post(server.url_for("score"),
|
||||
json=request_args)
|
||||
score_response.raise_for_status()
|
||||
|
||||
invocation_response = requests.post(server.url_for("invocations"),
|
||||
json=request_args)
|
||||
invocation_response.raise_for_status()
|
||||
|
||||
score_output = score_response.json()
|
||||
invocation_output = invocation_response.json()
|
||||
|
||||
assert score_output.keys() == invocation_output.keys()
|
||||
assert score_output["data"] == invocation_output["data"]
|
||||
|
||||
@ -18,7 +18,7 @@ from collections.abc import AsyncIterator, Awaitable
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import partial
|
||||
from http import HTTPStatus
|
||||
from typing import Annotated, Any, Optional
|
||||
from typing import Annotated, Any, Callable, Optional
|
||||
|
||||
import prometheus_client
|
||||
import pydantic
|
||||
@ -61,13 +61,9 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
CompletionResponse,
|
||||
DetokenizeRequest,
|
||||
DetokenizeResponse,
|
||||
EmbeddingChatRequest,
|
||||
EmbeddingCompletionRequest,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse, ErrorResponse,
|
||||
LoadLoRAAdapterRequest,
|
||||
PoolingChatRequest,
|
||||
PoolingCompletionRequest,
|
||||
PoolingRequest, PoolingResponse,
|
||||
RerankRequest, RerankResponse,
|
||||
ResponsesRequest,
|
||||
@ -434,6 +430,7 @@ async def get_server_load_metrics(request: Request):
|
||||
# - /v1/chat/completions
|
||||
# - /v1/completions
|
||||
# - /v1/audio/transcriptions
|
||||
# - /v1/audio/translations
|
||||
# - /v1/embeddings
|
||||
# - /pooling
|
||||
# - /classify
|
||||
@ -957,31 +954,6 @@ async def do_rerank_v2(request: RerankRequest, raw_request: Request):
|
||||
return await do_rerank(request, raw_request)
|
||||
|
||||
|
||||
TASK_HANDLERS: dict[str, dict[str, tuple]] = {
|
||||
"generate": {
|
||||
"messages": (ChatCompletionRequest, create_chat_completion),
|
||||
"default": (CompletionRequest, create_completion),
|
||||
},
|
||||
"embed": {
|
||||
"messages": (EmbeddingChatRequest, create_embedding),
|
||||
"default": (EmbeddingCompletionRequest, create_embedding),
|
||||
},
|
||||
"score": {
|
||||
"default": (RerankRequest, do_rerank)
|
||||
},
|
||||
"rerank": {
|
||||
"default": (RerankRequest, do_rerank)
|
||||
},
|
||||
"reward": {
|
||||
"messages": (PoolingChatRequest, create_pooling),
|
||||
"default": (PoolingCompletionRequest, create_pooling),
|
||||
},
|
||||
"classify": {
|
||||
"messages": (PoolingChatRequest, create_pooling),
|
||||
"default": (PoolingCompletionRequest, create_pooling),
|
||||
},
|
||||
}
|
||||
|
||||
if envs.VLLM_SERVER_DEV_MODE:
|
||||
logger.warning("SECURITY WARNING: Development endpoints are enabled! "
|
||||
"This should NOT be used in production!")
|
||||
@ -1033,6 +1005,30 @@ if envs.VLLM_SERVER_DEV_MODE:
|
||||
return JSONResponse(content={"is_sleeping": is_sleeping})
|
||||
|
||||
|
||||
# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers
|
||||
# (requires typing_extensions >= 4.13)
|
||||
RequestType = Any
|
||||
GetHandlerFn = Callable[[Request], Optional[OpenAIServing]]
|
||||
EndpointFn = Callable[[RequestType, Request], Awaitable[Any]]
|
||||
|
||||
# NOTE: Items defined earlier take higher priority
|
||||
INVOCATION_TYPES: list[tuple[RequestType, tuple[GetHandlerFn, EndpointFn]]] = [
|
||||
(ChatCompletionRequest, (chat, create_chat_completion)),
|
||||
(CompletionRequest, (completion, create_completion)),
|
||||
(EmbeddingRequest, (embedding, create_embedding)),
|
||||
(ClassificationRequest, (classify, create_classify)),
|
||||
(ScoreRequest, (score, create_score)),
|
||||
(RerankRequest, (rerank, do_rerank)),
|
||||
(PoolingRequest, (pooling, create_pooling)),
|
||||
]
|
||||
|
||||
# NOTE: Construct the TypeAdapters only once
|
||||
INVOCATION_VALIDATORS = [
|
||||
(pydantic.TypeAdapter(request_type), (get_handler, endpoint))
|
||||
for request_type, (get_handler, endpoint) in INVOCATION_TYPES
|
||||
]
|
||||
|
||||
|
||||
@router.post("/invocations",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
@ -1047,32 +1043,34 @@ if envs.VLLM_SERVER_DEV_MODE:
|
||||
},
|
||||
})
|
||||
async def invocations(raw_request: Request):
|
||||
"""
|
||||
For SageMaker, routes requests to other handlers based on model `task`.
|
||||
"""
|
||||
"""For SageMaker, routes requests based on the request type."""
|
||||
try:
|
||||
body = await raw_request.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value,
|
||||
detail=f"JSON decode error: {e}") from e
|
||||
|
||||
task = raw_request.app.state.task
|
||||
valid_endpoints = [(validator, endpoint)
|
||||
for validator, (get_handler,
|
||||
endpoint) in INVOCATION_VALIDATORS
|
||||
if get_handler(raw_request) is not None]
|
||||
|
||||
if task not in TASK_HANDLERS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unsupported task: '{task}' for '/invocations'. "
|
||||
f"Expected one of {set(TASK_HANDLERS.keys())}")
|
||||
for request_validator, endpoint in valid_endpoints:
|
||||
try:
|
||||
request = request_validator.validate_python(body)
|
||||
except pydantic.ValidationError:
|
||||
continue
|
||||
|
||||
handler_config = TASK_HANDLERS[task]
|
||||
if "messages" in body:
|
||||
request_model, handler = handler_config["messages"]
|
||||
else:
|
||||
request_model, handler = handler_config["default"]
|
||||
return await endpoint(request, raw_request)
|
||||
|
||||
# this is required since we lose the FastAPI automatic casting
|
||||
request = request_model.model_validate(body)
|
||||
return await handler(request, raw_request)
|
||||
type_names = [
|
||||
t.__name__ if isinstance(t := validator._type, type) else str(t)
|
||||
for validator, _ in valid_endpoints
|
||||
]
|
||||
msg = ("Cannot find suitable handler for request. "
|
||||
f"Expected one of: {type_names}")
|
||||
res = base(raw_request).create_error_response(message=msg)
|
||||
return JSONResponse(content=res.model_dump(), status_code=res.code)
|
||||
|
||||
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
|
||||
@ -11,6 +11,12 @@ from typing import Annotated, Any, ClassVar, Literal, Optional, Union
|
||||
import regex as re
|
||||
import torch
|
||||
from fastapi import HTTPException, UploadFile
|
||||
# yapf: disable
|
||||
from openai.types.chat.chat_completion_audio import (
|
||||
ChatCompletionAudio as OpenAIChatCompletionAudio)
|
||||
from openai.types.chat.chat_completion_message import (
|
||||
Annotation as OpenAIAnnotation)
|
||||
# yapf: enable
|
||||
from openai.types.responses import (ResponseInputParam, ResponseOutputItem,
|
||||
ResponseOutputMessage, ResponsePrompt,
|
||||
ResponseStatus, ResponseTextConfig)
|
||||
@ -1393,11 +1399,16 @@ class CompletionResponseChoice(OpenAIBaseModel):
|
||||
|
||||
class CompletionResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
|
||||
object: str = "text_completion"
|
||||
object: Literal["text_completion"] = "text_completion"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: list[CompletionResponseChoice]
|
||||
service_tier: Optional[Literal["auto", "default", "flex", "scale",
|
||||
"priority"]] = None
|
||||
system_fingerprint: Optional[str] = None
|
||||
usage: UsageInfo
|
||||
|
||||
# vLLM-specific fields that are not in OpenAI spec
|
||||
kv_transfer_params: Optional[dict[str, Any]] = Field(
|
||||
default=None, description="KVTransfer parameters.")
|
||||
|
||||
@ -1549,10 +1560,16 @@ class ExtractedToolCallInformation(BaseModel):
|
||||
|
||||
class ChatMessage(OpenAIBaseModel):
|
||||
role: str
|
||||
reasoning_content: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
refusal: Optional[str] = None
|
||||
annotations: Optional[OpenAIAnnotation] = None
|
||||
audio: Optional[OpenAIChatCompletionAudio] = None
|
||||
function_call: Optional[FunctionCall] = None
|
||||
tool_calls: list[ToolCall] = Field(default_factory=list)
|
||||
|
||||
# vLLM-specific fields that are not in OpenAI spec
|
||||
reasoning_content: Optional[str] = None
|
||||
|
||||
|
||||
class ChatCompletionLogProb(OpenAIBaseModel):
|
||||
token: str
|
||||
@ -1587,7 +1604,12 @@ class ChatCompletionResponse(OpenAIBaseModel):
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: list[ChatCompletionResponseChoice]
|
||||
service_tier: Optional[Literal["auto", "default", "flex", "scale",
|
||||
"priority"]] = None
|
||||
system_fingerprint: Optional[str] = None
|
||||
usage: UsageInfo
|
||||
|
||||
# vLLM-specific fields that are not in OpenAI spec
|
||||
prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None
|
||||
kv_transfer_params: Optional[dict[str, Any]] = Field(
|
||||
default=None, description="KVTransfer parameters.")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user