Migrate KimiVLImagePixelInputs to TensorSchema (#21769)

Signed-off-by: Benji Beck <benjibeck@meta.com>
Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Benji Beck 2025-08-05 02:36:18 -07:00 committed by GitHub
parent d1bf1b9711
commit 05fae02175
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -46,7 +46,7 @@ import copy
import math
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass
from typing import Any, Literal, Optional, TypedDict, Union
from typing import Annotated, Any, Literal, Optional, Union
import torch
from torch import nn
@ -79,6 +79,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import KimiVLConfig, MoonViTConfig
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .utils import is_pp_missing_parameter, maybe_prefix
@ -118,15 +119,22 @@ class KimiVLMultiModalProjector(nn.Module):
return hidden_states
class KimiVLImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
pixel_values: Union[torch.Tensor, list[torch.Tensor]]
class KimiVLImagePixelInputs(TensorSchema):
"""
Shape:`(num_patches, num_channels, patch_size, patch_size)`
Dimensions:
- nc: Number of channels
- np: Number of patches
- ps: Patch size
- ni: Number of images
"""
type: Literal["pixel_values"] = "pixel_values"
image_grid_hws: torch.Tensor
"""Shape:`(num_images, 2)`"""
pixel_values: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("np", 3, "ps", "ps"),
]
image_grid_hws: Annotated[torch.Tensor, TensorShape("ni", 2)]
# TODO: support embeds too
@ -348,8 +356,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal):
pixel_values = pixel_values.reshape(-1, num_channels, patch_size,
patch_size)
pixel_values = pixel_values.to(self.vision_tower.dtype)
# image_grid_hws.shape = (N, 2)
assert image_grid_hws.ndim == 2, f"unexpected shape for image_grid_hws: {image_grid_hws.shape}"
return KimiVLImagePixelInputs(
type="pixel_values",