From b50602d5f04677e75158c0d2e0e8b51793a5d545 Mon Sep 17 00:00:00 2001 From: Lukas Geiger Date: Tue, 27 May 2025 06:42:54 +0100 Subject: [PATCH] [Model][Gemma3] Cast image pixel values already on CPU (#18732) Signed-off-by: Lukas Geiger --- vllm/model_executor/models/gemma3_mm.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index c4ae5b50c4514..00a972d33b049 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -263,6 +263,11 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): mm_data, mm_kwargs, ) + if "pixel_values" in processed_outputs: + # Cast pixel values to model dtype already here, + # so we need to transfer less data to the GPU + processed_outputs["pixel_values"] = processed_outputs[ + "pixel_values"].to(self.info.ctx.model_config.dtype) # HF processor pops the `num_crops` kwarg, which is needed by vLLM if (images := mm_data.get("images")) is not None: @@ -543,9 +548,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, vision_tower: SiglipVisionModel, pixel_values: torch.Tensor, ) -> torch.Tensor: - target_dtype = vision_tower.get_input_embeddings().weight.dtype - image_features = vision_tower(pixel_values.to(dtype=target_dtype)) - return image_features + return vision_tower(pixel_values) def _process_image_input( self,