diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py index 7584b5188cf2a..537aeabf72d5a 100644 --- a/vllm/model_executor/models/glm4v.py +++ b/vllm/model_executor/models/glm4v.py @@ -6,7 +6,7 @@ """Inference-only CogAgent model compatible with THUDM weights.""" from argparse import Namespace from collections.abc import Mapping, Sequence -from typing import Literal, Optional, TypedDict, Union +from typing import Annotated, Literal, Optional, Union import torch from torch import nn @@ -38,6 +38,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import ChatGLMConfig +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .chatglm import ChatGLMBaseModel, ChatGLMModel from .interfaces import (MultiModalEmbeddings, SupportsLoRA, @@ -45,10 +46,16 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA, from .utils import flatten_bn, merge_multimodal_embeddings -class GLMVImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - data: torch.Tensor - """Shape: `(batch_size, num_channels, height, width)`""" +class GLMVImagePixelInputs(TensorSchema): + """ + Dimensions: + - b: Batch size + - c: Number of channels (3) + - h: Height of image + - w: Width of image + """ + type: Literal["pixel_values"] = "pixel_values" + data: Annotated[torch.Tensor, TensorShape("b", 3, "h", "w")] class EVA2CLIPPatchEmbedding(nn.Module): @@ -562,19 +569,6 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, self.transformer: GLM4VModel - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - h = w = self.config.vision_config["image_size"] - expected_dims = (3, h, w) - actual_dims = tuple(data.shape[1:]) - - if actual_dims != expected_dims: - expected_expr = ("batch_size", *map(str, expected_dims)) - raise ValueError( - f"The expected shape of pixel values is {expected_expr}. " - f"You supplied {tuple(data.shape)}.") - - return data - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[GLMVImagePixelInputs]: pixel_values = kwargs.pop("pixel_values", None) @@ -584,11 +578,14 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") - return GLMVImagePixelInputs( - type="pixel_values", - data=self._validate_pixel_values( - flatten_bn(pixel_values, concat=True)), - ) + expected_h = expected_w = self.config.vision_config["image_size"] + return GLMVImagePixelInputs(type="pixel_values", + data=flatten_bn(pixel_values, + concat=True), + resolve_bindings={ + "h": expected_h, + "w": expected_w + }) return None