mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 12:55:02 +08:00
Migrate KimiVLImagePixelInputs to TensorSchema (#21769)
Signed-off-by: Benji Beck <benjibeck@meta.com> Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
d1bf1b9711
commit
05fae02175
@ -46,7 +46,7 @@ import copy
|
|||||||
import math
|
import math
|
||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Literal, Optional, TypedDict, Union
|
from typing import Annotated, Any, Literal, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -79,6 +79,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.transformers_utils.configs import KimiVLConfig, MoonViTConfig
|
from vllm.transformers_utils.configs import KimiVLConfig, MoonViTConfig
|
||||||
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config
|
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config
|
||||||
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|
||||||
from .utils import is_pp_missing_parameter, maybe_prefix
|
from .utils import is_pp_missing_parameter, maybe_prefix
|
||||||
|
|
||||||
@ -118,15 +119,22 @@ class KimiVLMultiModalProjector(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class KimiVLImagePixelInputs(TypedDict):
|
class KimiVLImagePixelInputs(TensorSchema):
|
||||||
type: Literal["pixel_values"]
|
|
||||||
pixel_values: Union[torch.Tensor, list[torch.Tensor]]
|
|
||||||
"""
|
"""
|
||||||
Shape:`(num_patches, num_channels, patch_size, patch_size)`
|
Dimensions:
|
||||||
|
- nc: Number of channels
|
||||||
|
- np: Number of patches
|
||||||
|
- ps: Patch size
|
||||||
|
- ni: Number of images
|
||||||
"""
|
"""
|
||||||
|
type: Literal["pixel_values"] = "pixel_values"
|
||||||
|
|
||||||
image_grid_hws: torch.Tensor
|
pixel_values: Annotated[
|
||||||
"""Shape:`(num_images, 2)`"""
|
Union[torch.Tensor, list[torch.Tensor]],
|
||||||
|
TensorShape("np", 3, "ps", "ps"),
|
||||||
|
]
|
||||||
|
|
||||||
|
image_grid_hws: Annotated[torch.Tensor, TensorShape("ni", 2)]
|
||||||
|
|
||||||
|
|
||||||
# TODO: support embeds too
|
# TODO: support embeds too
|
||||||
@ -348,8 +356,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
pixel_values = pixel_values.reshape(-1, num_channels, patch_size,
|
pixel_values = pixel_values.reshape(-1, num_channels, patch_size,
|
||||||
patch_size)
|
patch_size)
|
||||||
pixel_values = pixel_values.to(self.vision_tower.dtype)
|
pixel_values = pixel_values.to(self.vision_tower.dtype)
|
||||||
# image_grid_hws.shape = (N, 2)
|
|
||||||
assert image_grid_hws.ndim == 2, f"unexpected shape for image_grid_hws: {image_grid_hws.shape}"
|
|
||||||
|
|
||||||
return KimiVLImagePixelInputs(
|
return KimiVLImagePixelInputs(
|
||||||
type="pixel_values",
|
type="pixel_values",
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user