Migrate Glm4vImageInputs, Glm4vVideoInputs to TensorSchema (#21678)

Signed-off-by: Benji Beck <benjibeck@meta.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Benji Beck 2025-07-27 22:36:08 -07:00 committed by GitHub
parent d8937de4c8
commit 88e46c7c8d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 73 additions and 70 deletions

View File

@ -5,6 +5,7 @@ import pytest
import torch
from vllm.model_executor.models.fuyu import FuyuImagePatchInputs
from vllm.model_executor.models.glm4_1v import Glm4vImageEmbeddingInputs
from vllm.model_executor.models.phi3v import Phi3VImagePixelInputs
@ -145,4 +146,25 @@ def test_tensor_schema_with_list_of_symbolic_dim_mismatch_in_length():
FuyuImagePatchInputs(
flat_data=flat_data,
patches_per_image=patches_per_image,
)
)
def test_valid_tensor_schema_with_static_last_dim():
image_embeds = torch.randn(256, 1024)
image_grid_thw = torch.randint(0, 4, (2, 3))
Glm4vImageEmbeddingInputs(
image_embeds=image_embeds,
image_grid_thw=image_grid_thw,
)
def test_invalid_tensor_schema_with_static_last_dim():
image_embeds = torch.randn(256, 1024)
image_grid_thw = torch.randint(0, 4, (2, 4)) # Wrong last dim
with pytest.raises(ValueError, match="dim\\[1\\] expected 3, got 4"):
Glm4vImageEmbeddingInputs(
image_embeds=image_embeds,
image_grid_thw=image_grid_thw,
)

View File

@ -29,7 +29,7 @@
import math
from collections.abc import Iterable, Mapping, Sequence
from functools import partial
from typing import Any, Callable, Literal, Optional, TypedDict, Union
from typing import Annotated, Any, Callable, Literal, Optional, Union
import numpy as np
import torch
@ -70,6 +70,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from ..layers.activation import SiluAndMul
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
@ -88,80 +89,68 @@ _MAX_FRAMES_PER_VIDEO = 600
# === Vision Inputs === #
class Glm4vImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
pixel_values: torch.Tensor
"""Shape:
`(num_patches, num_channels * patch_size * patch_size)`
class Glm4vImagePixelInputs(TensorSchema):
"""
image_grid_thw: torch.Tensor
"""Shape: `(num_images, 3)`
This should be in `(grid_t, grid_h, grid_w)` format.
Dimensions:
- np: Number of patches
- cpp: Number of channels * patch_size * patch_size
- ni: Number of images
- g: Grid dimensions (3 for grid_t, grid_h, grid_w)
"""
type: Literal["pixel_values"] = "pixel_values"
pixel_values: Annotated[torch.Tensor, TensorShape("np", "cpp")]
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
class Glm4vImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
image_embeds: torch.Tensor
"""Supported types:
- List[`torch.Tensor`]: A list of tensors holding all images' features.
Each tensor holds an image's features.
- `torch.Tensor`: A tensor holding all images' features
(concatenation of all images' feature tensors).
Tensor shape: `(num_image_features, hidden_size)`
- `num_image_features` varies based on
the number and resolution of the images.
- `hidden_size` must match the hidden size of language model backbone.
class Glm4vImageEmbeddingInputs(TensorSchema):
"""
image_grid_thw: torch.Tensor
"""Shape: `(num_images, 3)`
This should be in `(grid_t, grid_h, grid_w)` format.
Dimensions:
- f: Number of image features (varies based on image resolution)
- h: Hidden size (must match language model backbone)
- n: Number of images
- g: Grid dimensions (3 for grid_t, grid_h, grid_w)
"""
type: Literal["image_embeds"] = "image_embeds"
image_embeds: Annotated[torch.Tensor, TensorShape("f", "h")]
image_grid_thw: Annotated[torch.Tensor, TensorShape("n", 3)]
Glm4vImageInputs = Union[Glm4vImagePixelInputs, Glm4vImageEmbeddingInputs]
class Glm4vVideoPixelInputs(TypedDict):
type: Literal["pixel_values_videos"]
pixel_values_videos: torch.Tensor
"""Shape:
`(num_patches,
num_channels * temporal_patch_size * patch_size * patch_size)`
class Glm4vVideoPixelInputs(TensorSchema):
"""
Dimensions:
- np: Number of patches
- ctpp: Number of channels * temporal_patch_size *
patch_size * patch_size
- nv: Number of videos
- f: Number of frames
- g: Grid dimensions (3 for grid_t which is usually 1 for processed
video, grid_h, grid_w)
"""
type: Literal["pixel_values_videos"] = "pixel_values_videos"
pixel_values_videos: Annotated[torch.Tensor, TensorShape("np", "ctpp")]
# video_metadata: Union[list[VideoMetadata], list[dict]]
video_grid_thw: Union[list[torch.Tensor], torch.Tensor]
"""Shape: `(num_videos, num_frames, 3)` or `(1, num_frames, 3)`
for single video.
Each entry represents [grid_t, grid_h, grid_w] format where:
- grid_t: Temporal grid size (usually 1 for processed video)
- grid_h: Height grid size
- grid_w: Width grid size
This describes the grid structure of the video patches.
"""
video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", "f", 3)]
class Glm4vVideoEmbeddingInputs(TypedDict):
type: Literal["video_embeds"]
class Glm4vVideoEmbeddingInputs(TensorSchema):
"""
Dimensions:
- p: Number of video patches across all frames
- h: Hidden size (must match language model backbone)
- n: Number of videos
- g: Grid dimensions (3 for grid_t which is usually 1 for processed
video, grid_h, grid_w)
"""
type: Literal["video_embeds"] = "video_embeds"
video_embeds: torch.Tensor
"""
Tensor shape: `(num_video_patches, hidden_size)`
- `num_video_patches`: Total number of video patches across all frames
- `hidden_size`: Must match the hidden size of language model backbone
"""
video_grid_thw: torch.Tensor
"""Shape: `(num_videos, 1, 3)` or `(1, 1, 3)` for single video
Each entry represents [grid_t, grid_h, grid_w] format where:
- grid_t: Temporal grid size (usually 1 for processed video)
- grid_h: Height grid size
- grid_w: Width grid size
This describes the grid structure of the video patches.
"""
video_embeds: Annotated[torch.Tensor, TensorShape("p", "h")]
video_grid_thw: Annotated[torch.Tensor, TensorShape("n", 1, 3)]
Glm4vVideoInputs = Union[Glm4vVideoPixelInputs, Glm4vVideoEmbeddingInputs]
@ -1324,10 +1313,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw")
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of image pixel values. "
f"Got type: {type(pixel_values)}")
return Glm4vImagePixelInputs(
type="pixel_values",
pixel_values=pixel_values,
@ -1340,9 +1325,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw")
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return Glm4vImageEmbeddingInputs(
type="image_embeds",
image_embeds=image_embeds,
@ -1354,8 +1336,10 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
pixel_values_videos = kwargs.pop("pixel_values_videos", None)
video_embeds = kwargs.pop("video_embeds", None)
video_grid_thw = kwargs.pop("video_grid_thw", None)
if pixel_values_videos is None and video_embeds is None:
return None
if pixel_values_videos is not None:
pixel_values_videos = self._validate_and_reshape_mm_tensor(
pixel_values_videos, "video pixel values")
@ -1375,9 +1359,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
video_grid_thw = self._validate_and_reshape_mm_tensor(
video_grid_thw, "video grid_thw")
if not isinstance(video_embeds, torch.Tensor):
raise ValueError("Incorrect type of video embeddings. "
f"Got type: {type(video_embeds)}")
return Glm4vVideoEmbeddingInputs(
type="video_embeds",
video_embeds=video_embeds,