mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:45:00 +08:00
Migrate AriaImagePixelInputs to TensorSchema for shape validation (#21620)
Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
parent
e98def439c
commit
9d197280fa
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Optional, TypedDict, Union
|
||||
from typing import Annotated, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -29,6 +29,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
# yapf: disable
|
||||
from .idefics2_vision_model import Idefics2VisionConfig
|
||||
@ -42,15 +43,26 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
merge_multimodal_embeddings)
|
||||
|
||||
|
||||
class AriaImagePixelInputs(TypedDict):
|
||||
pixel_values: torch.Tensor
|
||||
pixel_mask: Optional[torch.Tensor]
|
||||
class AriaImagePixelInputs(TensorSchema):
|
||||
"""
|
||||
Shape:
|
||||
pixel_values: `(batch_size * num_images, num_channels, height, width)`
|
||||
pixel_mask: `(batch_size * num_images, height, width)`
|
||||
Dimensions:
|
||||
- b: Batch size
|
||||
- n: Number of images
|
||||
- c: Number of channels
|
||||
- h: Height of each image
|
||||
- w: Width of each image
|
||||
"""
|
||||
|
||||
pixel_values: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("bn", 3, "h", "w"),
|
||||
]
|
||||
|
||||
pixel_mask: Annotated[
|
||||
Optional[torch.Tensor],
|
||||
TensorShape("bn", "h", "w"),
|
||||
]
|
||||
|
||||
|
||||
class AriaVisionTransformer(Idefics3VisionTransformer, SupportsQuant):
|
||||
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
|
||||
@ -540,12 +552,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
self.vocab_size, logit_scale)
|
||||
|
||||
def _validate_image_sizes(
|
||||
self, images: list[torch.Tensor]) -> list[torch.Tensor]:
|
||||
if not all(img.shape == images[0].shape for img in images):
|
||||
raise ValueError("All images must be the same size")
|
||||
return images
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[AriaImagePixelInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
@ -554,23 +560,9 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
if pixel_values is None:
|
||||
return None
|
||||
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
pixel_values = self._validate_image_sizes(pixel_values)
|
||||
pixel_values = flatten_bn(pixel_values, concat=True)
|
||||
|
||||
if pixel_mask is not None:
|
||||
if not isinstance(pixel_mask, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel mask. "
|
||||
f"Got type: {type(pixel_mask)}")
|
||||
|
||||
pixel_mask = flatten_bn(pixel_mask, concat=True)
|
||||
|
||||
return AriaImagePixelInputs(
|
||||
pixel_values=pixel_values,
|
||||
pixel_mask=pixel_mask,
|
||||
pixel_values=flatten_bn(pixel_values, concat=True),
|
||||
pixel_mask=flatten_bn(pixel_mask, concat=True),
|
||||
)
|
||||
|
||||
def _create_patch_attention_mask(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user