[Renderer] Move Processor out of LLMEngine (#26165)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-10-03 23:08:22 +08:00 committed by GitHub
parent 73a99cc2a5
commit d78fda7cda
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 114 additions and 59 deletions

View File

@ -37,6 +37,7 @@ from vllm.entrypoints.utils import (_validate_truncation_size,
log_non_default_args) log_non_default_args)
from vllm.inputs import (DataPrompt, PromptType, SingletonPrompt, TextPrompt, from vllm.inputs import (DataPrompt, PromptType, SingletonPrompt, TextPrompt,
TokensPrompt) TokensPrompt)
from vllm.inputs.parse import get_prompt_components
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import QuantizationMethods
@ -49,10 +50,13 @@ from vllm.sampling_params import (BeamSearchParams, RequestOutputKind,
SamplingParams) SamplingParams)
from vllm.tasks import PoolingTask from vllm.tasks import PoolingTask
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer) get_cached_tokenizer,
init_tokenizer_from_configs)
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, Device, as_iter, is_list_of from vllm.utils import Counter, Device, as_iter, is_list_of
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.llm_engine import LLMEngine from vllm.v1.engine.llm_engine import LLMEngine
from vllm.v1.engine.processor import Processor
from vllm.v1.sample.logits_processor import LogitsProcessor from vllm.v1.sample.logits_processor import LogitsProcessor
if TYPE_CHECKING: if TYPE_CHECKING:
@ -312,6 +316,10 @@ class LLM:
self.io_processor = get_io_processor(self.llm_engine.vllm_config, self.io_processor = get_io_processor(self.llm_engine.vllm_config,
io_processor_plugin) io_processor_plugin)
@property
def model_config(self):
return self.llm_engine.model_config
def get_tokenizer(self) -> AnyTokenizer: def get_tokenizer(self) -> AnyTokenizer:
return self.llm_engine.get_tokenizer() return self.llm_engine.get_tokenizer()
@ -324,6 +332,16 @@ class LLM:
else: else:
self.llm_engine.tokenizer = get_cached_tokenizer(tokenizer) self.llm_engine.tokenizer = get_cached_tokenizer(tokenizer)
def _get_processor(self) -> Processor:
if not hasattr(self, "_processor"):
vllm_config = self.llm_engine.vllm_config
if self.model_config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = init_tokenizer_from_configs(self.model_config)
self._processor = Processor(vllm_config, tokenizer)
return self._processor
def get_default_sampling_params(self) -> SamplingParams: def get_default_sampling_params(self) -> SamplingParams:
if self.default_sampling_params is None: if self.default_sampling_params is None:
self.default_sampling_params = ( self.default_sampling_params = (
@ -1497,8 +1515,6 @@ class LLM:
tqdm_func = use_tqdm if callable(use_tqdm) else tqdm tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
it = tqdm_func(it, desc="Adding requests") it = tqdm_func(it, desc="Adding requests")
model_config = self.llm_engine.model_config
for i, prompt in enumerate(it): for i, prompt in enumerate(it):
if isinstance(prompt, dict): if isinstance(prompt, dict):
@ -1506,17 +1522,9 @@ class LLM:
prompt.get("multi_modal_data"), prompt.get("multi_modal_data"),
prompt.get("multi_modal_uuids")) prompt.get("multi_modal_uuids"))
param = params[i] if isinstance(params, Sequence) else params
tokenization_kwargs: dict[str, Any] = {}
_validate_truncation_size(model_config.max_model_len,
param.truncate_prompt_tokens,
tokenization_kwargs)
self._add_request( self._add_request(
prompt, prompt,
params[i] if isinstance(params, Sequence) else params, params[i] if isinstance(params, Sequence) else params,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request[i] if isinstance( lora_request=lora_request[i] if isinstance(
lora_request, Sequence) else lora_request, lora_request, Sequence) else lora_request,
priority=priority[i] if priority else 0, priority=priority[i] if priority else 0,
@ -1557,23 +1565,59 @@ class LLM:
raise ValueError(f"Multi-modal data for {modality} is None" raise ValueError(f"Multi-modal data for {modality} is None"
f" but UUID is not provided") f" but UUID is not provided")
def _add_request( def _process_inputs(
self, self,
prompt: PromptType, request_id: str,
engine_prompt: PromptType,
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
tokenization_kwargs: Optional[dict[str, Any]] = None, *,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest],
priority: int = 0, priority: int,
) -> None: ) -> tuple[EngineCoreRequest, dict[str, Any]]:
request_id = str(next(self.request_counter)) """Use the Processor to process inputs for LLMEngine."""
self.llm_engine.add_request( tokenization_kwargs: dict[str, Any] = {}
_validate_truncation_size(self.model_config.max_model_len,
params.truncate_prompt_tokens,
tokenization_kwargs)
processor = self._get_processor()
engine_request = processor.process_inputs(
request_id, request_id,
prompt, engine_prompt,
params, params,
lora_request=lora_request, lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
priority=priority, priority=priority,
) )
return engine_request, tokenization_kwargs
def _add_request(
self,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
lora_request: Optional[LoRARequest] = None,
priority: int = 0,
) -> None:
prompt_text, _, _ = get_prompt_components(prompt)
request_id = str(next(self.request_counter))
engine_request, tokenization_kwargs = self._process_inputs(
request_id,
prompt,
params,
lora_request=lora_request,
priority=priority,
)
self.llm_engine.add_request(
request_id,
engine_request,
params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
priority=priority,
prompt_text=prompt_text,
)
def _run_engine( def _run_engine(
self, self,

View File

@ -7,8 +7,7 @@ import traceback
from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from http import HTTPStatus from http import HTTPStatus
from typing import (Any, Callable, ClassVar, Generic, NamedTuple, Optional, from typing import Any, Callable, ClassVar, Generic, Optional, TypeVar, Union
TypeVar, Union)
import torch import torch
from fastapi import Request from fastapi import Request
@ -69,6 +68,7 @@ from vllm.entrypoints.renderer import (BaseRenderer, CompletionRenderer,
# yapf: enable # yapf: enable
from vllm.inputs.data import PromptType from vllm.inputs.data import PromptType
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.inputs.parse import PromptComponents, get_prompt_components
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob, PromptLogprobs from vllm.logprobs import Logprob, PromptLogprobs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
@ -140,12 +140,6 @@ def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]:
and "prompt_embeds" in prompt) and "prompt_embeds" in prompt)
class PromptComponents(NamedTuple):
text: Optional[str] = None
token_ids: Optional[list[int]] = None
embeds: Optional[torch.Tensor] = None
RequestT = TypeVar("RequestT", bound=AnyRequest) RequestT = TypeVar("RequestT", bound=AnyRequest)
@ -876,25 +870,23 @@ class OpenAIServing:
self, self,
request_id: str, request_id: str,
engine_prompt: PromptType, engine_prompt: PromptType,
sampling_params: SamplingParams, params: Union[SamplingParams, PoolingParams],
*, *,
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
trace_headers: Optional[Mapping[str, str]], trace_headers: Optional[Mapping[str, str]],
priority: int, priority: int,
) -> tuple[EngineCoreRequest, dict[str, Any]]: ) -> tuple[EngineCoreRequest, dict[str, Any]]:
""" """Use the Processor to process inputs for AsyncLLM."""
using the Processor to process inputs for AsyncLLM
"""
tokenization_kwargs: dict[str, Any] = {} tokenization_kwargs: dict[str, Any] = {}
_validate_truncation_size(self.max_model_len, _validate_truncation_size(self.max_model_len,
sampling_params.truncate_prompt_tokens, params.truncate_prompt_tokens,
tokenization_kwargs) tokenization_kwargs)
processor = await self._get_processor() processor = await self._get_processor()
engine_request = processor.process_inputs( engine_request = processor.process_inputs(
request_id, request_id,
engine_prompt, engine_prompt,
sampling_params, params,
lora_request=lora_request, lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers, trace_headers=trace_headers,
@ -973,25 +965,12 @@ class OpenAIServing:
def _get_prompt_components( def _get_prompt_components(
self, self,
inputs: Union[RequestPrompt, PromptType], prompt: Union[RequestPrompt, PromptType],
) -> PromptComponents: ) -> PromptComponents:
if isinstance(inputs, str): if isinstance(prompt, list):
return PromptComponents(text=inputs) return PromptComponents(token_ids=prompt)
if isinstance(inputs, list):
return PromptComponents(token_ids=inputs)
if isinstance(inputs, dict):
return PromptComponents(
text=inputs.get("prompt"), # type: ignore[arg-type]
token_ids=inputs.get(
"prompt_token_ids"), # type: ignore[arg-type]
embeds=inputs.get("prompt_embeds"),
)
return PromptComponents( return get_prompt_components(prompt) # type: ignore[arg-type]
text=getattr(inputs, "prompt", None),
token_ids=getattr(inputs, "prompt_token_ids", None),
embeds=getattr(inputs, "prompt_embeds", None),
)
def _log_inputs( def _log_inputs(
self, self,

View File

@ -1,7 +1,8 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence from collections.abc import Sequence
from typing import Literal, Optional, TypedDict, Union, cast, overload from typing import (TYPE_CHECKING, Literal, NamedTuple, Optional, TypedDict,
Union, cast, overload)
from typing_extensions import TypeIs from typing_extensions import TypeIs
@ -11,6 +12,9 @@ from .data import (EmbedsPrompt, ExplicitEncoderDecoderPrompt, ProcessorInputs,
PromptType, SingletonInputs, SingletonPrompt, TextPrompt, PromptType, SingletonInputs, SingletonPrompt, TextPrompt,
TokensPrompt) TokensPrompt)
if TYPE_CHECKING:
import torch
class ParsedText(TypedDict): class ParsedText(TypedDict):
content: str content: str
@ -149,3 +153,23 @@ def split_enc_dec_inputs(
) )
return None, inputs return None, inputs
class PromptComponents(NamedTuple):
text: Optional[str] = None
token_ids: Optional[list[int]] = None
embeds: Optional["torch.Tensor"] = None
def get_prompt_components(prompt: PromptType) -> PromptComponents:
if isinstance(prompt, str):
return PromptComponents(text=prompt)
if (encoder_prompt := prompt.get("encoder_prompt")):
return get_prompt_components(encoder_prompt) # type: ignore[arg-type]
return PromptComponents(
text=prompt.get("prompt"), # type: ignore[arg-type]
token_ids=prompt.get("prompt_token_ids"), # type: ignore[arg-type]
embeds=prompt.get("prompt_embeds"),
)

View File

@ -27,6 +27,7 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer,
init_tokenizer_from_configs) init_tokenizer_from_configs)
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import Device from vllm.utils import Device
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.output_processor import OutputProcessor from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.parallel_sampling import ParentRequest
@ -213,13 +214,14 @@ class LLMEngine:
def add_request( def add_request(
self, self,
request_id: str, request_id: str,
prompt: PromptType, prompt: Union[EngineCoreRequest, PromptType],
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0, priority: int = 0,
prompt_text: Optional[str] = None,
) -> None: ) -> None:
# Validate the request_id type. # Validate the request_id type.
if not isinstance(request_id, str): if not isinstance(request_id, str):
@ -227,12 +229,18 @@ class LLMEngine:
f"request_id must be a string, got {type(request_id)}") f"request_id must be a string, got {type(request_id)}")
# Process raw inputs into the request. # Process raw inputs into the request.
request = self.processor.process_inputs(request_id, prompt, params, if isinstance(prompt, EngineCoreRequest):
arrival_time, lora_request, request = prompt
tokenization_kwargs, else:
trace_headers, priority) assert prompt_text is None
prompt_text = prompt if isinstance(prompt, logger.warning_once("Processor has been moved under LLM and will "
str) else prompt.get("prompt") "be removed from LLMEngine in v0.13.")
request = self.processor.process_inputs(request_id, prompt, params,
arrival_time, lora_request,
tokenization_kwargs,
trace_headers, priority)
prompt_text = (prompt if isinstance(prompt, str) else
prompt.get("prompt"))
n = params.n if isinstance(params, SamplingParams) else 1 n = params.n if isinstance(params, SamplingParams) else 1