mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-02 18:17:55 +08:00
[Bugfix] Fix glm4.1v video_grid_thw tensor shape scheme (#21744)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
65e8466c37
commit
0ae970ed15
@ -126,7 +126,6 @@ class Glm4vVideoPixelInputs(TensorSchema):
|
|||||||
- np: Number of patches
|
- np: Number of patches
|
||||||
- ctpp: Number of channels * temporal_patch_size *
|
- ctpp: Number of channels * temporal_patch_size *
|
||||||
patch_size * patch_size
|
patch_size * patch_size
|
||||||
- nv: Number of videos
|
|
||||||
- f: Number of frames
|
- f: Number of frames
|
||||||
- g: Grid dimensions (3 for grid_t which is usually 1 for processed
|
- g: Grid dimensions (3 for grid_t which is usually 1 for processed
|
||||||
video, grid_h, grid_w)
|
video, grid_h, grid_w)
|
||||||
@ -134,8 +133,7 @@ class Glm4vVideoPixelInputs(TensorSchema):
|
|||||||
type: Literal["pixel_values_videos"] = "pixel_values_videos"
|
type: Literal["pixel_values_videos"] = "pixel_values_videos"
|
||||||
|
|
||||||
pixel_values_videos: Annotated[torch.Tensor, TensorShape("np", "ctpp")]
|
pixel_values_videos: Annotated[torch.Tensor, TensorShape("np", "ctpp")]
|
||||||
# video_metadata: Union[list[VideoMetadata], list[dict]]
|
video_grid_thw: Annotated[torch.Tensor, TensorShape("f", 3)]
|
||||||
video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", "f", 3)]
|
|
||||||
|
|
||||||
|
|
||||||
class Glm4vVideoEmbeddingInputs(TensorSchema):
|
class Glm4vVideoEmbeddingInputs(TensorSchema):
|
||||||
@ -143,14 +141,14 @@ class Glm4vVideoEmbeddingInputs(TensorSchema):
|
|||||||
Dimensions:
|
Dimensions:
|
||||||
- p: Number of video patches across all frames
|
- p: Number of video patches across all frames
|
||||||
- h: Hidden size (must match language model backbone)
|
- h: Hidden size (must match language model backbone)
|
||||||
- n: Number of videos
|
- f: Number of frames
|
||||||
- g: Grid dimensions (3 for grid_t which is usually 1 for processed
|
- g: Grid dimensions (3 for grid_t which is usually 1 for processed
|
||||||
video, grid_h, grid_w)
|
video, grid_h, grid_w)
|
||||||
"""
|
"""
|
||||||
type: Literal["video_embeds"] = "video_embeds"
|
type: Literal["video_embeds"] = "video_embeds"
|
||||||
|
|
||||||
video_embeds: Annotated[torch.Tensor, TensorShape("p", "h")]
|
video_embeds: Annotated[torch.Tensor, TensorShape("p", "h")]
|
||||||
video_grid_thw: Annotated[torch.Tensor, TensorShape("n", 1, 3)]
|
video_grid_thw: Annotated[torch.Tensor, TensorShape("f", 3)]
|
||||||
|
|
||||||
|
|
||||||
Glm4vVideoInputs = Union[Glm4vVideoPixelInputs, Glm4vVideoEmbeddingInputs]
|
Glm4vVideoInputs = Union[Glm4vVideoPixelInputs, Glm4vVideoEmbeddingInputs]
|
||||||
@ -1348,7 +1346,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
|
|
||||||
return Glm4vVideoPixelInputs(
|
return Glm4vVideoPixelInputs(
|
||||||
type="pixel_values_videos",
|
type="pixel_values_videos",
|
||||||
# video_metadata=video_metadata,
|
|
||||||
pixel_values_videos=pixel_values_videos,
|
pixel_values_videos=pixel_values_videos,
|
||||||
video_grid_thw=video_grid_thw,
|
video_grid_thw=video_grid_thw,
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user