diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 862f383e4ecb2..705a72f657a2d 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -37,6 +37,7 @@ from vllm.entrypoints.utils import (_validate_truncation_size, log_non_default_args) from vllm.inputs import (DataPrompt, PromptType, SingletonPrompt, TextPrompt, TokensPrompt) +from vllm.inputs.parse import get_prompt_components from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.quantization import QuantizationMethods @@ -49,10 +50,13 @@ from vllm.sampling_params import (BeamSearchParams, RequestOutputKind, SamplingParams) from vllm.tasks import PoolingTask from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, - get_cached_tokenizer) + get_cached_tokenizer, + init_tokenizer_from_configs) from vllm.usage.usage_lib import UsageContext from vllm.utils import Counter, Device, as_iter, is_list_of +from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.llm_engine import LLMEngine +from vllm.v1.engine.processor import Processor from vllm.v1.sample.logits_processor import LogitsProcessor if TYPE_CHECKING: @@ -312,6 +316,10 @@ class LLM: self.io_processor = get_io_processor(self.llm_engine.vllm_config, io_processor_plugin) + @property + def model_config(self): + return self.llm_engine.model_config + def get_tokenizer(self) -> AnyTokenizer: return self.llm_engine.get_tokenizer() @@ -324,6 +332,16 @@ class LLM: else: self.llm_engine.tokenizer = get_cached_tokenizer(tokenizer) + def _get_processor(self) -> Processor: + if not hasattr(self, "_processor"): + vllm_config = self.llm_engine.vllm_config + if self.model_config.skip_tokenizer_init: + tokenizer = None + else: + tokenizer = init_tokenizer_from_configs(self.model_config) + self._processor = Processor(vllm_config, tokenizer) + return self._processor + def get_default_sampling_params(self) -> SamplingParams: if self.default_sampling_params is None: self.default_sampling_params = ( @@ -1497,8 +1515,6 @@ class LLM: tqdm_func = use_tqdm if callable(use_tqdm) else tqdm it = tqdm_func(it, desc="Adding requests") - model_config = self.llm_engine.model_config - for i, prompt in enumerate(it): if isinstance(prompt, dict): @@ -1506,17 +1522,9 @@ class LLM: prompt.get("multi_modal_data"), prompt.get("multi_modal_uuids")) - param = params[i] if isinstance(params, Sequence) else params - - tokenization_kwargs: dict[str, Any] = {} - _validate_truncation_size(model_config.max_model_len, - param.truncate_prompt_tokens, - tokenization_kwargs) - self._add_request( prompt, params[i] if isinstance(params, Sequence) else params, - tokenization_kwargs=tokenization_kwargs, lora_request=lora_request[i] if isinstance( lora_request, Sequence) else lora_request, priority=priority[i] if priority else 0, @@ -1557,23 +1565,59 @@ class LLM: raise ValueError(f"Multi-modal data for {modality} is None" f" but UUID is not provided") - def _add_request( + def _process_inputs( self, - prompt: PromptType, + request_id: str, + engine_prompt: PromptType, params: Union[SamplingParams, PoolingParams], - tokenization_kwargs: Optional[dict[str, Any]] = None, - lora_request: Optional[LoRARequest] = None, - priority: int = 0, - ) -> None: - request_id = str(next(self.request_counter)) - self.llm_engine.add_request( + *, + lora_request: Optional[LoRARequest], + priority: int, + ) -> tuple[EngineCoreRequest, dict[str, Any]]: + """Use the Processor to process inputs for LLMEngine.""" + tokenization_kwargs: dict[str, Any] = {} + _validate_truncation_size(self.model_config.max_model_len, + params.truncate_prompt_tokens, + tokenization_kwargs) + + processor = self._get_processor() + engine_request = processor.process_inputs( request_id, - prompt, + engine_prompt, params, lora_request=lora_request, tokenization_kwargs=tokenization_kwargs, priority=priority, ) + return engine_request, tokenization_kwargs + + def _add_request( + self, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + lora_request: Optional[LoRARequest] = None, + priority: int = 0, + ) -> None: + prompt_text, _, _ = get_prompt_components(prompt) + request_id = str(next(self.request_counter)) + + engine_request, tokenization_kwargs = self._process_inputs( + request_id, + prompt, + params, + lora_request=lora_request, + priority=priority, + ) + + self.llm_engine.add_request( + request_id, + engine_request, + params, + lora_request=lora_request, + tokenization_kwargs=tokenization_kwargs, + priority=priority, + prompt_text=prompt_text, + ) def _run_engine( self, diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index dc41723800d0d..0e5279baed29f 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -7,8 +7,7 @@ import traceback from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence from concurrent.futures import ThreadPoolExecutor from http import HTTPStatus -from typing import (Any, Callable, ClassVar, Generic, NamedTuple, Optional, - TypeVar, Union) +from typing import Any, Callable, ClassVar, Generic, Optional, TypeVar, Union import torch from fastapi import Request @@ -69,6 +68,7 @@ from vllm.entrypoints.renderer import (BaseRenderer, CompletionRenderer, # yapf: enable from vllm.inputs.data import PromptType from vllm.inputs.data import TokensPrompt as EngineTokensPrompt +from vllm.inputs.parse import PromptComponents, get_prompt_components from vllm.logger import init_logger from vllm.logprobs import Logprob, PromptLogprobs from vllm.lora.request import LoRARequest @@ -140,12 +140,6 @@ def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]: and "prompt_embeds" in prompt) -class PromptComponents(NamedTuple): - text: Optional[str] = None - token_ids: Optional[list[int]] = None - embeds: Optional[torch.Tensor] = None - - RequestT = TypeVar("RequestT", bound=AnyRequest) @@ -876,25 +870,23 @@ class OpenAIServing: self, request_id: str, engine_prompt: PromptType, - sampling_params: SamplingParams, + params: Union[SamplingParams, PoolingParams], *, lora_request: Optional[LoRARequest], trace_headers: Optional[Mapping[str, str]], priority: int, ) -> tuple[EngineCoreRequest, dict[str, Any]]: - """ - using the Processor to process inputs for AsyncLLM - """ + """Use the Processor to process inputs for AsyncLLM.""" tokenization_kwargs: dict[str, Any] = {} _validate_truncation_size(self.max_model_len, - sampling_params.truncate_prompt_tokens, + params.truncate_prompt_tokens, tokenization_kwargs) processor = await self._get_processor() engine_request = processor.process_inputs( request_id, engine_prompt, - sampling_params, + params, lora_request=lora_request, tokenization_kwargs=tokenization_kwargs, trace_headers=trace_headers, @@ -973,25 +965,12 @@ class OpenAIServing: def _get_prompt_components( self, - inputs: Union[RequestPrompt, PromptType], + prompt: Union[RequestPrompt, PromptType], ) -> PromptComponents: - if isinstance(inputs, str): - return PromptComponents(text=inputs) - if isinstance(inputs, list): - return PromptComponents(token_ids=inputs) - if isinstance(inputs, dict): - return PromptComponents( - text=inputs.get("prompt"), # type: ignore[arg-type] - token_ids=inputs.get( - "prompt_token_ids"), # type: ignore[arg-type] - embeds=inputs.get("prompt_embeds"), - ) + if isinstance(prompt, list): + return PromptComponents(token_ids=prompt) - return PromptComponents( - text=getattr(inputs, "prompt", None), - token_ids=getattr(inputs, "prompt_token_ids", None), - embeds=getattr(inputs, "prompt_embeds", None), - ) + return get_prompt_components(prompt) # type: ignore[arg-type] def _log_inputs( self, diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index 8c3700799e4ab..123c811731208 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Sequence -from typing import Literal, Optional, TypedDict, Union, cast, overload +from typing import (TYPE_CHECKING, Literal, NamedTuple, Optional, TypedDict, + Union, cast, overload) from typing_extensions import TypeIs @@ -11,6 +12,9 @@ from .data import (EmbedsPrompt, ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType, SingletonInputs, SingletonPrompt, TextPrompt, TokensPrompt) +if TYPE_CHECKING: + import torch + class ParsedText(TypedDict): content: str @@ -149,3 +153,23 @@ def split_enc_dec_inputs( ) return None, inputs + + +class PromptComponents(NamedTuple): + text: Optional[str] = None + token_ids: Optional[list[int]] = None + embeds: Optional["torch.Tensor"] = None + + +def get_prompt_components(prompt: PromptType) -> PromptComponents: + if isinstance(prompt, str): + return PromptComponents(text=prompt) + + if (encoder_prompt := prompt.get("encoder_prompt")): + return get_prompt_components(encoder_prompt) # type: ignore[arg-type] + + return PromptComponents( + text=prompt.get("prompt"), # type: ignore[arg-type] + token_ids=prompt.get("prompt_token_ids"), # type: ignore[arg-type] + embeds=prompt.get("prompt_embeds"), + ) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index f81427161d7d7..3734c208004a5 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -27,6 +27,7 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, init_tokenizer_from_configs) from vllm.usage.usage_lib import UsageContext from vllm.utils import Device +from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.output_processor import OutputProcessor from vllm.v1.engine.parallel_sampling import ParentRequest @@ -213,13 +214,14 @@ class LLMEngine: def add_request( self, request_id: str, - prompt: PromptType, + prompt: Union[EngineCoreRequest, PromptType], params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, tokenization_kwargs: Optional[dict[str, Any]] = None, trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, + prompt_text: Optional[str] = None, ) -> None: # Validate the request_id type. if not isinstance(request_id, str): @@ -227,12 +229,18 @@ class LLMEngine: f"request_id must be a string, got {type(request_id)}") # Process raw inputs into the request. - request = self.processor.process_inputs(request_id, prompt, params, - arrival_time, lora_request, - tokenization_kwargs, - trace_headers, priority) - prompt_text = prompt if isinstance(prompt, - str) else prompt.get("prompt") + if isinstance(prompt, EngineCoreRequest): + request = prompt + else: + assert prompt_text is None + logger.warning_once("Processor has been moved under LLM and will " + "be removed from LLMEngine in v0.13.") + request = self.processor.process_inputs(request_id, prompt, params, + arrival_time, lora_request, + tokenization_kwargs, + trace_headers, priority) + prompt_text = (prompt if isinstance(prompt, str) else + prompt.get("prompt")) n = params.n if isinstance(params, SamplingParams) else 1