Migrate Gemma3ImagePixelInputs to TensorSchema (#21676)

Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
Benji Beck 2025-07-27 22:36:05 -07:00 committed by GitHub
parent e626d286f5
commit d8937de4c8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Literal, Optional, TypedDict
from typing import Annotated, Any, Literal, Optional
import torch
from torch import nn
@ -31,6 +31,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
# yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
@ -42,18 +43,21 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
logger = init_logger(__name__)
class Gemma3ImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
pixel_values: torch.Tensor
class Gemma3ImagePixelInputs(TensorSchema):
"""
Shape: `(num_patches_total, num_channels, height, width)`
`num_patches_total` is the total number of patches
over each image over each prompt in the batch.
Dimensions:
- p: Number of patches total (over each image over each prompt in the
batch)
- c: Number of channels (3)
- h: Height of each patch
- w: Width of each patch
- bn: Batch size * number of images
"""
type: Literal["pixel_values"] = "pixel_values"
num_patches: torch.Tensor
"""Shape: `(batch_size * num_images)`"""
pixel_values: Annotated[torch.Tensor, TensorShape("p", 3, "h", "w")]
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
Gemma3ImageInputs = Gemma3ImagePixelInputs
@ -523,15 +527,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
def dtype(self):
return next(self.parameters()).dtype
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
image_size = self.config.vision_config.image_size
expected_dims = (3, image_size, image_size)
if data.shape[1:] != expected_dims:
raise ValueError(
"The expected shape of pixel values per image per batch is "
f"{expected_dims}. You supplied {tuple(data.shape)}.")
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Gemma3ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
@ -549,14 +544,15 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
raise ValueError("Incorrect type of num_crops. "
f"Got type: {type(num_crops)}")
pixel_values = flatten_bn(pixel_values, concat=True)
num_crops = flatten_bn(num_crops, concat=True)
image_size = self.config.vision_config.image_size
return Gemma3ImagePixelInputs(
type="pixel_values",
pixel_values=self._validate_pixel_values(pixel_values),
num_patches=num_crops + 1,
)
pixel_values=flatten_bn(pixel_values, concat=True),
num_patches=flatten_bn(num_crops, concat=True) + 1,
resolve_bindings={
"h": image_size,
"w": image_size
})
def _image_pixels_to_features(
self,