[Model][Gemma3] Cast image pixel values already on CPU (#18732)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
Lukas Geiger 2025-05-27 06:42:54 +01:00 committed by GitHub
parent 1f1b1bc03b
commit b50602d5f0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -263,6 +263,11 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
mm_data,
mm_kwargs,
)
if "pixel_values" in processed_outputs:
# Cast pixel values to model dtype already here,
# so we need to transfer less data to the GPU
processed_outputs["pixel_values"] = processed_outputs[
"pixel_values"].to(self.info.ctx.model_config.dtype)
# HF processor pops the `num_crops` kwarg, which is needed by vLLM
if (images := mm_data.get("images")) is not None:
@ -543,9 +548,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
vision_tower: SiglipVisionModel,
pixel_values: torch.Tensor,
) -> torch.Tensor:
target_dtype = vision_tower.get_input_embeddings().weight.dtype
image_features = vision_tower(pixel_values.to(dtype=target_dtype))
return image_features
return vision_tower(pixel_values)
def _process_image_input(
self,