Migrate Mistral3ImagePixelInputs to TensorSchema (#21945)

Signed-off-by: Benji Beck <benjibeck@meta.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Benji Beck 2025-08-20 10:37:29 -07:00 committed by GitHub
parent dfd2382039
commit c4477f55e5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,7 +3,7 @@
from abc import abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar,
from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar,
Union)
import torch
@ -32,6 +32,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
@ -42,15 +43,23 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
from .vision import get_vision_encoder_info
class Mistral3ImagePixelInputs(TypedDict):
type: Literal["pixel_values_pixtral"]
pixel_values: Union[torch.Tensor, list[torch.Tensor]]
class Mistral3ImagePixelInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- c: Number of channels (3)
- h: Height of each image
- w: Width of each image
"""
Shape: `(batch_size * num_images, num_channels, height, width)`
Note that `height` or `width` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor.
"""
type: Literal["pixel_values_pixtral"] = "pixel_values_pixtral"
# Note that `height` or `width` may be different per batch and image,
# in which case the data is passed as a list instead of a batched tensor.
pixel_values: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", 3, "h", "w", dynamic_dims={"h", "w"}),
]
class Mistral3PatchMerger(nn.Module):
@ -456,19 +465,6 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
actual_dims = tuple(data.shape[1:])
if actual_dims != expected_dims:
expected_expr = ("batch_size", *map(str, expected_dims))
raise ValueError(
f"The expected shape of pixel values is {expected_expr}. "
f"You supplied {tuple(data.shape)}.")
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Mistral3ImagePixelInputs]:
pixel_values = kwargs.pop("pixel_values", None)