Migrate LlavaNextImageInputs to TensorSchema (#21774)

Signed-off-by: Benji Beck <benjibeck@meta.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Benji Beck 2025-08-10 09:05:21 -07:00 committed by GitHub
parent 65a7917be4
commit b4e2916721
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 35 additions and 64 deletions

View File

@ -3,7 +3,7 @@
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Iterable, Mapping from collections.abc import Iterable, Mapping
from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar, from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar,
Union) Union)
import torch import torch
@ -11,7 +11,6 @@ import torch.nn as nn
from transformers import BatchFeature, LlavaNextConfig, LlavaNextProcessor from transformers import BatchFeature, LlavaNextConfig, LlavaNextProcessor
from transformers.models.llava_next.modeling_llava_next import ( from transformers.models.llava_next.modeling_llava_next import (
get_anyres_image_grid_shape, unpad_image) get_anyres_image_grid_shape, unpad_image)
from typing_extensions import NotRequired
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
@ -19,6 +18,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig from vllm.multimodal.inputs import MultiModalFieldConfig
from vllm.multimodal.parse import ImageSize from vllm.multimodal.parse import ImageSize
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .clip import CLIPVisionModel from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
@ -30,32 +30,36 @@ from .utils import (AutoWeightsLoader, WeightsMapper, embed_multimodal,
flatten_bn, init_vllm_registered_model, maybe_prefix) flatten_bn, init_vllm_registered_model, maybe_prefix)
class LlavaNextImagePixelInputs(TypedDict): class LlavaNextImagePixelInputs(TensorSchema):
type: Literal["pixel_values"]
pixel_values: Union[torch.Tensor, list[torch.Tensor]]
""" """
Shape: Dimensions:
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)` - bn: Batch size * number of images
- np: Number of patches + 1
- c: Number of channels (3)
- h: Height
- w: Width
Note that `num_patches` may be different per batch and image, Note that `num_patches` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor. in which case the data is passed as a list instead of a batched tensor.
""" """
type: Literal["pixel_values"] = "pixel_values"
pixel_values: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "np", 3, "h", "w", dynamic_dims={"np"})]
image_sizes: NotRequired[torch.Tensor] image_sizes: Annotated[Optional[torch.Tensor], TensorShape("bn", 2)]
# This should be in `(height, width)` format.
class LlavaNextImageEmbeddingInputs(TensorSchema):
""" """
Shape: `(batch_size * num_images, 2)` Dimensions:
- bn: Batch size * number of images
This should be in `(height, width)` format. - ifs: Image feature size
""" - hs: Hidden size (must match language model backbone)
class LlavaNextImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
""" """
type: Literal["image_embeds"] = "image_embeds"
data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
LlavaNextImageInputs = Union[LlavaNextImagePixelInputs, LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
@ -269,44 +273,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors) self.language_model.make_empty_intermediate_tensors)
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
expected_dims = (2, )
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape)
if actual_dims != expected_dims:
expected_expr = str(expected_dims)
raise ValueError(
f"The expected shape of image sizes per image per batch "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _validate_pixel_values(
self, data: Union[torch.Tensor, list[torch.Tensor]]
) -> Union[torch.Tensor, list[torch.Tensor]]:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape[1:])
if actual_dims != expected_dims:
expected_expr = ("num_patches", *map(str, expected_dims))
raise ValueError(
"The expected shape of pixel values per image per batch "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaNextImageInputs]: self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
@ -325,13 +291,15 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
raise ValueError("Incorrect type of image sizes. " raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}") f"Got type: {type(image_sizes)}")
expected_h = expected_w = self.config.vision_config.image_size
return LlavaNextImagePixelInputs( return LlavaNextImagePixelInputs(
type="pixel_values", type="pixel_values",
pixel_values=self._validate_pixel_values( pixel_values=flatten_bn(pixel_values),
flatten_bn(pixel_values)), image_sizes=flatten_bn(image_sizes, concat=True),
image_sizes=self._validate_image_sizes( resolve_bindings={
flatten_bn(image_sizes, concat=True)), "h": expected_h,
) "w": expected_w,
})
if image_embeds is not None: if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor): if not isinstance(image_embeds, torch.Tensor):

View File

@ -60,6 +60,9 @@ class TensorSchema:
def __getitem__(self, item) -> Any: def __getitem__(self, item) -> Any:
return getattr(self, item) return getattr(self, item)
def get(self, item, default=None) -> Any:
return getattr(self, item, default)
def _match_shape_with_dynamic(self, actual: tuple[int, ...], def _match_shape_with_dynamic(self, actual: tuple[int, ...],
reference: tuple[int, ...], reference: tuple[int, ...],
expected_shape: tuple[Union[int, str], ...], expected_shape: tuple[Union[int, str], ...],