[Frontend] Added chat-style multimodal support to /classify. (#27516)

Signed-off-by: WorldExplored <srreyansh.sethi@gmail.com>
Signed-off-by: Srreyansh Sethi <107075589+WorldExplored@users.noreply.github.com>
Signed-off-by: vnadathur <glvikramn@gmail.com>
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
Co-authored-by: vnadathur <236933696+vnadathur@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: vnadathur <glvikramn@gmail.com>
Co-authored-by: wang.yuqi <noooop@126.com>
Co-authored-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
Srreyansh Sethi 2025-11-14 03:03:55 -08:00 committed by GitHub
parent ecf8230d4d
commit 360bd8762f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 318 additions and 27 deletions

View File

@ -46,6 +46,16 @@ def test_single_input_classification(server: RemoteOpenAIServer, model_name: str
assert hasattr(output.data[0], "probs") assert hasattr(output.data[0], "probs")
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_add_special_tokens_false(server: RemoteOpenAIServer, model_name: str):
response = requests.post(
server.url_for("classify"),
json={"model": model_name, "input": "hello", "add_special_tokens": False},
)
response.raise_for_status()
ClassificationResponse.model_validate(response.json())
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_multiple_inputs_classification(server: RemoteOpenAIServer, model_name: str): def test_multiple_inputs_classification(server: RemoteOpenAIServer, model_name: str):
input_texts = [ input_texts = [

View File

@ -0,0 +1,95 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import pytest
import requests
from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.openai.protocol import ClassificationResponse
VLM_MODEL_NAME = "muziyongshixin/Qwen2.5-VL-7B-for-VideoCls"
MAXIMUM_VIDEOS = 1
TEST_VIDEO_URL = "https://www.bogotobogo.com/python/OpenCV_Python/images/mean_shift_tracking/slow_traffic_small.mp4"
HF_OVERRIDES = {
"text_config": {
"architectures": ["Qwen2_5_VLForSequenceClassification"],
},
}
@pytest.fixture(scope="module")
def server_vlm_classify():
args = [
"--runner",
"pooling",
"--max-model-len",
"5000",
"--enforce-eager",
"--limit-mm-per-prompt",
json.dumps({"video": MAXIMUM_VIDEOS}),
]
with RemoteOpenAIServer(
VLM_MODEL_NAME, args, override_hf_configs=HF_OVERRIDES
) as remote_server:
yield remote_server
@pytest.mark.parametrize("model_name", [VLM_MODEL_NAME])
def test_classify_accepts_chat_text_only(
server_vlm_classify: RemoteOpenAIServer, model_name: str
) -> None:
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Please classify this text request."},
],
}
]
response = requests.post(
server_vlm_classify.url_for("classify"),
json={"model": model_name, "messages": messages},
)
response.raise_for_status()
output = ClassificationResponse.model_validate(response.json())
assert output.object == "list"
assert output.model == model_name
assert len(output.data) == 1
assert len(output.data[0].probs) == 2
assert output.usage.prompt_tokens == 22
@pytest.mark.parametrize("model_name", [VLM_MODEL_NAME])
def test_classify_accepts_chat_video_url(
server_vlm_classify: RemoteOpenAIServer, model_name: str
) -> None:
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Please classify this video."},
{"type": "video_url", "video_url": {"url": TEST_VIDEO_URL}},
],
}
]
response = requests.post(
server_vlm_classify.url_for("classify"),
json={"model": model_name, "messages": messages},
)
response.raise_for_status()
output = ClassificationResponse.model_validate(response.json())
assert output.object == "list"
assert output.model == model_name
assert len(output.data) == 1
assert len(output.data[0].probs) == 2
assert output.usage.prompt_tokens == 4807

View File

@ -1784,6 +1784,9 @@ async def init_app_state(
engine_client, engine_client,
state.openai_serving_models, state.openai_serving_models,
request_logger=request_logger, request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
log_error_stack=args.log_error_stack, log_error_stack=args.log_error_stack,
) )
if "classify" in supported_tasks if "classify" in supported_tasks

View File

@ -2000,10 +2000,10 @@ class ScoreResponse(OpenAIBaseModel):
usage: UsageInfo usage: UsageInfo
class ClassificationRequest(OpenAIBaseModel): class ClassificationCompletionRequest(OpenAIBaseModel):
model: str | None = None model: str | None = None
input: list[str] | str input: list[str] | str
truncate_prompt_tokens: int | None = None truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
user: str | None = None user: str | None = None
# --8<-- [start:classification-extra-params] # --8<-- [start:classification-extra-params]
@ -2015,7 +2015,21 @@ class ClassificationRequest(OpenAIBaseModel):
"if the served model does not use priority scheduling." "if the served model does not use priority scheduling."
), ),
) )
add_special_tokens: bool = Field(
default=True,
description=(
"If true (the default), special tokens (e.g. BOS) will be added to "
"the prompt."
),
)
request_id: str = Field(
default_factory=lambda: f"{random_uuid()}",
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."
),
)
softmax: bool | None = Field( softmax: bool | None = Field(
default=None, default=None,
description="softmax will be deprecated, please use use_activation instead.", description="softmax will be deprecated, please use use_activation instead.",
@ -2040,6 +2054,102 @@ class ClassificationRequest(OpenAIBaseModel):
) )
class ClassificationChatRequest(OpenAIBaseModel):
model: str | None = None
messages: list[ChatCompletionMessageParam]
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
user: str | None = None
# --8<-- [start:chat-classification-extra-params]
add_generation_prompt: bool = Field(
default=False,
description=(
"If true, the generation prompt will be added to the chat template. "
"This is a parameter used by chat template in tokenizer config of the "
"model."
),
)
add_special_tokens: bool = Field(
default=False,
description=(
"If true, special tokens (e.g. BOS) will be added to the prompt "
"on top of what is added by the chat template. "
"For most models, the chat template takes care of adding the "
"special tokens so this should be set to false (as is the "
"default)."
),
)
chat_template: str | None = Field(
default=None,
description=(
"A Jinja template to use for this conversion. "
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one."
),
)
chat_template_kwargs: dict[str, Any] | None = Field(
default=None,
description=(
"Additional keyword args to pass to the template renderer. "
"Will be accessible by the chat template."
),
)
mm_processor_kwargs: dict[str, Any] | None = Field(
default=None,
description=("Additional kwargs to pass to the HF processor."),
)
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."
),
)
request_id: str = Field(
default_factory=lambda: f"{random_uuid()}",
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."
),
)
softmax: bool | None = Field(
default=None,
description="softmax will be deprecated, please use use_activation instead.",
)
activation: bool | None = Field(
default=None,
description="activation will be deprecated, please use use_activation instead.",
)
use_activation: bool | None = Field(
default=None,
description="Whether to use activation for classification outputs. "
"Default is True.",
)
# --8<-- [end:chat-classification-extra-params]
def to_pooling_params(self):
return PoolingParams(
truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=get_use_activation(self),
)
ClassificationRequest: TypeAlias = (
ClassificationCompletionRequest | ClassificationChatRequest
)
class ClassificationData(OpenAIBaseModel): class ClassificationData(OpenAIBaseModel):
index: int index: int
label: str | None label: str | None

View File

@ -4,13 +4,17 @@
from http import HTTPStatus from http import HTTPStatus
from typing import cast from typing import cast
import jinja2
import numpy as np import numpy as np
from fastapi import Request from fastapi import Request
from typing_extensions import override
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
ClassificationChatRequest,
ClassificationCompletionRequest,
ClassificationData, ClassificationData,
ClassificationRequest, ClassificationRequest,
ClassificationResponse, ClassificationResponse,
@ -32,7 +36,10 @@ logger = init_logger(__name__)
class ClassificationMixin(OpenAIServing): class ClassificationMixin(OpenAIServing):
@override chat_template: str | None
chat_template_content_format: ChatTemplateContentFormatOption
trust_request_chat_template: bool
async def _preprocess( async def _preprocess(
self, self,
ctx: ServeContext, ctx: ServeContext,
@ -42,31 +49,79 @@ class ClassificationMixin(OpenAIServing):
and prepare model-specific inputs. and prepare model-specific inputs.
""" """
ctx = cast(ClassificationServeContext, ctx) ctx = cast(ClassificationServeContext, ctx)
if isinstance(ctx.request.input, str) and not ctx.request.input:
return self.create_error_response(
"Input cannot be empty for classification",
status_code=HTTPStatus.BAD_REQUEST,
)
if isinstance(ctx.request.input, list) and len(ctx.request.input) == 0:
return None
try: try:
ctx.tokenizer = await self.engine_client.get_tokenizer() ctx.tokenizer = await self.engine_client.get_tokenizer()
renderer = self._get_renderer(ctx.tokenizer) request_obj = ctx.request
ctx.engine_prompts = await renderer.render_prompt(
prompt_or_prompts=ctx.request.input, if isinstance(request_obj, ClassificationChatRequest):
config=self._build_render_config(ctx.request), chat_request = request_obj
) messages = chat_request.messages
trust_request_chat_template = getattr(
self,
"trust_request_chat_template",
False,
)
ret = self._validate_chat_template(
request_chat_template=chat_request.chat_template,
chat_template_kwargs=chat_request.chat_template_kwargs,
trust_request_chat_template=trust_request_chat_template,
)
if ret:
return ret
(
_,
_,
engine_prompts,
) = await self._preprocess_chat(
cast(ChatCompletionRequest, chat_request),
ctx.tokenizer,
messages,
chat_template=(
chat_request.chat_template
or getattr(self, "chat_template", None)
),
chat_template_content_format=cast(
ChatTemplateContentFormatOption,
getattr(self, "chat_template_content_format", "auto"),
),
add_generation_prompt=False,
continue_final_message=False,
add_special_tokens=chat_request.add_special_tokens,
)
ctx.engine_prompts = engine_prompts
elif isinstance(request_obj, ClassificationCompletionRequest):
completion_request = request_obj
input_data = completion_request.input
if input_data in (None, ""):
return self.create_error_response(
"Input or messages must be provided",
status_code=HTTPStatus.BAD_REQUEST,
)
if isinstance(input_data, list) and not input_data:
ctx.engine_prompts = []
return None
renderer = self._get_renderer(ctx.tokenizer)
prompt_input = cast(str | list[str], input_data)
ctx.engine_prompts = await renderer.render_prompt(
prompt_or_prompts=prompt_input,
config=self._build_render_config(completion_request),
)
else:
return self.create_error_response(
"Invalid classification request type",
status_code=HTTPStatus.BAD_REQUEST,
)
return None return None
except (ValueError, TypeError) as e: except (ValueError, TypeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs") logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e)) return self.create_error_response(str(e))
@override
def _build_response( def _build_response(
self, self,
ctx: ServeContext, ctx: ServeContext,
@ -118,6 +173,7 @@ class ClassificationMixin(OpenAIServing):
return RenderConfig( return RenderConfig(
max_length=self.max_model_len, max_length=self.max_model_len,
truncate_prompt_tokens=request.truncate_prompt_tokens, truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
) )
@ -130,6 +186,9 @@ class ServingClassification(ClassificationMixin):
models: OpenAIServingModels, models: OpenAIServingModels,
*, *,
request_logger: RequestLogger | None, request_logger: RequestLogger | None,
chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
trust_request_chat_template: bool = False,
log_error_stack: bool = False, log_error_stack: bool = False,
) -> None: ) -> None:
super().__init__( super().__init__(
@ -139,6 +198,10 @@ class ServingClassification(ClassificationMixin):
log_error_stack=log_error_stack, log_error_stack=log_error_stack,
) )
self.chat_template = chat_template
self.chat_template_content_format = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
async def create_classify( async def create_classify(
self, self,
request: ClassificationRequest, request: ClassificationRequest,
@ -156,7 +219,6 @@ class ServingClassification(ClassificationMixin):
return await super().handle(ctx) # type: ignore return await super().handle(ctx) # type: ignore
@override
def _create_pooling_params( def _create_pooling_params(
self, self,
ctx: ClassificationServeContext, ctx: ClassificationServeContext,

View File

@ -43,6 +43,8 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionNamedToolChoiceParam, ChatCompletionNamedToolChoiceParam,
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
ClassificationChatRequest,
ClassificationCompletionRequest,
ClassificationRequest, ClassificationRequest,
ClassificationResponse, ClassificationResponse,
CompletionRequest, CompletionRequest,
@ -114,13 +116,16 @@ CompletionLikeRequest: TypeAlias = (
| DetokenizeRequest | DetokenizeRequest
| EmbeddingCompletionRequest | EmbeddingCompletionRequest
| RerankRequest | RerankRequest
| ClassificationRequest | ClassificationCompletionRequest
| ScoreRequest | ScoreRequest
| TokenizeCompletionRequest | TokenizeCompletionRequest
) )
ChatLikeRequest: TypeAlias = ( ChatLikeRequest: TypeAlias = (
ChatCompletionRequest | EmbeddingChatRequest | TokenizeChatRequest ChatCompletionRequest
| EmbeddingChatRequest
| TokenizeChatRequest
| ClassificationChatRequest
) )
SpeechToTextRequest: TypeAlias = TranscriptionRequest | TranslationRequest SpeechToTextRequest: TypeAlias = TranscriptionRequest | TranslationRequest
AnyRequest: TypeAlias = ( AnyRequest: TypeAlias = (
@ -814,7 +819,11 @@ class OpenAIServing:
if not hasattr(request, "messages"): if not hasattr(request, "messages"):
return message_types return message_types
for message in request.messages: messages = request.messages
if messages is None or isinstance(messages, (str, bytes)):
return message_types
for message in messages:
if ( if (
isinstance(message, dict) isinstance(message, dict)
and "content" in message and "content" in message
@ -907,7 +916,8 @@ class OpenAIServing:
EmbeddingCompletionRequest, EmbeddingCompletionRequest,
ScoreRequest, ScoreRequest,
RerankRequest, RerankRequest,
ClassificationRequest, ClassificationCompletionRequest,
ClassificationChatRequest,
), ),
): ):
# Note: input length can be up to the entire model context length # Note: input length can be up to the entire model context length
@ -915,7 +925,8 @@ class OpenAIServing:
if token_num > self.max_model_len: if token_num > self.max_model_len:
operations: dict[type[AnyRequest], str] = { operations: dict[type[AnyRequest], str] = {
ScoreRequest: "score", ScoreRequest: "score",
ClassificationRequest: "classification", ClassificationCompletionRequest: "classification",
ClassificationChatRequest: "classification",
} }
operation = operations.get(type(request), "embedding generation") operation = operations.get(type(request), "embedding generation")
raise ValueError( raise ValueError(