mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:55:40 +08:00
[Frontend] Added chat-style multimodal support to /classify. (#27516)
Signed-off-by: WorldExplored <srreyansh.sethi@gmail.com> Signed-off-by: Srreyansh Sethi <107075589+WorldExplored@users.noreply.github.com> Signed-off-by: vnadathur <glvikramn@gmail.com> Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io> Co-authored-by: vnadathur <236933696+vnadathur@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: vnadathur <glvikramn@gmail.com> Co-authored-by: wang.yuqi <noooop@126.com> Co-authored-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
parent
ecf8230d4d
commit
360bd8762f
@ -46,6 +46,16 @@ def test_single_input_classification(server: RemoteOpenAIServer, model_name: str
|
||||
assert hasattr(output.data[0], "probs")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_add_special_tokens_false(server: RemoteOpenAIServer, model_name: str):
|
||||
response = requests.post(
|
||||
server.url_for("classify"),
|
||||
json={"model": model_name, "input": "hello", "add_special_tokens": False},
|
||||
)
|
||||
response.raise_for_status()
|
||||
ClassificationResponse.model_validate(response.json())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_multiple_inputs_classification(server: RemoteOpenAIServer, model_name: str):
|
||||
input_texts = [
|
||||
|
||||
@ -0,0 +1,95 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.entrypoints.openai.protocol import ClassificationResponse
|
||||
|
||||
VLM_MODEL_NAME = "muziyongshixin/Qwen2.5-VL-7B-for-VideoCls"
|
||||
MAXIMUM_VIDEOS = 1
|
||||
TEST_VIDEO_URL = "https://www.bogotobogo.com/python/OpenCV_Python/images/mean_shift_tracking/slow_traffic_small.mp4"
|
||||
|
||||
HF_OVERRIDES = {
|
||||
"text_config": {
|
||||
"architectures": ["Qwen2_5_VLForSequenceClassification"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server_vlm_classify():
|
||||
args = [
|
||||
"--runner",
|
||||
"pooling",
|
||||
"--max-model-len",
|
||||
"5000",
|
||||
"--enforce-eager",
|
||||
"--limit-mm-per-prompt",
|
||||
json.dumps({"video": MAXIMUM_VIDEOS}),
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(
|
||||
VLM_MODEL_NAME, args, override_hf_configs=HF_OVERRIDES
|
||||
) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", [VLM_MODEL_NAME])
|
||||
def test_classify_accepts_chat_text_only(
|
||||
server_vlm_classify: RemoteOpenAIServer, model_name: str
|
||||
) -> None:
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Please classify this text request."},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
response = requests.post(
|
||||
server_vlm_classify.url_for("classify"),
|
||||
json={"model": model_name, "messages": messages},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
output = ClassificationResponse.model_validate(response.json())
|
||||
|
||||
assert output.object == "list"
|
||||
assert output.model == model_name
|
||||
assert len(output.data) == 1
|
||||
assert len(output.data[0].probs) == 2
|
||||
assert output.usage.prompt_tokens == 22
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", [VLM_MODEL_NAME])
|
||||
def test_classify_accepts_chat_video_url(
|
||||
server_vlm_classify: RemoteOpenAIServer, model_name: str
|
||||
) -> None:
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Please classify this video."},
|
||||
{"type": "video_url", "video_url": {"url": TEST_VIDEO_URL}},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
response = requests.post(
|
||||
server_vlm_classify.url_for("classify"),
|
||||
json={"model": model_name, "messages": messages},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
output = ClassificationResponse.model_validate(response.json())
|
||||
|
||||
assert output.object == "list"
|
||||
assert output.model == model_name
|
||||
assert len(output.data) == 1
|
||||
assert len(output.data[0].probs) == 2
|
||||
assert output.usage.prompt_tokens == 4807
|
||||
@ -1784,6 +1784,9 @@ async def init_app_state(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
trust_request_chat_template=args.trust_request_chat_template,
|
||||
log_error_stack=args.log_error_stack,
|
||||
)
|
||||
if "classify" in supported_tasks
|
||||
|
||||
@ -2000,10 +2000,10 @@ class ScoreResponse(OpenAIBaseModel):
|
||||
usage: UsageInfo
|
||||
|
||||
|
||||
class ClassificationRequest(OpenAIBaseModel):
|
||||
class ClassificationCompletionRequest(OpenAIBaseModel):
|
||||
model: str | None = None
|
||||
input: list[str] | str
|
||||
truncate_prompt_tokens: int | None = None
|
||||
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
|
||||
user: str | None = None
|
||||
|
||||
# --8<-- [start:classification-extra-params]
|
||||
@ -2015,7 +2015,21 @@ class ClassificationRequest(OpenAIBaseModel):
|
||||
"if the served model does not use priority scheduling."
|
||||
),
|
||||
)
|
||||
|
||||
add_special_tokens: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"If true (the default), special tokens (e.g. BOS) will be added to "
|
||||
"the prompt."
|
||||
),
|
||||
)
|
||||
request_id: str = Field(
|
||||
default_factory=lambda: f"{random_uuid()}",
|
||||
description=(
|
||||
"The request_id related to this request. If the caller does "
|
||||
"not set it, a random_uuid will be generated. This id is used "
|
||||
"through out the inference process and return in response."
|
||||
),
|
||||
)
|
||||
softmax: bool | None = Field(
|
||||
default=None,
|
||||
description="softmax will be deprecated, please use use_activation instead.",
|
||||
@ -2040,6 +2054,102 @@ class ClassificationRequest(OpenAIBaseModel):
|
||||
)
|
||||
|
||||
|
||||
class ClassificationChatRequest(OpenAIBaseModel):
|
||||
model: str | None = None
|
||||
messages: list[ChatCompletionMessageParam]
|
||||
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
|
||||
user: str | None = None
|
||||
|
||||
# --8<-- [start:chat-classification-extra-params]
|
||||
add_generation_prompt: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"If true, the generation prompt will be added to the chat template. "
|
||||
"This is a parameter used by chat template in tokenizer config of the "
|
||||
"model."
|
||||
),
|
||||
)
|
||||
|
||||
add_special_tokens: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"If true, special tokens (e.g. BOS) will be added to the prompt "
|
||||
"on top of what is added by the chat template. "
|
||||
"For most models, the chat template takes care of adding the "
|
||||
"special tokens so this should be set to false (as is the "
|
||||
"default)."
|
||||
),
|
||||
)
|
||||
|
||||
chat_template: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"A Jinja template to use for this conversion. "
|
||||
"As of transformers v4.44, default chat template is no longer "
|
||||
"allowed, so you must provide a chat template if the tokenizer "
|
||||
"does not define one."
|
||||
),
|
||||
)
|
||||
|
||||
chat_template_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Additional keyword args to pass to the template renderer. "
|
||||
"Will be accessible by the chat template."
|
||||
),
|
||||
)
|
||||
|
||||
mm_processor_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description=("Additional kwargs to pass to the HF processor."),
|
||||
)
|
||||
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."
|
||||
),
|
||||
)
|
||||
|
||||
request_id: str = Field(
|
||||
default_factory=lambda: f"{random_uuid()}",
|
||||
description=(
|
||||
"The request_id related to this request. If the caller does "
|
||||
"not set it, a random_uuid will be generated. This id is used "
|
||||
"through out the inference process and return in response."
|
||||
),
|
||||
)
|
||||
softmax: bool | None = Field(
|
||||
default=None,
|
||||
description="softmax will be deprecated, please use use_activation instead.",
|
||||
)
|
||||
|
||||
activation: bool | None = Field(
|
||||
default=None,
|
||||
description="activation will be deprecated, please use use_activation instead.",
|
||||
)
|
||||
|
||||
use_activation: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to use activation for classification outputs. "
|
||||
"Default is True.",
|
||||
)
|
||||
# --8<-- [end:chat-classification-extra-params]
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
use_activation=get_use_activation(self),
|
||||
)
|
||||
|
||||
|
||||
ClassificationRequest: TypeAlias = (
|
||||
ClassificationCompletionRequest | ClassificationChatRequest
|
||||
)
|
||||
|
||||
|
||||
class ClassificationData(OpenAIBaseModel):
|
||||
index: int
|
||||
label: str | None
|
||||
|
||||
@ -4,13 +4,17 @@
|
||||
from http import HTTPStatus
|
||||
from typing import cast
|
||||
|
||||
import jinja2
|
||||
import numpy as np
|
||||
from fastapi import Request
|
||||
from typing_extensions import override
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
ClassificationChatRequest,
|
||||
ClassificationCompletionRequest,
|
||||
ClassificationData,
|
||||
ClassificationRequest,
|
||||
ClassificationResponse,
|
||||
@ -32,7 +36,10 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ClassificationMixin(OpenAIServing):
|
||||
@override
|
||||
chat_template: str | None
|
||||
chat_template_content_format: ChatTemplateContentFormatOption
|
||||
trust_request_chat_template: bool
|
||||
|
||||
async def _preprocess(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
@ -42,31 +49,79 @@ class ClassificationMixin(OpenAIServing):
|
||||
and prepare model-specific inputs.
|
||||
"""
|
||||
ctx = cast(ClassificationServeContext, ctx)
|
||||
if isinstance(ctx.request.input, str) and not ctx.request.input:
|
||||
return self.create_error_response(
|
||||
"Input cannot be empty for classification",
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
|
||||
if isinstance(ctx.request.input, list) and len(ctx.request.input) == 0:
|
||||
return None
|
||||
|
||||
try:
|
||||
ctx.tokenizer = await self.engine_client.get_tokenizer()
|
||||
|
||||
renderer = self._get_renderer(ctx.tokenizer)
|
||||
ctx.engine_prompts = await renderer.render_prompt(
|
||||
prompt_or_prompts=ctx.request.input,
|
||||
config=self._build_render_config(ctx.request),
|
||||
)
|
||||
request_obj = ctx.request
|
||||
|
||||
if isinstance(request_obj, ClassificationChatRequest):
|
||||
chat_request = request_obj
|
||||
messages = chat_request.messages
|
||||
trust_request_chat_template = getattr(
|
||||
self,
|
||||
"trust_request_chat_template",
|
||||
False,
|
||||
)
|
||||
ret = self._validate_chat_template(
|
||||
request_chat_template=chat_request.chat_template,
|
||||
chat_template_kwargs=chat_request.chat_template_kwargs,
|
||||
trust_request_chat_template=trust_request_chat_template,
|
||||
)
|
||||
if ret:
|
||||
return ret
|
||||
|
||||
(
|
||||
_,
|
||||
_,
|
||||
engine_prompts,
|
||||
) = await self._preprocess_chat(
|
||||
cast(ChatCompletionRequest, chat_request),
|
||||
ctx.tokenizer,
|
||||
messages,
|
||||
chat_template=(
|
||||
chat_request.chat_template
|
||||
or getattr(self, "chat_template", None)
|
||||
),
|
||||
chat_template_content_format=cast(
|
||||
ChatTemplateContentFormatOption,
|
||||
getattr(self, "chat_template_content_format", "auto"),
|
||||
),
|
||||
add_generation_prompt=False,
|
||||
continue_final_message=False,
|
||||
add_special_tokens=chat_request.add_special_tokens,
|
||||
)
|
||||
ctx.engine_prompts = engine_prompts
|
||||
|
||||
elif isinstance(request_obj, ClassificationCompletionRequest):
|
||||
completion_request = request_obj
|
||||
input_data = completion_request.input
|
||||
if input_data in (None, ""):
|
||||
return self.create_error_response(
|
||||
"Input or messages must be provided",
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
if isinstance(input_data, list) and not input_data:
|
||||
ctx.engine_prompts = []
|
||||
return None
|
||||
|
||||
renderer = self._get_renderer(ctx.tokenizer)
|
||||
prompt_input = cast(str | list[str], input_data)
|
||||
ctx.engine_prompts = await renderer.render_prompt(
|
||||
prompt_or_prompts=prompt_input,
|
||||
config=self._build_render_config(completion_request),
|
||||
)
|
||||
else:
|
||||
return self.create_error_response(
|
||||
"Invalid classification request type",
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
except (ValueError, TypeError) as e:
|
||||
except (ValueError, TypeError, jinja2.TemplateError) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
@override
|
||||
def _build_response(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
@ -118,6 +173,7 @@ class ClassificationMixin(OpenAIServing):
|
||||
return RenderConfig(
|
||||
max_length=self.max_model_len,
|
||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
|
||||
|
||||
@ -130,6 +186,9 @@ class ServingClassification(ClassificationMixin):
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
chat_template: str | None = None,
|
||||
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
|
||||
trust_request_chat_template: bool = False,
|
||||
log_error_stack: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
@ -139,6 +198,10 @@ class ServingClassification(ClassificationMixin):
|
||||
log_error_stack=log_error_stack,
|
||||
)
|
||||
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_content_format = chat_template_content_format
|
||||
self.trust_request_chat_template = trust_request_chat_template
|
||||
|
||||
async def create_classify(
|
||||
self,
|
||||
request: ClassificationRequest,
|
||||
@ -156,7 +219,6 @@ class ServingClassification(ClassificationMixin):
|
||||
|
||||
return await super().handle(ctx) # type: ignore
|
||||
|
||||
@override
|
||||
def _create_pooling_params(
|
||||
self,
|
||||
ctx: ClassificationServeContext,
|
||||
|
||||
@ -43,6 +43,8 @@ from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionNamedToolChoiceParam,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ClassificationChatRequest,
|
||||
ClassificationCompletionRequest,
|
||||
ClassificationRequest,
|
||||
ClassificationResponse,
|
||||
CompletionRequest,
|
||||
@ -114,13 +116,16 @@ CompletionLikeRequest: TypeAlias = (
|
||||
| DetokenizeRequest
|
||||
| EmbeddingCompletionRequest
|
||||
| RerankRequest
|
||||
| ClassificationRequest
|
||||
| ClassificationCompletionRequest
|
||||
| ScoreRequest
|
||||
| TokenizeCompletionRequest
|
||||
)
|
||||
|
||||
ChatLikeRequest: TypeAlias = (
|
||||
ChatCompletionRequest | EmbeddingChatRequest | TokenizeChatRequest
|
||||
ChatCompletionRequest
|
||||
| EmbeddingChatRequest
|
||||
| TokenizeChatRequest
|
||||
| ClassificationChatRequest
|
||||
)
|
||||
SpeechToTextRequest: TypeAlias = TranscriptionRequest | TranslationRequest
|
||||
AnyRequest: TypeAlias = (
|
||||
@ -814,7 +819,11 @@ class OpenAIServing:
|
||||
if not hasattr(request, "messages"):
|
||||
return message_types
|
||||
|
||||
for message in request.messages:
|
||||
messages = request.messages
|
||||
if messages is None or isinstance(messages, (str, bytes)):
|
||||
return message_types
|
||||
|
||||
for message in messages:
|
||||
if (
|
||||
isinstance(message, dict)
|
||||
and "content" in message
|
||||
@ -907,7 +916,8 @@ class OpenAIServing:
|
||||
EmbeddingCompletionRequest,
|
||||
ScoreRequest,
|
||||
RerankRequest,
|
||||
ClassificationRequest,
|
||||
ClassificationCompletionRequest,
|
||||
ClassificationChatRequest,
|
||||
),
|
||||
):
|
||||
# Note: input length can be up to the entire model context length
|
||||
@ -915,7 +925,8 @@ class OpenAIServing:
|
||||
if token_num > self.max_model_len:
|
||||
operations: dict[type[AnyRequest], str] = {
|
||||
ScoreRequest: "score",
|
||||
ClassificationRequest: "classification",
|
||||
ClassificationCompletionRequest: "classification",
|
||||
ClassificationChatRequest: "classification",
|
||||
}
|
||||
operation = operations.get(type(request), "embedding generation")
|
||||
raise ValueError(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user