[Frontend] Added chat-style multimodal support to /classify. (#27516)

Signed-off-by: WorldExplored <srreyansh.sethi@gmail.com>
Signed-off-by: Srreyansh Sethi <107075589+WorldExplored@users.noreply.github.com>
Signed-off-by: vnadathur <glvikramn@gmail.com>
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
Co-authored-by: vnadathur <236933696+vnadathur@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: vnadathur <glvikramn@gmail.com>
Co-authored-by: wang.yuqi <noooop@126.com>
Co-authored-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
Srreyansh Sethi 2025-11-14 03:03:55 -08:00 committed by GitHub
parent ecf8230d4d
commit 360bd8762f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 318 additions and 27 deletions

View File

@ -46,6 +46,16 @@ def test_single_input_classification(server: RemoteOpenAIServer, model_name: str
assert hasattr(output.data[0], "probs")
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_add_special_tokens_false(server: RemoteOpenAIServer, model_name: str):
response = requests.post(
server.url_for("classify"),
json={"model": model_name, "input": "hello", "add_special_tokens": False},
)
response.raise_for_status()
ClassificationResponse.model_validate(response.json())
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_multiple_inputs_classification(server: RemoteOpenAIServer, model_name: str):
input_texts = [

View File

@ -0,0 +1,95 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import pytest
import requests
from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.openai.protocol import ClassificationResponse
VLM_MODEL_NAME = "muziyongshixin/Qwen2.5-VL-7B-for-VideoCls"
MAXIMUM_VIDEOS = 1
TEST_VIDEO_URL = "https://www.bogotobogo.com/python/OpenCV_Python/images/mean_shift_tracking/slow_traffic_small.mp4"
HF_OVERRIDES = {
"text_config": {
"architectures": ["Qwen2_5_VLForSequenceClassification"],
},
}
@pytest.fixture(scope="module")
def server_vlm_classify():
args = [
"--runner",
"pooling",
"--max-model-len",
"5000",
"--enforce-eager",
"--limit-mm-per-prompt",
json.dumps({"video": MAXIMUM_VIDEOS}),
]
with RemoteOpenAIServer(
VLM_MODEL_NAME, args, override_hf_configs=HF_OVERRIDES
) as remote_server:
yield remote_server
@pytest.mark.parametrize("model_name", [VLM_MODEL_NAME])
def test_classify_accepts_chat_text_only(
server_vlm_classify: RemoteOpenAIServer, model_name: str
) -> None:
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Please classify this text request."},
],
}
]
response = requests.post(
server_vlm_classify.url_for("classify"),
json={"model": model_name, "messages": messages},
)
response.raise_for_status()
output = ClassificationResponse.model_validate(response.json())
assert output.object == "list"
assert output.model == model_name
assert len(output.data) == 1
assert len(output.data[0].probs) == 2
assert output.usage.prompt_tokens == 22
@pytest.mark.parametrize("model_name", [VLM_MODEL_NAME])
def test_classify_accepts_chat_video_url(
server_vlm_classify: RemoteOpenAIServer, model_name: str
) -> None:
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Please classify this video."},
{"type": "video_url", "video_url": {"url": TEST_VIDEO_URL}},
],
}
]
response = requests.post(
server_vlm_classify.url_for("classify"),
json={"model": model_name, "messages": messages},
)
response.raise_for_status()
output = ClassificationResponse.model_validate(response.json())
assert output.object == "list"
assert output.model == model_name
assert len(output.data) == 1
assert len(output.data[0].probs) == 2
assert output.usage.prompt_tokens == 4807

View File

@ -1784,6 +1784,9 @@ async def init_app_state(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
log_error_stack=args.log_error_stack,
)
if "classify" in supported_tasks

View File

@ -2000,10 +2000,10 @@ class ScoreResponse(OpenAIBaseModel):
usage: UsageInfo
class ClassificationRequest(OpenAIBaseModel):
class ClassificationCompletionRequest(OpenAIBaseModel):
model: str | None = None
input: list[str] | str
truncate_prompt_tokens: int | None = None
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
user: str | None = None
# --8<-- [start:classification-extra-params]
@ -2015,7 +2015,21 @@ class ClassificationRequest(OpenAIBaseModel):
"if the served model does not use priority scheduling."
),
)
add_special_tokens: bool = Field(
default=True,
description=(
"If true (the default), special tokens (e.g. BOS) will be added to "
"the prompt."
),
)
request_id: str = Field(
default_factory=lambda: f"{random_uuid()}",
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."
),
)
softmax: bool | None = Field(
default=None,
description="softmax will be deprecated, please use use_activation instead.",
@ -2040,6 +2054,102 @@ class ClassificationRequest(OpenAIBaseModel):
)
class ClassificationChatRequest(OpenAIBaseModel):
model: str | None = None
messages: list[ChatCompletionMessageParam]
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
user: str | None = None
# --8<-- [start:chat-classification-extra-params]
add_generation_prompt: bool = Field(
default=False,
description=(
"If true, the generation prompt will be added to the chat template. "
"This is a parameter used by chat template in tokenizer config of the "
"model."
),
)
add_special_tokens: bool = Field(
default=False,
description=(
"If true, special tokens (e.g. BOS) will be added to the prompt "
"on top of what is added by the chat template. "
"For most models, the chat template takes care of adding the "
"special tokens so this should be set to false (as is the "
"default)."
),
)
chat_template: str | None = Field(
default=None,
description=(
"A Jinja template to use for this conversion. "
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one."
),
)
chat_template_kwargs: dict[str, Any] | None = Field(
default=None,
description=(
"Additional keyword args to pass to the template renderer. "
"Will be accessible by the chat template."
),
)
mm_processor_kwargs: dict[str, Any] | None = Field(
default=None,
description=("Additional kwargs to pass to the HF processor."),
)
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."
),
)
request_id: str = Field(
default_factory=lambda: f"{random_uuid()}",
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."
),
)
softmax: bool | None = Field(
default=None,
description="softmax will be deprecated, please use use_activation instead.",
)
activation: bool | None = Field(
default=None,
description="activation will be deprecated, please use use_activation instead.",
)
use_activation: bool | None = Field(
default=None,
description="Whether to use activation for classification outputs. "
"Default is True.",
)
# --8<-- [end:chat-classification-extra-params]
def to_pooling_params(self):
return PoolingParams(
truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=get_use_activation(self),
)
ClassificationRequest: TypeAlias = (
ClassificationCompletionRequest | ClassificationChatRequest
)
class ClassificationData(OpenAIBaseModel):
index: int
label: str | None

View File

@ -4,13 +4,17 @@
from http import HTTPStatus
from typing import cast
import jinja2
import numpy as np
from fastapi import Request
from typing_extensions import override
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
ClassificationChatRequest,
ClassificationCompletionRequest,
ClassificationData,
ClassificationRequest,
ClassificationResponse,
@ -32,7 +36,10 @@ logger = init_logger(__name__)
class ClassificationMixin(OpenAIServing):
@override
chat_template: str | None
chat_template_content_format: ChatTemplateContentFormatOption
trust_request_chat_template: bool
async def _preprocess(
self,
ctx: ServeContext,
@ -42,31 +49,79 @@ class ClassificationMixin(OpenAIServing):
and prepare model-specific inputs.
"""
ctx = cast(ClassificationServeContext, ctx)
if isinstance(ctx.request.input, str) and not ctx.request.input:
return self.create_error_response(
"Input cannot be empty for classification",
status_code=HTTPStatus.BAD_REQUEST,
)
if isinstance(ctx.request.input, list) and len(ctx.request.input) == 0:
return None
try:
ctx.tokenizer = await self.engine_client.get_tokenizer()
renderer = self._get_renderer(ctx.tokenizer)
ctx.engine_prompts = await renderer.render_prompt(
prompt_or_prompts=ctx.request.input,
config=self._build_render_config(ctx.request),
)
request_obj = ctx.request
if isinstance(request_obj, ClassificationChatRequest):
chat_request = request_obj
messages = chat_request.messages
trust_request_chat_template = getattr(
self,
"trust_request_chat_template",
False,
)
ret = self._validate_chat_template(
request_chat_template=chat_request.chat_template,
chat_template_kwargs=chat_request.chat_template_kwargs,
trust_request_chat_template=trust_request_chat_template,
)
if ret:
return ret
(
_,
_,
engine_prompts,
) = await self._preprocess_chat(
cast(ChatCompletionRequest, chat_request),
ctx.tokenizer,
messages,
chat_template=(
chat_request.chat_template
or getattr(self, "chat_template", None)
),
chat_template_content_format=cast(
ChatTemplateContentFormatOption,
getattr(self, "chat_template_content_format", "auto"),
),
add_generation_prompt=False,
continue_final_message=False,
add_special_tokens=chat_request.add_special_tokens,
)
ctx.engine_prompts = engine_prompts
elif isinstance(request_obj, ClassificationCompletionRequest):
completion_request = request_obj
input_data = completion_request.input
if input_data in (None, ""):
return self.create_error_response(
"Input or messages must be provided",
status_code=HTTPStatus.BAD_REQUEST,
)
if isinstance(input_data, list) and not input_data:
ctx.engine_prompts = []
return None
renderer = self._get_renderer(ctx.tokenizer)
prompt_input = cast(str | list[str], input_data)
ctx.engine_prompts = await renderer.render_prompt(
prompt_or_prompts=prompt_input,
config=self._build_render_config(completion_request),
)
else:
return self.create_error_response(
"Invalid classification request type",
status_code=HTTPStatus.BAD_REQUEST,
)
return None
except (ValueError, TypeError) as e:
except (ValueError, TypeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
@override
def _build_response(
self,
ctx: ServeContext,
@ -118,6 +173,7 @@ class ClassificationMixin(OpenAIServing):
return RenderConfig(
max_length=self.max_model_len,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
@ -130,6 +186,9 @@ class ServingClassification(ClassificationMixin):
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
@ -139,6 +198,10 @@ class ServingClassification(ClassificationMixin):
log_error_stack=log_error_stack,
)
self.chat_template = chat_template
self.chat_template_content_format = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
async def create_classify(
self,
request: ClassificationRequest,
@ -156,7 +219,6 @@ class ServingClassification(ClassificationMixin):
return await super().handle(ctx) # type: ignore
@override
def _create_pooling_params(
self,
ctx: ClassificationServeContext,

View File

@ -43,6 +43,8 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionNamedToolChoiceParam,
ChatCompletionRequest,
ChatCompletionResponse,
ClassificationChatRequest,
ClassificationCompletionRequest,
ClassificationRequest,
ClassificationResponse,
CompletionRequest,
@ -114,13 +116,16 @@ CompletionLikeRequest: TypeAlias = (
| DetokenizeRequest
| EmbeddingCompletionRequest
| RerankRequest
| ClassificationRequest
| ClassificationCompletionRequest
| ScoreRequest
| TokenizeCompletionRequest
)
ChatLikeRequest: TypeAlias = (
ChatCompletionRequest | EmbeddingChatRequest | TokenizeChatRequest
ChatCompletionRequest
| EmbeddingChatRequest
| TokenizeChatRequest
| ClassificationChatRequest
)
SpeechToTextRequest: TypeAlias = TranscriptionRequest | TranslationRequest
AnyRequest: TypeAlias = (
@ -814,7 +819,11 @@ class OpenAIServing:
if not hasattr(request, "messages"):
return message_types
for message in request.messages:
messages = request.messages
if messages is None or isinstance(messages, (str, bytes)):
return message_types
for message in messages:
if (
isinstance(message, dict)
and "content" in message
@ -907,7 +916,8 @@ class OpenAIServing:
EmbeddingCompletionRequest,
ScoreRequest,
RerankRequest,
ClassificationRequest,
ClassificationCompletionRequest,
ClassificationChatRequest,
),
):
# Note: input length can be up to the entire model context length
@ -915,7 +925,8 @@ class OpenAIServing:
if token_num > self.max_model_len:
operations: dict[type[AnyRequest], str] = {
ScoreRequest: "score",
ClassificationRequest: "classification",
ClassificationCompletionRequest: "classification",
ClassificationChatRequest: "classification",
}
operation = operations.get(type(request), "embedding generation")
raise ValueError(