[Misc] Replace is_encoder_decoder_inputs with split_enc_dec_inputs (#15620)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-03-28 01:36:32 +08:00 committed by GitHub
parent 07bf813fb5
commit 247181536f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 49 additions and 54 deletions

View File

@ -29,7 +29,7 @@ def test_processor_override(
num_imgs: int, num_imgs: int,
kwargs_on_init: bool, kwargs_on_init: bool,
): ):
"""Ensure input_processor_for_idefics3 handles num_crops properly.""" """Ensure Idefics3MultiModalProcessor handles num_crops properly."""
# Same as the previous test - don't initialize mm_processor_kwargs # Same as the previous test - don't initialize mm_processor_kwargs
# in this test and assume that the kwargs will be correctly expanded by # in this test and assume that the kwargs will be correctly expanded by
# the partial when calling the custom input processor. # the partial when calling the custom input processor.

View File

@ -30,7 +30,7 @@ def test_processor_override(
num_imgs: int, num_imgs: int,
kwargs_on_init: bool, kwargs_on_init: bool,
): ):
"""Ensure input_processor_for_phi3v handles num_crops properly.""" """Ensure Phi3VMultiModalProcessor handles num_crops properly."""
# Avoid initializing CUDA early # Avoid initializing CUDA early
from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID

View File

@ -665,7 +665,7 @@ class EngineArgs:
type=nullable_kvs, type=nullable_kvs,
default=EngineArgs.limit_mm_per_prompt, default=EngineArgs.limit_mm_per_prompt,
# The default value is given in # The default value is given in
# MultiModalRegistry.init_mm_limits_per_prompt # MultiModalConfig.get_limit_per_prompt
help=('For each multimodal plugin, limit how many ' help=('For each multimodal plugin, limit how many '
'input instances to allow for each prompt. ' 'input instances to allow for each prompt. '
'Expects a comma-separated list of items, ' 'Expects a comma-separated list of items, '

View File

@ -30,8 +30,8 @@ from vllm.entrypoints.openai.logits_processors import (
get_logits_processors as get_openai_logits_processors) get_logits_processors as get_openai_logits_processors)
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
PromptType, SingletonInputsAdapter) PromptType)
from vllm.inputs.parse import is_encoder_decoder_inputs, is_token_prompt from vllm.inputs.parse import is_token_prompt, split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logits_process import get_bad_words_logits_processors from vllm.logits_process import get_bad_words_logits_processors
@ -609,12 +609,7 @@ class LLMEngine:
seq_id = next(self.seq_counter) seq_id = next(self.seq_counter)
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
if is_encoder_decoder_inputs(processed_inputs): encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
decoder_inputs = processed_inputs["decoder"]
encoder_inputs = processed_inputs["encoder"]
else:
decoder_inputs = processed_inputs
encoder_inputs = None
seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id, seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
lora_request, prompt_adapter_request) lora_request, prompt_adapter_request)
@ -2031,15 +2026,16 @@ class LLMEngine:
def _validate_model_inputs(self, inputs: ProcessorInputs, def _validate_model_inputs(self, inputs: ProcessorInputs,
lora_request: Optional[LoRARequest]): lora_request: Optional[LoRARequest]):
if is_encoder_decoder_inputs(inputs): encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)
# For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length
prompt_inputs = inputs["decoder" if self.model_config.
is_multimodal_model else "encoder"]
else:
prompt_inputs = inputs
prompt_ids = SingletonInputsAdapter(prompt_inputs).prompt_token_ids # For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length
if self.model_config.is_multimodal_model:
prompt_inputs = decoder_inputs
else:
prompt_inputs = encoder_inputs or decoder_inputs
prompt_ids = prompt_inputs["prompt_token_ids"]
if prompt_ids is None or len(prompt_ids) == 0: if prompt_ids is None or len(prompt_ids) == 0:
raise ValueError("Prompt cannot be empty") raise ValueError("Prompt cannot be empty")

View File

@ -1,15 +1,13 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from collections.abc import Sequence from collections.abc import Sequence
from typing import Literal, TypedDict, Union, cast, overload from typing import Literal, Optional, TypedDict, Union, cast, overload
from typing_extensions import TypeIs from typing_extensions import TypeIs
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .data import (EncoderDecoderInputs, ExplicitEncoderDecoderPrompt, from .data import (ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType,
ProcessorInputs, PromptType, SingletonPrompt, TextPrompt, SingletonInputs, SingletonPrompt, TextPrompt, TokensPrompt)
TokensPrompt)
class ParsedText(TypedDict): class ParsedText(TypedDict):
@ -110,6 +108,14 @@ def is_explicit_encoder_decoder_prompt(
return isinstance(prompt, dict) and "encoder_prompt" in prompt return isinstance(prompt, dict) and "encoder_prompt" in prompt
def is_encoder_decoder_inputs( def split_enc_dec_inputs(
inputs: ProcessorInputs) -> TypeIs[EncoderDecoderInputs]: inputs: ProcessorInputs,
return "encoder" in inputs and "decoder" in inputs ) -> tuple[Optional[SingletonInputs], SingletonInputs]:
if "encoder" in inputs and "decoder" in inputs:
# NOTE: This passes pyright but not mypy
return (
inputs["encoder"], # type: ignore[typeddict-item]
inputs["decoder"], # type: ignore[typeddict-item]
)
return None, inputs

View File

@ -19,7 +19,7 @@ from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides,
resolve_mm_processor_kwargs) resolve_mm_processor_kwargs)
from .data import ProcessorInputs, SingletonInputs from .data import ProcessorInputs, SingletonInputs
from .parse import is_encoder_decoder_inputs from .parse import split_enc_dec_inputs
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig from vllm.config import ModelConfig
@ -462,13 +462,11 @@ class InputRegistry:
**mm_processor_kwargs, **mm_processor_kwargs,
) )
if is_encoder_decoder_inputs(processed_inputs): encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
self._ensure_mm_kwargs(processed_inputs["encoder"], if encoder_inputs is not None:
mm_processor_kwargs) self._ensure_mm_kwargs(encoder_inputs, mm_processor_kwargs)
self._ensure_mm_kwargs(processed_inputs["decoder"], if decoder_inputs is not None:
mm_processor_kwargs) self._ensure_mm_kwargs(decoder_inputs, mm_processor_kwargs)
else:
self._ensure_mm_kwargs(processed_inputs, mm_processor_kwargs)
return processed_inputs return processed_inputs

View File

@ -232,7 +232,7 @@ class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo]
) )
class Idefics3MultimodalProcessor( class Idefics3MultiModalProcessor(
BaseMultiModalProcessor[Idefics3ProcessingInfo]): BaseMultiModalProcessor[Idefics3ProcessingInfo]):
def _call_hf_processor( def _call_hf_processor(
@ -575,7 +575,7 @@ class Idefics3Model(nn.Module):
@MULTIMODAL_REGISTRY.register_processor( @MULTIMODAL_REGISTRY.register_processor(
Idefics3MultimodalProcessor, Idefics3MultiModalProcessor,
info=Idefics3ProcessingInfo, info=Idefics3ProcessingInfo,
dummy_inputs=Idefics3DummyInputsBuilder) dummy_inputs=Idefics3DummyInputsBuilder)
class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,

View File

@ -7,7 +7,7 @@ from typing import Optional, Union
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
PromptType, SingletonInputsAdapter) PromptType, SingletonInputsAdapter)
from vllm.inputs.parse import is_encoder_decoder_inputs from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor from vllm.inputs.preprocess import InputPreprocessor
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
@ -209,14 +209,8 @@ class Processor:
self._validate_model_inputs(processed_inputs, lora_request) self._validate_model_inputs(processed_inputs, lora_request)
if is_encoder_decoder_inputs(processed_inputs): encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
decoder_inputs = SingletonInputsAdapter( decoder_inputs = SingletonInputsAdapter(decoder_inputs)
processed_inputs["decoder"])
encoder_inputs = SingletonInputsAdapter(
processed_inputs["encoder"])
else:
decoder_inputs = SingletonInputsAdapter(processed_inputs)
encoder_inputs = None
# TODO: Impl encoder-decoder # TODO: Impl encoder-decoder
if encoder_inputs is not None: if encoder_inputs is not None:
@ -301,15 +295,16 @@ class Processor:
def _validate_model_inputs(self, def _validate_model_inputs(self,
inputs: ProcessorInputs, inputs: ProcessorInputs,
lora_request: Optional[LoRARequest] = None): lora_request: Optional[LoRARequest] = None):
if is_encoder_decoder_inputs(inputs): encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)
# For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length
prompt_inputs = inputs["decoder" if self.model_config.
is_multimodal_model else "encoder"]
else:
prompt_inputs = inputs
prompt_ids = SingletonInputsAdapter(prompt_inputs).prompt_token_ids # For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length
if self.model_config.is_multimodal_model:
prompt_inputs = decoder_inputs
else:
prompt_inputs = encoder_inputs or decoder_inputs
prompt_ids = prompt_inputs["prompt_token_ids"]
if prompt_ids is None or len(prompt_ids) == 0: if prompt_ids is None or len(prompt_ids) == 0:
raise ValueError("Prompt cannot be empty") raise ValueError("Prompt cannot be empty")