Migrate ChameleonImagePixelInputs to TensorSchema (#21657)

Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
Benji Beck 2025-07-26 19:34:25 -07:00 committed by GitHub
parent 3339cba3ff
commit 20950b29fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,7 +3,7 @@
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property from functools import cached_property
from typing import Any, Literal, Optional, TypedDict, Union from typing import Annotated, Any, Literal, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -38,6 +38,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptUpdate, PromptUpdateDetails) PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP, from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
SupportsQuant) SupportsQuant)
@ -48,10 +49,16 @@ from .utils import (flatten_bn, is_pp_missing_parameter,
logger = init_logger(__name__) logger = init_logger(__name__)
class ChameleonImagePixelInputs(TypedDict): class ChameleonImagePixelInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- c: Number of channels (3)
- h: Height of each image
- w: Width of each image
"""
type: Literal["pixel_values"] type: Literal["pixel_values"]
data: torch.Tensor data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
class ChameleonProcessingInfo(BaseProcessingInfo): class ChameleonProcessingInfo(BaseProcessingInfo):
@ -962,19 +969,6 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
vq_config: ChameleonVQVAEConfig = self.config.vq_config
expected_dims = (3, vq_config.resolution, vq_config.resolution)
actual_dims = tuple(data.shape[1:])
if actual_dims != expected_dims:
expected_expr = ("batch_size", *map(str, expected_dims))
raise ValueError(
f"The expected shape of pixel values is {expected_expr}. "
f"You supplied {tuple(data.shape)}.")
return data
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[ChameleonImagePixelInputs]: self, **kwargs: object) -> Optional[ChameleonImagePixelInputs]:
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
@ -982,16 +976,16 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
if pixel_values is None: if pixel_values is None:
return None return None
if not isinstance(pixel_values, (torch.Tensor, list)): vq_config: ChameleonVQVAEConfig = self.config.vq_config
raise ValueError("Incorrect type of pixel values. " expected_h = expected_w = vq_config.resolution
f"Got type: {type(pixel_values)}")
pixel_values = flatten_bn(pixel_values, concat=True) return ChameleonImagePixelInputs(type="pixel_values",
data=flatten_bn(pixel_values,
return ChameleonImagePixelInputs( concat=True),
type="pixel_values", resolve_bindings={
data=self._validate_pixel_values(pixel_values), "h": expected_h,
) "w": expected_w
})
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.model return self.model