[Model][Gemma3] Simplify image input validation (#18710)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
Lukas Geiger 2025-05-27 04:13:37 +01:00 committed by GitHub
parent 27bebcd897
commit 0eebd74842
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -504,18 +504,12 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
return next(self.parameters()).dtype
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
if d.shape != expected_dims:
raise ValueError(
"The expected shape of pixel values per image per batch "
f"is {expected_dims}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
image_size = self.config.vision_config.image_size
expected_dims = (3, image_size, image_size)
if data.shape[1:] != expected_dims:
raise ValueError(
"The expected shape of pixel values per image per batch is "
f"{expected_dims}. You supplied {tuple(data.shape)}.")
return data
def _parse_and_validate_image_input(