From 88e46c7c8dfa7c761259795773b758114721169a Mon Sep 17 00:00:00 2001 From: Benji Beck Date: Sun, 27 Jul 2025 22:36:08 -0700 Subject: [PATCH] Migrate Glm4vImageInputs, Glm4vVideoInputs to TensorSchema (#21678) Signed-off-by: Benji Beck Signed-off-by: DarkLight1337 --- tests/standalone_tests/test_tensor_schema.py | 24 +++- vllm/model_executor/models/glm4_1v.py | 119 ++++++++----------- 2 files changed, 73 insertions(+), 70 deletions(-) diff --git a/tests/standalone_tests/test_tensor_schema.py b/tests/standalone_tests/test_tensor_schema.py index b276b88fac1f..e98aa3f53fb5 100644 --- a/tests/standalone_tests/test_tensor_schema.py +++ b/tests/standalone_tests/test_tensor_schema.py @@ -5,6 +5,7 @@ import pytest import torch from vllm.model_executor.models.fuyu import FuyuImagePatchInputs +from vllm.model_executor.models.glm4_1v import Glm4vImageEmbeddingInputs from vllm.model_executor.models.phi3v import Phi3VImagePixelInputs @@ -145,4 +146,25 @@ def test_tensor_schema_with_list_of_symbolic_dim_mismatch_in_length(): FuyuImagePatchInputs( flat_data=flat_data, patches_per_image=patches_per_image, - ) \ No newline at end of file + ) + + +def test_valid_tensor_schema_with_static_last_dim(): + image_embeds = torch.randn(256, 1024) + image_grid_thw = torch.randint(0, 4, (2, 3)) + + Glm4vImageEmbeddingInputs( + image_embeds=image_embeds, + image_grid_thw=image_grid_thw, + ) + + +def test_invalid_tensor_schema_with_static_last_dim(): + image_embeds = torch.randn(256, 1024) + image_grid_thw = torch.randint(0, 4, (2, 4)) # Wrong last dim + + with pytest.raises(ValueError, match="dim\\[1\\] expected 3, got 4"): + Glm4vImageEmbeddingInputs( + image_embeds=image_embeds, + image_grid_thw=image_grid_thw, + ) diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 0996bcf60aa1..773b95c2d780 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -29,7 +29,7 @@ import math from collections.abc import Iterable, Mapping, Sequence from functools import partial -from typing import Any, Callable, Literal, Optional, TypedDict, Union +from typing import Annotated, Any, Callable, Literal, Optional, Union import numpy as np import torch @@ -70,6 +70,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope +from vllm.utils.tensor_schema import TensorSchema, TensorShape from ..layers.activation import SiluAndMul from .interfaces import (MultiModalEmbeddings, SupportsLoRA, @@ -88,80 +89,68 @@ _MAX_FRAMES_PER_VIDEO = 600 # === Vision Inputs === # -class Glm4vImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - pixel_values: torch.Tensor - """Shape: - `(num_patches, num_channels * patch_size * patch_size)` +class Glm4vImagePixelInputs(TensorSchema): """ - - image_grid_thw: torch.Tensor - """Shape: `(num_images, 3)` - This should be in `(grid_t, grid_h, grid_w)` format. + Dimensions: + - np: Number of patches + - cpp: Number of channels * patch_size * patch_size + - ni: Number of images + - g: Grid dimensions (3 for grid_t, grid_h, grid_w) """ + type: Literal["pixel_values"] = "pixel_values" + + pixel_values: Annotated[torch.Tensor, TensorShape("np", "cpp")] + image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)] -class Glm4vImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - image_embeds: torch.Tensor - """Supported types: - - List[`torch.Tensor`]: A list of tensors holding all images' features. - Each tensor holds an image's features. - - `torch.Tensor`: A tensor holding all images' features - (concatenation of all images' feature tensors). - - Tensor shape: `(num_image_features, hidden_size)` - - `num_image_features` varies based on - the number and resolution of the images. - - `hidden_size` must match the hidden size of language model backbone. +class Glm4vImageEmbeddingInputs(TensorSchema): """ - - image_grid_thw: torch.Tensor - """Shape: `(num_images, 3)` - This should be in `(grid_t, grid_h, grid_w)` format. + Dimensions: + - f: Number of image features (varies based on image resolution) + - h: Hidden size (must match language model backbone) + - n: Number of images + - g: Grid dimensions (3 for grid_t, grid_h, grid_w) """ + type: Literal["image_embeds"] = "image_embeds" + + image_embeds: Annotated[torch.Tensor, TensorShape("f", "h")] + image_grid_thw: Annotated[torch.Tensor, TensorShape("n", 3)] Glm4vImageInputs = Union[Glm4vImagePixelInputs, Glm4vImageEmbeddingInputs] -class Glm4vVideoPixelInputs(TypedDict): - type: Literal["pixel_values_videos"] - pixel_values_videos: torch.Tensor - """Shape: - `(num_patches, - num_channels * temporal_patch_size * patch_size * patch_size)` +class Glm4vVideoPixelInputs(TensorSchema): """ + Dimensions: + - np: Number of patches + - ctpp: Number of channels * temporal_patch_size * + patch_size * patch_size + - nv: Number of videos + - f: Number of frames + - g: Grid dimensions (3 for grid_t which is usually 1 for processed + video, grid_h, grid_w) + """ + type: Literal["pixel_values_videos"] = "pixel_values_videos" + + pixel_values_videos: Annotated[torch.Tensor, TensorShape("np", "ctpp")] # video_metadata: Union[list[VideoMetadata], list[dict]] - video_grid_thw: Union[list[torch.Tensor], torch.Tensor] - """Shape: `(num_videos, num_frames, 3)` or `(1, num_frames, 3)` - for single video. - Each entry represents [grid_t, grid_h, grid_w] format where: - - grid_t: Temporal grid size (usually 1 for processed video) - - grid_h: Height grid size - - grid_w: Width grid size - This describes the grid structure of the video patches. - """ + video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", "f", 3)] -class Glm4vVideoEmbeddingInputs(TypedDict): - type: Literal["video_embeds"] +class Glm4vVideoEmbeddingInputs(TensorSchema): + """ + Dimensions: + - p: Number of video patches across all frames + - h: Hidden size (must match language model backbone) + - n: Number of videos + - g: Grid dimensions (3 for grid_t which is usually 1 for processed + video, grid_h, grid_w) + """ + type: Literal["video_embeds"] = "video_embeds" - video_embeds: torch.Tensor - """ - Tensor shape: `(num_video_patches, hidden_size)` - - `num_video_patches`: Total number of video patches across all frames - - `hidden_size`: Must match the hidden size of language model backbone - """ - - video_grid_thw: torch.Tensor - """Shape: `(num_videos, 1, 3)` or `(1, 1, 3)` for single video - Each entry represents [grid_t, grid_h, grid_w] format where: - - grid_t: Temporal grid size (usually 1 for processed video) - - grid_h: Height grid size - - grid_w: Width grid size - This describes the grid structure of the video patches. - """ + video_embeds: Annotated[torch.Tensor, TensorShape("p", "h")] + video_grid_thw: Annotated[torch.Tensor, TensorShape("n", 1, 3)] Glm4vVideoInputs = Union[Glm4vVideoPixelInputs, Glm4vVideoEmbeddingInputs] @@ -1324,10 +1313,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, image_grid_thw = self._validate_and_reshape_mm_tensor( image_grid_thw, "image grid_thw") - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of image pixel values. " - f"Got type: {type(pixel_values)}") - return Glm4vImagePixelInputs( type="pixel_values", pixel_values=pixel_values, @@ -1340,9 +1325,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, image_grid_thw = self._validate_and_reshape_mm_tensor( image_grid_thw, "image grid_thw") - if not isinstance(image_embeds, torch.Tensor): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") return Glm4vImageEmbeddingInputs( type="image_embeds", image_embeds=image_embeds, @@ -1354,8 +1336,10 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, pixel_values_videos = kwargs.pop("pixel_values_videos", None) video_embeds = kwargs.pop("video_embeds", None) video_grid_thw = kwargs.pop("video_grid_thw", None) + if pixel_values_videos is None and video_embeds is None: return None + if pixel_values_videos is not None: pixel_values_videos = self._validate_and_reshape_mm_tensor( pixel_values_videos, "video pixel values") @@ -1375,9 +1359,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, video_grid_thw = self._validate_and_reshape_mm_tensor( video_grid_thw, "video grid_thw") - if not isinstance(video_embeds, torch.Tensor): - raise ValueError("Incorrect type of video embeddings. " - f"Got type: {type(video_embeds)}") return Glm4vVideoEmbeddingInputs( type="video_embeds", video_embeds=video_embeds,