mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-08 21:15:46 +08:00
[Misc] Replace is_encoder_decoder_inputs with split_enc_dec_inputs (#15620)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
07bf813fb5
commit
247181536f
@ -29,7 +29,7 @@ def test_processor_override(
|
|||||||
num_imgs: int,
|
num_imgs: int,
|
||||||
kwargs_on_init: bool,
|
kwargs_on_init: bool,
|
||||||
):
|
):
|
||||||
"""Ensure input_processor_for_idefics3 handles num_crops properly."""
|
"""Ensure Idefics3MultiModalProcessor handles num_crops properly."""
|
||||||
# Same as the previous test - don't initialize mm_processor_kwargs
|
# Same as the previous test - don't initialize mm_processor_kwargs
|
||||||
# in this test and assume that the kwargs will be correctly expanded by
|
# in this test and assume that the kwargs will be correctly expanded by
|
||||||
# the partial when calling the custom input processor.
|
# the partial when calling the custom input processor.
|
||||||
|
|||||||
@ -30,7 +30,7 @@ def test_processor_override(
|
|||||||
num_imgs: int,
|
num_imgs: int,
|
||||||
kwargs_on_init: bool,
|
kwargs_on_init: bool,
|
||||||
):
|
):
|
||||||
"""Ensure input_processor_for_phi3v handles num_crops properly."""
|
"""Ensure Phi3VMultiModalProcessor handles num_crops properly."""
|
||||||
# Avoid initializing CUDA early
|
# Avoid initializing CUDA early
|
||||||
from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID
|
from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID
|
||||||
|
|
||||||
|
|||||||
@ -665,7 +665,7 @@ class EngineArgs:
|
|||||||
type=nullable_kvs,
|
type=nullable_kvs,
|
||||||
default=EngineArgs.limit_mm_per_prompt,
|
default=EngineArgs.limit_mm_per_prompt,
|
||||||
# The default value is given in
|
# The default value is given in
|
||||||
# MultiModalRegistry.init_mm_limits_per_prompt
|
# MultiModalConfig.get_limit_per_prompt
|
||||||
help=('For each multimodal plugin, limit how many '
|
help=('For each multimodal plugin, limit how many '
|
||||||
'input instances to allow for each prompt. '
|
'input instances to allow for each prompt. '
|
||||||
'Expects a comma-separated list of items, '
|
'Expects a comma-separated list of items, '
|
||||||
|
|||||||
@ -30,8 +30,8 @@ from vllm.entrypoints.openai.logits_processors import (
|
|||||||
get_logits_processors as get_openai_logits_processors)
|
get_logits_processors as get_openai_logits_processors)
|
||||||
from vllm.executor.executor_base import ExecutorBase
|
from vllm.executor.executor_base import ExecutorBase
|
||||||
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
|
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
|
||||||
PromptType, SingletonInputsAdapter)
|
PromptType)
|
||||||
from vllm.inputs.parse import is_encoder_decoder_inputs, is_token_prompt
|
from vllm.inputs.parse import is_token_prompt, split_enc_dec_inputs
|
||||||
from vllm.inputs.preprocess import InputPreprocessor
|
from vllm.inputs.preprocess import InputPreprocessor
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.logits_process import get_bad_words_logits_processors
|
from vllm.logits_process import get_bad_words_logits_processors
|
||||||
@ -609,12 +609,7 @@ class LLMEngine:
|
|||||||
seq_id = next(self.seq_counter)
|
seq_id = next(self.seq_counter)
|
||||||
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
|
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
|
||||||
|
|
||||||
if is_encoder_decoder_inputs(processed_inputs):
|
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
|
||||||
decoder_inputs = processed_inputs["decoder"]
|
|
||||||
encoder_inputs = processed_inputs["encoder"]
|
|
||||||
else:
|
|
||||||
decoder_inputs = processed_inputs
|
|
||||||
encoder_inputs = None
|
|
||||||
|
|
||||||
seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
|
seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
|
||||||
lora_request, prompt_adapter_request)
|
lora_request, prompt_adapter_request)
|
||||||
@ -2031,15 +2026,16 @@ class LLMEngine:
|
|||||||
|
|
||||||
def _validate_model_inputs(self, inputs: ProcessorInputs,
|
def _validate_model_inputs(self, inputs: ProcessorInputs,
|
||||||
lora_request: Optional[LoRARequest]):
|
lora_request: Optional[LoRARequest]):
|
||||||
if is_encoder_decoder_inputs(inputs):
|
encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)
|
||||||
# For encoder-decoder multimodal models, the max_prompt_len
|
|
||||||
# restricts the decoder prompt length
|
|
||||||
prompt_inputs = inputs["decoder" if self.model_config.
|
|
||||||
is_multimodal_model else "encoder"]
|
|
||||||
else:
|
|
||||||
prompt_inputs = inputs
|
|
||||||
|
|
||||||
prompt_ids = SingletonInputsAdapter(prompt_inputs).prompt_token_ids
|
# For encoder-decoder multimodal models, the max_prompt_len
|
||||||
|
# restricts the decoder prompt length
|
||||||
|
if self.model_config.is_multimodal_model:
|
||||||
|
prompt_inputs = decoder_inputs
|
||||||
|
else:
|
||||||
|
prompt_inputs = encoder_inputs or decoder_inputs
|
||||||
|
|
||||||
|
prompt_ids = prompt_inputs["prompt_token_ids"]
|
||||||
|
|
||||||
if prompt_ids is None or len(prompt_ids) == 0:
|
if prompt_ids is None or len(prompt_ids) == 0:
|
||||||
raise ValueError("Prompt cannot be empty")
|
raise ValueError("Prompt cannot be empty")
|
||||||
|
|||||||
@ -1,15 +1,13 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Literal, TypedDict, Union, cast, overload
|
from typing import Literal, Optional, TypedDict, Union, cast, overload
|
||||||
|
|
||||||
from typing_extensions import TypeIs
|
from typing_extensions import TypeIs
|
||||||
|
|
||||||
from vllm.utils import is_list_of
|
from vllm.utils import is_list_of
|
||||||
|
|
||||||
from .data import (EncoderDecoderInputs, ExplicitEncoderDecoderPrompt,
|
from .data import (ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType,
|
||||||
ProcessorInputs, PromptType, SingletonPrompt, TextPrompt,
|
SingletonInputs, SingletonPrompt, TextPrompt, TokensPrompt)
|
||||||
TokensPrompt)
|
|
||||||
|
|
||||||
|
|
||||||
class ParsedText(TypedDict):
|
class ParsedText(TypedDict):
|
||||||
@ -110,6 +108,14 @@ def is_explicit_encoder_decoder_prompt(
|
|||||||
return isinstance(prompt, dict) and "encoder_prompt" in prompt
|
return isinstance(prompt, dict) and "encoder_prompt" in prompt
|
||||||
|
|
||||||
|
|
||||||
def is_encoder_decoder_inputs(
|
def split_enc_dec_inputs(
|
||||||
inputs: ProcessorInputs) -> TypeIs[EncoderDecoderInputs]:
|
inputs: ProcessorInputs,
|
||||||
return "encoder" in inputs and "decoder" in inputs
|
) -> tuple[Optional[SingletonInputs], SingletonInputs]:
|
||||||
|
if "encoder" in inputs and "decoder" in inputs:
|
||||||
|
# NOTE: This passes pyright but not mypy
|
||||||
|
return (
|
||||||
|
inputs["encoder"], # type: ignore[typeddict-item]
|
||||||
|
inputs["decoder"], # type: ignore[typeddict-item]
|
||||||
|
)
|
||||||
|
|
||||||
|
return None, inputs
|
||||||
|
|||||||
@ -19,7 +19,7 @@ from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides,
|
|||||||
resolve_mm_processor_kwargs)
|
resolve_mm_processor_kwargs)
|
||||||
|
|
||||||
from .data import ProcessorInputs, SingletonInputs
|
from .data import ProcessorInputs, SingletonInputs
|
||||||
from .parse import is_encoder_decoder_inputs
|
from .parse import split_enc_dec_inputs
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
@ -462,13 +462,11 @@ class InputRegistry:
|
|||||||
**mm_processor_kwargs,
|
**mm_processor_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_encoder_decoder_inputs(processed_inputs):
|
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
|
||||||
self._ensure_mm_kwargs(processed_inputs["encoder"],
|
if encoder_inputs is not None:
|
||||||
mm_processor_kwargs)
|
self._ensure_mm_kwargs(encoder_inputs, mm_processor_kwargs)
|
||||||
self._ensure_mm_kwargs(processed_inputs["decoder"],
|
if decoder_inputs is not None:
|
||||||
mm_processor_kwargs)
|
self._ensure_mm_kwargs(decoder_inputs, mm_processor_kwargs)
|
||||||
else:
|
|
||||||
self._ensure_mm_kwargs(processed_inputs, mm_processor_kwargs)
|
|
||||||
|
|
||||||
return processed_inputs
|
return processed_inputs
|
||||||
|
|
||||||
|
|||||||
@ -232,7 +232,7 @@ class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo]
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Idefics3MultimodalProcessor(
|
class Idefics3MultiModalProcessor(
|
||||||
BaseMultiModalProcessor[Idefics3ProcessingInfo]):
|
BaseMultiModalProcessor[Idefics3ProcessingInfo]):
|
||||||
|
|
||||||
def _call_hf_processor(
|
def _call_hf_processor(
|
||||||
@ -575,7 +575,7 @@ class Idefics3Model(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_processor(
|
@MULTIMODAL_REGISTRY.register_processor(
|
||||||
Idefics3MultimodalProcessor,
|
Idefics3MultiModalProcessor,
|
||||||
info=Idefics3ProcessingInfo,
|
info=Idefics3ProcessingInfo,
|
||||||
dummy_inputs=Idefics3DummyInputsBuilder)
|
dummy_inputs=Idefics3DummyInputsBuilder)
|
||||||
class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||||
|
|||||||
@ -7,7 +7,7 @@ from typing import Optional, Union
|
|||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
|
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
|
||||||
PromptType, SingletonInputsAdapter)
|
PromptType, SingletonInputsAdapter)
|
||||||
from vllm.inputs.parse import is_encoder_decoder_inputs
|
from vllm.inputs.parse import split_enc_dec_inputs
|
||||||
from vllm.inputs.preprocess import InputPreprocessor
|
from vllm.inputs.preprocess import InputPreprocessor
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
|
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
|
||||||
@ -209,14 +209,8 @@ class Processor:
|
|||||||
|
|
||||||
self._validate_model_inputs(processed_inputs, lora_request)
|
self._validate_model_inputs(processed_inputs, lora_request)
|
||||||
|
|
||||||
if is_encoder_decoder_inputs(processed_inputs):
|
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
|
||||||
decoder_inputs = SingletonInputsAdapter(
|
decoder_inputs = SingletonInputsAdapter(decoder_inputs)
|
||||||
processed_inputs["decoder"])
|
|
||||||
encoder_inputs = SingletonInputsAdapter(
|
|
||||||
processed_inputs["encoder"])
|
|
||||||
else:
|
|
||||||
decoder_inputs = SingletonInputsAdapter(processed_inputs)
|
|
||||||
encoder_inputs = None
|
|
||||||
|
|
||||||
# TODO: Impl encoder-decoder
|
# TODO: Impl encoder-decoder
|
||||||
if encoder_inputs is not None:
|
if encoder_inputs is not None:
|
||||||
@ -301,15 +295,16 @@ class Processor:
|
|||||||
def _validate_model_inputs(self,
|
def _validate_model_inputs(self,
|
||||||
inputs: ProcessorInputs,
|
inputs: ProcessorInputs,
|
||||||
lora_request: Optional[LoRARequest] = None):
|
lora_request: Optional[LoRARequest] = None):
|
||||||
if is_encoder_decoder_inputs(inputs):
|
encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)
|
||||||
# For encoder-decoder multimodal models, the max_prompt_len
|
|
||||||
# restricts the decoder prompt length
|
|
||||||
prompt_inputs = inputs["decoder" if self.model_config.
|
|
||||||
is_multimodal_model else "encoder"]
|
|
||||||
else:
|
|
||||||
prompt_inputs = inputs
|
|
||||||
|
|
||||||
prompt_ids = SingletonInputsAdapter(prompt_inputs).prompt_token_ids
|
# For encoder-decoder multimodal models, the max_prompt_len
|
||||||
|
# restricts the decoder prompt length
|
||||||
|
if self.model_config.is_multimodal_model:
|
||||||
|
prompt_inputs = decoder_inputs
|
||||||
|
else:
|
||||||
|
prompt_inputs = encoder_inputs or decoder_inputs
|
||||||
|
|
||||||
|
prompt_ids = prompt_inputs["prompt_token_ids"]
|
||||||
|
|
||||||
if prompt_ids is None or len(prompt_ids) == 0:
|
if prompt_ids is None or len(prompt_ids) == 0:
|
||||||
raise ValueError("Prompt cannot be empty")
|
raise ValueError("Prompt cannot be empty")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user