mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 16:54:59 +08:00
[Frontend] Added chat-style multimodal support to /classify. (#27516)
Signed-off-by: WorldExplored <srreyansh.sethi@gmail.com> Signed-off-by: Srreyansh Sethi <107075589+WorldExplored@users.noreply.github.com> Signed-off-by: vnadathur <glvikramn@gmail.com> Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io> Co-authored-by: vnadathur <236933696+vnadathur@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: vnadathur <glvikramn@gmail.com> Co-authored-by: wang.yuqi <noooop@126.com> Co-authored-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
parent
ecf8230d4d
commit
360bd8762f
@ -46,6 +46,16 @@ def test_single_input_classification(server: RemoteOpenAIServer, model_name: str
|
|||||||
assert hasattr(output.data[0], "probs")
|
assert hasattr(output.data[0], "probs")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
|
def test_add_special_tokens_false(server: RemoteOpenAIServer, model_name: str):
|
||||||
|
response = requests.post(
|
||||||
|
server.url_for("classify"),
|
||||||
|
json={"model": model_name, "input": "hello", "add_special_tokens": False},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
ClassificationResponse.model_validate(response.json())
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
def test_multiple_inputs_classification(server: RemoteOpenAIServer, model_name: str):
|
def test_multiple_inputs_classification(server: RemoteOpenAIServer, model_name: str):
|
||||||
input_texts = [
|
input_texts = [
|
||||||
|
|||||||
@ -0,0 +1,95 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from tests.utils import RemoteOpenAIServer
|
||||||
|
from vllm.entrypoints.openai.protocol import ClassificationResponse
|
||||||
|
|
||||||
|
VLM_MODEL_NAME = "muziyongshixin/Qwen2.5-VL-7B-for-VideoCls"
|
||||||
|
MAXIMUM_VIDEOS = 1
|
||||||
|
TEST_VIDEO_URL = "https://www.bogotobogo.com/python/OpenCV_Python/images/mean_shift_tracking/slow_traffic_small.mp4"
|
||||||
|
|
||||||
|
HF_OVERRIDES = {
|
||||||
|
"text_config": {
|
||||||
|
"architectures": ["Qwen2_5_VLForSequenceClassification"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def server_vlm_classify():
|
||||||
|
args = [
|
||||||
|
"--runner",
|
||||||
|
"pooling",
|
||||||
|
"--max-model-len",
|
||||||
|
"5000",
|
||||||
|
"--enforce-eager",
|
||||||
|
"--limit-mm-per-prompt",
|
||||||
|
json.dumps({"video": MAXIMUM_VIDEOS}),
|
||||||
|
]
|
||||||
|
|
||||||
|
with RemoteOpenAIServer(
|
||||||
|
VLM_MODEL_NAME, args, override_hf_configs=HF_OVERRIDES
|
||||||
|
) as remote_server:
|
||||||
|
yield remote_server
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_name", [VLM_MODEL_NAME])
|
||||||
|
def test_classify_accepts_chat_text_only(
|
||||||
|
server_vlm_classify: RemoteOpenAIServer, model_name: str
|
||||||
|
) -> None:
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "Please classify this text request."},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
server_vlm_classify.url_for("classify"),
|
||||||
|
json={"model": model_name, "messages": messages},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
output = ClassificationResponse.model_validate(response.json())
|
||||||
|
|
||||||
|
assert output.object == "list"
|
||||||
|
assert output.model == model_name
|
||||||
|
assert len(output.data) == 1
|
||||||
|
assert len(output.data[0].probs) == 2
|
||||||
|
assert output.usage.prompt_tokens == 22
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_name", [VLM_MODEL_NAME])
|
||||||
|
def test_classify_accepts_chat_video_url(
|
||||||
|
server_vlm_classify: RemoteOpenAIServer, model_name: str
|
||||||
|
) -> None:
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "Please classify this video."},
|
||||||
|
{"type": "video_url", "video_url": {"url": TEST_VIDEO_URL}},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
server_vlm_classify.url_for("classify"),
|
||||||
|
json={"model": model_name, "messages": messages},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
output = ClassificationResponse.model_validate(response.json())
|
||||||
|
|
||||||
|
assert output.object == "list"
|
||||||
|
assert output.model == model_name
|
||||||
|
assert len(output.data) == 1
|
||||||
|
assert len(output.data[0].probs) == 2
|
||||||
|
assert output.usage.prompt_tokens == 4807
|
||||||
@ -1784,6 +1784,9 @@ async def init_app_state(
|
|||||||
engine_client,
|
engine_client,
|
||||||
state.openai_serving_models,
|
state.openai_serving_models,
|
||||||
request_logger=request_logger,
|
request_logger=request_logger,
|
||||||
|
chat_template=resolved_chat_template,
|
||||||
|
chat_template_content_format=args.chat_template_content_format,
|
||||||
|
trust_request_chat_template=args.trust_request_chat_template,
|
||||||
log_error_stack=args.log_error_stack,
|
log_error_stack=args.log_error_stack,
|
||||||
)
|
)
|
||||||
if "classify" in supported_tasks
|
if "classify" in supported_tasks
|
||||||
|
|||||||
@ -2000,10 +2000,10 @@ class ScoreResponse(OpenAIBaseModel):
|
|||||||
usage: UsageInfo
|
usage: UsageInfo
|
||||||
|
|
||||||
|
|
||||||
class ClassificationRequest(OpenAIBaseModel):
|
class ClassificationCompletionRequest(OpenAIBaseModel):
|
||||||
model: str | None = None
|
model: str | None = None
|
||||||
input: list[str] | str
|
input: list[str] | str
|
||||||
truncate_prompt_tokens: int | None = None
|
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
|
||||||
user: str | None = None
|
user: str | None = None
|
||||||
|
|
||||||
# --8<-- [start:classification-extra-params]
|
# --8<-- [start:classification-extra-params]
|
||||||
@ -2015,7 +2015,21 @@ class ClassificationRequest(OpenAIBaseModel):
|
|||||||
"if the served model does not use priority scheduling."
|
"if the served model does not use priority scheduling."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
add_special_tokens: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description=(
|
||||||
|
"If true (the default), special tokens (e.g. BOS) will be added to "
|
||||||
|
"the prompt."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
request_id: str = Field(
|
||||||
|
default_factory=lambda: f"{random_uuid()}",
|
||||||
|
description=(
|
||||||
|
"The request_id related to this request. If the caller does "
|
||||||
|
"not set it, a random_uuid will be generated. This id is used "
|
||||||
|
"through out the inference process and return in response."
|
||||||
|
),
|
||||||
|
)
|
||||||
softmax: bool | None = Field(
|
softmax: bool | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="softmax will be deprecated, please use use_activation instead.",
|
description="softmax will be deprecated, please use use_activation instead.",
|
||||||
@ -2040,6 +2054,102 @@ class ClassificationRequest(OpenAIBaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationChatRequest(OpenAIBaseModel):
|
||||||
|
model: str | None = None
|
||||||
|
messages: list[ChatCompletionMessageParam]
|
||||||
|
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
|
||||||
|
user: str | None = None
|
||||||
|
|
||||||
|
# --8<-- [start:chat-classification-extra-params]
|
||||||
|
add_generation_prompt: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description=(
|
||||||
|
"If true, the generation prompt will be added to the chat template. "
|
||||||
|
"This is a parameter used by chat template in tokenizer config of the "
|
||||||
|
"model."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
add_special_tokens: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description=(
|
||||||
|
"If true, special tokens (e.g. BOS) will be added to the prompt "
|
||||||
|
"on top of what is added by the chat template. "
|
||||||
|
"For most models, the chat template takes care of adding the "
|
||||||
|
"special tokens so this should be set to false (as is the "
|
||||||
|
"default)."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_template: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"A Jinja template to use for this conversion. "
|
||||||
|
"As of transformers v4.44, default chat template is no longer "
|
||||||
|
"allowed, so you must provide a chat template if the tokenizer "
|
||||||
|
"does not define one."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_template_kwargs: dict[str, Any] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"Additional keyword args to pass to the template renderer. "
|
||||||
|
"Will be accessible by the chat template."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
mm_processor_kwargs: dict[str, Any] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description=("Additional kwargs to pass to the HF processor."),
|
||||||
|
)
|
||||||
|
|
||||||
|
priority: int = Field(
|
||||||
|
default=0,
|
||||||
|
description=(
|
||||||
|
"The priority of the request (lower means earlier handling; "
|
||||||
|
"default: 0). Any priority other than 0 will raise an error "
|
||||||
|
"if the served model does not use priority scheduling."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
request_id: str = Field(
|
||||||
|
default_factory=lambda: f"{random_uuid()}",
|
||||||
|
description=(
|
||||||
|
"The request_id related to this request. If the caller does "
|
||||||
|
"not set it, a random_uuid will be generated. This id is used "
|
||||||
|
"through out the inference process and return in response."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
softmax: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="softmax will be deprecated, please use use_activation instead.",
|
||||||
|
)
|
||||||
|
|
||||||
|
activation: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="activation will be deprecated, please use use_activation instead.",
|
||||||
|
)
|
||||||
|
|
||||||
|
use_activation: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Whether to use activation for classification outputs. "
|
||||||
|
"Default is True.",
|
||||||
|
)
|
||||||
|
# --8<-- [end:chat-classification-extra-params]
|
||||||
|
|
||||||
|
def to_pooling_params(self):
|
||||||
|
return PoolingParams(
|
||||||
|
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||||
|
use_activation=get_use_activation(self),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
ClassificationRequest: TypeAlias = (
|
||||||
|
ClassificationCompletionRequest | ClassificationChatRequest
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ClassificationData(OpenAIBaseModel):
|
class ClassificationData(OpenAIBaseModel):
|
||||||
index: int
|
index: int
|
||||||
label: str | None
|
label: str | None
|
||||||
|
|||||||
@ -4,13 +4,17 @@
|
|||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
|
import jinja2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from typing_extensions import override
|
|
||||||
|
|
||||||
from vllm.engine.protocol import EngineClient
|
from vllm.engine.protocol import EngineClient
|
||||||
|
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||||
from vllm.entrypoints.logger import RequestLogger
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
from vllm.entrypoints.openai.protocol import (
|
from vllm.entrypoints.openai.protocol import (
|
||||||
|
ChatCompletionRequest,
|
||||||
|
ClassificationChatRequest,
|
||||||
|
ClassificationCompletionRequest,
|
||||||
ClassificationData,
|
ClassificationData,
|
||||||
ClassificationRequest,
|
ClassificationRequest,
|
||||||
ClassificationResponse,
|
ClassificationResponse,
|
||||||
@ -32,7 +36,10 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class ClassificationMixin(OpenAIServing):
|
class ClassificationMixin(OpenAIServing):
|
||||||
@override
|
chat_template: str | None
|
||||||
|
chat_template_content_format: ChatTemplateContentFormatOption
|
||||||
|
trust_request_chat_template: bool
|
||||||
|
|
||||||
async def _preprocess(
|
async def _preprocess(
|
||||||
self,
|
self,
|
||||||
ctx: ServeContext,
|
ctx: ServeContext,
|
||||||
@ -42,31 +49,79 @@ class ClassificationMixin(OpenAIServing):
|
|||||||
and prepare model-specific inputs.
|
and prepare model-specific inputs.
|
||||||
"""
|
"""
|
||||||
ctx = cast(ClassificationServeContext, ctx)
|
ctx = cast(ClassificationServeContext, ctx)
|
||||||
if isinstance(ctx.request.input, str) and not ctx.request.input:
|
|
||||||
return self.create_error_response(
|
|
||||||
"Input cannot be empty for classification",
|
|
||||||
status_code=HTTPStatus.BAD_REQUEST,
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(ctx.request.input, list) and len(ctx.request.input) == 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ctx.tokenizer = await self.engine_client.get_tokenizer()
|
ctx.tokenizer = await self.engine_client.get_tokenizer()
|
||||||
|
|
||||||
renderer = self._get_renderer(ctx.tokenizer)
|
request_obj = ctx.request
|
||||||
ctx.engine_prompts = await renderer.render_prompt(
|
|
||||||
prompt_or_prompts=ctx.request.input,
|
if isinstance(request_obj, ClassificationChatRequest):
|
||||||
config=self._build_render_config(ctx.request),
|
chat_request = request_obj
|
||||||
)
|
messages = chat_request.messages
|
||||||
|
trust_request_chat_template = getattr(
|
||||||
|
self,
|
||||||
|
"trust_request_chat_template",
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
ret = self._validate_chat_template(
|
||||||
|
request_chat_template=chat_request.chat_template,
|
||||||
|
chat_template_kwargs=chat_request.chat_template_kwargs,
|
||||||
|
trust_request_chat_template=trust_request_chat_template,
|
||||||
|
)
|
||||||
|
if ret:
|
||||||
|
return ret
|
||||||
|
|
||||||
|
(
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
engine_prompts,
|
||||||
|
) = await self._preprocess_chat(
|
||||||
|
cast(ChatCompletionRequest, chat_request),
|
||||||
|
ctx.tokenizer,
|
||||||
|
messages,
|
||||||
|
chat_template=(
|
||||||
|
chat_request.chat_template
|
||||||
|
or getattr(self, "chat_template", None)
|
||||||
|
),
|
||||||
|
chat_template_content_format=cast(
|
||||||
|
ChatTemplateContentFormatOption,
|
||||||
|
getattr(self, "chat_template_content_format", "auto"),
|
||||||
|
),
|
||||||
|
add_generation_prompt=False,
|
||||||
|
continue_final_message=False,
|
||||||
|
add_special_tokens=chat_request.add_special_tokens,
|
||||||
|
)
|
||||||
|
ctx.engine_prompts = engine_prompts
|
||||||
|
|
||||||
|
elif isinstance(request_obj, ClassificationCompletionRequest):
|
||||||
|
completion_request = request_obj
|
||||||
|
input_data = completion_request.input
|
||||||
|
if input_data in (None, ""):
|
||||||
|
return self.create_error_response(
|
||||||
|
"Input or messages must be provided",
|
||||||
|
status_code=HTTPStatus.BAD_REQUEST,
|
||||||
|
)
|
||||||
|
if isinstance(input_data, list) and not input_data:
|
||||||
|
ctx.engine_prompts = []
|
||||||
|
return None
|
||||||
|
|
||||||
|
renderer = self._get_renderer(ctx.tokenizer)
|
||||||
|
prompt_input = cast(str | list[str], input_data)
|
||||||
|
ctx.engine_prompts = await renderer.render_prompt(
|
||||||
|
prompt_or_prompts=prompt_input,
|
||||||
|
config=self._build_render_config(completion_request),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self.create_error_response(
|
||||||
|
"Invalid classification request type",
|
||||||
|
status_code=HTTPStatus.BAD_REQUEST,
|
||||||
|
)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
except (ValueError, TypeError) as e:
|
except (ValueError, TypeError, jinja2.TemplateError) as e:
|
||||||
logger.exception("Error in preprocessing prompt inputs")
|
logger.exception("Error in preprocessing prompt inputs")
|
||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
@override
|
|
||||||
def _build_response(
|
def _build_response(
|
||||||
self,
|
self,
|
||||||
ctx: ServeContext,
|
ctx: ServeContext,
|
||||||
@ -118,6 +173,7 @@ class ClassificationMixin(OpenAIServing):
|
|||||||
return RenderConfig(
|
return RenderConfig(
|
||||||
max_length=self.max_model_len,
|
max_length=self.max_model_len,
|
||||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||||
|
add_special_tokens=request.add_special_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -130,6 +186,9 @@ class ServingClassification(ClassificationMixin):
|
|||||||
models: OpenAIServingModels,
|
models: OpenAIServingModels,
|
||||||
*,
|
*,
|
||||||
request_logger: RequestLogger | None,
|
request_logger: RequestLogger | None,
|
||||||
|
chat_template: str | None = None,
|
||||||
|
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
|
||||||
|
trust_request_chat_template: bool = False,
|
||||||
log_error_stack: bool = False,
|
log_error_stack: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@ -139,6 +198,10 @@ class ServingClassification(ClassificationMixin):
|
|||||||
log_error_stack=log_error_stack,
|
log_error_stack=log_error_stack,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.chat_template = chat_template
|
||||||
|
self.chat_template_content_format = chat_template_content_format
|
||||||
|
self.trust_request_chat_template = trust_request_chat_template
|
||||||
|
|
||||||
async def create_classify(
|
async def create_classify(
|
||||||
self,
|
self,
|
||||||
request: ClassificationRequest,
|
request: ClassificationRequest,
|
||||||
@ -156,7 +219,6 @@ class ServingClassification(ClassificationMixin):
|
|||||||
|
|
||||||
return await super().handle(ctx) # type: ignore
|
return await super().handle(ctx) # type: ignore
|
||||||
|
|
||||||
@override
|
|
||||||
def _create_pooling_params(
|
def _create_pooling_params(
|
||||||
self,
|
self,
|
||||||
ctx: ClassificationServeContext,
|
ctx: ClassificationServeContext,
|
||||||
|
|||||||
@ -43,6 +43,8 @@ from vllm.entrypoints.openai.protocol import (
|
|||||||
ChatCompletionNamedToolChoiceParam,
|
ChatCompletionNamedToolChoiceParam,
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
|
ClassificationChatRequest,
|
||||||
|
ClassificationCompletionRequest,
|
||||||
ClassificationRequest,
|
ClassificationRequest,
|
||||||
ClassificationResponse,
|
ClassificationResponse,
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
@ -114,13 +116,16 @@ CompletionLikeRequest: TypeAlias = (
|
|||||||
| DetokenizeRequest
|
| DetokenizeRequest
|
||||||
| EmbeddingCompletionRequest
|
| EmbeddingCompletionRequest
|
||||||
| RerankRequest
|
| RerankRequest
|
||||||
| ClassificationRequest
|
| ClassificationCompletionRequest
|
||||||
| ScoreRequest
|
| ScoreRequest
|
||||||
| TokenizeCompletionRequest
|
| TokenizeCompletionRequest
|
||||||
)
|
)
|
||||||
|
|
||||||
ChatLikeRequest: TypeAlias = (
|
ChatLikeRequest: TypeAlias = (
|
||||||
ChatCompletionRequest | EmbeddingChatRequest | TokenizeChatRequest
|
ChatCompletionRequest
|
||||||
|
| EmbeddingChatRequest
|
||||||
|
| TokenizeChatRequest
|
||||||
|
| ClassificationChatRequest
|
||||||
)
|
)
|
||||||
SpeechToTextRequest: TypeAlias = TranscriptionRequest | TranslationRequest
|
SpeechToTextRequest: TypeAlias = TranscriptionRequest | TranslationRequest
|
||||||
AnyRequest: TypeAlias = (
|
AnyRequest: TypeAlias = (
|
||||||
@ -814,7 +819,11 @@ class OpenAIServing:
|
|||||||
if not hasattr(request, "messages"):
|
if not hasattr(request, "messages"):
|
||||||
return message_types
|
return message_types
|
||||||
|
|
||||||
for message in request.messages:
|
messages = request.messages
|
||||||
|
if messages is None or isinstance(messages, (str, bytes)):
|
||||||
|
return message_types
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
if (
|
if (
|
||||||
isinstance(message, dict)
|
isinstance(message, dict)
|
||||||
and "content" in message
|
and "content" in message
|
||||||
@ -907,7 +916,8 @@ class OpenAIServing:
|
|||||||
EmbeddingCompletionRequest,
|
EmbeddingCompletionRequest,
|
||||||
ScoreRequest,
|
ScoreRequest,
|
||||||
RerankRequest,
|
RerankRequest,
|
||||||
ClassificationRequest,
|
ClassificationCompletionRequest,
|
||||||
|
ClassificationChatRequest,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
# Note: input length can be up to the entire model context length
|
# Note: input length can be up to the entire model context length
|
||||||
@ -915,7 +925,8 @@ class OpenAIServing:
|
|||||||
if token_num > self.max_model_len:
|
if token_num > self.max_model_len:
|
||||||
operations: dict[type[AnyRequest], str] = {
|
operations: dict[type[AnyRequest], str] = {
|
||||||
ScoreRequest: "score",
|
ScoreRequest: "score",
|
||||||
ClassificationRequest: "classification",
|
ClassificationCompletionRequest: "classification",
|
||||||
|
ClassificationChatRequest: "classification",
|
||||||
}
|
}
|
||||||
operation = operations.get(type(request), "embedding generation")
|
operation = operations.get(type(request), "embedding generation")
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user