Migrate LlavaImageInputs to TensorSchema (#21770)

Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
Benji Beck 2025-08-10 19:29:19 -07:00 committed by GitHub
parent a554991748
commit 06da44f0cb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,7 +3,7 @@
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar, from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar,
Union, cast) Union, cast)
import torch import torch
@ -33,6 +33,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptUpdateDetails) PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .clip import CLIPVisionModel from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
@ -44,35 +45,46 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
from .vision import get_vision_encoder_info from .vision import get_vision_encoder_info
class LlavaImagePixelInputs(TypedDict): class LlavaImagePixelInputs(TensorSchema):
type: Literal["pixel_values"]
pixel_values: torch.Tensor
""" """
Shape: `(batch_size * num_images, num_channels, height, width)` Dimensions:
- bn: Batch size * number of images
- c: Number of channels (3)
- h: Height
- w: Width
Note that `height` or `width` may be different per batch and image, Note that `height` or `width` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor. in which case the data is passed as a list instead of a batched tensor.
""" """
type: Literal["pixel_values"] = "pixel_values"
pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
class PixtralHFImagePixelInputs(TypedDict): class PixtralHFImagePixelInputs(TensorSchema):
type: Literal["pixel_values_pixtral"]
pixel_values: Union[torch.Tensor, list[torch.Tensor]]
""" """
Shape: `(batch_size * num_images, num_channels, height, width)` Dimensions:
- bn: Batch size * number of images
- c: Number of channels
- h: Height
- w: Width
Note that `height` or `width` may be different per batch and image, Note that `height` or `width` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor. in which case the data is passed as a list instead of a batched tensor.
""" """
type: Literal["pixel_values_pixtral"] = "pixel_values_pixtral"
pixel_values: Annotated[Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "c", "h", "w")]
class LlavaImageEmbeddingInputs(TypedDict): class LlavaImageEmbeddingInputs(TensorSchema):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
""" """
Dimensions:
- bn: Batch size * number of images
- ifs: Image feature size
- hs: Hidden size (must match language model backbone)
"""
type: Literal["image_embeds"] = "image_embeds"
data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
LlavaImageInputs = Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs, LlavaImageInputs = Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs,
@ -547,19 +559,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors) self.language_model.make_empty_intermediate_tensors)
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
actual_dims = tuple(data.shape[1:])
if actual_dims != expected_dims:
expected_expr = ("batch_size", *map(str, expected_dims))
raise ValueError(
f"The expected shape of pixel values is {expected_expr}. "
f"You supplied {tuple(data.shape)}.")
return data
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaImageInputs]: self, **kwargs: object) -> Optional[LlavaImageInputs]:
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
@ -579,10 +578,14 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
pixel_values=flatten_bn(pixel_values), pixel_values=flatten_bn(pixel_values),
) )
expected_h = expected_w = self.config.vision_config.image_size
return LlavaImagePixelInputs( return LlavaImagePixelInputs(
type="pixel_values", type="pixel_values",
pixel_values=self._validate_pixel_values( pixel_values=flatten_bn(pixel_values, concat=True),
flatten_bn(pixel_values, concat=True)), resolve_bindings={
"h": expected_h,
"w": expected_w
},
) )
if image_embeds is not None: if image_embeds is not None: