Migrate skyworkr1v inputs to TensorSchema (#23499)

Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
Benji Beck 2025-08-24 21:43:21 -07:00 committed by GitHub
parent 99f8094400
commit a5203d04df
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -8,7 +8,7 @@
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
# -------------------------------------------------------- # --------------------------------------------------------
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, TypedDict, Union from typing import Annotated, Literal, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -35,6 +35,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
@ -48,27 +49,42 @@ IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225) IMAGENET_STD = (0.229, 0.224, 0.225)
class SkyworkR1VImagePixelInputs(TypedDict): class SkyworkR1VImagePixelInputs(TensorSchema):
type: Literal["pixel_values"]
pixel_values_flat: torch.Tensor
""" """
Shape: Dimensions:
`(batch_size * num_images * (1 + num_patches), num_channels, height, width)` - bnp: Batch size * number of images * (1 + num_patches)
- c: Number of channels (3)
- h: Height
- w: Width
- bn: Batch size * number of images
""" """
type: Literal["pixel_values"] = "pixel_values"
num_patches: torch.Tensor pixel_values_flat: Annotated[
"""Shape: `(batch_size * num_images)`""" torch.Tensor,
TensorShape("bnp", 3, "h", "w"),
]
num_patches: Annotated[
torch.Tensor,
TensorShape("bn"),
]
class SkyworkR1VImageEmbeddingInputs(TypedDict): class SkyworkR1VImageEmbeddingInputs(TensorSchema):
type: Literal["image_embeds"]
data: Union[torch.Tensor, list[torch.Tensor]]
"""
A tensor of shape `(num_images, total_image_feature_size, hidden_size)`
or a list of tensors of shape `(total_image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
""" """
Dimensions:
- ni: Number of images
- ifs: Image feature size
- hs: Hidden size (must match the hidden size of language model
backbone)
"""
type: Literal["image_embeds"] = "image_embeds"
data: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("ni", "ifs", "hs"),
]
SkyworkR1VImageInputs = Union[SkyworkR1VImagePixelInputs, SkyworkR1VImageInputs = Union[SkyworkR1VImagePixelInputs,
@ -731,26 +747,6 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
vit_embeds = self.mlp1(vit_embeds) vit_embeds = self.mlp1(vit_embeds)
return vit_embeds return vit_embeds
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape)
if actual_dims != expected_dims:
expected_expr = str(expected_dims)
raise ValueError(
"The expected shape of pixel values per image per batch "
f" per patch is {expected_expr}. "
f"You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[SkyworkR1VImageInputs]: self, **kwargs: object) -> Optional[SkyworkR1VImageInputs]:
pixel_values_flat = kwargs.pop("pixel_values_flat", None) pixel_values_flat = kwargs.pop("pixel_values_flat", None)
@ -788,10 +784,12 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
return SkyworkR1VImagePixelInputs( return SkyworkR1VImagePixelInputs(
type="pixel_values", type="pixel_values",
pixel_values_flat=self._validate_pixel_values( pixel_values_flat=pixel_values_flat,
pixel_values_flat),
num_patches=image_num_patches, num_patches=image_num_patches,
) resolve_bindings={
"h": self.config.vision_config.image_size,
"w": self.config.vision_config.image_size,
})
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")