diff --git a/vllm/model_executor/models/kimi_vl.py b/vllm/model_executor/models/kimi_vl.py index 9c0a6ba92389..1c7ddd7df7f8 100644 --- a/vllm/model_executor/models/kimi_vl.py +++ b/vllm/model_executor/models/kimi_vl.py @@ -46,7 +46,7 @@ import copy import math from collections.abc import Iterable, Mapping, Sequence from dataclasses import dataclass -from typing import Any, Literal, Optional, TypedDict, Union +from typing import Annotated, Any, Literal, Optional, Union import torch from torch import nn @@ -79,6 +79,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import KimiVLConfig, MoonViTConfig from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .utils import is_pp_missing_parameter, maybe_prefix @@ -118,15 +119,22 @@ class KimiVLMultiModalProjector(nn.Module): return hidden_states -class KimiVLImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - pixel_values: Union[torch.Tensor, list[torch.Tensor]] +class KimiVLImagePixelInputs(TensorSchema): """ - Shape:`(num_patches, num_channels, patch_size, patch_size)` + Dimensions: + - nc: Number of channels + - np: Number of patches + - ps: Patch size + - ni: Number of images """ + type: Literal["pixel_values"] = "pixel_values" - image_grid_hws: torch.Tensor - """Shape:`(num_images, 2)`""" + pixel_values: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("np", 3, "ps", "ps"), + ] + + image_grid_hws: Annotated[torch.Tensor, TensorShape("ni", 2)] # TODO: support embeds too @@ -348,8 +356,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal): pixel_values = pixel_values.reshape(-1, num_channels, patch_size, patch_size) pixel_values = pixel_values.to(self.vision_tower.dtype) - # image_grid_hws.shape = (N, 2) - assert image_grid_hws.ndim == 2, f"unexpected shape for image_grid_hws: {image_grid_hws.shape}" return KimiVLImagePixelInputs( type="pixel_values",