[Chore] Remove redundant RequestPrompt (#30612)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-12-14 17:22:37 +08:00 committed by GitHub
parent f569c654e1
commit dcb31196da
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 188 additions and 253 deletions

View File

@ -80,10 +80,9 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
return dict(engine_prompt), {} return dict(engine_prompt), {}
async def _fake_preprocess_chat(*args, **kwargs): async def _fake_preprocess_chat(*args, **kwargs):
# return conversation, request_prompts, engine_prompts # return conversation, engine_prompts
return ( return (
[{"role": "user", "content": "Test"}], [{"role": "user", "content": "Test"}],
[[1, 2, 3]],
[{"prompt_token_ids": [1, 2, 3]}], [{"prompt_token_ids": [1, 2, 3]}],
) )

View File

@ -877,7 +877,7 @@ class TestServingChatWithHarmony:
# Test the Harmony messages for the first turn's input # Test the Harmony messages for the first turn's input
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages) req = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages, _, _ = serving_chat._make_request_with_harmony(req) input_messages, _ = serving_chat._make_request_with_harmony(req)
verify_harmony_messages( verify_harmony_messages(
input_messages, input_messages,
[ [
@ -905,7 +905,7 @@ class TestServingChatWithHarmony:
# Test the Harmony messages for the second turn's input # Test the Harmony messages for the second turn's input
req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages) req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages_2, _, _ = serving_chat._make_request_with_harmony(req_2) input_messages_2, _ = serving_chat._make_request_with_harmony(req_2)
verify_harmony_messages( verify_harmony_messages(
input_messages_2, input_messages_2,
[ [
@ -927,7 +927,7 @@ class TestServingChatWithHarmony:
# Test the Harmony messages for the first turn's input # Test the Harmony messages for the first turn's input
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages, tools=tools) req = ChatCompletionRequest(model=MODEL_NAME, messages=messages, tools=tools)
input_messages, _, _ = serving_chat._make_request_with_harmony(req) input_messages, _ = serving_chat._make_request_with_harmony(req)
verify_harmony_messages( verify_harmony_messages(
input_messages, input_messages,
[ [
@ -971,7 +971,7 @@ class TestServingChatWithHarmony:
# Test the Harmony messages for the second turn's input # Test the Harmony messages for the second turn's input
req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages) req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages_2, _, _ = serving_chat._make_request_with_harmony(req_2) input_messages_2, _ = serving_chat._make_request_with_harmony(req_2)
verify_harmony_messages( verify_harmony_messages(
input_messages_2, input_messages_2,
[ [
@ -1008,7 +1008,7 @@ class TestServingChatWithHarmony:
# Test the Harmony messages for the first turn's input # Test the Harmony messages for the first turn's input
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages, tools=tools) req = ChatCompletionRequest(model=MODEL_NAME, messages=messages, tools=tools)
input_messages, _, _ = serving_chat._make_request_with_harmony(req) input_messages, _ = serving_chat._make_request_with_harmony(req)
verify_harmony_messages( verify_harmony_messages(
input_messages, input_messages,
[ [
@ -1052,7 +1052,7 @@ class TestServingChatWithHarmony:
# Test the Harmony messages for the second turn's input # Test the Harmony messages for the second turn's input
req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages) req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages_2, _, _ = serving_chat._make_request_with_harmony(req_2) input_messages_2, _ = serving_chat._make_request_with_harmony(req_2)
verify_harmony_messages( verify_harmony_messages(
input_messages_2, input_messages_2,
[ [
@ -1089,7 +1089,7 @@ class TestServingChatWithHarmony:
# Test the Harmony messages for the first turn's input # Test the Harmony messages for the first turn's input
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages, tools=tools) req = ChatCompletionRequest(model=MODEL_NAME, messages=messages, tools=tools)
input_messages, _, _ = serving_chat._make_request_with_harmony(req) input_messages, _ = serving_chat._make_request_with_harmony(req)
verify_harmony_messages( verify_harmony_messages(
input_messages, input_messages,
[ [
@ -1133,7 +1133,7 @@ class TestServingChatWithHarmony:
# Test the Harmony messages for the second turn's input # Test the Harmony messages for the second turn's input
req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages) req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages_2, _, _ = serving_chat._make_request_with_harmony(req_2) input_messages_2, _ = serving_chat._make_request_with_harmony(req_2)
verify_harmony_messages( verify_harmony_messages(
input_messages_2, input_messages_2,
[ [
@ -1183,7 +1183,7 @@ class TestServingChatWithHarmony:
# Test the Harmony messages for the third turn's input # Test the Harmony messages for the third turn's input
req_3 = ChatCompletionRequest(model=MODEL_NAME, messages=messages) req_3 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages_3, _, _ = serving_chat._make_request_with_harmony(req_3) input_messages_3, _ = serving_chat._make_request_with_harmony(req_3)
verify_harmony_messages( verify_harmony_messages(
input_messages_3, input_messages_3,
[ [
@ -1246,7 +1246,7 @@ class TestServingChatWithHarmony:
# Test the Harmony messages for the fourth turn's input # Test the Harmony messages for the fourth turn's input
req_4 = ChatCompletionRequest(model=MODEL_NAME, messages=messages) req_4 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages_4, _, _ = serving_chat._make_request_with_harmony(req_4) input_messages_4, _ = serving_chat._make_request_with_harmony(req_4)
verify_harmony_messages( verify_harmony_messages(
input_messages_4, input_messages_4,
[ [
@ -1295,7 +1295,7 @@ class TestServingChatWithHarmony:
}, },
] ]
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages) req = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages, _, _ = serving_chat._make_request_with_harmony(req) input_messages, _ = serving_chat._make_request_with_harmony(req)
verify_harmony_messages( verify_harmony_messages(
input_messages, input_messages,
@ -1327,7 +1327,7 @@ class TestServingChatWithHarmony:
}, },
] ]
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages) req = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages, _, _ = serving_chat._make_request_with_harmony(req) input_messages, _ = serving_chat._make_request_with_harmony(req)
verify_harmony_messages( verify_harmony_messages(
input_messages, input_messages,
@ -1357,7 +1357,7 @@ class TestServingChatWithHarmony:
}, },
] ]
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages) req = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages, _, _ = serving_chat._make_request_with_harmony(req) input_messages, _ = serving_chat._make_request_with_harmony(req)
verify_harmony_messages( verify_harmony_messages(
input_messages, input_messages,

View File

@ -21,7 +21,7 @@ from vllm.entrypoints.openai.serving_responses import (
extract_tool_types, extract_tool_types,
) )
from vllm.entrypoints.tool_server import ToolServer from vllm.entrypoints.tool_server import ToolServer
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.inputs.data import TokensPrompt
class MockConversationContext(ConversationContext): class MockConversationContext(ConversationContext):
@ -237,7 +237,7 @@ class TestValidateGeneratorInput:
"""Test _validate_generator_input with valid prompt length""" """Test _validate_generator_input with valid prompt length"""
# Create an engine prompt with valid length (less than max_model_len) # Create an engine prompt with valid length (less than max_model_len)
valid_prompt_token_ids = list(range(5)) # 5 tokens < 100 max_model_len valid_prompt_token_ids = list(range(5)) # 5 tokens < 100 max_model_len
engine_prompt = EngineTokensPrompt(prompt_token_ids=valid_prompt_token_ids) engine_prompt = TokensPrompt(prompt_token_ids=valid_prompt_token_ids)
# Call the method # Call the method
result = serving_responses_instance._validate_generator_input(engine_prompt) result = serving_responses_instance._validate_generator_input(engine_prompt)
@ -247,7 +247,7 @@ class TestValidateGeneratorInput:
# create an invalid engine prompt # create an invalid engine prompt
invalid_prompt_token_ids = list(range(200)) # 100 tokens >= 100 max_model_len invalid_prompt_token_ids = list(range(200)) # 100 tokens >= 100 max_model_len
engine_prompt = EngineTokensPrompt(prompt_token_ids=invalid_prompt_token_ids) engine_prompt = TokensPrompt(prompt_token_ids=invalid_prompt_token_ids)
# Call the method # Call the method
result = serving_responses_instance._validate_generator_input(engine_prompt) result = serving_responses_instance._validate_generator_input(engine_prompt)

View File

@ -61,7 +61,7 @@ from vllm.entrypoints.openai.tool_parsers import ToolParser
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolCall from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolCall
from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls
from vllm.entrypoints.utils import get_max_tokens, should_include_usage from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
@ -234,11 +234,7 @@ class OpenAIServingChat(OpenAIServing):
) )
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
( conversation, engine_prompts = await self._preprocess_chat(
conversation,
request_prompts,
engine_prompts,
) = await self._preprocess_chat(
request, request,
tokenizer, tokenizer,
request.messages, request.messages,
@ -254,11 +250,7 @@ class OpenAIServingChat(OpenAIServing):
) )
else: else:
# For GPT-OSS. # For GPT-OSS.
( conversation, engine_prompts = self._make_request_with_harmony(request)
conversation,
request_prompts,
engine_prompts,
) = self._make_request_with_harmony(request)
except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e: except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs") logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(f"{e} {e.__cause__}") return self.create_error_response(f"{e} {e.__cause__}")
@ -278,7 +270,7 @@ class OpenAIServingChat(OpenAIServing):
generators: list[AsyncGenerator[RequestOutput, None]] = [] generators: list[AsyncGenerator[RequestOutput, None]] = []
try: try:
for i, engine_prompt in enumerate(engine_prompts): for i, engine_prompt in enumerate(engine_prompts):
prompt_text, _, _ = self._get_prompt_components(request_prompts[i]) prompt_text, _, _ = self._get_prompt_components(engine_prompt)
# If we are creating sub requests for multiple prompts, ensure that they # If we are creating sub requests for multiple prompts, ensure that they
# have unique request ids. # have unique request ids.
sub_request_id = ( sub_request_id = (
@ -313,7 +305,7 @@ class OpenAIServingChat(OpenAIServing):
self._log_inputs( self._log_inputs(
sub_request_id, sub_request_id,
request_prompts[i], engine_prompt,
params=sampling_params, params=sampling_params,
lora_request=lora_request, lora_request=lora_request,
) )
@ -537,7 +529,7 @@ class OpenAIServingChat(OpenAIServing):
request_id: str, request_id: str,
model_name: str, model_name: str,
conversation: list[ConversationMessage], conversation: list[ConversationMessage],
tokenizer: TokenizerLike, tokenizer: TokenizerLike | None,
request_metadata: RequestResponseMetadata, request_metadata: RequestResponseMetadata,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
created_time = int(time.time()) created_time = int(time.time())
@ -591,6 +583,11 @@ class OpenAIServingChat(OpenAIServing):
try: try:
if self.reasoning_parser: if self.reasoning_parser:
if tokenizer is None:
raise ValueError(
"Tokenizer not available when `skip_tokenizer_init=True`"
)
reasoning_parser = self.reasoning_parser( reasoning_parser = self.reasoning_parser(
tokenizer, tokenizer,
chat_template_kwargs=request.chat_template_kwargs, # type: ignore chat_template_kwargs=request.chat_template_kwargs, # type: ignore
@ -604,6 +601,11 @@ class OpenAIServingChat(OpenAIServing):
# Prepare the tool parser if it's needed # Prepare the tool parser if it's needed
try: try:
if tool_choice_auto and self.tool_parser: if tool_choice_auto and self.tool_parser:
if tokenizer is None:
raise ValueError(
"Tokenizer not available when `skip_tokenizer_init=True`"
)
tool_parsers: list[ToolParser | None] = [ tool_parsers: list[ToolParser | None] = [
self.tool_parser(tokenizer) self.tool_parser(tokenizer)
] * num_choices ] * num_choices
@ -1317,7 +1319,7 @@ class OpenAIServingChat(OpenAIServing):
request_id: str, request_id: str,
model_name: str, model_name: str,
conversation: list[ConversationMessage], conversation: list[ConversationMessage],
tokenizer: TokenizerLike, tokenizer: TokenizerLike | None,
request_metadata: RequestResponseMetadata, request_metadata: RequestResponseMetadata,
) -> ErrorResponse | ChatCompletionResponse: ) -> ErrorResponse | ChatCompletionResponse:
created_time = int(time.time()) created_time = int(time.time())
@ -1367,6 +1369,11 @@ class OpenAIServingChat(OpenAIServing):
reasoning = None reasoning = None
if self.tool_parser is not None: if self.tool_parser is not None:
if tokenizer is None:
raise ValueError(
"Tokenizer not available when `skip_tokenizer_init=True`"
)
tool_parser = self.tool_parser(tokenizer) tool_parser = self.tool_parser(tokenizer)
# NOTE: We use token_ids for openai tool parser # NOTE: We use token_ids for openai tool parser
tool_call_info = tool_parser.extract_tool_calls( tool_call_info = tool_parser.extract_tool_calls(
@ -1409,6 +1416,11 @@ class OpenAIServingChat(OpenAIServing):
if self.reasoning_parser: if self.reasoning_parser:
try: try:
if tokenizer is None:
raise ValueError(
"Tokenizer not available when `skip_tokenizer_init=True`"
)
reasoning_parser = self.reasoning_parser( reasoning_parser = self.reasoning_parser(
tokenizer, tokenizer,
chat_template_kwargs=request.chat_template_kwargs, # type: ignore chat_template_kwargs=request.chat_template_kwargs, # type: ignore
@ -1648,7 +1660,7 @@ class OpenAIServingChat(OpenAIServing):
self, self,
logprobs: dict[int, Logprob], logprobs: dict[int, Logprob],
top_logprobs: int | None, top_logprobs: int | None,
tokenizer: TokenizerLike, tokenizer: TokenizerLike | None,
should_return_as_token_id: bool, should_return_as_token_id: bool,
) -> list[ChatCompletionLogProb]: ) -> list[ChatCompletionLogProb]:
return [ return [
@ -1672,7 +1684,7 @@ class OpenAIServingChat(OpenAIServing):
self, self,
token_ids: GenericSequence[int], token_ids: GenericSequence[int],
top_logprobs: GenericSequence[dict[int, Logprob] | None], top_logprobs: GenericSequence[dict[int, Logprob] | None],
tokenizer: TokenizerLike, tokenizer: TokenizerLike | None,
num_output_top_logprobs: int | None = None, num_output_top_logprobs: int | None = None,
return_as_token_id: bool | None = None, return_as_token_id: bool | None = None,
) -> ChatCompletionLogProbs: ) -> ChatCompletionLogProbs:
@ -1690,6 +1702,11 @@ class OpenAIServingChat(OpenAIServing):
if should_return_as_token_id: if should_return_as_token_id:
token = f"token_id:{token_id}" token = f"token_id:{token_id}"
else: else:
if tokenizer is None:
raise ValueError(
"Tokenizer not available when `skip_tokenizer_init=True`"
)
token = tokenizer.decode(token_id) token = tokenizer.decode(token_id)
logprobs_content.append( logprobs_content.append(
@ -1800,10 +1817,10 @@ class OpenAIServingChat(OpenAIServing):
# Render prompt token ids. # Render prompt token ids.
prompt_token_ids = render_for_completion(messages) prompt_token_ids = render_for_completion(messages)
engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids) engine_prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
# Add cache_salt if provided in the request # Add cache_salt if provided in the request
if request.cache_salt is not None: if request.cache_salt is not None:
engine_prompt["cache_salt"] = request.cache_salt engine_prompt["cache_salt"] = request.cache_salt
return messages, [prompt_token_ids], [engine_prompt] return messages, [engine_prompt]

View File

@ -5,29 +5,61 @@ import json
import sys import sys
import time import time
import traceback import traceback
from collections.abc import AsyncGenerator, Callable, Iterable, Mapping, Sequence from collections.abc import AsyncGenerator, Callable, Iterable, Mapping
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field from dataclasses import dataclass, field
from http import HTTPStatus from http import HTTPStatus
from typing import Any, ClassVar, Generic, TypeAlias, TypeVar from typing import Any, ClassVar, Generic, TypeAlias, TypeVar
import numpy as np import numpy as np
import torch
from fastapi import Request from fastapi import Request
from openai.types.responses import (
ToolChoiceFunction,
)
from pydantic import ConfigDict, TypeAdapter from pydantic import ConfigDict, TypeAdapter
from starlette.datastructures import Headers from starlette.datastructures import Headers
from typing_extensions import TypeIs
import vllm.envs as envs
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ChatTemplateContentFormatOption,
ConversationMessage,
apply_hf_chat_template,
apply_mistral_chat_template,
parse_chat_messages_futures,
resolve_chat_template_content_format,
)
from vllm.entrypoints.context import ( from vllm.entrypoints.context import (
ConversationContext,
HarmonyContext, HarmonyContext,
ParsableContext, ParsableContext,
StreamingHarmonyContext, StreamingHarmonyContext,
) )
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
ChatCompletionNamedToolChoiceParam,
ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
CompletionResponse,
DetokenizeRequest,
ErrorInfo,
ErrorResponse,
FunctionCall, FunctionCall,
FunctionDefinition,
ResponseInputOutputItem, ResponseInputOutputItem,
ResponsesRequest, ResponsesRequest,
TokenizeChatRequest,
TokenizeCompletionRequest,
TokenizeResponse,
TranscriptionRequest,
TranscriptionResponse,
TranslationRequest,
) )
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.entrypoints.pooling.classify.protocol import ( from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest, ClassificationChatRequest,
ClassificationCompletionRequest, ClassificationCompletionRequest,
@ -49,58 +81,13 @@ from vllm.entrypoints.pooling.score.protocol import (
ScoreRequest, ScoreRequest,
ScoreResponse, ScoreResponse,
) )
from vllm.transformers_utils.tokenizer import AnyTokenizer
if sys.version_info >= (3, 12):
from typing import TypedDict
else:
from typing_extensions import TypedDict
from openai.types.responses import (
ToolChoiceFunction,
)
import vllm.envs as envs
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ChatTemplateContentFormatOption,
ConversationMessage,
apply_hf_chat_template,
apply_mistral_chat_template,
parse_chat_messages_futures,
resolve_chat_template_content_format,
)
from vllm.entrypoints.context import ConversationContext
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
ChatCompletionNamedToolChoiceParam,
ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
CompletionResponse,
DetokenizeRequest,
ErrorInfo,
ErrorResponse,
FunctionDefinition,
TokenizeChatRequest,
TokenizeCompletionRequest,
TokenizeResponse,
TranscriptionRequest,
TranscriptionResponse,
TranslationRequest,
)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig
from vllm.entrypoints.responses_utils import ( from vllm.entrypoints.responses_utils import (
construct_input_messages, construct_input_messages,
) )
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
from vllm.entrypoints.utils import _validate_truncation_size from vllm.entrypoints.utils import _validate_truncation_size
from vllm.inputs.data import PromptType from vllm.inputs.data import PromptType, TokensPrompt
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.inputs.parse import ( from vllm.inputs.parse import (
PromptComponents, PromptComponents,
get_prompt_components, get_prompt_components,
@ -109,10 +96,7 @@ from vllm.inputs.parse import (
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob, PromptLogprobs from vllm.logprobs import Logprob, PromptLogprobs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import ( # noqa: F401 - Required to resolve Pydantic error in RequestProcessingMixin from vllm.multimodal import MultiModalDataDict
MultiModalDataDict,
MultiModalUUIDDict,
)
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.reasoning import ReasoningParser, ReasoningParserManager
@ -185,34 +169,6 @@ AnyResponse: TypeAlias = (
) )
class TextTokensPrompt(TypedDict):
prompt: str
prompt_token_ids: list[int]
class EmbedsPrompt(TypedDict):
prompt_embeds: torch.Tensor
RequestPrompt: TypeAlias = list[int] | str | TextTokensPrompt | EmbedsPrompt
def is_text_tokens_prompt(prompt: RequestPrompt) -> TypeIs[TextTokensPrompt]:
return (
isinstance(prompt, dict)
and "prompt_token_ids" in prompt
and "prompt_embeds" not in prompt
)
def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]:
return (
isinstance(prompt, dict)
and "prompt_token_ids" not in prompt
and "prompt_embeds" in prompt
)
RequestT = TypeVar("RequestT", bound=AnyRequest) RequestT = TypeVar("RequestT", bound=AnyRequest)
@ -223,8 +179,7 @@ class RequestProcessingMixin:
handling prompt preparation and engine input. handling prompt preparation and engine input.
""" """
request_prompts: Sequence[RequestPrompt] | None = field(default_factory=list) engine_prompts: list[TokensPrompt] | None = field(default_factory=list)
engine_prompts: list[EngineTokensPrompt] | None = field(default_factory=list)
@dataclass(kw_only=True) @dataclass(kw_only=True)
@ -425,7 +380,7 @@ class OpenAIServing:
prompts_batch, lora_req_batch = zip( prompts_batch, lora_req_batch = zip(
*[ *[
( (
EngineTokensPrompt( TokensPrompt(
prompt_token_ids=beam.tokens, prompt_token_ids=beam.tokens,
multi_modal_data=beam.multi_modal_data, multi_modal_data=beam.multi_modal_data,
mm_processor_kwargs=beam.mm_processor_kwargs, mm_processor_kwargs=beam.mm_processor_kwargs,
@ -947,7 +902,7 @@ class OpenAIServing:
prompt: str, prompt: str,
tokenizer: TokenizerLike, tokenizer: TokenizerLike,
add_special_tokens: bool, add_special_tokens: bool,
) -> TextTokensPrompt: ) -> TokensPrompt:
async_tokenizer = self._get_async_tokenizer(tokenizer) async_tokenizer = self._get_async_tokenizer(tokenizer)
if ( if (
@ -988,7 +943,7 @@ class OpenAIServing:
request: AnyRequest, request: AnyRequest,
prompt_ids: list[int], prompt_ids: list[int],
tokenizer: TokenizerLike | None, tokenizer: TokenizerLike | None,
) -> TextTokensPrompt: ) -> TokensPrompt:
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None) truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
if truncate_prompt_tokens is None: if truncate_prompt_tokens is None:
@ -1011,7 +966,7 @@ class OpenAIServing:
request: AnyRequest, request: AnyRequest,
input_ids: list[int], input_ids: list[int],
input_text: str, input_text: str,
) -> TextTokensPrompt: ) -> TokensPrompt:
token_num = len(input_ids) token_num = len(input_ids)
# Note: EmbeddingRequest, ClassificationRequest, # Note: EmbeddingRequest, ClassificationRequest,
@ -1042,7 +997,7 @@ class OpenAIServing:
f"{token_num} tokens in the input for {operation}. " f"{token_num} tokens in the input for {operation}. "
f"Please reduce the length of the input." f"Please reduce the length of the input."
) )
return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids) return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
# and does not require model context length validation # and does not require model context length validation
@ -1050,7 +1005,7 @@ class OpenAIServing:
request, request,
(TokenizeCompletionRequest, TokenizeChatRequest, DetokenizeRequest), (TokenizeCompletionRequest, TokenizeChatRequest, DetokenizeRequest),
): ):
return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids) return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
# chat completion endpoint supports max_completion_tokens # chat completion endpoint supports max_completion_tokens
if isinstance(request, ChatCompletionRequest): if isinstance(request, ChatCompletionRequest):
@ -1078,7 +1033,7 @@ class OpenAIServing:
f" - {token_num})." f" - {token_num})."
) )
return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids) return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
async def _tokenize_prompt_input_async( async def _tokenize_prompt_input_async(
self, self,
@ -1086,7 +1041,7 @@ class OpenAIServing:
tokenizer: TokenizerLike, tokenizer: TokenizerLike,
prompt_input: str | list[int], prompt_input: str | list[int],
add_special_tokens: bool = True, add_special_tokens: bool = True,
) -> TextTokensPrompt: ) -> TokensPrompt:
""" """
A simpler implementation that tokenizes a single prompt input. A simpler implementation that tokenizes a single prompt input.
""" """
@ -1105,7 +1060,7 @@ class OpenAIServing:
tokenizer: TokenizerLike, tokenizer: TokenizerLike,
prompt_inputs: Iterable[str | list[int]], prompt_inputs: Iterable[str | list[int]],
add_special_tokens: bool = True, add_special_tokens: bool = True,
) -> AsyncGenerator[TextTokensPrompt, None]: ) -> AsyncGenerator[TokensPrompt, None]:
""" """
A simpler implementation that tokenizes multiple prompt inputs. A simpler implementation that tokenizes multiple prompt inputs.
""" """
@ -1158,11 +1113,7 @@ class OpenAIServing:
chat_template_kwargs: dict[str, Any] | None = None, chat_template_kwargs: dict[str, Any] | None = None,
tool_parser: Callable[[TokenizerLike], ToolParser] | None = None, tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
add_special_tokens: bool = False, add_special_tokens: bool = False,
) -> tuple[ ) -> tuple[list[ConversationMessage], list[TokensPrompt]]:
list[ConversationMessage],
Sequence[RequestPrompt],
list[EngineTokensPrompt],
]:
model_config = self.model_config model_config = self.model_config
resolved_content_format = resolve_chat_template_content_format( resolved_content_format = resolve_chat_template_content_format(
@ -1235,9 +1186,7 @@ class OpenAIServing:
"Prompt has to be a string", "Prompt has to be a string",
"when the tokenizer is not initialised", "when the tokenizer is not initialised",
) )
prompt_inputs = TextTokensPrompt( prompt_inputs = TokensPrompt(prompt=request_prompt, prompt_token_ids=[1])
prompt=request_prompt, prompt_token_ids=[1]
)
elif isinstance(request_prompt, str): elif isinstance(request_prompt, str):
prompt_inputs = await self._tokenize_prompt_input_async( prompt_inputs = await self._tokenize_prompt_input_async(
request, request,
@ -1250,14 +1199,15 @@ class OpenAIServing:
assert is_list_of(request_prompt, int), ( assert is_list_of(request_prompt, int), (
"Prompt has to be either a string or a list of token ids" "Prompt has to be either a string or a list of token ids"
) )
prompt_inputs = TextTokensPrompt( prompt_inputs = TokensPrompt(
prompt=tokenizer.decode(request_prompt), prompt=tokenizer.decode(request_prompt),
prompt_token_ids=request_prompt, prompt_token_ids=request_prompt,
) )
engine_prompt = EngineTokensPrompt( engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["prompt_token_ids"])
prompt_token_ids=prompt_inputs["prompt_token_ids"] if "prompt" in prompt_inputs:
) engine_prompt["prompt"] = prompt_inputs["prompt"]
if mm_data is not None: if mm_data is not None:
engine_prompt["multi_modal_data"] = mm_data engine_prompt["multi_modal_data"] = mm_data
@ -1270,7 +1220,7 @@ class OpenAIServing:
if hasattr(request, "cache_salt") and request.cache_salt is not None: if hasattr(request, "cache_salt") and request.cache_salt is not None:
engine_prompt["cache_salt"] = request.cache_salt engine_prompt["cache_salt"] = request.cache_salt
return conversation, [request_prompt], [engine_prompt] return conversation, [engine_prompt]
async def _process_inputs( async def _process_inputs(
self, self,
@ -1302,7 +1252,7 @@ class OpenAIServing:
async def _render_next_turn( async def _render_next_turn(
self, self,
request: ResponsesRequest, request: ResponsesRequest,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike | None,
messages: list[ResponseInputOutputItem], messages: list[ResponseInputOutputItem],
tool_dicts: list[dict[str, Any]] | None, tool_dicts: list[dict[str, Any]] | None,
tool_parser, tool_parser,
@ -1313,7 +1263,7 @@ class OpenAIServing:
request_input=messages, request_input=messages,
) )
_, request_prompts, engine_prompts = await self._preprocess_chat( _, engine_prompts = await self._preprocess_chat(
request, request,
tokenizer, tokenizer,
new_messages, new_messages,
@ -1322,20 +1272,20 @@ class OpenAIServing:
chat_template=chat_template, chat_template=chat_template,
chat_template_content_format=chat_template_content_format, chat_template_content_format=chat_template_content_format,
) )
return request_prompts, engine_prompts return engine_prompts
async def _generate_with_builtin_tools( async def _generate_with_builtin_tools(
self, self,
request_id: str, request_id: str,
request_prompt: RequestPrompt, engine_prompt: TokensPrompt,
engine_prompt: EngineTokensPrompt,
sampling_params: SamplingParams, sampling_params: SamplingParams,
context: ConversationContext, context: ConversationContext,
lora_request: LoRARequest | None = None, lora_request: LoRARequest | None = None,
priority: int = 0, priority: int = 0,
**kwargs, **kwargs,
): ):
prompt_text, _, _ = self._get_prompt_components(request_prompt) prompt_text, _, _ = self._get_prompt_components(engine_prompt)
orig_priority = priority orig_priority = priority
sub_request = 0 sub_request = 0
while True: while True:
@ -1343,7 +1293,7 @@ class OpenAIServing:
sub_request_id = f"{request_id}_{sub_request}" sub_request_id = f"{request_id}_{sub_request}"
self._log_inputs( self._log_inputs(
sub_request_id, sub_request_id,
request_prompt, engine_prompt,
params=sampling_params, params=sampling_params,
lora_request=lora_request, lora_request=lora_request,
) )
@ -1388,10 +1338,9 @@ class OpenAIServing:
# Render the next prompt token ids. # Render the next prompt token ids.
if isinstance(context, (HarmonyContext, StreamingHarmonyContext)): if isinstance(context, (HarmonyContext, StreamingHarmonyContext)):
prompt_token_ids = context.render_for_completion() prompt_token_ids = context.render_for_completion()
engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids) engine_prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
request_prompt = prompt_token_ids
elif isinstance(context, ParsableContext): elif isinstance(context, ParsableContext):
request_prompts, engine_prompts = await self._render_next_turn( engine_prompts = await self._render_next_turn(
context.request, context.request,
context.tokenizer, context.tokenizer,
context.parser.response_messages, context.parser.response_messages,
@ -1401,8 +1350,7 @@ class OpenAIServing:
context.chat_template_content_format, context.chat_template_content_format,
) )
engine_prompt = engine_prompts[0] engine_prompt = engine_prompts[0]
request_prompt = request_prompts[0] prompt_text, _, _ = self._get_prompt_components(engine_prompt)
prompt_text, _, _ = self._get_prompt_components(request_prompt)
# Update the sampling params. # Update the sampling params.
sampling_params.max_tokens = self.max_model_len - len( sampling_params.max_tokens = self.max_model_len - len(
@ -1412,19 +1360,13 @@ class OpenAIServing:
priority = orig_priority - 1 priority = orig_priority - 1
sub_request += 1 sub_request += 1
def _get_prompt_components( def _get_prompt_components(self, prompt: PromptType) -> PromptComponents:
self, return get_prompt_components(prompt)
prompt: RequestPrompt | PromptType,
) -> PromptComponents:
if isinstance(prompt, list):
return PromptComponents(token_ids=prompt)
return get_prompt_components(prompt) # type: ignore[arg-type]
def _log_inputs( def _log_inputs(
self, self,
request_id: str, request_id: str,
inputs: RequestPrompt | PromptType, inputs: PromptType,
params: SamplingParams | PoolingParams | BeamSearchParams | None, params: SamplingParams | PoolingParams | BeamSearchParams | None,
lora_request: LoRARequest | None, lora_request: LoRARequest | None,
) -> None: ) -> None:
@ -1486,7 +1428,7 @@ class OpenAIServing:
@staticmethod @staticmethod
def _parse_tool_calls_from_content( def _parse_tool_calls_from_content(
request: ResponsesRequest | ChatCompletionRequest, request: ResponsesRequest | ChatCompletionRequest,
tokenizer: TokenizerLike, tokenizer: TokenizerLike | None,
enable_auto_tools: bool, enable_auto_tools: bool,
tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None, tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
content: str | None = None, content: str | None = None,
@ -1526,6 +1468,11 @@ class OpenAIServing:
and enable_auto_tools and enable_auto_tools
and (request.tool_choice == "auto" or request.tool_choice is None) and (request.tool_choice == "auto" or request.tool_choice is None)
): ):
if tokenizer is None:
raise ValueError(
"Tokenizer not available when `skip_tokenizer_init=True`"
)
# Automatic Tool Call Parsing # Automatic Tool Call Parsing
try: try:
tool_parser = tool_parser_cls(tokenizer) tool_parser = tool_parser_cls(tokenizer)

View File

@ -107,7 +107,7 @@ from vllm.entrypoints.responses_utils import (
make_response_output_items_from_parsable_context, make_response_output_items_from_parsable_context,
) )
from vllm.entrypoints.tool_server import ToolServer from vllm.entrypoints.tool_server import ToolServer
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob as SampleLogprob from vllm.logprobs import Logprob as SampleLogprob
from vllm.logprobs import SampleLogprobs from vllm.logprobs import SampleLogprobs
@ -258,7 +258,7 @@ class OpenAIServingResponses(OpenAIServing):
self.tool_server = tool_server self.tool_server = tool_server
def _validate_generator_input( def _validate_generator_input(
self, engine_prompt: EngineTokensPrompt self, engine_prompt: TokensPrompt
) -> ErrorResponse | None: ) -> ErrorResponse | None:
"""Add validations to the input to the generator here.""" """Add validations to the input to the generator here."""
if self.max_model_len <= len(engine_prompt["prompt_token_ids"]): if self.max_model_len <= len(engine_prompt["prompt_token_ids"]):
@ -353,11 +353,11 @@ class OpenAIServingResponses(OpenAIServing):
tokenizer = await self.engine_client.get_tokenizer() tokenizer = await self.engine_client.get_tokenizer()
if self.use_harmony: if self.use_harmony:
messages, request_prompts, engine_prompts = ( messages, engine_prompts = self._make_request_with_harmony(
self._make_request_with_harmony(request, prev_response) request, prev_response
) )
else: else:
messages, request_prompts, engine_prompts = await self._make_request( messages, engine_prompts = await self._make_request(
request, prev_response, tokenizer request, prev_response, tokenizer
) )
@ -393,7 +393,7 @@ class OpenAIServingResponses(OpenAIServing):
assert len(builtin_tool_list) == 0 assert len(builtin_tool_list) == 0
available_tools = [] available_tools = []
try: try:
for i, engine_prompt in enumerate(engine_prompts): for engine_prompt in engine_prompts:
maybe_error = self._validate_generator_input(engine_prompt) maybe_error = self._validate_generator_input(engine_prompt)
if maybe_error is not None: if maybe_error is not None:
return maybe_error return maybe_error
@ -449,7 +449,6 @@ class OpenAIServingResponses(OpenAIServing):
) )
generator = self._generate_with_builtin_tools( generator = self._generate_with_builtin_tools(
request_id=request.request_id, request_id=request.request_id,
request_prompt=request_prompts[i],
engine_prompt=engine_prompt, engine_prompt=engine_prompt,
sampling_params=sampling_params, sampling_params=sampling_params,
context=context, context=context,
@ -564,7 +563,7 @@ class OpenAIServingResponses(OpenAIServing):
prev_msg=self.msg_store.get(prev_response.id) if prev_response else None, prev_msg=self.msg_store.get(prev_response.id) if prev_response else None,
prev_response_output=prev_response.output if prev_response else None, prev_response_output=prev_response.output if prev_response else None,
) )
_, request_prompts, engine_prompts = await self._preprocess_chat( _, engine_prompts = await self._preprocess_chat(
request, request,
tokenizer, tokenizer,
messages, messages,
@ -573,7 +572,7 @@ class OpenAIServingResponses(OpenAIServing):
chat_template=self.chat_template, chat_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format, chat_template_content_format=self.chat_template_content_format,
) )
return messages, request_prompts, engine_prompts return messages, engine_prompts
def _make_request_with_harmony( def _make_request_with_harmony(
self, self,
@ -586,13 +585,13 @@ class OpenAIServingResponses(OpenAIServing):
) )
messages = self._construct_input_messages_with_harmony(request, prev_response) messages = self._construct_input_messages_with_harmony(request, prev_response)
prompt_token_ids = render_for_completion(messages) prompt_token_ids = render_for_completion(messages)
engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids) engine_prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
# Add cache_salt if provided in the request # Add cache_salt if provided in the request
if request.cache_salt is not None: if request.cache_salt is not None:
engine_prompt["cache_salt"] = request.cache_salt engine_prompt["cache_salt"] = request.cache_salt
return messages, [prompt_token_ids], [engine_prompt] return messages, [engine_prompt]
async def _initialize_tool_sessions( async def _initialize_tool_sessions(
self, self,

View File

@ -72,11 +72,7 @@ class ClassificationMixin(OpenAIServing):
if ret: if ret:
return ret return ret
( _, engine_prompts = await self._preprocess_chat(
_,
_,
engine_prompts,
) = await self._preprocess_chat(
cast(ChatCompletionRequest, chat_request), cast(ChatCompletionRequest, chat_request),
ctx.tokenizer, ctx.tokenizer,
messages, messages,

View File

@ -20,7 +20,6 @@ from vllm.entrypoints.openai.serving_engine import (
EmbeddingServeContext, EmbeddingServeContext,
OpenAIServing, OpenAIServing,
ServeContext, ServeContext,
TextTokensPrompt,
) )
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.pooling.embed.protocol import ( from vllm.entrypoints.pooling.embed.protocol import (
@ -32,7 +31,7 @@ from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingResponseData, EmbeddingResponseData,
) )
from vllm.entrypoints.renderer import RenderConfig from vllm.entrypoints.renderer import RenderConfig
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import ( from vllm.outputs import (
EmbeddingRequestOutput, EmbeddingRequestOutput,
@ -83,11 +82,7 @@ class EmbeddingMixin(OpenAIServing):
renderer = self._get_renderer(tokenizer) renderer = self._get_renderer(tokenizer)
if isinstance(ctx.request, EmbeddingChatRequest): if isinstance(ctx.request, EmbeddingChatRequest):
( _, ctx.engine_prompts = await self._preprocess_chat(
_,
_,
ctx.engine_prompts,
) = await self._preprocess_chat(
ctx.request, ctx.request,
tokenizer, tokenizer,
ctx.request.messages, ctx.request.messages,
@ -209,14 +204,13 @@ class EmbeddingMixin(OpenAIServing):
async def _process_chunked_request( async def _process_chunked_request(
self, self,
ctx: EmbeddingServeContext, ctx: EmbeddingServeContext,
original_prompt: TextTokensPrompt, token_ids: list[int],
pooling_params, pooling_params,
trace_headers, trace_headers,
prompt_idx: int, prompt_idx: int,
) -> list[AsyncGenerator[PoolingRequestOutput, None]]: ) -> list[AsyncGenerator[PoolingRequestOutput, None]]:
"""Process a single prompt using chunked processing.""" """Process a single prompt using chunked processing."""
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
token_ids = original_prompt["prompt_token_ids"]
# Split into chunks using max_position_embeddings # Split into chunks using max_position_embeddings
max_pos_embeddings = self._get_max_position_embeddings() max_pos_embeddings = self._get_max_position_embeddings()
@ -228,18 +222,12 @@ class EmbeddingMixin(OpenAIServing):
chunk_request_id = f"{ctx.request_id}-prompt-{prompt_idx}-chunk-{chunk_idx}" chunk_request_id = f"{ctx.request_id}-prompt-{prompt_idx}-chunk-{chunk_idx}"
# Create engine prompt for this chunk # Create engine prompt for this chunk
chunk_engine_prompt = EngineTokensPrompt(prompt_token_ids=chunk_tokens) chunk_engine_prompt = TokensPrompt(prompt_token_ids=chunk_tokens)
# Create chunk request prompt for logging
chunk_text = ""
chunk_request_prompt = TextTokensPrompt(
prompt=chunk_text, prompt_token_ids=chunk_tokens
)
# Log the chunk # Log the chunk
self._log_inputs( self._log_inputs(
chunk_request_id, chunk_request_id,
chunk_request_prompt, chunk_engine_prompt,
params=pooling_params, params=pooling_params,
lora_request=ctx.lora_request, lora_request=ctx.lora_request,
) )
@ -263,7 +251,7 @@ class EmbeddingMixin(OpenAIServing):
request, request,
input_ids: list[int], input_ids: list[int],
input_text: str, input_text: str,
) -> TextTokensPrompt: ) -> TokensPrompt:
"""Override to support chunked processing for embedding requests.""" """Override to support chunked processing for embedding requests."""
token_num = len(input_ids) token_num = len(input_ids)
@ -328,23 +316,15 @@ class EmbeddingMixin(OpenAIServing):
) )
) )
return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids) return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
# For other request types, use the parent's implementation # For other request types, use the parent's implementation
return super()._validate_input(request, input_ids, input_text) return super()._validate_input(request, input_ids, input_text)
def _is_text_tokens_prompt(self, prompt) -> bool:
"""Check if a prompt is a TextTokensPrompt (has prompt_token_ids)."""
return (
isinstance(prompt, dict)
and "prompt_token_ids" in prompt
and "prompt_embeds" not in prompt
)
async def _create_single_prompt_generator( async def _create_single_prompt_generator(
self, self,
ctx: EmbeddingServeContext, ctx: EmbeddingServeContext,
engine_prompt: EngineTokensPrompt, engine_prompt: TokensPrompt,
pooling_params: PoolingParams, pooling_params: PoolingParams,
trace_headers: Mapping[str, str] | None, trace_headers: Mapping[str, str] | None,
prompt_index: int, prompt_index: int,
@ -413,14 +393,16 @@ class EmbeddingMixin(OpenAIServing):
for i, engine_prompt in enumerate(ctx.engine_prompts): for i, engine_prompt in enumerate(ctx.engine_prompts):
# Check if this specific prompt needs chunked processing # Check if this specific prompt needs chunked processing
if self._is_text_tokens_prompt(engine_prompt): if "prompt_token_ids" in engine_prompt:
# Cast to TextTokensPrompt since we've verified prompt_token_ids = engine_prompt["prompt_token_ids"]
# prompt_token_ids if len(prompt_token_ids) > max_pos_embeddings:
text_tokens_prompt = cast(TextTokensPrompt, engine_prompt)
if len(text_tokens_prompt["prompt_token_ids"]) > max_pos_embeddings:
# Use chunked processing for this prompt # Use chunked processing for this prompt
chunk_generators = await self._process_chunked_request( chunk_generators = await self._process_chunked_request(
ctx, text_tokens_prompt, pooling_params, trace_headers, i ctx,
prompt_token_ids,
pooling_params,
trace_headers,
i,
) )
generators.extend(chunk_generators) generators.extend(chunk_generators)
continue continue
@ -578,14 +560,13 @@ class EmbeddingMixin(OpenAIServing):
# Get original prompt token IDs for this prompt # Get original prompt token IDs for this prompt
original_prompt = ctx.engine_prompts[prompt_idx] original_prompt = ctx.engine_prompts[prompt_idx]
if not self._is_text_tokens_prompt(original_prompt): if "prompt_token_ids" not in original_prompt:
return self.create_error_response( return self.create_error_response(
f"Chunked prompt {prompt_idx} is not a TextTokensPrompt" f"Chunked prompt {prompt_idx} does not contain "
"token IDs"
) )
original_token_ids = cast(TextTokensPrompt, original_prompt)[ original_token_ids = original_prompt["prompt_token_ids"]
"prompt_token_ids"
]
pooling_request_output = PoolingRequestOutput( pooling_request_output = PoolingRequestOutput(
request_id=aggregator["request_id"], request_id=aggregator["request_id"],

View File

@ -137,11 +137,8 @@ class OpenAIServingPooling(OpenAIServing):
) )
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
(
_, _, engine_prompts = await self._preprocess_chat(
_,
engine_prompts,
) = await self._preprocess_chat(
request, request,
tokenizer, tokenizer,
request.messages, request.messages,

View File

@ -12,9 +12,7 @@ import torch
from pydantic import Field from pydantic import Field
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt from vllm.inputs.data import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.inputs.data import TextPrompt as EngineTextPrompt
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.inputs.parse import get_prompt_components, parse_raw_prompts from vllm.inputs.parse import get_prompt_components, parse_raw_prompts
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import AsyncMicrobatchTokenizer from vllm.utils.async_utils import AsyncMicrobatchTokenizer
@ -97,7 +95,7 @@ class BaseRenderer(ABC):
*, *,
prompt_or_prompts: str | list[str] | list[int] | list[list[int]], prompt_or_prompts: str | list[str] | list[int] | list[list[int]],
config: RenderConfig, config: RenderConfig,
) -> list[EngineTokensPrompt]: ) -> list[TokensPrompt]:
""" """
Convert text or token inputs into engine-ready TokensPrompt objects. Convert text or token inputs into engine-ready TokensPrompt objects.
@ -115,7 +113,7 @@ class BaseRenderer(ABC):
(e.g., tokenization and length handling). (e.g., tokenization and length handling).
Returns: Returns:
list[EngineTokensPrompt]: Engine-ready token prompts. list[TokensPrompt]: Engine-ready token prompts.
Raises: Raises:
ValueError: If input formats are invalid or length limits exceeded. ValueError: If input formats are invalid or length limits exceeded.
@ -129,7 +127,7 @@ class BaseRenderer(ABC):
prompt_or_prompts: str | list[str] | list[int] | list[list[int]] | None = None, prompt_or_prompts: str | list[str] | list[int] | list[list[int]] | None = None,
prompt_embeds: bytes | list[bytes] | None = None, prompt_embeds: bytes | list[bytes] | None = None,
config: RenderConfig, config: RenderConfig,
) -> list[EngineTokensPrompt | EngineEmbedsPrompt]: ) -> list[TokensPrompt | EmbedsPrompt]:
""" """
Convert text/token and/or base64-encoded embeddings inputs into Convert text/token and/or base64-encoded embeddings inputs into
engine-ready prompt objects using a unified RenderConfig. engine-ready prompt objects using a unified RenderConfig.
@ -146,7 +144,7 @@ class BaseRenderer(ABC):
(e.g., tokenization and length handling). (e.g., tokenization and length handling).
Returns: Returns:
list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]: list[Union[TokensPrompt, EmbedsPrompt]]:
Engine-ready prompt objects. Engine-ready prompt objects.
Raises: Raises:
@ -161,14 +159,14 @@ class BaseRenderer(ABC):
prompt_embeds: bytes | list[bytes], prompt_embeds: bytes | list[bytes],
truncate_prompt_tokens: Annotated[int, Field(ge=0)] | None = None, truncate_prompt_tokens: Annotated[int, Field(ge=0)] | None = None,
cache_salt: str | None = None, cache_salt: str | None = None,
) -> list[EngineEmbedsPrompt]: ) -> list[EmbedsPrompt]:
"""Load and validate base64-encoded embeddings into prompt objects.""" """Load and validate base64-encoded embeddings into prompt objects."""
if not self.model_config.enable_prompt_embeds: if not self.model_config.enable_prompt_embeds:
raise ValueError( raise ValueError(
"You must set `--enable-prompt-embeds` to input `prompt_embeds`." "You must set `--enable-prompt-embeds` to input `prompt_embeds`."
) )
def _load_and_validate_embed(embed: bytes) -> EngineEmbedsPrompt: def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt:
tensor = torch.load( tensor = torch.load(
io.BytesIO(pybase64.b64decode(embed, validate=True)), io.BytesIO(pybase64.b64decode(embed, validate=True)),
weights_only=True, weights_only=True,
@ -185,7 +183,7 @@ class BaseRenderer(ABC):
assert tensor.dim() == 2 assert tensor.dim() == 2
if truncate_prompt_tokens is not None: if truncate_prompt_tokens is not None:
tensor = tensor[-truncate_prompt_tokens:] tensor = tensor[-truncate_prompt_tokens:]
embeds_prompt = EngineEmbedsPrompt(prompt_embeds=tensor) embeds_prompt = EmbedsPrompt(prompt_embeds=tensor)
if cache_salt is not None: if cache_salt is not None:
embeds_prompt["cache_salt"] = cache_salt embeds_prompt["cache_salt"] = cache_salt
return embeds_prompt return embeds_prompt
@ -213,7 +211,7 @@ class CompletionRenderer(BaseRenderer):
*, *,
prompt_or_prompts: str | list[str] | list[int] | list[list[int]], prompt_or_prompts: str | list[str] | list[int] | list[list[int]],
config: RenderConfig, config: RenderConfig,
) -> list[EngineTokensPrompt]: ) -> list[TokensPrompt]:
"""Implementation of prompt rendering for completion-style requests. """Implementation of prompt rendering for completion-style requests.
Uses async tokenizer pooling for improved performance. See base class Uses async tokenizer pooling for improved performance. See base class
@ -240,7 +238,7 @@ class CompletionRenderer(BaseRenderer):
prompt_or_prompts: str | list[str] | list[int] | list[list[int]] | None = None, prompt_or_prompts: str | list[str] | list[int] | list[list[int]] | None = None,
prompt_embeds: bytes | list[bytes] | None = None, prompt_embeds: bytes | list[bytes] | None = None,
config: RenderConfig, config: RenderConfig,
) -> list[EngineTokensPrompt | EngineEmbedsPrompt]: ) -> list[TokensPrompt | EmbedsPrompt]:
""" """
Render text/token prompts and/or precomputed embedding prompts. At Render text/token prompts and/or precomputed embedding prompts. At
least one of `prompt_or_prompts` or `prompt_embeds` must be provided. least one of `prompt_or_prompts` or `prompt_embeds` must be provided.
@ -249,7 +247,7 @@ class CompletionRenderer(BaseRenderer):
if truncate_prompt_tokens == 0: if truncate_prompt_tokens == 0:
return [] return []
rendered: list[EngineTokensPrompt | EngineEmbedsPrompt] = [] rendered: list[TokensPrompt | EmbedsPrompt] = []
if prompt_embeds is not None: if prompt_embeds is not None:
rendered.extend( rendered.extend(
@ -281,10 +279,10 @@ class CompletionRenderer(BaseRenderer):
async def _create_prompt( async def _create_prompt(
self, self,
prompt_input: EngineTextPrompt | EngineTokensPrompt, prompt_input: TextPrompt | TokensPrompt,
config: RenderConfig, config: RenderConfig,
truncate_prompt_tokens: int | None, truncate_prompt_tokens: int | None,
) -> EngineTokensPrompt: ) -> TokensPrompt:
prompt, prompt_token_ids, _ = get_prompt_components(prompt_input) prompt, prompt_token_ids, _ = get_prompt_components(prompt_input)
if prompt_token_ids is not None: if prompt_token_ids is not None:
@ -317,7 +315,7 @@ class CompletionRenderer(BaseRenderer):
truncate_prompt_tokens: int | None, truncate_prompt_tokens: int | None,
add_special_tokens: bool, add_special_tokens: bool,
cache_salt: str | None, cache_salt: str | None,
) -> EngineTokensPrompt: ) -> TokensPrompt:
"""Tokenize text input asynchronously.""" """Tokenize text input asynchronously."""
async_tokenizer = self._get_async_tokenizer() async_tokenizer = self._get_async_tokenizer()
@ -350,7 +348,7 @@ class CompletionRenderer(BaseRenderer):
truncate_prompt_tokens: int | None, truncate_prompt_tokens: int | None,
cache_salt: str | None, cache_salt: str | None,
needs_detokenization: bool | None = False, needs_detokenization: bool | None = False,
) -> EngineTokensPrompt: ) -> TokensPrompt:
"""Optionally detokenize token IDs and build a tokens prompt.""" """Optionally detokenize token IDs and build a tokens prompt."""
token_ids = self._maybe_apply_truncation(token_ids, truncate_prompt_tokens) token_ids = self._maybe_apply_truncation(token_ids, truncate_prompt_tokens)
@ -392,8 +390,8 @@ class CompletionRenderer(BaseRenderer):
max_length: int | None = None, max_length: int | None = None,
cache_salt: str | None = None, cache_salt: str | None = None,
prompt: str | None = None, prompt: str | None = None,
) -> EngineTokensPrompt: ) -> TokensPrompt:
"""Create validated EngineTokensPrompt.""" """Create validated TokensPrompt."""
if max_length is not None and len(token_ids) > max_length: if max_length is not None and len(token_ids) > max_length:
raise ValueError( raise ValueError(
f"This model's maximum context length is {max_length} tokens. " f"This model's maximum context length is {max_length} tokens. "
@ -401,7 +399,7 @@ class CompletionRenderer(BaseRenderer):
"Please reduce the length of the input messages." "Please reduce the length of the input messages."
) )
tokens_prompt = EngineTokensPrompt(prompt_token_ids=token_ids) tokens_prompt = TokensPrompt(prompt_token_ids=token_ids)
if cache_salt is not None: if cache_salt is not None:
tokens_prompt["cache_salt"] = cache_salt tokens_prompt["cache_salt"] = cache_salt
if prompt is not None: if prompt is not None:

View File

@ -27,7 +27,7 @@ from vllm.entrypoints.serve.disagg.protocol import (
GenerateResponse, GenerateResponse,
GenerateResponseChoice, GenerateResponseChoice,
) )
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
@ -99,7 +99,7 @@ class ServingTokens(OpenAIServing):
# TODO(NickLucche): Change to EngineCoreRequest once Renderer work is # TODO(NickLucche): Change to EngineCoreRequest once Renderer work is
# completed # completed
engine_prompt = EngineTokensPrompt(prompt_token_ids=request.token_ids) engine_prompt = TokensPrompt(prompt_token_ids=request.token_ids)
if request.features is not None: if request.features is not None:
engine_prompt["multi_modal_data"] = None engine_prompt["multi_modal_data"] = None
@ -115,7 +115,7 @@ class ServingTokens(OpenAIServing):
self._log_inputs( self._log_inputs(
request_id, request_id,
request.token_ids, TokensPrompt(prompt_token_ids=request.token_ids),
params=sampling_params, params=sampling_params,
lora_request=lora_request, lora_request=lora_request,
) )

View File

@ -21,6 +21,7 @@ from vllm.entrypoints.openai.protocol import (
from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig from vllm.entrypoints.renderer import RenderConfig
from vllm.inputs import TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
@ -80,11 +81,8 @@ class OpenAIServingTokenization(OpenAIServing):
) )
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
(
_, _, engine_prompts = await self._preprocess_chat(
_,
engine_prompts,
) = await self._preprocess_chat(
request, request,
tokenizer, tokenizer,
request.messages, request.messages,
@ -141,7 +139,10 @@ class OpenAIServingTokenization(OpenAIServing):
tokenizer = await self.engine_client.get_tokenizer() tokenizer = await self.engine_client.get_tokenizer()
self._log_inputs( self._log_inputs(
request_id, request.tokens, params=None, lora_request=lora_request request_id,
TokensPrompt(prompt_token_ids=request.tokens),
params=None,
lora_request=lora_request,
) )
prompt_input = await self._tokenize_prompt_input_async( prompt_input = await self._tokenize_prompt_input_async(