[Renderer] Move Processor out of LLMEngine (#26165)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-10-03 23:08:22 +08:00 committed by GitHub
parent 73a99cc2a5
commit d78fda7cda
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 114 additions and 59 deletions

View File

@ -37,6 +37,7 @@ from vllm.entrypoints.utils import (_validate_truncation_size,
log_non_default_args)
from vllm.inputs import (DataPrompt, PromptType, SingletonPrompt, TextPrompt,
TokensPrompt)
from vllm.inputs.parse import get_prompt_components
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.quantization import QuantizationMethods
@ -49,10 +50,13 @@ from vllm.sampling_params import (BeamSearchParams, RequestOutputKind,
SamplingParams)
from vllm.tasks import PoolingTask
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer)
get_cached_tokenizer,
init_tokenizer_from_configs)
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, Device, as_iter, is_list_of
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.llm_engine import LLMEngine
from vllm.v1.engine.processor import Processor
from vllm.v1.sample.logits_processor import LogitsProcessor
if TYPE_CHECKING:
@ -312,6 +316,10 @@ class LLM:
self.io_processor = get_io_processor(self.llm_engine.vllm_config,
io_processor_plugin)
@property
def model_config(self):
return self.llm_engine.model_config
def get_tokenizer(self) -> AnyTokenizer:
return self.llm_engine.get_tokenizer()
@ -324,6 +332,16 @@ class LLM:
else:
self.llm_engine.tokenizer = get_cached_tokenizer(tokenizer)
def _get_processor(self) -> Processor:
if not hasattr(self, "_processor"):
vllm_config = self.llm_engine.vllm_config
if self.model_config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = init_tokenizer_from_configs(self.model_config)
self._processor = Processor(vllm_config, tokenizer)
return self._processor
def get_default_sampling_params(self) -> SamplingParams:
if self.default_sampling_params is None:
self.default_sampling_params = (
@ -1497,8 +1515,6 @@ class LLM:
tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
it = tqdm_func(it, desc="Adding requests")
model_config = self.llm_engine.model_config
for i, prompt in enumerate(it):
if isinstance(prompt, dict):
@ -1506,17 +1522,9 @@ class LLM:
prompt.get("multi_modal_data"),
prompt.get("multi_modal_uuids"))
param = params[i] if isinstance(params, Sequence) else params
tokenization_kwargs: dict[str, Any] = {}
_validate_truncation_size(model_config.max_model_len,
param.truncate_prompt_tokens,
tokenization_kwargs)
self._add_request(
prompt,
params[i] if isinstance(params, Sequence) else params,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request[i] if isinstance(
lora_request, Sequence) else lora_request,
priority=priority[i] if priority else 0,
@ -1557,23 +1565,59 @@ class LLM:
raise ValueError(f"Multi-modal data for {modality} is None"
f" but UUID is not provided")
def _add_request(
def _process_inputs(
self,
prompt: PromptType,
request_id: str,
engine_prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
priority: int = 0,
) -> None:
request_id = str(next(self.request_counter))
self.llm_engine.add_request(
*,
lora_request: Optional[LoRARequest],
priority: int,
) -> tuple[EngineCoreRequest, dict[str, Any]]:
"""Use the Processor to process inputs for LLMEngine."""
tokenization_kwargs: dict[str, Any] = {}
_validate_truncation_size(self.model_config.max_model_len,
params.truncate_prompt_tokens,
tokenization_kwargs)
processor = self._get_processor()
engine_request = processor.process_inputs(
request_id,
prompt,
engine_prompt,
params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
priority=priority,
)
return engine_request, tokenization_kwargs
def _add_request(
self,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
lora_request: Optional[LoRARequest] = None,
priority: int = 0,
) -> None:
prompt_text, _, _ = get_prompt_components(prompt)
request_id = str(next(self.request_counter))
engine_request, tokenization_kwargs = self._process_inputs(
request_id,
prompt,
params,
lora_request=lora_request,
priority=priority,
)
self.llm_engine.add_request(
request_id,
engine_request,
params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
priority=priority,
prompt_text=prompt_text,
)
def _run_engine(
self,

View File

@ -7,8 +7,7 @@ import traceback
from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor
from http import HTTPStatus
from typing import (Any, Callable, ClassVar, Generic, NamedTuple, Optional,
TypeVar, Union)
from typing import Any, Callable, ClassVar, Generic, Optional, TypeVar, Union
import torch
from fastapi import Request
@ -69,6 +68,7 @@ from vllm.entrypoints.renderer import (BaseRenderer, CompletionRenderer,
# yapf: enable
from vllm.inputs.data import PromptType
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.inputs.parse import PromptComponents, get_prompt_components
from vllm.logger import init_logger
from vllm.logprobs import Logprob, PromptLogprobs
from vllm.lora.request import LoRARequest
@ -140,12 +140,6 @@ def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]:
and "prompt_embeds" in prompt)
class PromptComponents(NamedTuple):
text: Optional[str] = None
token_ids: Optional[list[int]] = None
embeds: Optional[torch.Tensor] = None
RequestT = TypeVar("RequestT", bound=AnyRequest)
@ -876,25 +870,23 @@ class OpenAIServing:
self,
request_id: str,
engine_prompt: PromptType,
sampling_params: SamplingParams,
params: Union[SamplingParams, PoolingParams],
*,
lora_request: Optional[LoRARequest],
trace_headers: Optional[Mapping[str, str]],
priority: int,
) -> tuple[EngineCoreRequest, dict[str, Any]]:
"""
using the Processor to process inputs for AsyncLLM
"""
"""Use the Processor to process inputs for AsyncLLM."""
tokenization_kwargs: dict[str, Any] = {}
_validate_truncation_size(self.max_model_len,
sampling_params.truncate_prompt_tokens,
params.truncate_prompt_tokens,
tokenization_kwargs)
processor = await self._get_processor()
engine_request = processor.process_inputs(
request_id,
engine_prompt,
sampling_params,
params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
@ -973,25 +965,12 @@ class OpenAIServing:
def _get_prompt_components(
self,
inputs: Union[RequestPrompt, PromptType],
prompt: Union[RequestPrompt, PromptType],
) -> PromptComponents:
if isinstance(inputs, str):
return PromptComponents(text=inputs)
if isinstance(inputs, list):
return PromptComponents(token_ids=inputs)
if isinstance(inputs, dict):
return PromptComponents(
text=inputs.get("prompt"), # type: ignore[arg-type]
token_ids=inputs.get(
"prompt_token_ids"), # type: ignore[arg-type]
embeds=inputs.get("prompt_embeds"),
)
if isinstance(prompt, list):
return PromptComponents(token_ids=prompt)
return PromptComponents(
text=getattr(inputs, "prompt", None),
token_ids=getattr(inputs, "prompt_token_ids", None),
embeds=getattr(inputs, "prompt_embeds", None),
)
return get_prompt_components(prompt) # type: ignore[arg-type]
def _log_inputs(
self,

View File

@ -1,7 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import Literal, Optional, TypedDict, Union, cast, overload
from typing import (TYPE_CHECKING, Literal, NamedTuple, Optional, TypedDict,
Union, cast, overload)
from typing_extensions import TypeIs
@ -11,6 +12,9 @@ from .data import (EmbedsPrompt, ExplicitEncoderDecoderPrompt, ProcessorInputs,
PromptType, SingletonInputs, SingletonPrompt, TextPrompt,
TokensPrompt)
if TYPE_CHECKING:
import torch
class ParsedText(TypedDict):
content: str
@ -149,3 +153,23 @@ def split_enc_dec_inputs(
)
return None, inputs
class PromptComponents(NamedTuple):
text: Optional[str] = None
token_ids: Optional[list[int]] = None
embeds: Optional["torch.Tensor"] = None
def get_prompt_components(prompt: PromptType) -> PromptComponents:
if isinstance(prompt, str):
return PromptComponents(text=prompt)
if (encoder_prompt := prompt.get("encoder_prompt")):
return get_prompt_components(encoder_prompt) # type: ignore[arg-type]
return PromptComponents(
text=prompt.get("prompt"), # type: ignore[arg-type]
token_ids=prompt.get("prompt_token_ids"), # type: ignore[arg-type]
embeds=prompt.get("prompt_embeds"),
)

View File

@ -27,6 +27,7 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer,
init_tokenizer_from_configs)
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Device
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.parallel_sampling import ParentRequest
@ -213,13 +214,14 @@ class LLMEngine:
def add_request(
self,
request_id: str,
prompt: PromptType,
prompt: Union[EngineCoreRequest, PromptType],
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
prompt_text: Optional[str] = None,
) -> None:
# Validate the request_id type.
if not isinstance(request_id, str):
@ -227,12 +229,18 @@ class LLMEngine:
f"request_id must be a string, got {type(request_id)}")
# Process raw inputs into the request.
request = self.processor.process_inputs(request_id, prompt, params,
arrival_time, lora_request,
tokenization_kwargs,
trace_headers, priority)
prompt_text = prompt if isinstance(prompt,
str) else prompt.get("prompt")
if isinstance(prompt, EngineCoreRequest):
request = prompt
else:
assert prompt_text is None
logger.warning_once("Processor has been moved under LLM and will "
"be removed from LLMEngine in v0.13.")
request = self.processor.process_inputs(request_id, prompt, params,
arrival_time, lora_request,
tokenization_kwargs,
trace_headers, priority)
prompt_text = (prompt if isinstance(prompt, str) else
prompt.get("prompt"))
n = params.n if isinstance(params, SamplingParams) else 1