mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-26 04:04:29 +08:00
[Misc] Replace is_encoder_decoder_inputs with split_enc_dec_inputs (#15620)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
07bf813fb5
commit
247181536f
@ -29,7 +29,7 @@ def test_processor_override(
|
||||
num_imgs: int,
|
||||
kwargs_on_init: bool,
|
||||
):
|
||||
"""Ensure input_processor_for_idefics3 handles num_crops properly."""
|
||||
"""Ensure Idefics3MultiModalProcessor handles num_crops properly."""
|
||||
# Same as the previous test - don't initialize mm_processor_kwargs
|
||||
# in this test and assume that the kwargs will be correctly expanded by
|
||||
# the partial when calling the custom input processor.
|
||||
|
||||
@ -30,7 +30,7 @@ def test_processor_override(
|
||||
num_imgs: int,
|
||||
kwargs_on_init: bool,
|
||||
):
|
||||
"""Ensure input_processor_for_phi3v handles num_crops properly."""
|
||||
"""Ensure Phi3VMultiModalProcessor handles num_crops properly."""
|
||||
# Avoid initializing CUDA early
|
||||
from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID
|
||||
|
||||
|
||||
@ -665,7 +665,7 @@ class EngineArgs:
|
||||
type=nullable_kvs,
|
||||
default=EngineArgs.limit_mm_per_prompt,
|
||||
# The default value is given in
|
||||
# MultiModalRegistry.init_mm_limits_per_prompt
|
||||
# MultiModalConfig.get_limit_per_prompt
|
||||
help=('For each multimodal plugin, limit how many '
|
||||
'input instances to allow for each prompt. '
|
||||
'Expects a comma-separated list of items, '
|
||||
|
||||
@ -30,8 +30,8 @@ from vllm.entrypoints.openai.logits_processors import (
|
||||
get_logits_processors as get_openai_logits_processors)
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
|
||||
PromptType, SingletonInputsAdapter)
|
||||
from vllm.inputs.parse import is_encoder_decoder_inputs, is_token_prompt
|
||||
PromptType)
|
||||
from vllm.inputs.parse import is_token_prompt, split_enc_dec_inputs
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logits_process import get_bad_words_logits_processors
|
||||
@ -609,12 +609,7 @@ class LLMEngine:
|
||||
seq_id = next(self.seq_counter)
|
||||
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
|
||||
|
||||
if is_encoder_decoder_inputs(processed_inputs):
|
||||
decoder_inputs = processed_inputs["decoder"]
|
||||
encoder_inputs = processed_inputs["encoder"]
|
||||
else:
|
||||
decoder_inputs = processed_inputs
|
||||
encoder_inputs = None
|
||||
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
|
||||
|
||||
seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
|
||||
lora_request, prompt_adapter_request)
|
||||
@ -2031,15 +2026,16 @@ class LLMEngine:
|
||||
|
||||
def _validate_model_inputs(self, inputs: ProcessorInputs,
|
||||
lora_request: Optional[LoRARequest]):
|
||||
if is_encoder_decoder_inputs(inputs):
|
||||
# For encoder-decoder multimodal models, the max_prompt_len
|
||||
# restricts the decoder prompt length
|
||||
prompt_inputs = inputs["decoder" if self.model_config.
|
||||
is_multimodal_model else "encoder"]
|
||||
else:
|
||||
prompt_inputs = inputs
|
||||
encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)
|
||||
|
||||
prompt_ids = SingletonInputsAdapter(prompt_inputs).prompt_token_ids
|
||||
# For encoder-decoder multimodal models, the max_prompt_len
|
||||
# restricts the decoder prompt length
|
||||
if self.model_config.is_multimodal_model:
|
||||
prompt_inputs = decoder_inputs
|
||||
else:
|
||||
prompt_inputs = encoder_inputs or decoder_inputs
|
||||
|
||||
prompt_ids = prompt_inputs["prompt_token_ids"]
|
||||
|
||||
if prompt_ids is None or len(prompt_ids) == 0:
|
||||
raise ValueError("Prompt cannot be empty")
|
||||
|
||||
@ -1,15 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal, TypedDict, Union, cast, overload
|
||||
from typing import Literal, Optional, TypedDict, Union, cast, overload
|
||||
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .data import (EncoderDecoderInputs, ExplicitEncoderDecoderPrompt,
|
||||
ProcessorInputs, PromptType, SingletonPrompt, TextPrompt,
|
||||
TokensPrompt)
|
||||
from .data import (ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType,
|
||||
SingletonInputs, SingletonPrompt, TextPrompt, TokensPrompt)
|
||||
|
||||
|
||||
class ParsedText(TypedDict):
|
||||
@ -110,6 +108,14 @@ def is_explicit_encoder_decoder_prompt(
|
||||
return isinstance(prompt, dict) and "encoder_prompt" in prompt
|
||||
|
||||
|
||||
def is_encoder_decoder_inputs(
|
||||
inputs: ProcessorInputs) -> TypeIs[EncoderDecoderInputs]:
|
||||
return "encoder" in inputs and "decoder" in inputs
|
||||
def split_enc_dec_inputs(
|
||||
inputs: ProcessorInputs,
|
||||
) -> tuple[Optional[SingletonInputs], SingletonInputs]:
|
||||
if "encoder" in inputs and "decoder" in inputs:
|
||||
# NOTE: This passes pyright but not mypy
|
||||
return (
|
||||
inputs["encoder"], # type: ignore[typeddict-item]
|
||||
inputs["decoder"], # type: ignore[typeddict-item]
|
||||
)
|
||||
|
||||
return None, inputs
|
||||
|
||||
@ -19,7 +19,7 @@ from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides,
|
||||
resolve_mm_processor_kwargs)
|
||||
|
||||
from .data import ProcessorInputs, SingletonInputs
|
||||
from .parse import is_encoder_decoder_inputs
|
||||
from .parse import split_enc_dec_inputs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
@ -462,13 +462,11 @@ class InputRegistry:
|
||||
**mm_processor_kwargs,
|
||||
)
|
||||
|
||||
if is_encoder_decoder_inputs(processed_inputs):
|
||||
self._ensure_mm_kwargs(processed_inputs["encoder"],
|
||||
mm_processor_kwargs)
|
||||
self._ensure_mm_kwargs(processed_inputs["decoder"],
|
||||
mm_processor_kwargs)
|
||||
else:
|
||||
self._ensure_mm_kwargs(processed_inputs, mm_processor_kwargs)
|
||||
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
|
||||
if encoder_inputs is not None:
|
||||
self._ensure_mm_kwargs(encoder_inputs, mm_processor_kwargs)
|
||||
if decoder_inputs is not None:
|
||||
self._ensure_mm_kwargs(decoder_inputs, mm_processor_kwargs)
|
||||
|
||||
return processed_inputs
|
||||
|
||||
|
||||
@ -232,7 +232,7 @@ class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo]
|
||||
)
|
||||
|
||||
|
||||
class Idefics3MultimodalProcessor(
|
||||
class Idefics3MultiModalProcessor(
|
||||
BaseMultiModalProcessor[Idefics3ProcessingInfo]):
|
||||
|
||||
def _call_hf_processor(
|
||||
@ -575,7 +575,7 @@ class Idefics3Model(nn.Module):
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
Idefics3MultimodalProcessor,
|
||||
Idefics3MultiModalProcessor,
|
||||
info=Idefics3ProcessingInfo,
|
||||
dummy_inputs=Idefics3DummyInputsBuilder)
|
||||
class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
@ -7,7 +7,7 @@ from typing import Optional, Union
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
|
||||
PromptType, SingletonInputsAdapter)
|
||||
from vllm.inputs.parse import is_encoder_decoder_inputs
|
||||
from vllm.inputs.parse import split_enc_dec_inputs
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
|
||||
@ -209,14 +209,8 @@ class Processor:
|
||||
|
||||
self._validate_model_inputs(processed_inputs, lora_request)
|
||||
|
||||
if is_encoder_decoder_inputs(processed_inputs):
|
||||
decoder_inputs = SingletonInputsAdapter(
|
||||
processed_inputs["decoder"])
|
||||
encoder_inputs = SingletonInputsAdapter(
|
||||
processed_inputs["encoder"])
|
||||
else:
|
||||
decoder_inputs = SingletonInputsAdapter(processed_inputs)
|
||||
encoder_inputs = None
|
||||
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
|
||||
decoder_inputs = SingletonInputsAdapter(decoder_inputs)
|
||||
|
||||
# TODO: Impl encoder-decoder
|
||||
if encoder_inputs is not None:
|
||||
@ -301,15 +295,16 @@ class Processor:
|
||||
def _validate_model_inputs(self,
|
||||
inputs: ProcessorInputs,
|
||||
lora_request: Optional[LoRARequest] = None):
|
||||
if is_encoder_decoder_inputs(inputs):
|
||||
# For encoder-decoder multimodal models, the max_prompt_len
|
||||
# restricts the decoder prompt length
|
||||
prompt_inputs = inputs["decoder" if self.model_config.
|
||||
is_multimodal_model else "encoder"]
|
||||
else:
|
||||
prompt_inputs = inputs
|
||||
encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)
|
||||
|
||||
prompt_ids = SingletonInputsAdapter(prompt_inputs).prompt_token_ids
|
||||
# For encoder-decoder multimodal models, the max_prompt_len
|
||||
# restricts the decoder prompt length
|
||||
if self.model_config.is_multimodal_model:
|
||||
prompt_inputs = decoder_inputs
|
||||
else:
|
||||
prompt_inputs = encoder_inputs or decoder_inputs
|
||||
|
||||
prompt_ids = prompt_inputs["prompt_token_ids"]
|
||||
|
||||
if prompt_ids is None or len(prompt_ids) == 0:
|
||||
raise ValueError("Prompt cannot be empty")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user