Migrate DeepseekVL2ImageInputs to TensorSchema (#21658)

Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
Benji Beck 2025-07-26 19:34:11 -07:00 committed by GitHub
parent ccf27cc4d4
commit 0b8caf9095
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -5,7 +5,7 @@
"""Inference-only Deepseek-VL2 model compatible with HuggingFace weights."""
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, TypedDict, Union
from typing import Annotated, Literal, Optional, Union
import torch
import torch.nn as nn
@ -36,6 +36,7 @@ from vllm.transformers_utils.processors.deepseek_vl2 import (
DeepseekVLV2Processor)
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from vllm.utils import is_list_of
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
@ -46,25 +47,30 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
_IMAGE_TOKEN = "<image>"
class DeepseekVL2ImagePixelInputs(TypedDict):
class DeepseekVL2ImagePixelInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- c: Number of channels (3)
- h: Height of each image
- w: Width of each image
"""
type: Literal["pixel_values"]
data: Union[torch.Tensor, list[torch.Tensor]]
"""
Shape: `(batch_size * num_images, num_channels, height, width)`
"""
images_spatial_crop: torch.Tensor
"""
Shape: `(batch_size * num_images, 2)`
"""
data: Annotated[Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", 3, "h", "w")]
images_spatial_crop: Annotated[torch.Tensor, TensorShape("bn", 2)]
class DeepseekVL2VImageEmbeddingInputs(TypedDict):
class DeepseekVL2VImageEmbeddingInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- f: Image feature size
- h: Hidden size (must match language model backbone)
"""
type: Literal["image_embeds"]
data: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
data: Annotated[Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "f", "h")]
DeepseekVL2ImageInputs = Union[DeepseekVL2ImagePixelInputs,
@ -439,46 +445,6 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
model = model.to(dtype=torch.get_default_dtype())
return model
def _validate_pixel_values(
self, data: Union[torch.Tensor, list[torch.Tensor]]
) -> Union[torch.Tensor, list[torch.Tensor]]:
h = w = self.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape[1:])
if actual_dims != expected_dims:
expected_expr = ("num_patches", *map(str, expected_dims))
raise ValueError(
"The expected shape of pixel values per image per batch "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _validate_images_spatial_crop(
self, data: Union[torch.Tensor, list[torch.Tensor]]
) -> Union[torch.Tensor, list[torch.Tensor]]:
expected_dims = 2
def _validate_shape(d: torch.Tensor):
actual_dims = d.size(-1)
if actual_dims != expected_dims:
expected_expr = str(expected_dims)
raise ValueError(
f"The expected shape of image sizes per image per batch "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[DeepseekVL2ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
@ -489,25 +455,18 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return None
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if not isinstance(images_spatial_crop, (torch.Tensor, list)):
raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(images_spatial_crop)}")
return DeepseekVL2ImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(flatten_bn(pixel_values)),
images_spatial_crop=self._validate_images_spatial_crop(
flatten_bn(images_spatial_crop, concat=True)))
expected_h = expected_w = self.vision_config.image_size
return DeepseekVL2ImagePixelInputs(type="pixel_values",
data=flatten_bn(pixel_values),
images_spatial_crop=flatten_bn(
images_spatial_crop,
concat=True),
resolve_bindings={
"h": expected_h,
"w": expected_w,
})
if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return DeepseekVL2VImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds),