From 247181536fc2cab728077f3e7489622e19671d2d Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 28 Mar 2025 01:36:32 +0800 Subject: [PATCH] [Misc] Replace `is_encoder_decoder_inputs` with `split_enc_dec_inputs` (#15620) Signed-off-by: DarkLight1337 --- .../multimodal/processing/test_idefics3.py | 2 +- .../multimodal/processing/test_phi3v.py | 2 +- vllm/engine/arg_utils.py | 2 +- vllm/engine/llm_engine.py | 28 ++++++++---------- vllm/inputs/parse.py | 22 +++++++++----- vllm/inputs/registry.py | 14 ++++----- vllm/model_executor/models/idefics3.py | 4 +-- vllm/v1/engine/processor.py | 29 ++++++++----------- 8 files changed, 49 insertions(+), 54 deletions(-) diff --git a/tests/models/multimodal/processing/test_idefics3.py b/tests/models/multimodal/processing/test_idefics3.py index fdbe2f17692f7..4cff429a53941 100644 --- a/tests/models/multimodal/processing/test_idefics3.py +++ b/tests/models/multimodal/processing/test_idefics3.py @@ -29,7 +29,7 @@ def test_processor_override( num_imgs: int, kwargs_on_init: bool, ): - """Ensure input_processor_for_idefics3 handles num_crops properly.""" + """Ensure Idefics3MultiModalProcessor handles num_crops properly.""" # Same as the previous test - don't initialize mm_processor_kwargs # in this test and assume that the kwargs will be correctly expanded by # the partial when calling the custom input processor. diff --git a/tests/models/multimodal/processing/test_phi3v.py b/tests/models/multimodal/processing/test_phi3v.py index 2f0c8e7e5492c..dd5f30a23176b 100644 --- a/tests/models/multimodal/processing/test_phi3v.py +++ b/tests/models/multimodal/processing/test_phi3v.py @@ -30,7 +30,7 @@ def test_processor_override( num_imgs: int, kwargs_on_init: bool, ): - """Ensure input_processor_for_phi3v handles num_crops properly.""" + """Ensure Phi3VMultiModalProcessor handles num_crops properly.""" # Avoid initializing CUDA early from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 784ea35beb357..53af3e5717c52 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -665,7 +665,7 @@ class EngineArgs: type=nullable_kvs, default=EngineArgs.limit_mm_per_prompt, # The default value is given in - # MultiModalRegistry.init_mm_limits_per_prompt + # MultiModalConfig.get_limit_per_prompt help=('For each multimodal plugin, limit how many ' 'input instances to allow for each prompt. ' 'Expects a comma-separated list of items, ' diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 3d019ea58c5e1..4856c3568319b 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -30,8 +30,8 @@ from vllm.entrypoints.openai.logits_processors import ( get_logits_processors as get_openai_logits_processors) from vllm.executor.executor_base import ExecutorBase from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, - PromptType, SingletonInputsAdapter) -from vllm.inputs.parse import is_encoder_decoder_inputs, is_token_prompt + PromptType) +from vllm.inputs.parse import is_token_prompt, split_enc_dec_inputs from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.logits_process import get_bad_words_logits_processors @@ -609,12 +609,7 @@ class LLMEngine: seq_id = next(self.seq_counter) eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) - if is_encoder_decoder_inputs(processed_inputs): - decoder_inputs = processed_inputs["decoder"] - encoder_inputs = processed_inputs["encoder"] - else: - decoder_inputs = processed_inputs - encoder_inputs = None + encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id, lora_request, prompt_adapter_request) @@ -2031,15 +2026,16 @@ class LLMEngine: def _validate_model_inputs(self, inputs: ProcessorInputs, lora_request: Optional[LoRARequest]): - if is_encoder_decoder_inputs(inputs): - # For encoder-decoder multimodal models, the max_prompt_len - # restricts the decoder prompt length - prompt_inputs = inputs["decoder" if self.model_config. - is_multimodal_model else "encoder"] - else: - prompt_inputs = inputs + encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs) - prompt_ids = SingletonInputsAdapter(prompt_inputs).prompt_token_ids + # For encoder-decoder multimodal models, the max_prompt_len + # restricts the decoder prompt length + if self.model_config.is_multimodal_model: + prompt_inputs = decoder_inputs + else: + prompt_inputs = encoder_inputs or decoder_inputs + + prompt_ids = prompt_inputs["prompt_token_ids"] if prompt_ids is None or len(prompt_ids) == 0: raise ValueError("Prompt cannot be empty") diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index ed1056948d807..28e207de1fd39 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -1,15 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 - from collections.abc import Sequence -from typing import Literal, TypedDict, Union, cast, overload +from typing import Literal, Optional, TypedDict, Union, cast, overload from typing_extensions import TypeIs from vllm.utils import is_list_of -from .data import (EncoderDecoderInputs, ExplicitEncoderDecoderPrompt, - ProcessorInputs, PromptType, SingletonPrompt, TextPrompt, - TokensPrompt) +from .data import (ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType, + SingletonInputs, SingletonPrompt, TextPrompt, TokensPrompt) class ParsedText(TypedDict): @@ -110,6 +108,14 @@ def is_explicit_encoder_decoder_prompt( return isinstance(prompt, dict) and "encoder_prompt" in prompt -def is_encoder_decoder_inputs( - inputs: ProcessorInputs) -> TypeIs[EncoderDecoderInputs]: - return "encoder" in inputs and "decoder" in inputs +def split_enc_dec_inputs( + inputs: ProcessorInputs, +) -> tuple[Optional[SingletonInputs], SingletonInputs]: + if "encoder" in inputs and "decoder" in inputs: + # NOTE: This passes pyright but not mypy + return ( + inputs["encoder"], # type: ignore[typeddict-item] + inputs["decoder"], # type: ignore[typeddict-item] + ) + + return None, inputs diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index b6ceb5fb82d70..8b95db7a72522 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -19,7 +19,7 @@ from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides, resolve_mm_processor_kwargs) from .data import ProcessorInputs, SingletonInputs -from .parse import is_encoder_decoder_inputs +from .parse import split_enc_dec_inputs if TYPE_CHECKING: from vllm.config import ModelConfig @@ -462,13 +462,11 @@ class InputRegistry: **mm_processor_kwargs, ) - if is_encoder_decoder_inputs(processed_inputs): - self._ensure_mm_kwargs(processed_inputs["encoder"], - mm_processor_kwargs) - self._ensure_mm_kwargs(processed_inputs["decoder"], - mm_processor_kwargs) - else: - self._ensure_mm_kwargs(processed_inputs, mm_processor_kwargs) + encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) + if encoder_inputs is not None: + self._ensure_mm_kwargs(encoder_inputs, mm_processor_kwargs) + if decoder_inputs is not None: + self._ensure_mm_kwargs(decoder_inputs, mm_processor_kwargs) return processed_inputs diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 234e4498f163b..432f26141048b 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -232,7 +232,7 @@ class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo] ) -class Idefics3MultimodalProcessor( +class Idefics3MultiModalProcessor( BaseMultiModalProcessor[Idefics3ProcessingInfo]): def _call_hf_processor( @@ -575,7 +575,7 @@ class Idefics3Model(nn.Module): @MULTIMODAL_REGISTRY.register_processor( - Idefics3MultimodalProcessor, + Idefics3MultiModalProcessor, info=Idefics3ProcessingInfo, dummy_inputs=Idefics3DummyInputsBuilder) class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index e281781675769..065ac0920af77 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -7,7 +7,7 @@ from typing import Optional, Union from vllm.config import VllmConfig from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, PromptType, SingletonInputsAdapter) -from vllm.inputs.parse import is_encoder_decoder_inputs +from vllm.inputs.parse import split_enc_dec_inputs from vllm.inputs.preprocess import InputPreprocessor from vllm.lora.request import LoRARequest from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, @@ -209,14 +209,8 @@ class Processor: self._validate_model_inputs(processed_inputs, lora_request) - if is_encoder_decoder_inputs(processed_inputs): - decoder_inputs = SingletonInputsAdapter( - processed_inputs["decoder"]) - encoder_inputs = SingletonInputsAdapter( - processed_inputs["encoder"]) - else: - decoder_inputs = SingletonInputsAdapter(processed_inputs) - encoder_inputs = None + encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) + decoder_inputs = SingletonInputsAdapter(decoder_inputs) # TODO: Impl encoder-decoder if encoder_inputs is not None: @@ -301,15 +295,16 @@ class Processor: def _validate_model_inputs(self, inputs: ProcessorInputs, lora_request: Optional[LoRARequest] = None): - if is_encoder_decoder_inputs(inputs): - # For encoder-decoder multimodal models, the max_prompt_len - # restricts the decoder prompt length - prompt_inputs = inputs["decoder" if self.model_config. - is_multimodal_model else "encoder"] - else: - prompt_inputs = inputs + encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs) - prompt_ids = SingletonInputsAdapter(prompt_inputs).prompt_token_ids + # For encoder-decoder multimodal models, the max_prompt_len + # restricts the decoder prompt length + if self.model_config.is_multimodal_model: + prompt_inputs = decoder_inputs + else: + prompt_inputs = encoder_inputs or decoder_inputs + + prompt_ids = prompt_inputs["prompt_token_ids"] if prompt_ids is None or len(prompt_ids) == 0: raise ValueError("Prompt cannot be empty")