[Misc] Replace is_encoder_decoder_inputs with split_enc_dec_inputs (#15620)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-03-28 01:36:32 +08:00 committed by GitHub
parent 07bf813fb5
commit 247181536f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 49 additions and 54 deletions

View File

@ -29,7 +29,7 @@ def test_processor_override(
num_imgs: int,
kwargs_on_init: bool,
):
"""Ensure input_processor_for_idefics3 handles num_crops properly."""
"""Ensure Idefics3MultiModalProcessor handles num_crops properly."""
# Same as the previous test - don't initialize mm_processor_kwargs
# in this test and assume that the kwargs will be correctly expanded by
# the partial when calling the custom input processor.

View File

@ -30,7 +30,7 @@ def test_processor_override(
num_imgs: int,
kwargs_on_init: bool,
):
"""Ensure input_processor_for_phi3v handles num_crops properly."""
"""Ensure Phi3VMultiModalProcessor handles num_crops properly."""
# Avoid initializing CUDA early
from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID

View File

@ -665,7 +665,7 @@ class EngineArgs:
type=nullable_kvs,
default=EngineArgs.limit_mm_per_prompt,
# The default value is given in
# MultiModalRegistry.init_mm_limits_per_prompt
# MultiModalConfig.get_limit_per_prompt
help=('For each multimodal plugin, limit how many '
'input instances to allow for each prompt. '
'Expects a comma-separated list of items, '

View File

@ -30,8 +30,8 @@ from vllm.entrypoints.openai.logits_processors import (
get_logits_processors as get_openai_logits_processors)
from vllm.executor.executor_base import ExecutorBase
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
PromptType, SingletonInputsAdapter)
from vllm.inputs.parse import is_encoder_decoder_inputs, is_token_prompt
PromptType)
from vllm.inputs.parse import is_token_prompt, split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.logits_process import get_bad_words_logits_processors
@ -609,12 +609,7 @@ class LLMEngine:
seq_id = next(self.seq_counter)
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
if is_encoder_decoder_inputs(processed_inputs):
decoder_inputs = processed_inputs["decoder"]
encoder_inputs = processed_inputs["encoder"]
else:
decoder_inputs = processed_inputs
encoder_inputs = None
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
lora_request, prompt_adapter_request)
@ -2031,15 +2026,16 @@ class LLMEngine:
def _validate_model_inputs(self, inputs: ProcessorInputs,
lora_request: Optional[LoRARequest]):
if is_encoder_decoder_inputs(inputs):
# For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length
prompt_inputs = inputs["decoder" if self.model_config.
is_multimodal_model else "encoder"]
else:
prompt_inputs = inputs
encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)
prompt_ids = SingletonInputsAdapter(prompt_inputs).prompt_token_ids
# For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length
if self.model_config.is_multimodal_model:
prompt_inputs = decoder_inputs
else:
prompt_inputs = encoder_inputs or decoder_inputs
prompt_ids = prompt_inputs["prompt_token_ids"]
if prompt_ids is None or len(prompt_ids) == 0:
raise ValueError("Prompt cannot be empty")

View File

@ -1,15 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
from collections.abc import Sequence
from typing import Literal, TypedDict, Union, cast, overload
from typing import Literal, Optional, TypedDict, Union, cast, overload
from typing_extensions import TypeIs
from vllm.utils import is_list_of
from .data import (EncoderDecoderInputs, ExplicitEncoderDecoderPrompt,
ProcessorInputs, PromptType, SingletonPrompt, TextPrompt,
TokensPrompt)
from .data import (ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType,
SingletonInputs, SingletonPrompt, TextPrompt, TokensPrompt)
class ParsedText(TypedDict):
@ -110,6 +108,14 @@ def is_explicit_encoder_decoder_prompt(
return isinstance(prompt, dict) and "encoder_prompt" in prompt
def is_encoder_decoder_inputs(
inputs: ProcessorInputs) -> TypeIs[EncoderDecoderInputs]:
return "encoder" in inputs and "decoder" in inputs
def split_enc_dec_inputs(
inputs: ProcessorInputs,
) -> tuple[Optional[SingletonInputs], SingletonInputs]:
if "encoder" in inputs and "decoder" in inputs:
# NOTE: This passes pyright but not mypy
return (
inputs["encoder"], # type: ignore[typeddict-item]
inputs["decoder"], # type: ignore[typeddict-item]
)
return None, inputs

View File

@ -19,7 +19,7 @@ from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides,
resolve_mm_processor_kwargs)
from .data import ProcessorInputs, SingletonInputs
from .parse import is_encoder_decoder_inputs
from .parse import split_enc_dec_inputs
if TYPE_CHECKING:
from vllm.config import ModelConfig
@ -462,13 +462,11 @@ class InputRegistry:
**mm_processor_kwargs,
)
if is_encoder_decoder_inputs(processed_inputs):
self._ensure_mm_kwargs(processed_inputs["encoder"],
mm_processor_kwargs)
self._ensure_mm_kwargs(processed_inputs["decoder"],
mm_processor_kwargs)
else:
self._ensure_mm_kwargs(processed_inputs, mm_processor_kwargs)
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
if encoder_inputs is not None:
self._ensure_mm_kwargs(encoder_inputs, mm_processor_kwargs)
if decoder_inputs is not None:
self._ensure_mm_kwargs(decoder_inputs, mm_processor_kwargs)
return processed_inputs

View File

@ -232,7 +232,7 @@ class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo]
)
class Idefics3MultimodalProcessor(
class Idefics3MultiModalProcessor(
BaseMultiModalProcessor[Idefics3ProcessingInfo]):
def _call_hf_processor(
@ -575,7 +575,7 @@ class Idefics3Model(nn.Module):
@MULTIMODAL_REGISTRY.register_processor(
Idefics3MultimodalProcessor,
Idefics3MultiModalProcessor,
info=Idefics3ProcessingInfo,
dummy_inputs=Idefics3DummyInputsBuilder)
class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,

View File

@ -7,7 +7,7 @@ from typing import Optional, Union
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
PromptType, SingletonInputsAdapter)
from vllm.inputs.parse import is_encoder_decoder_inputs
from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor
from vllm.lora.request import LoRARequest
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
@ -209,14 +209,8 @@ class Processor:
self._validate_model_inputs(processed_inputs, lora_request)
if is_encoder_decoder_inputs(processed_inputs):
decoder_inputs = SingletonInputsAdapter(
processed_inputs["decoder"])
encoder_inputs = SingletonInputsAdapter(
processed_inputs["encoder"])
else:
decoder_inputs = SingletonInputsAdapter(processed_inputs)
encoder_inputs = None
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
decoder_inputs = SingletonInputsAdapter(decoder_inputs)
# TODO: Impl encoder-decoder
if encoder_inputs is not None:
@ -301,15 +295,16 @@ class Processor:
def _validate_model_inputs(self,
inputs: ProcessorInputs,
lora_request: Optional[LoRARequest] = None):
if is_encoder_decoder_inputs(inputs):
# For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length
prompt_inputs = inputs["decoder" if self.model_config.
is_multimodal_model else "encoder"]
else:
prompt_inputs = inputs
encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)
prompt_ids = SingletonInputsAdapter(prompt_inputs).prompt_token_ids
# For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length
if self.model_config.is_multimodal_model:
prompt_inputs = decoder_inputs
else:
prompt_inputs = encoder_inputs or decoder_inputs
prompt_ids = prompt_inputs["prompt_token_ids"]
if prompt_ids is None or len(prompt_ids) == 0:
raise ValueError("Prompt cannot be empty")