Migrate GLMVImagePixelInputs to TensorSchema (#21679)

Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
Benji Beck 2025-07-27 22:36:11 -07:00 committed by GitHub
parent 88e46c7c8d
commit 304dcdf575
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -6,7 +6,7 @@
"""Inference-only CogAgent model compatible with THUDM weights.""" """Inference-only CogAgent model compatible with THUDM weights."""
from argparse import Namespace from argparse import Namespace
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import Literal, Optional, TypedDict, Union from typing import Annotated, Literal, Optional, Union
import torch import torch
from torch import nn from torch import nn
@ -38,6 +38,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import ChatGLMConfig from vllm.transformers_utils.configs import ChatGLMConfig
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .chatglm import ChatGLMBaseModel, ChatGLMModel from .chatglm import ChatGLMBaseModel, ChatGLMModel
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
@ -45,10 +46,16 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
from .utils import flatten_bn, merge_multimodal_embeddings from .utils import flatten_bn, merge_multimodal_embeddings
class GLMVImagePixelInputs(TypedDict): class GLMVImagePixelInputs(TensorSchema):
type: Literal["pixel_values"] """
data: torch.Tensor Dimensions:
"""Shape: `(batch_size, num_channels, height, width)`""" - b: Batch size
- c: Number of channels (3)
- h: Height of image
- w: Width of image
"""
type: Literal["pixel_values"] = "pixel_values"
data: Annotated[torch.Tensor, TensorShape("b", 3, "h", "w")]
class EVA2CLIPPatchEmbedding(nn.Module): class EVA2CLIPPatchEmbedding(nn.Module):
@ -562,19 +569,6 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
self.transformer: GLM4VModel self.transformer: GLM4VModel
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config["image_size"]
expected_dims = (3, h, w)
actual_dims = tuple(data.shape[1:])
if actual_dims != expected_dims:
expected_expr = ("batch_size", *map(str, expected_dims))
raise ValueError(
f"The expected shape of pixel values is {expected_expr}. "
f"You supplied {tuple(data.shape)}.")
return data
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[GLMVImagePixelInputs]: self, **kwargs: object) -> Optional[GLMVImagePixelInputs]:
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
@ -584,11 +578,14 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
raise ValueError("Incorrect type of pixel values. " raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}") f"Got type: {type(pixel_values)}")
return GLMVImagePixelInputs( expected_h = expected_w = self.config.vision_config["image_size"]
type="pixel_values", return GLMVImagePixelInputs(type="pixel_values",
data=self._validate_pixel_values( data=flatten_bn(pixel_values,
flatten_bn(pixel_values, concat=True)), concat=True),
) resolve_bindings={
"h": expected_h,
"w": expected_w
})
return None return None