mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 13:37:23 +08:00
[Frontend] Add /classify endpoint (#17032)
Signed-off-by: Frieda (Jingying) Huang <jingyingfhuang@gmail.com>
This commit is contained in:
parent
d1110f5b5a
commit
9cea90eab4
@ -140,6 +140,7 @@ Our [OpenAI-Compatible Server](#openai-compatible-server) provides endpoints tha
|
||||
|
||||
- [Pooling API](#pooling-api) is similar to `LLM.encode`, being applicable to all types of pooling models.
|
||||
- [Embeddings API](#embeddings-api) is similar to `LLM.embed`, accepting both text and [multi-modal inputs](#multimodal-inputs) for embedding models.
|
||||
- [Classification API](#classification-api) is similar to `LLM.classify` and is applicable to sequence classification models.
|
||||
- [Score API](#score-api) is similar to `LLM.score` for cross-encoder models.
|
||||
|
||||
## Matryoshka Embeddings
|
||||
|
||||
@ -61,6 +61,8 @@ In addition, we have the following custom APIs:
|
||||
- Applicable to any model with a tokenizer.
|
||||
- [Pooling API](#pooling-api) (`/pooling`)
|
||||
- Applicable to all [pooling models](../models/pooling_models.md).
|
||||
- [Classification API](#classification-api) (`/classify`)
|
||||
- Only applicable to [classification models](../models/pooling_models.md) (`--task classify`).
|
||||
- [Score API](#score-api) (`/score`)
|
||||
- Applicable to embedding models and [cross-encoder models](../models/pooling_models.md) (`--task score`).
|
||||
- [Re-rank API](#rerank-api) (`/rerank`, `/v1/rerank`, `/v2/rerank`)
|
||||
@ -443,6 +445,130 @@ The input format is the same as [Embeddings API](#embeddings-api), but the outpu
|
||||
|
||||
Code example: <gh-file:examples/online_serving/openai_pooling_client.py>
|
||||
|
||||
(classification-api)=
|
||||
|
||||
### Classification API
|
||||
|
||||
Our Classification API directly supports Hugging Face sequence-classification models such as [ai21labs/Jamba-tiny-reward-dev](https://huggingface.co/ai21labs/Jamba-tiny-reward-dev) and [jason9693/Qwen2.5-1.5B-apeach](https://huggingface.co/jason9693/Qwen2.5-1.5B-apeach).
|
||||
|
||||
We automatically wrap any other transformer via `as_classification_model()`, which pools on the last token, attaches a `RowParallelLinear` head, and applies a softmax to produce per-class probabilities.
|
||||
|
||||
Code example: <gh-file:examples/online_serving/openai_classification_client.py>
|
||||
|
||||
#### Example Requests
|
||||
|
||||
You can classify multiple texts by passing an array of strings:
|
||||
|
||||
Request:
|
||||
|
||||
```bash
|
||||
curl -v "http://127.0.0.1:8000/classify" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "jason9693/Qwen2.5-1.5B-apeach",
|
||||
"input": [
|
||||
"Loved the new café—coffee was great.",
|
||||
"This update broke everything. Frustrating."
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
Response:
|
||||
|
||||
```bash
|
||||
{
|
||||
"id": "classify-7c87cac407b749a6935d8c7ce2a8fba2",
|
||||
"object": "list",
|
||||
"created": 1745383065,
|
||||
"model": "jason9693/Qwen2.5-1.5B-apeach",
|
||||
"data": [
|
||||
{
|
||||
"index": 0,
|
||||
"label": "Default",
|
||||
"probs": [
|
||||
0.565970778465271,
|
||||
0.4340292513370514
|
||||
],
|
||||
"num_classes": 2
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"label": "Spoiled",
|
||||
"probs": [
|
||||
0.26448777318000793,
|
||||
0.7355121970176697
|
||||
],
|
||||
"num_classes": 2
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 20,
|
||||
"total_tokens": 20,
|
||||
"completion_tokens": 0,
|
||||
"prompt_tokens_details": null
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
You can also pass a string directly to the `input` field:
|
||||
|
||||
Request:
|
||||
|
||||
```bash
|
||||
curl -v "http://127.0.0.1:8000/classify" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "jason9693/Qwen2.5-1.5B-apeach",
|
||||
"input": "Loved the new café—coffee was great."
|
||||
}'
|
||||
```
|
||||
|
||||
Response:
|
||||
|
||||
```bash
|
||||
{
|
||||
"id": "classify-9bf17f2847b046c7b2d5495f4b4f9682",
|
||||
"object": "list",
|
||||
"created": 1745383213,
|
||||
"model": "jason9693/Qwen2.5-1.5B-apeach",
|
||||
"data": [
|
||||
{
|
||||
"index": 0,
|
||||
"label": "Default",
|
||||
"probs": [
|
||||
0.565970778465271,
|
||||
0.4340292513370514
|
||||
],
|
||||
"num_classes": 2
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"total_tokens": 10,
|
||||
"completion_tokens": 0,
|
||||
"prompt_tokens_details": null
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### Extra parameters
|
||||
|
||||
The following [pooling parameters](#pooling-params) are supported.
|
||||
|
||||
:::{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
|
||||
:language: python
|
||||
:start-after: begin-classification-pooling-params
|
||||
:end-before: end-classification-pooling-params
|
||||
:::
|
||||
|
||||
The following extra parameters are supported:
|
||||
|
||||
:::{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
|
||||
:language: python
|
||||
:start-after: begin-classification-extra-params
|
||||
:end-before: end-classification-extra-params
|
||||
:::
|
||||
|
||||
(score-api)=
|
||||
|
||||
### Score API
|
||||
|
||||
49
examples/online_serving/openai_classification_client.py
Normal file
49
examples/online_serving/openai_classification_client.py
Normal file
@ -0,0 +1,49 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import argparse
|
||||
import pprint
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def post_http_request(payload: dict, api_url: str) -> requests.Response:
|
||||
headers = {"User-Agent": "Test Client"}
|
||||
response = requests.post(api_url, headers=headers, json=payload)
|
||||
return response
|
||||
|
||||
|
||||
def parse_args():
|
||||
parse = argparse.ArgumentParser()
|
||||
parse.add_argument("--host", type=str, default="localhost")
|
||||
parse.add_argument("--port", type=int, default=8000)
|
||||
parse.add_argument("--model",
|
||||
type=str,
|
||||
default="jason9693/Qwen2.5-1.5B-apeach")
|
||||
return parse.parse_args()
|
||||
|
||||
|
||||
def main(args):
|
||||
host = args.host
|
||||
port = args.port
|
||||
model_name = args.model
|
||||
|
||||
api_url = f"http://{host}:{port}/classify"
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
payload = {
|
||||
"model": model_name,
|
||||
"input": prompts,
|
||||
}
|
||||
|
||||
classify_response = post_http_request(payload=payload, api_url=api_url)
|
||||
pprint.pprint(classify_response.json())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
156
tests/entrypoints/openai/test_classification.py
Normal file
156
tests/entrypoints/openai/test_classification.py
Normal file
@ -0,0 +1,156 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from vllm.entrypoints.openai.protocol import ClassificationResponse
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach"
|
||||
DTYPE = "float32" # Use float32 to avoid NaN issue
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = [
|
||||
"--enforce-eager",
|
||||
"--max-model-len",
|
||||
"512",
|
||||
"--dtype",
|
||||
DTYPE,
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_single_input_classification(server: RemoteOpenAIServer,
|
||||
model_name: str):
|
||||
input_text = "This product was excellent and exceeded my expectations"
|
||||
|
||||
classification_response = requests.post(
|
||||
server.url_for("classify"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"input": input_text
|
||||
},
|
||||
)
|
||||
|
||||
classification_response.raise_for_status()
|
||||
output = ClassificationResponse.model_validate(
|
||||
classification_response.json())
|
||||
|
||||
assert output.object == "list"
|
||||
assert output.model == MODEL_NAME
|
||||
assert len(output.data) == 1
|
||||
assert hasattr(output.data[0], "label")
|
||||
assert hasattr(output.data[0], "probs")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_multiple_inputs_classification(server: RemoteOpenAIServer,
|
||||
model_name: str):
|
||||
input_texts = [
|
||||
"The product arrived on time and works perfectly",
|
||||
"I'm very satisfied with my purchase, would buy again",
|
||||
"The customer service was helpful and resolved my issue quickly",
|
||||
"This product broke after one week, terrible quality",
|
||||
"I'm very disappointed with this purchase, complete waste of money",
|
||||
"The customer service was rude and unhelpful",
|
||||
]
|
||||
|
||||
classification_response = requests.post(
|
||||
server.url_for("classify"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"input": input_texts
|
||||
},
|
||||
)
|
||||
output = ClassificationResponse.model_validate(
|
||||
classification_response.json())
|
||||
|
||||
assert len(output.data) == len(input_texts)
|
||||
for i, item in enumerate(output.data):
|
||||
assert item.index == i
|
||||
assert hasattr(item, "label")
|
||||
assert hasattr(item, "probs")
|
||||
assert len(item.probs) == item.num_classes
|
||||
assert item.label in ["Default", "Spoiled"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_truncate_prompt_tokens(server: RemoteOpenAIServer, model_name: str):
|
||||
long_text = "hello " * 600
|
||||
|
||||
classification_response = requests.post(
|
||||
server.url_for("classify"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"input": long_text,
|
||||
"truncate_prompt_tokens": 5
|
||||
},
|
||||
)
|
||||
|
||||
classification_response.raise_for_status()
|
||||
output = ClassificationResponse.model_validate(
|
||||
classification_response.json())
|
||||
|
||||
assert len(output.data) == 1
|
||||
assert output.data[0].index == 0
|
||||
assert hasattr(output.data[0], "probs")
|
||||
assert output.usage.prompt_tokens == 5
|
||||
assert output.usage.total_tokens == 5
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_invalid_truncate_prompt_tokens_error(server: RemoteOpenAIServer,
|
||||
model_name: str):
|
||||
classification_response = requests.post(
|
||||
server.url_for("classify"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"input": "test",
|
||||
"truncate_prompt_tokens": 513
|
||||
},
|
||||
)
|
||||
|
||||
error = classification_response.json()
|
||||
assert classification_response.status_code == 400
|
||||
assert error["object"] == "error"
|
||||
assert "truncate_prompt_tokens" in error["message"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_empty_input_error(server: RemoteOpenAIServer, model_name: str):
|
||||
classification_response = requests.post(
|
||||
server.url_for("classify"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"input": ""
|
||||
},
|
||||
)
|
||||
|
||||
error = classification_response.json()
|
||||
assert classification_response.status_code == 400
|
||||
assert error["object"] == "error"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_batch_classification_empty_list(server: RemoteOpenAIServer,
|
||||
model_name: str):
|
||||
classification_response = requests.post(
|
||||
server.url_for("classify"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"input": []
|
||||
},
|
||||
)
|
||||
classification_response.raise_for_status()
|
||||
output = ClassificationResponse.model_validate(
|
||||
classification_response.json())
|
||||
|
||||
assert output.object == "list"
|
||||
assert isinstance(output.data, list)
|
||||
assert len(output.data) == 0
|
||||
@ -48,6 +48,8 @@ from vllm.entrypoints.openai.cli_args import (log_non_default_args,
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ClassificationRequest,
|
||||
ClassificationResponse,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
DetokenizeRequest,
|
||||
@ -71,6 +73,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
UnloadLoRAAdapterRequest)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_classification import (
|
||||
ServingClassification)
|
||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
@ -373,6 +377,10 @@ def score(request: Request) -> Optional[ServingScores]:
|
||||
return request.app.state.openai_serving_scores
|
||||
|
||||
|
||||
def classify(request: Request) -> Optional[ServingClassification]:
|
||||
return request.app.state.openai_serving_classification
|
||||
|
||||
|
||||
def rerank(request: Request) -> Optional[ServingScores]:
|
||||
return request.app.state.openai_serving_scores
|
||||
|
||||
@ -405,6 +413,7 @@ async def get_server_load_metrics(request: Request):
|
||||
# - /v1/audio/transcriptions
|
||||
# - /v1/embeddings
|
||||
# - /pooling
|
||||
# - /classify
|
||||
# - /score
|
||||
# - /v1/score
|
||||
# - /rerank
|
||||
@ -572,6 +581,27 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
@router.post("/classify", dependencies=[Depends(validate_json_request)])
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_classify(request: ClassificationRequest,
|
||||
raw_request: Request):
|
||||
handler = classify(raw_request)
|
||||
if handler is None:
|
||||
return base(raw_request).create_error_response(
|
||||
message="The model does not support Classification API")
|
||||
|
||||
generator = await handler.create_classify(request, raw_request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
|
||||
elif isinstance(generator, ClassificationResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
@router.post("/score", dependencies=[Depends(validate_json_request)])
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
@ -1001,6 +1031,12 @@ async def init_app_state(
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger) if model_config.task in (
|
||||
"score", "embed", "pooling") else None
|
||||
state.openai_serving_classification = ServingClassification(
|
||||
engine_client,
|
||||
model_config,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
) if model_config.task == "classify" else None
|
||||
state.jinaai_serving_reranking = ServingScores(
|
||||
engine_client,
|
||||
model_config,
|
||||
|
||||
@ -1292,6 +1292,47 @@ class ScoreResponse(OpenAIBaseModel):
|
||||
usage: UsageInfo
|
||||
|
||||
|
||||
class ClassificationRequest(OpenAIBaseModel):
|
||||
model: Optional[str] = None
|
||||
input: Union[list[str], str]
|
||||
truncate_prompt_tokens: Optional[int] = None
|
||||
user: Optional[str] = None
|
||||
|
||||
# doc: begin-classification-pooling-params
|
||||
additional_data: Optional[Any] = None
|
||||
# doc: end-classification-pooling-params
|
||||
|
||||
# doc: begin-classification-extra-params
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."),
|
||||
)
|
||||
|
||||
# doc: end-classification-extra-params
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(additional_data=self.additional_data)
|
||||
|
||||
|
||||
class ClassificationData(OpenAIBaseModel):
|
||||
index: int
|
||||
label: Optional[str]
|
||||
probs: list[float]
|
||||
num_classes: int
|
||||
|
||||
|
||||
class ClassificationResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"classify-{random_uuid()}")
|
||||
object: str = "list"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
data: list[ClassificationData]
|
||||
usage: UsageInfo
|
||||
|
||||
|
||||
class FunctionCall(OpenAIBaseModel):
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
159
vllm/entrypoints/openai/serving_classification.py
Normal file
159
vllm/entrypoints/openai/serving_classification.py
Normal file
@ -0,0 +1,159 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from http import HTTPStatus
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
import numpy as np
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (ClassificationData,
|
||||
ClassificationRequest,
|
||||
ClassificationResponse,
|
||||
ErrorResponse, UsageInfo)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_engine import (ClassificationServeContext,
|
||||
OpenAIServing,
|
||||
ServeContext)
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import ClassificationOutput, PoolingRequestOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ClassificationMixin(OpenAIServing):
|
||||
|
||||
async def _preprocess(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> Optional[ErrorResponse]:
|
||||
"""
|
||||
Process classification inputs: tokenize text, resolve adapters,
|
||||
and prepare model-specific inputs.
|
||||
"""
|
||||
ctx = cast(ClassificationServeContext, ctx)
|
||||
if isinstance(ctx.request.input, str) and not ctx.request.input:
|
||||
return self.create_error_response(
|
||||
"Input cannot be empty for classification",
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
|
||||
if isinstance(ctx.request.input, list) and len(ctx.request.input) == 0:
|
||||
return None
|
||||
|
||||
try:
|
||||
(
|
||||
ctx.lora_request,
|
||||
ctx.prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(ctx.request)
|
||||
|
||||
ctx.tokenizer = await self.engine_client.get_tokenizer(
|
||||
ctx.lora_request)
|
||||
|
||||
if ctx.prompt_adapter_request is not None:
|
||||
raise NotImplementedError(
|
||||
"Prompt adapter is not supported for classification models"
|
||||
)
|
||||
|
||||
(
|
||||
ctx.request_prompts,
|
||||
ctx.engine_prompts,
|
||||
) = await self._preprocess_completion(
|
||||
ctx.request,
|
||||
ctx.tokenizer,
|
||||
ctx.request.input,
|
||||
truncate_prompt_tokens=ctx.request.truncate_prompt_tokens,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
def _build_response(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> Union[ClassificationResponse, ErrorResponse]:
|
||||
"""
|
||||
Convert model outputs to a formatted classification response
|
||||
with probabilities and labels.
|
||||
"""
|
||||
ctx = cast(ClassificationServeContext, ctx)
|
||||
items: list[ClassificationData] = []
|
||||
num_prompt_tokens = 0
|
||||
|
||||
final_res_batch_checked = cast(list[PoolingRequestOutput],
|
||||
ctx.final_res_batch)
|
||||
|
||||
for idx, final_res in enumerate(final_res_batch_checked):
|
||||
classify_res = ClassificationOutput.from_base(final_res.outputs)
|
||||
|
||||
probs = classify_res.probs
|
||||
predicted_index = int(np.argmax(probs))
|
||||
label = getattr(self.model_config.hf_config, "id2label",
|
||||
{}).get(predicted_index)
|
||||
|
||||
item = ClassificationData(
|
||||
index=idx,
|
||||
label=label,
|
||||
probs=probs,
|
||||
num_classes=len(probs),
|
||||
)
|
||||
|
||||
items.append(item)
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
total_tokens=num_prompt_tokens,
|
||||
)
|
||||
|
||||
return ClassificationResponse(
|
||||
id=ctx.request_id,
|
||||
created=ctx.created_time,
|
||||
model=ctx.model_name,
|
||||
data=items,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
||||
class ServingClassification(ClassificationMixin):
|
||||
request_id_prefix = "classify"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
) -> None:
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
)
|
||||
|
||||
async def create_classify(
|
||||
self,
|
||||
request: ClassificationRequest,
|
||||
raw_request: Request,
|
||||
) -> Union[ClassificationResponse, ErrorResponse]:
|
||||
model_name = self._get_model_name(request.model)
|
||||
request_id = (f"{self.request_id_prefix}-"
|
||||
f"{self._base_request_id(raw_request)}")
|
||||
|
||||
ctx = ClassificationServeContext(
|
||||
request=request,
|
||||
raw_request=raw_request,
|
||||
model_name=model_name,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
return await super().handle(ctx) # type: ignore
|
||||
@ -1,14 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Final, Literal, Optional, Union, cast
|
||||
|
||||
import numpy as np
|
||||
from fastapi import Request
|
||||
from typing_extensions import assert_never
|
||||
from typing_extensions import assert_never, override
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
@ -19,13 +16,13 @@ from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
|
||||
EmbeddingResponse,
|
||||
EmbeddingResponseData,
|
||||
ErrorResponse, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_engine import (EmbeddingServeContext,
|
||||
OpenAIServing,
|
||||
ServeContext)
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.utils import _validate_truncation_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput,
|
||||
PoolingRequestOutput)
|
||||
from vllm.utils import merge_async_iterators
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -45,7 +42,99 @@ def _get_embedding(
|
||||
assert_never(encoding_format)
|
||||
|
||||
|
||||
class OpenAIServingEmbedding(OpenAIServing):
|
||||
class EmbeddingMixin(OpenAIServing):
|
||||
|
||||
async def _preprocess(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> Optional[ErrorResponse]:
|
||||
ctx = cast(EmbeddingServeContext, ctx)
|
||||
try:
|
||||
(
|
||||
ctx.lora_request,
|
||||
ctx.prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(ctx.request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(ctx.lora_request
|
||||
)
|
||||
|
||||
if ctx.prompt_adapter_request is not None:
|
||||
raise NotImplementedError("Prompt adapter is not supported "
|
||||
"for embedding models")
|
||||
|
||||
if isinstance(ctx.request, EmbeddingChatRequest):
|
||||
(
|
||||
_,
|
||||
ctx.request_prompts,
|
||||
ctx.engine_prompts,
|
||||
) = await self._preprocess_chat(
|
||||
ctx.request,
|
||||
tokenizer,
|
||||
ctx.request.messages,
|
||||
chat_template=ctx.request.chat_template
|
||||
or ctx.chat_template,
|
||||
chat_template_content_format=ctx.
|
||||
chat_template_content_format,
|
||||
# In embedding requests, we are not generating tokens,
|
||||
# so there is no need to append extra tokens to the input
|
||||
add_generation_prompt=False,
|
||||
continue_final_message=False,
|
||||
truncate_prompt_tokens=ctx.truncate_prompt_tokens,
|
||||
add_special_tokens=ctx.request.add_special_tokens,
|
||||
)
|
||||
else:
|
||||
(ctx.request_prompts,
|
||||
ctx.engine_prompts) = await self._preprocess_completion(
|
||||
ctx.request,
|
||||
tokenizer,
|
||||
ctx.request.input,
|
||||
truncate_prompt_tokens=ctx.truncate_prompt_tokens,
|
||||
add_special_tokens=ctx.request.add_special_tokens,
|
||||
)
|
||||
return None
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
def _build_response(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> Union[EmbeddingResponse, ErrorResponse]:
|
||||
items: list[EmbeddingResponseData] = []
|
||||
num_prompt_tokens = 0
|
||||
|
||||
final_res_batch_checked = cast(list[PoolingRequestOutput],
|
||||
ctx.final_res_batch)
|
||||
|
||||
for idx, final_res in enumerate(final_res_batch_checked):
|
||||
embedding_res = EmbeddingRequestOutput.from_base(final_res)
|
||||
|
||||
item = EmbeddingResponseData(
|
||||
index=idx,
|
||||
embedding=_get_embedding(embedding_res.outputs,
|
||||
ctx.request.encoding_format),
|
||||
)
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
|
||||
items.append(item)
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
total_tokens=num_prompt_tokens,
|
||||
)
|
||||
|
||||
return EmbeddingResponse(
|
||||
id=ctx.request_id,
|
||||
created=ctx.created_time,
|
||||
model=ctx.model_name,
|
||||
data=items,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
||||
class OpenAIServingEmbedding(EmbeddingMixin):
|
||||
request_id_prefix = "embd"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -76,164 +165,36 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
See https://platform.openai.com/docs/api-reference/embeddings/create
|
||||
for the API specification. This API mimics the OpenAI Embedding API.
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
encoding_format = request.encoding_format
|
||||
|
||||
model_name = self._get_model_name(request.model)
|
||||
request_id = f"embd-{self._base_request_id(raw_request)}"
|
||||
created_time = int(time.time())
|
||||
request_id = (f"{self.request_id_prefix}-"
|
||||
f"{self._base_request_id(raw_request)}")
|
||||
|
||||
truncate_prompt_tokens = request.truncate_prompt_tokens
|
||||
ctx = EmbeddingServeContext(
|
||||
request=request,
|
||||
raw_request=raw_request,
|
||||
model_name=model_name,
|
||||
request_id=request_id,
|
||||
chat_template=self.chat_template,
|
||||
chat_template_content_format=self.chat_template_content_format,
|
||||
)
|
||||
|
||||
pooling_params = request.to_pooling_params()
|
||||
return await super().handle(ctx) # type: ignore
|
||||
|
||||
@override
|
||||
def _validate_request(
|
||||
self,
|
||||
ctx: ServeContext[EmbeddingRequest],
|
||||
) -> Optional[ErrorResponse]:
|
||||
if error := super()._validate_request(ctx):
|
||||
return error
|
||||
|
||||
ctx.truncate_prompt_tokens = ctx.request.truncate_prompt_tokens
|
||||
|
||||
pooling_params = ctx.request.to_pooling_params()
|
||||
|
||||
try:
|
||||
pooling_params.verify(self.model_config)
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
try:
|
||||
truncate_prompt_tokens = _validate_truncation_size(
|
||||
self.max_model_len, truncate_prompt_tokens)
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
if prompt_adapter_request is not None:
|
||||
raise NotImplementedError("Prompt adapter is not supported "
|
||||
"for embedding models")
|
||||
|
||||
if isinstance(request, EmbeddingChatRequest):
|
||||
(
|
||||
_,
|
||||
request_prompts,
|
||||
engine_prompts,
|
||||
) = await self._preprocess_chat(
|
||||
request,
|
||||
tokenizer,
|
||||
request.messages,
|
||||
chat_template=request.chat_template or self.chat_template,
|
||||
chat_template_content_format=self.
|
||||
chat_template_content_format,
|
||||
# In embedding requests, we are not generating tokens,
|
||||
# so there is no need to append extra tokens to the input
|
||||
add_generation_prompt=False,
|
||||
continue_final_message=False,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
else:
|
||||
(request_prompts,
|
||||
engine_prompts) = await self._preprocess_completion(
|
||||
request,
|
||||
tokenizer,
|
||||
request.input,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
try:
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(request_id_item,
|
||||
request_prompts[i],
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
trace_headers = (None if raw_request is None else await
|
||||
self._get_trace_headers(raw_request.headers))
|
||||
|
||||
generator = self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
num_prompts = len(engine_prompts)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: list[Optional[PoolingRequestOutput]]
|
||||
final_res_batch = [None] * num_prompts
|
||||
try:
|
||||
async for i, res in result_generator:
|
||||
final_res_batch[i] = res
|
||||
|
||||
assert all(final_res is not None for final_res in final_res_batch)
|
||||
|
||||
final_res_batch_checked = cast(list[PoolingRequestOutput],
|
||||
final_res_batch)
|
||||
|
||||
response = self.request_output_to_embedding_response(
|
||||
final_res_batch_checked,
|
||||
request_id,
|
||||
created_time,
|
||||
model_name,
|
||||
encoding_format,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
return response
|
||||
|
||||
def request_output_to_embedding_response(
|
||||
self,
|
||||
final_res_batch: list[PoolingRequestOutput],
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
encoding_format: Literal["float", "base64"],
|
||||
) -> EmbeddingResponse:
|
||||
items: list[EmbeddingResponseData] = []
|
||||
num_prompt_tokens = 0
|
||||
|
||||
for idx, final_res in enumerate(final_res_batch):
|
||||
embedding_res = EmbeddingRequestOutput.from_base(final_res)
|
||||
|
||||
item = EmbeddingResponseData(
|
||||
index=idx,
|
||||
embedding=_get_embedding(embedding_res.outputs,
|
||||
encoding_format),
|
||||
)
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
|
||||
items.append(item)
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
total_tokens=num_prompt_tokens,
|
||||
)
|
||||
|
||||
return EmbeddingResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
data=items,
|
||||
usage=usage,
|
||||
)
|
||||
return None
|
||||
|
||||
@ -1,13 +1,16 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import json
|
||||
from collections.abc import Iterable, Iterator, Mapping, Sequence
|
||||
import time
|
||||
from collections.abc import (AsyncGenerator, Iterable, Iterator, Mapping,
|
||||
Sequence)
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from http import HTTPStatus
|
||||
from typing import Annotated, Any, Callable, Optional, TypedDict, Union
|
||||
from typing import (Annotated, Any, Callable, ClassVar, Generic, Optional,
|
||||
TypedDict, TypeVar, Union)
|
||||
|
||||
from fastapi import Request
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from starlette.datastructures import Headers
|
||||
|
||||
import vllm.envs as envs
|
||||
@ -24,15 +27,23 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
||||
resolve_chat_template_content_format)
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ClassificationRequest,
|
||||
ClassificationResponse,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
DetokenizeRequest,
|
||||
EmbeddingChatRequest,
|
||||
EmbeddingCompletionRequest,
|
||||
ErrorResponse, RerankRequest,
|
||||
ScoreRequest,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse, ErrorResponse,
|
||||
PoolingResponse, RerankRequest,
|
||||
ScoreRequest, ScoreResponse,
|
||||
TokenizeChatRequest,
|
||||
TokenizeCompletionRequest,
|
||||
TranscriptionRequest)
|
||||
TokenizeResponse,
|
||||
TranscriptionRequest,
|
||||
TranscriptionResponse)
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser
|
||||
# yapf: enable
|
||||
@ -40,6 +51,9 @@ from vllm.inputs import TokensPrompt
|
||||
from vllm.inputs.parse import parse_and_batch_prompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import ( # noqa: F401 - Required to resolve Pydantic error in RequestProcessingMixin
|
||||
MultiModalDataDict)
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
@ -47,13 +61,15 @@ from vllm.sequence import Logprob, PromptLogprobs
|
||||
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
||||
log_tracing_disabled_warning)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import is_list_of, make_async, random_uuid
|
||||
from vllm.utils import (is_list_of, make_async, merge_async_iterators,
|
||||
random_uuid)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest,
|
||||
EmbeddingCompletionRequest, RerankRequest,
|
||||
ScoreRequest, TokenizeCompletionRequest]
|
||||
ClassificationRequest, ScoreRequest,
|
||||
TokenizeCompletionRequest]
|
||||
|
||||
ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
|
||||
TokenizeChatRequest]
|
||||
@ -61,6 +77,17 @@ ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
|
||||
AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest,
|
||||
TranscriptionRequest]
|
||||
|
||||
AnyResponse = Union[
|
||||
CompletionResponse,
|
||||
ChatCompletionResponse,
|
||||
EmbeddingResponse,
|
||||
TranscriptionResponse,
|
||||
TokenizeResponse,
|
||||
PoolingResponse,
|
||||
ClassificationResponse,
|
||||
ScoreResponse,
|
||||
]
|
||||
|
||||
|
||||
class TextTokensPrompt(TypedDict):
|
||||
prompt: str
|
||||
@ -69,8 +96,79 @@ class TextTokensPrompt(TypedDict):
|
||||
|
||||
RequestPrompt = Union[list[int], str, TextTokensPrompt]
|
||||
|
||||
RequestT = TypeVar("RequestT", bound=AnyRequest)
|
||||
|
||||
|
||||
class RequestProcessingMixin(BaseModel):
|
||||
"""
|
||||
Mixin for request processing,
|
||||
handling prompt preparation and engine input.
|
||||
"""
|
||||
request_prompts: Optional[Sequence[RequestPrompt]] = \
|
||||
Field(default_factory=list)
|
||||
engine_prompts: Optional[list[TokensPrompt]] = \
|
||||
Field(default_factory=list)
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class ResponseGenerationMixin(BaseModel):
|
||||
"""
|
||||
Mixin for response generation,
|
||||
managing result generators and final batch results.
|
||||
"""
|
||||
result_generator: Optional[AsyncGenerator[tuple[int, Union[
|
||||
RequestOutput, PoolingRequestOutput]], None]] = None
|
||||
final_res_batch: list[Union[RequestOutput, PoolingRequestOutput]] = Field(
|
||||
default_factory=list)
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, BaseModel,
|
||||
Generic[RequestT]):
|
||||
# Shared across all requests
|
||||
request: RequestT
|
||||
raw_request: Optional[Request] = None
|
||||
model_name: str
|
||||
request_id: str
|
||||
created_time: int = Field(default_factory=lambda: int(time.time()))
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
|
||||
# Shared across most requests
|
||||
tokenizer: Optional[AnyTokenizer] = None
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||
|
||||
# `protected_namespaces` resolves Pydantic v2's warning
|
||||
# on conflict with protected namespace "model_"
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
|
||||
ClassificationServeContext = ServeContext[ClassificationRequest]
|
||||
|
||||
|
||||
class EmbeddingServeContext(ServeContext[EmbeddingRequest]):
|
||||
chat_template: Optional[str] = None
|
||||
chat_template_content_format: ChatTemplateContentFormatOption
|
||||
|
||||
|
||||
# Used to resolve the Pydantic error related to
|
||||
# forward reference of MultiModalDataDict in TokensPrompt
|
||||
RequestProcessingMixin.model_rebuild()
|
||||
ServeContext.model_rebuild()
|
||||
ClassificationServeContext.model_rebuild()
|
||||
EmbeddingServeContext.model_rebuild()
|
||||
|
||||
|
||||
class OpenAIServing:
|
||||
request_id_prefix: ClassVar[str] = """
|
||||
A short string prepended to every request’s ID (e.g. "embd", "classify")
|
||||
so you can easily tell “this ID came from Embedding vs Classification.”
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -100,6 +198,167 @@ class OpenAIServing:
|
||||
self._tokenize_prompt_input_or_inputs,
|
||||
executor=self._tokenizer_executor)
|
||||
|
||||
async def _preprocess(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> Optional[ErrorResponse]:
|
||||
"""
|
||||
Default preprocessing hook. Subclasses may override
|
||||
to prepare `ctx` (classification, embedding, etc.).
|
||||
"""
|
||||
return None
|
||||
|
||||
def _build_response(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> Union[AnyResponse, ErrorResponse]:
|
||||
"""
|
||||
Default response builder. Subclass may override this method
|
||||
to return the appropriate response object.
|
||||
"""
|
||||
return self.create_error_response("unimplemented endpoint")
|
||||
|
||||
async def handle(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> Union[AnyResponse, ErrorResponse]:
|
||||
generation: AsyncGenerator[Union[AnyResponse, ErrorResponse], None]
|
||||
generation = self._pipeline(ctx)
|
||||
|
||||
async for response in generation:
|
||||
return response
|
||||
|
||||
return self.create_error_response("No response yielded from pipeline")
|
||||
|
||||
async def _pipeline(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> AsyncGenerator[Union[AnyResponse, ErrorResponse], None]:
|
||||
"""Execute the request processing pipeline yielding responses."""
|
||||
if error := await self._check_model(ctx.request):
|
||||
yield error
|
||||
if error := self._validate_request(ctx):
|
||||
yield error
|
||||
|
||||
preprocess_ret = await self._preprocess(ctx)
|
||||
if isinstance(preprocess_ret, ErrorResponse):
|
||||
yield preprocess_ret
|
||||
|
||||
generators_ret = await self._prepare_generators(ctx)
|
||||
if isinstance(generators_ret, ErrorResponse):
|
||||
yield generators_ret
|
||||
|
||||
collect_ret = await self._collect_batch(ctx)
|
||||
if isinstance(collect_ret, ErrorResponse):
|
||||
yield collect_ret
|
||||
|
||||
yield self._build_response(ctx)
|
||||
|
||||
def _validate_request(self, ctx: ServeContext) -> Optional[ErrorResponse]:
|
||||
truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens",
|
||||
None)
|
||||
|
||||
if truncate_prompt_tokens is not None:
|
||||
if truncate_prompt_tokens <= self.max_model_len:
|
||||
ctx.truncate_prompt_tokens = truncate_prompt_tokens
|
||||
else:
|
||||
return self.create_error_response(
|
||||
"truncate_prompt_tokens value is "
|
||||
"greater than max_model_len."
|
||||
" Please, select a smaller truncation size.")
|
||||
return None
|
||||
|
||||
async def _prepare_generators(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> Optional[ErrorResponse]:
|
||||
"""Schedule the request and get the result generator."""
|
||||
generators: list[AsyncGenerator[Union[RequestOutput,
|
||||
PoolingRequestOutput],
|
||||
None]] = []
|
||||
|
||||
try:
|
||||
trace_headers = (None if ctx.raw_request is None else await
|
||||
self._get_trace_headers(ctx.raw_request.headers))
|
||||
|
||||
if not hasattr(ctx.request, "to_pooling_params"):
|
||||
return self.create_error_response(
|
||||
"Request type does not support pooling parameters")
|
||||
|
||||
pooling_params = ctx.request.to_pooling_params()
|
||||
|
||||
if ctx.engine_prompts is None:
|
||||
return self.create_error_response(
|
||||
"Engine prompts not available")
|
||||
|
||||
for i, engine_prompt in enumerate(ctx.engine_prompts):
|
||||
request_id_item = f"{ctx.request_id}-{i}"
|
||||
|
||||
if ctx.request_prompts is None:
|
||||
return self.create_error_response(
|
||||
"Request prompts not available")
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
ctx.request_prompts[i],
|
||||
params=pooling_params,
|
||||
lora_request=ctx.lora_request,
|
||||
prompt_adapter_request=ctx.prompt_adapter_request)
|
||||
|
||||
generator = self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=ctx.lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=getattr(ctx.request, "priority", 0),
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
|
||||
ctx.result_generator = merge_async_iterators(*generators)
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
async def _collect_batch(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> Optional[ErrorResponse]:
|
||||
"""Collect batch results from the result generator."""
|
||||
try:
|
||||
if ctx.engine_prompts is None:
|
||||
return self.create_error_response(
|
||||
"Engine prompts not available")
|
||||
|
||||
num_prompts = len(ctx.engine_prompts)
|
||||
final_res_batch: list[Optional[Union[RequestOutput,
|
||||
PoolingRequestOutput]]]
|
||||
final_res_batch = [None] * num_prompts
|
||||
|
||||
if ctx.result_generator is None:
|
||||
return self.create_error_response(
|
||||
"Result generator not available")
|
||||
|
||||
async for i, res in ctx.result_generator:
|
||||
final_res_batch[i] = res
|
||||
|
||||
if None in final_res_batch:
|
||||
return self.create_error_response(
|
||||
"Failed to generate results for all prompts")
|
||||
|
||||
ctx.final_res_batch = [
|
||||
res for res in final_res_batch if res is not None
|
||||
]
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
def create_error_response(
|
||||
self,
|
||||
message: str,
|
||||
@ -183,6 +442,12 @@ class OpenAIServing:
|
||||
|
||||
if truncate_prompt_tokens is None:
|
||||
encoded = tokenizer(prompt, add_special_tokens=add_special_tokens)
|
||||
elif truncate_prompt_tokens < 0:
|
||||
# Negative means we cap at the model's max length
|
||||
encoded = tokenizer(prompt,
|
||||
add_special_tokens=add_special_tokens,
|
||||
truncation=True,
|
||||
max_length=self.max_model_len)
|
||||
else:
|
||||
encoded = tokenizer(prompt,
|
||||
add_special_tokens=add_special_tokens,
|
||||
@ -204,6 +469,8 @@ class OpenAIServing:
|
||||
) -> TextTokensPrompt:
|
||||
if truncate_prompt_tokens is None:
|
||||
input_ids = prompt_ids
|
||||
elif truncate_prompt_tokens < 0:
|
||||
input_ids = prompt_ids[-self.max_model_len:]
|
||||
else:
|
||||
input_ids = prompt_ids[-truncate_prompt_tokens:]
|
||||
|
||||
@ -219,13 +486,16 @@ class OpenAIServing:
|
||||
) -> TextTokensPrompt:
|
||||
token_num = len(input_ids)
|
||||
|
||||
# Note: EmbeddingRequest and ScoreRequest doesn't have max_tokens
|
||||
# Note: EmbeddingRequest, ClassificationRequest,
|
||||
# and ScoreRequest doesn't have max_tokens
|
||||
if isinstance(request,
|
||||
(EmbeddingChatRequest, EmbeddingCompletionRequest,
|
||||
ScoreRequest, RerankRequest)):
|
||||
ScoreRequest, RerankRequest, ClassificationRequest)):
|
||||
operation = {
|
||||
ScoreRequest: "score",
|
||||
ClassificationRequest: "classification"
|
||||
}.get(type(request), "embedding generation")
|
||||
|
||||
operation = "score" if isinstance(request, ScoreRequest) \
|
||||
else "embedding generation"
|
||||
if token_num > self.max_model_len:
|
||||
raise ValueError(
|
||||
f"This model's maximum context length is "
|
||||
@ -247,7 +517,7 @@ class OpenAIServing:
|
||||
# TODO(#9845): remove max_tokens when field dropped from OpenAI API
|
||||
max_tokens = request.max_completion_tokens or request.max_tokens
|
||||
else:
|
||||
max_tokens = request.max_tokens
|
||||
max_tokens = getattr(request, "max_tokens", None)
|
||||
if max_tokens is None:
|
||||
if token_num >= self.max_model_len:
|
||||
raise ValueError(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user