From 1409ef913446aa282f6426efbb0ed02a59320467 Mon Sep 17 00:00:00 2001 From: Lukas Geiger Date: Wed, 4 Jun 2025 04:24:56 +0100 Subject: [PATCH] [Core] Cast multimodal input in hf processor (#18862) Signed-off-by: Lukas Geiger --- vllm/inputs/registry.py | 26 +++++++++++++++++-- vllm/multimodal/inputs.py | 8 +----- vllm/spec_decode/draft_model_runner.py | 1 - vllm/v1/worker/gpu_model_runner.py | 2 -- vllm/v1/worker/tpu_model_runner.py | 2 -- vllm/worker/cpu_enc_dec_model_runner.py | 1 - vllm/worker/cpu_model_runner.py | 1 - vllm/worker/cpu_pooling_model_runner.py | 1 - vllm/worker/enc_dec_model_runner.py | 1 - vllm/worker/model_runner.py | 1 - vllm/worker/multi_step_neuron_model_runner.py | 1 - ...i_step_neuronx_distributed_model_runner.py | 1 - vllm/worker/neuron_model_runner.py | 2 -- vllm/worker/pooling_model_runner.py | 1 - vllm/worker/xpu_model_runner.py | 1 - 15 files changed, 25 insertions(+), 25 deletions(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 73d19aecde6c5..3dad021e31668 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -4,9 +4,12 @@ from collections.abc import Mapping from dataclasses import dataclass from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union +import torch from transformers import BatchFeature, PretrainedConfig, ProcessorMixin from typing_extensions import TypeVar +from vllm.jsontree import JSONTree, json_map_leaves +from vllm.logger import init_logger from vllm.transformers_utils.processor import cached_processor_from_config from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import resolve_mm_processor_kwargs @@ -21,6 +24,8 @@ _T = TypeVar("_T") _C = TypeVar("_C", bound=PretrainedConfig, default=PretrainedConfig) _P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin) +logger = init_logger(__name__) + @dataclass(frozen=True) class InputContext: @@ -134,7 +139,7 @@ class InputProcessingContext(InputContext): hf_processor: ProcessorMixin, data: Mapping[str, object], kwargs: Mapping[str, object] = {}, - ) -> BatchFeature: + ) -> Union[BatchFeature, JSONTree]: """ Call `hf_processor` on the prompt `data` (text, image, audio...) with configurable options `kwargs`. @@ -154,8 +159,25 @@ class InputProcessingContext(InputContext): allow_var_kwargs=True, ) + def maybe_cast_dtype(x): + # This mimics the behavior of transformers.BatchFeature + if isinstance(x, torch.Tensor) and x.is_floating_point(): + return x.to(dtype=self.model_config.dtype) + return x + try: - return hf_processor(**data, **merged_kwargs, return_tensors="pt") + output = hf_processor(**data, **merged_kwargs, return_tensors="pt") + # this emulates output.to(dtype=self.model_config.dtype) + cast_output = json_map_leaves(maybe_cast_dtype, output) + if isinstance(output, BatchFeature): + return BatchFeature(cast_output) + + logger.warning_once( + f"{type(hf_processor).__name__} did not return `BatchFeature`. " + "Make sure to match the behaviour of `ProcessorMixin` when " + "implementing custom processors.") + return cast_output + except Exception as exc: msg = (f"Failed to apply {type(hf_processor).__name__} " f"on data={data} with kwargs={merged_kwargs}") diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 35d2a6e8c74ff..0bf5b1cf1c6c7 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -747,17 +747,11 @@ class MultiModalKwargs(UserDict[str, NestedTensors]): batched_inputs: BatchedTensorInputs, *, device: torch.types.Device, - dtype: Optional[torch.dtype] = None, ) -> BatchedTensorInputs: json_inputs = cast(JSONTree[torch.Tensor], batched_inputs) - def maybe_cast_dtype(x: torch.Tensor): - # This mimics the behavior of transformers.BatchFeature - return x.to(dtype=dtype) if x.is_floating_point() else x - json_mapped = json_map_leaves( - # NOTE: Cast the dtype before sending it to device - lambda x: maybe_cast_dtype(x).to(device=device, non_blocking=True), + lambda x: x.to(device=device, non_blocking=True), json_inputs, ) diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 8ccfefea1acbd..96646ec947186 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -297,7 +297,6 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase): intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs( multi_modal_kwargs, - dtype=self.model_runner.model_config.dtype, device=self.device, ), **model_execute_kwargs, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 6ea6bb020ed7f..9ac33a1499610 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -957,7 +957,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs) batched_mm_inputs = MultiModalKwargs.as_kwargs( batched_mm_inputs, - dtype=self.model_config.dtype, device=self.device, ) @@ -1951,7 +1950,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): [dummy_mm_kwargs] * max_num_mm_items) batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs( batched_dummy_mm_inputs, - dtype=self.model_config.dtype, device=self.device, ) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 73c445d14e38e..94e438fb44ec1 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -718,7 +718,6 @@ class TPUModelRunner(LoRAModelRunnerMixin): batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs) batched_mm_inputs = MultiModalKwargs.as_kwargs( batched_mm_inputs, - dtype=self.model_config.dtype, device=self.device, ) @@ -1560,7 +1559,6 @@ class TPUModelRunner(LoRAModelRunnerMixin): batch_size) return MultiModalKwargs.as_kwargs( batched_dummy_mm_inputs, - dtype=self.model_config.dtype, device=self.device, ) diff --git a/vllm/worker/cpu_enc_dec_model_runner.py b/vllm/worker/cpu_enc_dec_model_runner.py index 677d66357a7fa..c99e2652a3972 100644 --- a/vllm/worker/cpu_enc_dec_model_runner.py +++ b/vllm/worker/cpu_enc_dec_model_runner.py @@ -300,7 +300,6 @@ class CPUEncoderDecoderModelRunner( model_input.encoder_input_positions, **MultiModalKwargs.as_kwargs( model_input.multi_modal_kwargs or {}, - dtype=self.model_config.dtype, device=self.device, ), "intermediate_tensors": diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 6213cf760ac55..68cdf65cafa79 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -630,7 +630,6 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]): if model_input.multi_modal_kwargs is not None: multimodal_kwargs = MultiModalKwargs.as_kwargs( model_input.multi_modal_kwargs, - dtype=self.model_config.dtype, device=self.device, ) execute_model_kwargs = {} diff --git a/vllm/worker/cpu_pooling_model_runner.py b/vllm/worker/cpu_pooling_model_runner.py index 174f86f48b568..203fdf225a41a 100644 --- a/vllm/worker/cpu_pooling_model_runner.py +++ b/vllm/worker/cpu_pooling_model_runner.py @@ -53,7 +53,6 @@ class CPUPoolingModelRunner( model_input.input_positions, **MultiModalKwargs.as_kwargs( model_input.multi_modal_kwargs or {}, - dtype=self.model_config.dtype, device=self.device, ), **cross_enc_kwargs, diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index a3e7b0147961c..8d92edc5b386e 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -205,7 +205,6 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs( multi_modal_kwargs, - dtype=self.model_config.dtype, device=self.device, ), **seqlen_agnostic_kwargs, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 75501e0f748ab..82db6617ba55f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1848,7 +1848,6 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs( multi_modal_kwargs, - dtype=self.model_config.dtype, device=self.device, ), **seqlen_agnostic_kwargs, diff --git a/vllm/worker/multi_step_neuron_model_runner.py b/vllm/worker/multi_step_neuron_model_runner.py index 336e41649df58..25f588077cb42 100644 --- a/vllm/worker/multi_step_neuron_model_runner.py +++ b/vllm/worker/multi_step_neuron_model_runner.py @@ -73,7 +73,6 @@ class MultiStepNeuronModelRunner(NeuronModelRunner): input_block_ids=model_input.input_block_ids, **MultiModalKwargs.as_kwargs( model_input.multi_modal_kwargs or {}, - dtype=self.model_config.dtype, device=self.device, ), ) diff --git a/vllm/worker/multi_step_neuronx_distributed_model_runner.py b/vllm/worker/multi_step_neuronx_distributed_model_runner.py index de9827723eecf..dd521dd67dad0 100644 --- a/vllm/worker/multi_step_neuronx_distributed_model_runner.py +++ b/vllm/worker/multi_step_neuronx_distributed_model_runner.py @@ -52,7 +52,6 @@ class MultiStepNeuronxDistributedModelRunner(NeuronxDistributedModelRunner): sampling_params=sampling_params, **MultiModalKwargs.as_kwargs( model_input.multi_modal_kwargs or {}, - dtype=self.model_config.dtype, device=self.device, ), ) diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index 28855bb4698bc..7ccf1a2c0a876 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -395,7 +395,6 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): adapter_ids=model_input.adapter_ids, **MultiModalKwargs.as_kwargs( model_input.multi_modal_kwargs or {}, - dtype=self.model_config.dtype, device=self.device, ), ) @@ -408,7 +407,6 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): input_block_ids=model_input.input_block_ids, **MultiModalKwargs.as_kwargs( model_input.multi_modal_kwargs or {}, - dtype=self.model_config.dtype, device=self.device, ), ) diff --git a/vllm/worker/pooling_model_runner.py b/vllm/worker/pooling_model_runner.py index be6b3d1379fdc..f80955f71a5a3 100644 --- a/vllm/worker/pooling_model_runner.py +++ b/vllm/worker/pooling_model_runner.py @@ -122,7 +122,6 @@ class PoolingModelRunner( intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs( multi_modal_kwargs, - dtype=self.model_config.dtype, device=self.device, ), **cross_enc_kwargs, diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index ecbb63d912766..b2d3ce8526d51 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -565,7 +565,6 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs( model_input.multi_modal_kwargs or {}, - dtype=self.model_config.dtype, device=self.device, ), )