[Frontend] Add /classify endpoint (#17032)

Signed-off-by: Frieda (Jingying) Huang <jingyingfhuang@gmail.com>
This commit is contained in:
Frieda Huang 2025-05-11 03:57:07 -04:00 committed by GitHub
parent d1110f5b5a
commit 9cea90eab4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 972 additions and 173 deletions

View File

@ -140,6 +140,7 @@ Our [OpenAI-Compatible Server](#openai-compatible-server) provides endpoints tha
- [Pooling API](#pooling-api) is similar to `LLM.encode`, being applicable to all types of pooling models.
- [Embeddings API](#embeddings-api) is similar to `LLM.embed`, accepting both text and [multi-modal inputs](#multimodal-inputs) for embedding models.
- [Classification API](#classification-api) is similar to `LLM.classify` and is applicable to sequence classification models.
- [Score API](#score-api) is similar to `LLM.score` for cross-encoder models.
## Matryoshka Embeddings

View File

@ -61,6 +61,8 @@ In addition, we have the following custom APIs:
- Applicable to any model with a tokenizer.
- [Pooling API](#pooling-api) (`/pooling`)
- Applicable to all [pooling models](../models/pooling_models.md).
- [Classification API](#classification-api) (`/classify`)
- Only applicable to [classification models](../models/pooling_models.md) (`--task classify`).
- [Score API](#score-api) (`/score`)
- Applicable to embedding models and [cross-encoder models](../models/pooling_models.md) (`--task score`).
- [Re-rank API](#rerank-api) (`/rerank`, `/v1/rerank`, `/v2/rerank`)
@ -443,6 +445,130 @@ The input format is the same as [Embeddings API](#embeddings-api), but the outpu
Code example: <gh-file:examples/online_serving/openai_pooling_client.py>
(classification-api)=
### Classification API
Our Classification API directly supports Hugging Face sequence-classification models such as [ai21labs/Jamba-tiny-reward-dev](https://huggingface.co/ai21labs/Jamba-tiny-reward-dev) and [jason9693/Qwen2.5-1.5B-apeach](https://huggingface.co/jason9693/Qwen2.5-1.5B-apeach).
We automatically wrap any other transformer via `as_classification_model()`, which pools on the last token, attaches a `RowParallelLinear` head, and applies a softmax to produce per-class probabilities.
Code example: <gh-file:examples/online_serving/openai_classification_client.py>
#### Example Requests
You can classify multiple texts by passing an array of strings:
Request:
```bash
curl -v "http://127.0.0.1:8000/classify" \
-H "Content-Type: application/json" \
-d '{
"model": "jason9693/Qwen2.5-1.5B-apeach",
"input": [
"Loved the new café—coffee was great.",
"This update broke everything. Frustrating."
]
}'
```
Response:
```bash
{
"id": "classify-7c87cac407b749a6935d8c7ce2a8fba2",
"object": "list",
"created": 1745383065,
"model": "jason9693/Qwen2.5-1.5B-apeach",
"data": [
{
"index": 0,
"label": "Default",
"probs": [
0.565970778465271,
0.4340292513370514
],
"num_classes": 2
},
{
"index": 1,
"label": "Spoiled",
"probs": [
0.26448777318000793,
0.7355121970176697
],
"num_classes": 2
}
],
"usage": {
"prompt_tokens": 20,
"total_tokens": 20,
"completion_tokens": 0,
"prompt_tokens_details": null
}
}
```
You can also pass a string directly to the `input` field:
Request:
```bash
curl -v "http://127.0.0.1:8000/classify" \
-H "Content-Type: application/json" \
-d '{
"model": "jason9693/Qwen2.5-1.5B-apeach",
"input": "Loved the new café—coffee was great."
}'
```
Response:
```bash
{
"id": "classify-9bf17f2847b046c7b2d5495f4b4f9682",
"object": "list",
"created": 1745383213,
"model": "jason9693/Qwen2.5-1.5B-apeach",
"data": [
{
"index": 0,
"label": "Default",
"probs": [
0.565970778465271,
0.4340292513370514
],
"num_classes": 2
}
],
"usage": {
"prompt_tokens": 10,
"total_tokens": 10,
"completion_tokens": 0,
"prompt_tokens_details": null
}
}
```
#### Extra parameters
The following [pooling parameters](#pooling-params) are supported.
:::{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python
:start-after: begin-classification-pooling-params
:end-before: end-classification-pooling-params
:::
The following extra parameters are supported:
:::{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python
:start-after: begin-classification-extra-params
:end-before: end-classification-extra-params
:::
(score-api)=
### Score API

View File

@ -0,0 +1,49 @@
# SPDX-License-Identifier: Apache-2.0
import argparse
import pprint
import requests
def post_http_request(payload: dict, api_url: str) -> requests.Response:
headers = {"User-Agent": "Test Client"}
response = requests.post(api_url, headers=headers, json=payload)
return response
def parse_args():
parse = argparse.ArgumentParser()
parse.add_argument("--host", type=str, default="localhost")
parse.add_argument("--port", type=int, default=8000)
parse.add_argument("--model",
type=str,
default="jason9693/Qwen2.5-1.5B-apeach")
return parse.parse_args()
def main(args):
host = args.host
port = args.port
model_name = args.model
api_url = f"http://{host}:{port}/classify"
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
payload = {
"model": model_name,
"input": prompts,
}
classify_response = post_http_request(payload=payload, api_url=api_url)
pprint.pprint(classify_response.json())
if __name__ == "__main__":
args = parse_args()
main(args)

View File

@ -0,0 +1,156 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
import requests
from vllm.entrypoints.openai.protocol import ClassificationResponse
from ...utils import RemoteOpenAIServer
MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach"
DTYPE = "float32" # Use float32 to avoid NaN issue
@pytest.fixture(scope="module")
def server():
args = [
"--enforce-eager",
"--max-model-len",
"512",
"--dtype",
DTYPE,
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_single_input_classification(server: RemoteOpenAIServer,
model_name: str):
input_text = "This product was excellent and exceeded my expectations"
classification_response = requests.post(
server.url_for("classify"),
json={
"model": model_name,
"input": input_text
},
)
classification_response.raise_for_status()
output = ClassificationResponse.model_validate(
classification_response.json())
assert output.object == "list"
assert output.model == MODEL_NAME
assert len(output.data) == 1
assert hasattr(output.data[0], "label")
assert hasattr(output.data[0], "probs")
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_multiple_inputs_classification(server: RemoteOpenAIServer,
model_name: str):
input_texts = [
"The product arrived on time and works perfectly",
"I'm very satisfied with my purchase, would buy again",
"The customer service was helpful and resolved my issue quickly",
"This product broke after one week, terrible quality",
"I'm very disappointed with this purchase, complete waste of money",
"The customer service was rude and unhelpful",
]
classification_response = requests.post(
server.url_for("classify"),
json={
"model": model_name,
"input": input_texts
},
)
output = ClassificationResponse.model_validate(
classification_response.json())
assert len(output.data) == len(input_texts)
for i, item in enumerate(output.data):
assert item.index == i
assert hasattr(item, "label")
assert hasattr(item, "probs")
assert len(item.probs) == item.num_classes
assert item.label in ["Default", "Spoiled"]
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_truncate_prompt_tokens(server: RemoteOpenAIServer, model_name: str):
long_text = "hello " * 600
classification_response = requests.post(
server.url_for("classify"),
json={
"model": model_name,
"input": long_text,
"truncate_prompt_tokens": 5
},
)
classification_response.raise_for_status()
output = ClassificationResponse.model_validate(
classification_response.json())
assert len(output.data) == 1
assert output.data[0].index == 0
assert hasattr(output.data[0], "probs")
assert output.usage.prompt_tokens == 5
assert output.usage.total_tokens == 5
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_invalid_truncate_prompt_tokens_error(server: RemoteOpenAIServer,
model_name: str):
classification_response = requests.post(
server.url_for("classify"),
json={
"model": model_name,
"input": "test",
"truncate_prompt_tokens": 513
},
)
error = classification_response.json()
assert classification_response.status_code == 400
assert error["object"] == "error"
assert "truncate_prompt_tokens" in error["message"]
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_empty_input_error(server: RemoteOpenAIServer, model_name: str):
classification_response = requests.post(
server.url_for("classify"),
json={
"model": model_name,
"input": ""
},
)
error = classification_response.json()
assert classification_response.status_code == 400
assert error["object"] == "error"
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_batch_classification_empty_list(server: RemoteOpenAIServer,
model_name: str):
classification_response = requests.post(
server.url_for("classify"),
json={
"model": model_name,
"input": []
},
)
classification_response.raise_for_status()
output = ClassificationResponse.model_validate(
classification_response.json())
assert output.object == "list"
assert isinstance(output.data, list)
assert len(output.data) == 0

View File

@ -48,6 +48,8 @@ from vllm.entrypoints.openai.cli_args import (log_non_default_args,
# yapf: disable
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionResponse,
ClassificationRequest,
ClassificationResponse,
CompletionRequest,
CompletionResponse,
DetokenizeRequest,
@ -71,6 +73,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
UnloadLoRAAdapterRequest)
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_classification import (
ServingClassification)
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_engine import OpenAIServing
@ -373,6 +377,10 @@ def score(request: Request) -> Optional[ServingScores]:
return request.app.state.openai_serving_scores
def classify(request: Request) -> Optional[ServingClassification]:
return request.app.state.openai_serving_classification
def rerank(request: Request) -> Optional[ServingScores]:
return request.app.state.openai_serving_scores
@ -405,6 +413,7 @@ async def get_server_load_metrics(request: Request):
# - /v1/audio/transcriptions
# - /v1/embeddings
# - /pooling
# - /classify
# - /score
# - /v1/score
# - /rerank
@ -572,6 +581,27 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
assert_never(generator)
@router.post("/classify", dependencies=[Depends(validate_json_request)])
@with_cancellation
@load_aware_call
async def create_classify(request: ClassificationRequest,
raw_request: Request):
handler = classify(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Classification API")
generator = await handler.create_classify(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
elif isinstance(generator, ClassificationResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
@router.post("/score", dependencies=[Depends(validate_json_request)])
@with_cancellation
@load_aware_call
@ -1001,6 +1031,12 @@ async def init_app_state(
state.openai_serving_models,
request_logger=request_logger) if model_config.task in (
"score", "embed", "pooling") else None
state.openai_serving_classification = ServingClassification(
engine_client,
model_config,
state.openai_serving_models,
request_logger=request_logger,
) if model_config.task == "classify" else None
state.jinaai_serving_reranking = ServingScores(
engine_client,
model_config,

View File

@ -1292,6 +1292,47 @@ class ScoreResponse(OpenAIBaseModel):
usage: UsageInfo
class ClassificationRequest(OpenAIBaseModel):
model: Optional[str] = None
input: Union[list[str], str]
truncate_prompt_tokens: Optional[int] = None
user: Optional[str] = None
# doc: begin-classification-pooling-params
additional_data: Optional[Any] = None
# doc: end-classification-pooling-params
# doc: begin-classification-extra-params
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."),
)
# doc: end-classification-extra-params
def to_pooling_params(self):
return PoolingParams(additional_data=self.additional_data)
class ClassificationData(OpenAIBaseModel):
index: int
label: Optional[str]
probs: list[float]
num_classes: int
class ClassificationResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"classify-{random_uuid()}")
object: str = "list"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
data: list[ClassificationData]
usage: UsageInfo
class FunctionCall(OpenAIBaseModel):
name: str
arguments: str

View File

@ -0,0 +1,159 @@
# SPDX-License-Identifier: Apache-2.0
from http import HTTPStatus
from typing import Optional, Union, cast
import numpy as np
from fastapi import Request
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (ClassificationData,
ClassificationRequest,
ClassificationResponse,
ErrorResponse, UsageInfo)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import (ClassificationServeContext,
OpenAIServing,
ServeContext)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.logger import init_logger
from vllm.outputs import ClassificationOutput, PoolingRequestOutput
logger = init_logger(__name__)
class ClassificationMixin(OpenAIServing):
async def _preprocess(
self,
ctx: ServeContext,
) -> Optional[ErrorResponse]:
"""
Process classification inputs: tokenize text, resolve adapters,
and prepare model-specific inputs.
"""
ctx = cast(ClassificationServeContext, ctx)
if isinstance(ctx.request.input, str) and not ctx.request.input:
return self.create_error_response(
"Input cannot be empty for classification",
status_code=HTTPStatus.BAD_REQUEST,
)
if isinstance(ctx.request.input, list) and len(ctx.request.input) == 0:
return None
try:
(
ctx.lora_request,
ctx.prompt_adapter_request,
) = self._maybe_get_adapters(ctx.request)
ctx.tokenizer = await self.engine_client.get_tokenizer(
ctx.lora_request)
if ctx.prompt_adapter_request is not None:
raise NotImplementedError(
"Prompt adapter is not supported for classification models"
)
(
ctx.request_prompts,
ctx.engine_prompts,
) = await self._preprocess_completion(
ctx.request,
ctx.tokenizer,
ctx.request.input,
truncate_prompt_tokens=ctx.request.truncate_prompt_tokens,
)
return None
except (ValueError, TypeError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
def _build_response(
self,
ctx: ServeContext,
) -> Union[ClassificationResponse, ErrorResponse]:
"""
Convert model outputs to a formatted classification response
with probabilities and labels.
"""
ctx = cast(ClassificationServeContext, ctx)
items: list[ClassificationData] = []
num_prompt_tokens = 0
final_res_batch_checked = cast(list[PoolingRequestOutput],
ctx.final_res_batch)
for idx, final_res in enumerate(final_res_batch_checked):
classify_res = ClassificationOutput.from_base(final_res.outputs)
probs = classify_res.probs
predicted_index = int(np.argmax(probs))
label = getattr(self.model_config.hf_config, "id2label",
{}).get(predicted_index)
item = ClassificationData(
index=idx,
label=label,
probs=probs,
num_classes=len(probs),
)
items.append(item)
prompt_token_ids = final_res.prompt_token_ids
num_prompt_tokens += len(prompt_token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
total_tokens=num_prompt_tokens,
)
return ClassificationResponse(
id=ctx.request_id,
created=ctx.created_time,
model=ctx.model_name,
data=items,
usage=usage,
)
class ServingClassification(ClassificationMixin):
request_id_prefix = "classify"
def __init__(
self,
engine_client: EngineClient,
model_config: ModelConfig,
models: OpenAIServingModels,
*,
request_logger: Optional[RequestLogger],
) -> None:
super().__init__(
engine_client=engine_client,
model_config=model_config,
models=models,
request_logger=request_logger,
)
async def create_classify(
self,
request: ClassificationRequest,
raw_request: Request,
) -> Union[ClassificationResponse, ErrorResponse]:
model_name = self._get_model_name(request.model)
request_id = (f"{self.request_id_prefix}-"
f"{self._base_request_id(raw_request)}")
ctx = ClassificationServeContext(
request=request,
raw_request=raw_request,
model_name=model_name,
request_id=request_id,
)
return await super().handle(ctx) # type: ignore

View File

@ -1,14 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import base64
import time
from collections.abc import AsyncGenerator
from typing import Final, Literal, Optional, Union, cast
import numpy as np
from fastapi import Request
from typing_extensions import assert_never
from typing_extensions import assert_never, override
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
@ -19,13 +16,13 @@ from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
EmbeddingResponse,
EmbeddingResponseData,
ErrorResponse, UsageInfo)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_engine import (EmbeddingServeContext,
OpenAIServing,
ServeContext)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.logger import init_logger
from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput,
PoolingRequestOutput)
from vllm.utils import merge_async_iterators
logger = init_logger(__name__)
@ -45,7 +42,99 @@ def _get_embedding(
assert_never(encoding_format)
class OpenAIServingEmbedding(OpenAIServing):
class EmbeddingMixin(OpenAIServing):
async def _preprocess(
self,
ctx: ServeContext,
) -> Optional[ErrorResponse]:
ctx = cast(EmbeddingServeContext, ctx)
try:
(
ctx.lora_request,
ctx.prompt_adapter_request,
) = self._maybe_get_adapters(ctx.request)
tokenizer = await self.engine_client.get_tokenizer(ctx.lora_request
)
if ctx.prompt_adapter_request is not None:
raise NotImplementedError("Prompt adapter is not supported "
"for embedding models")
if isinstance(ctx.request, EmbeddingChatRequest):
(
_,
ctx.request_prompts,
ctx.engine_prompts,
) = await self._preprocess_chat(
ctx.request,
tokenizer,
ctx.request.messages,
chat_template=ctx.request.chat_template
or ctx.chat_template,
chat_template_content_format=ctx.
chat_template_content_format,
# In embedding requests, we are not generating tokens,
# so there is no need to append extra tokens to the input
add_generation_prompt=False,
continue_final_message=False,
truncate_prompt_tokens=ctx.truncate_prompt_tokens,
add_special_tokens=ctx.request.add_special_tokens,
)
else:
(ctx.request_prompts,
ctx.engine_prompts) = await self._preprocess_completion(
ctx.request,
tokenizer,
ctx.request.input,
truncate_prompt_tokens=ctx.truncate_prompt_tokens,
add_special_tokens=ctx.request.add_special_tokens,
)
return None
except (ValueError, TypeError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
def _build_response(
self,
ctx: ServeContext,
) -> Union[EmbeddingResponse, ErrorResponse]:
items: list[EmbeddingResponseData] = []
num_prompt_tokens = 0
final_res_batch_checked = cast(list[PoolingRequestOutput],
ctx.final_res_batch)
for idx, final_res in enumerate(final_res_batch_checked):
embedding_res = EmbeddingRequestOutput.from_base(final_res)
item = EmbeddingResponseData(
index=idx,
embedding=_get_embedding(embedding_res.outputs,
ctx.request.encoding_format),
)
prompt_token_ids = final_res.prompt_token_ids
items.append(item)
num_prompt_tokens += len(prompt_token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
total_tokens=num_prompt_tokens,
)
return EmbeddingResponse(
id=ctx.request_id,
created=ctx.created_time,
model=ctx.model_name,
data=items,
usage=usage,
)
class OpenAIServingEmbedding(EmbeddingMixin):
request_id_prefix = "embd"
def __init__(
self,
@ -76,164 +165,36 @@ class OpenAIServingEmbedding(OpenAIServing):
See https://platform.openai.com/docs/api-reference/embeddings/create
for the API specification. This API mimics the OpenAI Embedding API.
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
encoding_format = request.encoding_format
model_name = self._get_model_name(request.model)
request_id = f"embd-{self._base_request_id(raw_request)}"
created_time = int(time.time())
request_id = (f"{self.request_id_prefix}-"
f"{self._base_request_id(raw_request)}")
truncate_prompt_tokens = request.truncate_prompt_tokens
ctx = EmbeddingServeContext(
request=request,
raw_request=raw_request,
model_name=model_name,
request_id=request_id,
chat_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format,
)
pooling_params = request.to_pooling_params()
return await super().handle(ctx) # type: ignore
@override
def _validate_request(
self,
ctx: ServeContext[EmbeddingRequest],
) -> Optional[ErrorResponse]:
if error := super()._validate_request(ctx):
return error
ctx.truncate_prompt_tokens = ctx.request.truncate_prompt_tokens
pooling_params = ctx.request.to_pooling_params()
try:
pooling_params.verify(self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
try:
truncate_prompt_tokens = _validate_truncation_size(
self.max_model_len, truncate_prompt_tokens)
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
if prompt_adapter_request is not None:
raise NotImplementedError("Prompt adapter is not supported "
"for embedding models")
if isinstance(request, EmbeddingChatRequest):
(
_,
request_prompts,
engine_prompts,
) = await self._preprocess_chat(
request,
tokenizer,
request.messages,
chat_template=request.chat_template or self.chat_template,
chat_template_content_format=self.
chat_template_content_format,
# In embedding requests, we are not generating tokens,
# so there is no need to append extra tokens to the input
add_generation_prompt=False,
continue_final_message=False,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
else:
(request_prompts,
engine_prompts) = await self._preprocess_completion(
request,
tokenizer,
request.input,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
except (ValueError, TypeError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
try:
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"
self._log_inputs(request_id_item,
request_prompts[i],
params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
trace_headers = (None if raw_request is None else await
self._get_trace_headers(raw_request.headers))
generator = self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
)
generators.append(generator)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
result_generator = merge_async_iterators(*generators)
num_prompts = len(engine_prompts)
# Non-streaming response
final_res_batch: list[Optional[PoolingRequestOutput]]
final_res_batch = [None] * num_prompts
try:
async for i, res in result_generator:
final_res_batch[i] = res
assert all(final_res is not None for final_res in final_res_batch)
final_res_batch_checked = cast(list[PoolingRequestOutput],
final_res_batch)
response = self.request_output_to_embedding_response(
final_res_batch_checked,
request_id,
created_time,
model_name,
encoding_format,
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return response
def request_output_to_embedding_response(
self,
final_res_batch: list[PoolingRequestOutput],
request_id: str,
created_time: int,
model_name: str,
encoding_format: Literal["float", "base64"],
) -> EmbeddingResponse:
items: list[EmbeddingResponseData] = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
embedding_res = EmbeddingRequestOutput.from_base(final_res)
item = EmbeddingResponseData(
index=idx,
embedding=_get_embedding(embedding_res.outputs,
encoding_format),
)
prompt_token_ids = final_res.prompt_token_ids
items.append(item)
num_prompt_tokens += len(prompt_token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
total_tokens=num_prompt_tokens,
)
return EmbeddingResponse(
id=request_id,
created=created_time,
model=model_name,
data=items,
usage=usage,
)
return None

View File

@ -1,13 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
import json
from collections.abc import Iterable, Iterator, Mapping, Sequence
import time
from collections.abc import (AsyncGenerator, Iterable, Iterator, Mapping,
Sequence)
from concurrent.futures.thread import ThreadPoolExecutor
from http import HTTPStatus
from typing import Annotated, Any, Callable, Optional, TypedDict, Union
from typing import (Annotated, Any, Callable, ClassVar, Generic, Optional,
TypedDict, TypeVar, Union)
from fastapi import Request
from pydantic import Field
from pydantic import BaseModel, ConfigDict, Field
from starlette.datastructures import Headers
import vllm.envs as envs
@ -24,15 +27,23 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
resolve_chat_template_content_format)
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionResponse,
ClassificationRequest,
ClassificationResponse,
CompletionRequest,
CompletionResponse,
DetokenizeRequest,
EmbeddingChatRequest,
EmbeddingCompletionRequest,
ErrorResponse, RerankRequest,
ScoreRequest,
EmbeddingRequest,
EmbeddingResponse, ErrorResponse,
PoolingResponse, RerankRequest,
ScoreRequest, ScoreResponse,
TokenizeChatRequest,
TokenizeCompletionRequest,
TranscriptionRequest)
TokenizeResponse,
TranscriptionRequest,
TranscriptionResponse)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser
# yapf: enable
@ -40,6 +51,9 @@ from vllm.inputs import TokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import ( # noqa: F401 - Required to resolve Pydantic error in RequestProcessingMixin
MultiModalDataDict)
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import BeamSearchParams, SamplingParams
@ -47,13 +61,15 @@ from vllm.sequence import Logprob, PromptLogprobs
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import is_list_of, make_async, random_uuid
from vllm.utils import (is_list_of, make_async, merge_async_iterators,
random_uuid)
logger = init_logger(__name__)
CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest,
EmbeddingCompletionRequest, RerankRequest,
ScoreRequest, TokenizeCompletionRequest]
ClassificationRequest, ScoreRequest,
TokenizeCompletionRequest]
ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
TokenizeChatRequest]
@ -61,6 +77,17 @@ ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest,
TranscriptionRequest]
AnyResponse = Union[
CompletionResponse,
ChatCompletionResponse,
EmbeddingResponse,
TranscriptionResponse,
TokenizeResponse,
PoolingResponse,
ClassificationResponse,
ScoreResponse,
]
class TextTokensPrompt(TypedDict):
prompt: str
@ -69,8 +96,79 @@ class TextTokensPrompt(TypedDict):
RequestPrompt = Union[list[int], str, TextTokensPrompt]
RequestT = TypeVar("RequestT", bound=AnyRequest)
class RequestProcessingMixin(BaseModel):
"""
Mixin for request processing,
handling prompt preparation and engine input.
"""
request_prompts: Optional[Sequence[RequestPrompt]] = \
Field(default_factory=list)
engine_prompts: Optional[list[TokensPrompt]] = \
Field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True)
class ResponseGenerationMixin(BaseModel):
"""
Mixin for response generation,
managing result generators and final batch results.
"""
result_generator: Optional[AsyncGenerator[tuple[int, Union[
RequestOutput, PoolingRequestOutput]], None]] = None
final_res_batch: list[Union[RequestOutput, PoolingRequestOutput]] = Field(
default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True)
class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, BaseModel,
Generic[RequestT]):
# Shared across all requests
request: RequestT
raw_request: Optional[Request] = None
model_name: str
request_id: str
created_time: int = Field(default_factory=lambda: int(time.time()))
lora_request: Optional[LoRARequest] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None
# Shared across most requests
tokenizer: Optional[AnyTokenizer] = None
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
# `protected_namespaces` resolves Pydantic v2's warning
# on conflict with protected namespace "model_"
model_config = ConfigDict(
protected_namespaces=(),
arbitrary_types_allowed=True,
)
ClassificationServeContext = ServeContext[ClassificationRequest]
class EmbeddingServeContext(ServeContext[EmbeddingRequest]):
chat_template: Optional[str] = None
chat_template_content_format: ChatTemplateContentFormatOption
# Used to resolve the Pydantic error related to
# forward reference of MultiModalDataDict in TokensPrompt
RequestProcessingMixin.model_rebuild()
ServeContext.model_rebuild()
ClassificationServeContext.model_rebuild()
EmbeddingServeContext.model_rebuild()
class OpenAIServing:
request_id_prefix: ClassVar[str] = """
A short string prepended to every requests ID (e.g. "embd", "classify")
so you can easily tell this ID came from Embedding vs Classification.
"""
def __init__(
self,
@ -100,6 +198,167 @@ class OpenAIServing:
self._tokenize_prompt_input_or_inputs,
executor=self._tokenizer_executor)
async def _preprocess(
self,
ctx: ServeContext,
) -> Optional[ErrorResponse]:
"""
Default preprocessing hook. Subclasses may override
to prepare `ctx` (classification, embedding, etc.).
"""
return None
def _build_response(
self,
ctx: ServeContext,
) -> Union[AnyResponse, ErrorResponse]:
"""
Default response builder. Subclass may override this method
to return the appropriate response object.
"""
return self.create_error_response("unimplemented endpoint")
async def handle(
self,
ctx: ServeContext,
) -> Union[AnyResponse, ErrorResponse]:
generation: AsyncGenerator[Union[AnyResponse, ErrorResponse], None]
generation = self._pipeline(ctx)
async for response in generation:
return response
return self.create_error_response("No response yielded from pipeline")
async def _pipeline(
self,
ctx: ServeContext,
) -> AsyncGenerator[Union[AnyResponse, ErrorResponse], None]:
"""Execute the request processing pipeline yielding responses."""
if error := await self._check_model(ctx.request):
yield error
if error := self._validate_request(ctx):
yield error
preprocess_ret = await self._preprocess(ctx)
if isinstance(preprocess_ret, ErrorResponse):
yield preprocess_ret
generators_ret = await self._prepare_generators(ctx)
if isinstance(generators_ret, ErrorResponse):
yield generators_ret
collect_ret = await self._collect_batch(ctx)
if isinstance(collect_ret, ErrorResponse):
yield collect_ret
yield self._build_response(ctx)
def _validate_request(self, ctx: ServeContext) -> Optional[ErrorResponse]:
truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens",
None)
if truncate_prompt_tokens is not None:
if truncate_prompt_tokens <= self.max_model_len:
ctx.truncate_prompt_tokens = truncate_prompt_tokens
else:
return self.create_error_response(
"truncate_prompt_tokens value is "
"greater than max_model_len."
" Please, select a smaller truncation size.")
return None
async def _prepare_generators(
self,
ctx: ServeContext,
) -> Optional[ErrorResponse]:
"""Schedule the request and get the result generator."""
generators: list[AsyncGenerator[Union[RequestOutput,
PoolingRequestOutput],
None]] = []
try:
trace_headers = (None if ctx.raw_request is None else await
self._get_trace_headers(ctx.raw_request.headers))
if not hasattr(ctx.request, "to_pooling_params"):
return self.create_error_response(
"Request type does not support pooling parameters")
pooling_params = ctx.request.to_pooling_params()
if ctx.engine_prompts is None:
return self.create_error_response(
"Engine prompts not available")
for i, engine_prompt in enumerate(ctx.engine_prompts):
request_id_item = f"{ctx.request_id}-{i}"
if ctx.request_prompts is None:
return self.create_error_response(
"Request prompts not available")
self._log_inputs(
request_id_item,
ctx.request_prompts[i],
params=pooling_params,
lora_request=ctx.lora_request,
prompt_adapter_request=ctx.prompt_adapter_request)
generator = self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=ctx.lora_request,
trace_headers=trace_headers,
priority=getattr(ctx.request, "priority", 0),
)
generators.append(generator)
ctx.result_generator = merge_async_iterators(*generators)
return None
except Exception as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
async def _collect_batch(
self,
ctx: ServeContext,
) -> Optional[ErrorResponse]:
"""Collect batch results from the result generator."""
try:
if ctx.engine_prompts is None:
return self.create_error_response(
"Engine prompts not available")
num_prompts = len(ctx.engine_prompts)
final_res_batch: list[Optional[Union[RequestOutput,
PoolingRequestOutput]]]
final_res_batch = [None] * num_prompts
if ctx.result_generator is None:
return self.create_error_response(
"Result generator not available")
async for i, res in ctx.result_generator:
final_res_batch[i] = res
if None in final_res_batch:
return self.create_error_response(
"Failed to generate results for all prompts")
ctx.final_res_batch = [
res for res in final_res_batch if res is not None
]
return None
except Exception as e:
return self.create_error_response(str(e))
def create_error_response(
self,
message: str,
@ -183,6 +442,12 @@ class OpenAIServing:
if truncate_prompt_tokens is None:
encoded = tokenizer(prompt, add_special_tokens=add_special_tokens)
elif truncate_prompt_tokens < 0:
# Negative means we cap at the model's max length
encoded = tokenizer(prompt,
add_special_tokens=add_special_tokens,
truncation=True,
max_length=self.max_model_len)
else:
encoded = tokenizer(prompt,
add_special_tokens=add_special_tokens,
@ -204,6 +469,8 @@ class OpenAIServing:
) -> TextTokensPrompt:
if truncate_prompt_tokens is None:
input_ids = prompt_ids
elif truncate_prompt_tokens < 0:
input_ids = prompt_ids[-self.max_model_len:]
else:
input_ids = prompt_ids[-truncate_prompt_tokens:]
@ -219,13 +486,16 @@ class OpenAIServing:
) -> TextTokensPrompt:
token_num = len(input_ids)
# Note: EmbeddingRequest and ScoreRequest doesn't have max_tokens
# Note: EmbeddingRequest, ClassificationRequest,
# and ScoreRequest doesn't have max_tokens
if isinstance(request,
(EmbeddingChatRequest, EmbeddingCompletionRequest,
ScoreRequest, RerankRequest)):
ScoreRequest, RerankRequest, ClassificationRequest)):
operation = {
ScoreRequest: "score",
ClassificationRequest: "classification"
}.get(type(request), "embedding generation")
operation = "score" if isinstance(request, ScoreRequest) \
else "embedding generation"
if token_num > self.max_model_len:
raise ValueError(
f"This model's maximum context length is "
@ -247,7 +517,7 @@ class OpenAIServing:
# TODO(#9845): remove max_tokens when field dropped from OpenAI API
max_tokens = request.max_completion_tokens or request.max_tokens
else:
max_tokens = request.max_tokens
max_tokens = getattr(request, "max_tokens", None)
if max_tokens is None:
if token_num >= self.max_model_len:
raise ValueError(