From 696259ca0180c4357cf437a334aaf0966be5cb4b Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Tue, 27 May 2025 23:45:48 +0800 Subject: [PATCH] [Core] Automatically cast multi-modal input dtype (#18756) Signed-off-by: DarkLight1337 --- vllm/model_executor/models/deepseek_vl2.py | 4 +--- vllm/model_executor/models/gemma3_mm.py | 5 ----- vllm/multimodal/inputs.py | 8 +++++++- vllm/spec_decode/draft_model_runner.py | 7 +++++-- vllm/v1/worker/gpu_model_runner.py | 12 +++++++++--- vllm/v1/worker/tpu_model_runner.py | 14 ++++++++++---- vllm/worker/cpu_enc_dec_model_runner.py | 7 +++++-- vllm/worker/cpu_model_runner.py | 5 ++++- vllm/worker/cpu_pooling_model_runner.py | 7 +++++-- vllm/worker/enc_dec_model_runner.py | 10 +++++++--- vllm/worker/model_runner.py | 7 +++++-- vllm/worker/multi_step_neuron_model_runner.py | 7 +++++-- ...ulti_step_neuronx_distributed_model_runner.py | 7 +++++-- vllm/worker/neuron_model_runner.py | 16 ++++++++++------ vllm/worker/pooling_model_runner.py | 10 +++++++--- vllm/worker/xpu_model_runner.py | 9 ++++++--- 16 files changed, 91 insertions(+), 44 deletions(-) diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index 164fa40ffebe5..5c8793f59ffbe 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -210,9 +210,7 @@ class DeepseekVL2MultiModalProcessor( dict(prompt=prompt, **mm_data), mm_kwargs, ) - target_dtype = self.info.ctx.model_config.dtype - pixel_values = processed_outputs.pop("pixel_values").to( - target_dtype) + pixel_values = processed_outputs["pixel_values"] # split pixel values into patches corresponding to each image images_spatial_crop = processed_outputs["images_spatial_crop"] patches_per_image = [ diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 00a972d33b049..182cc86d3ca8f 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -263,11 +263,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): mm_data, mm_kwargs, ) - if "pixel_values" in processed_outputs: - # Cast pixel values to model dtype already here, - # so we need to transfer less data to the GPU - processed_outputs["pixel_values"] = processed_outputs[ - "pixel_values"].to(self.info.ctx.model_config.dtype) # HF processor pops the `num_crops` kwarg, which is needed by vLLM if (images := mm_data.get("images")) is not None: diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 162dd52e3e73c..600a34d39ef68 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -746,11 +746,17 @@ class MultiModalKwargs(UserDict[str, NestedTensors]): batched_inputs: BatchedTensorInputs, *, device: torch.types.Device, + dtype: Optional[torch.dtype] = None, ) -> BatchedTensorInputs: json_inputs = cast(JSONTree[torch.Tensor], batched_inputs) + def maybe_cast_dtype(x: torch.Tensor): + # This mimics the behavior of transformers.BatchFeature + return x.to(dtype=dtype) if x.is_floating_point() else x + json_mapped = json_map_leaves( - lambda x: x.to(device, non_blocking=True), + # NOTE: Cast the dtype before sending it to device + lambda x: maybe_cast_dtype(x).to(device=device, non_blocking=True), json_inputs, ) diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index a6276c5633945..991d2040a878a 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -294,8 +294,11 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase): inputs_embeds=None, positions=model_input.input_positions, intermediate_tensors=intermediate_tensors, - **MultiModalKwargs.as_kwargs(multi_modal_kwargs, - device=self.device), + **MultiModalKwargs.as_kwargs( + multi_modal_kwargs, + dtype=self.model_runner.model_config.dtype, + device=self.device, + ), **model_execute_kwargs, ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index aa47ac253bb93..910c0e80bb31c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -929,8 +929,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): encoder_outputs = [] for grouped_mm_inputs in grouped_mm_inputs_list: batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs) - batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs, - device=self.device) + batched_mm_inputs = MultiModalKwargs.as_kwargs( + batched_mm_inputs, + dtype=self.model_config.dtype, + device=self.device, + ) # Run the encoder. # `curr_group_outputs` is either of the following: @@ -1874,7 +1877,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): batched_dummy_mm_inputs = MultiModalKwargs.batch( [dummy_mm_kwargs] * max_num_mm_items) batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs( - batched_dummy_mm_inputs, device=self.device) + batched_dummy_mm_inputs, + dtype=self.model_config.dtype, + device=self.device, + ) # Run multimodal encoder. dummy_encoder_outputs = self.model.get_multimodal_embeddings( diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index b13ff9f97e6fa..46bcf64ed0c39 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -652,8 +652,11 @@ class TPUModelRunner(LoRAModelRunnerMixin): encoder_outputs = [] for grouped_mm_inputs in grouped_mm_inputs_list: batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs) - batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs, - device=self.device) + batched_mm_inputs = MultiModalKwargs.as_kwargs( + batched_mm_inputs, + dtype=self.model_config.dtype, + device=self.device, + ) # Run the encoder. # `curr_group_outputs` is either of the following: @@ -1435,8 +1438,11 @@ class TPUModelRunner(LoRAModelRunnerMixin): batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] * batch_size) - return MultiModalKwargs.as_kwargs(batched_dummy_mm_inputs, - device=self.device) + return MultiModalKwargs.as_kwargs( + batched_dummy_mm_inputs, + dtype=self.model_config.dtype, + device=self.device, + ) def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]: diff --git a/vllm/worker/cpu_enc_dec_model_runner.py b/vllm/worker/cpu_enc_dec_model_runner.py index c2120c035175a..82eeeb570d222 100644 --- a/vllm/worker/cpu_enc_dec_model_runner.py +++ b/vllm/worker/cpu_enc_dec_model_runner.py @@ -297,8 +297,11 @@ class CPUEncoderDecoderModelRunner( model_input.encoder_input_tokens, "encoder_positions": model_input.encoder_input_positions, - **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, - device=self.device), + **MultiModalKwargs.as_kwargs( + model_input.multi_modal_kwargs or {}, + dtype=self.model_config.dtype, + device=self.device, + ), "intermediate_tensors": intermediate_tensors, } diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 710ca1a13b0c5..fb436a079f878 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -628,7 +628,10 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]): multimodal_kwargs = {} if model_input.multi_modal_kwargs is not None: multimodal_kwargs = MultiModalKwargs.as_kwargs( - model_input.multi_modal_kwargs, device=self.device) + model_input.multi_modal_kwargs, + dtype=self.model_config.dtype, + device=self.device, + ) execute_model_kwargs = {} if previous_hidden_states is not None: execute_model_kwargs.update( diff --git a/vllm/worker/cpu_pooling_model_runner.py b/vllm/worker/cpu_pooling_model_runner.py index 1ceb2557c6b3d..2a60e51261ad6 100644 --- a/vllm/worker/cpu_pooling_model_runner.py +++ b/vllm/worker/cpu_pooling_model_runner.py @@ -50,8 +50,11 @@ class CPUPoolingModelRunner( model_input.input_tokens, "positions": model_input.input_positions, - **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, - device=self.device), + **MultiModalKwargs.as_kwargs( + model_input.multi_modal_kwargs or {}, + dtype=self.model_config.dtype, + device=self.device, + ), **cross_enc_kwargs, "intermediate_tensors": intermediate_tensors, diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 4864163b0de2a..3957e5608524f 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -202,9 +202,13 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): encoder_input_ids=model_input.encoder_input_tokens, encoder_positions=model_input.encoder_input_positions, intermediate_tensors=intermediate_tensors, - **MultiModalKwargs.as_kwargs(multi_modal_kwargs, - device=self.device), - **seqlen_agnostic_kwargs) + **MultiModalKwargs.as_kwargs( + multi_modal_kwargs, + dtype=self.model_config.dtype, + device=self.device, + ), + **seqlen_agnostic_kwargs, + ) logits = self.model.compute_logits(hidden_or_intermediate_states, model_input.sampling_metadata) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 53e79adf9aaec..8c968faa78101 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1845,8 +1845,11 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): inputs_embeds=model_input.inputs_embeds, positions=model_input.input_positions, intermediate_tensors=intermediate_tensors, - **MultiModalKwargs.as_kwargs(multi_modal_kwargs, - device=self.device), + **MultiModalKwargs.as_kwargs( + multi_modal_kwargs, + dtype=self.model_config.dtype, + device=self.device, + ), **seqlen_agnostic_kwargs, **model_kwargs, ) diff --git a/vllm/worker/multi_step_neuron_model_runner.py b/vllm/worker/multi_step_neuron_model_runner.py index 9618a4b49ff89..aafb7ab7cfb8d 100644 --- a/vllm/worker/multi_step_neuron_model_runner.py +++ b/vllm/worker/multi_step_neuron_model_runner.py @@ -70,8 +70,11 @@ class MultiStepNeuronModelRunner(NeuronModelRunner): input_ids=model_input.input_tokens, positions=model_input.input_positions, input_block_ids=model_input.input_block_ids, - **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, - device=self.device), + **MultiModalKwargs.as_kwargs( + model_input.multi_modal_kwargs or {}, + dtype=self.model_config.dtype, + device=self.device, + ), ) output = self.model.sample( diff --git a/vllm/worker/multi_step_neuronx_distributed_model_runner.py b/vllm/worker/multi_step_neuronx_distributed_model_runner.py index b6a3492a493bb..3a9c0993e004f 100644 --- a/vllm/worker/multi_step_neuronx_distributed_model_runner.py +++ b/vllm/worker/multi_step_neuronx_distributed_model_runner.py @@ -49,8 +49,11 @@ class MultiStepNeuronxDistributedModelRunner(NeuronxDistributedModelRunner): positions=model_input.input_positions, input_block_ids=model_input.input_block_ids, sampling_params=sampling_params, - **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, - device=self.device), + **MultiModalKwargs.as_kwargs( + model_input.multi_modal_kwargs or {}, + dtype=self.model_config.dtype, + device=self.device, + ), ) output = self.model.sample( diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index e97adf757cc12..968596471a26e 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -378,9 +378,11 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): positions=model_input.input_positions, input_block_ids=model_input.input_block_ids, sampling_params=sampling_params, - **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs - or {}, - device=self.device), + **MultiModalKwargs.as_kwargs( + model_input.multi_modal_kwargs or {}, + dtype=self.model_config.dtype, + device=self.device, + ), ) elif current_platform.use_transformers_neuronx(): # [TODO] validate on-device sampling @@ -389,9 +391,11 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): input_ids=model_input.input_tokens, positions=model_input.input_positions, input_block_ids=model_input.input_block_ids, - **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs - or {}, - device=self.device), + **MultiModalKwargs.as_kwargs( + model_input.multi_modal_kwargs or {}, + dtype=self.model_config.dtype, + device=self.device, + ), ) # Compute the logits only if the on-device sampling is turned off as diff --git a/vllm/worker/pooling_model_runner.py b/vllm/worker/pooling_model_runner.py index fdb7353f2f9ce..912e04c435f54 100644 --- a/vllm/worker/pooling_model_runner.py +++ b/vllm/worker/pooling_model_runner.py @@ -119,10 +119,14 @@ class PoolingModelRunner( input_ids=model_input.input_tokens, positions=model_input.input_positions, intermediate_tensors=intermediate_tensors, - **MultiModalKwargs.as_kwargs(multi_modal_kwargs, - device=self.device), + **MultiModalKwargs.as_kwargs( + multi_modal_kwargs, + dtype=self.model_config.dtype, + device=self.device, + ), **cross_enc_kwargs, - **seqlen_agnostic_kwargs) + **seqlen_agnostic_kwargs, + ) if (self.observability_config is not None and self.observability_config.collect_model_forward_time): diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 7042b575aa787..79fa7d2c73e88 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -562,9 +562,12 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): input_ids=model_input.input_tokens, positions=model_input.input_positions, intermediate_tensors=intermediate_tensors, - **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs - or {}, - device=self.device)) + **MultiModalKwargs.as_kwargs( + model_input.multi_modal_kwargs or {}, + dtype=self.model_config.dtype, + device=self.device, + ), + ) # Compute the logits in the last pipeline stage. if not get_pp_group().is_last_rank: return hidden_or_intermediate_states