diff --git a/vllm/model_executor/models/qwen_vl.py b/vllm/model_executor/models/qwen_vl.py index 2950ca664a98f..90200f319464b 100644 --- a/vllm/model_executor/models/qwen_vl.py +++ b/vllm/model_executor/models/qwen_vl.py @@ -11,7 +11,7 @@ import math import unicodedata from collections.abc import Collection, Mapping, Sequence, Set from functools import lru_cache, partial -from typing import Callable, Literal, Optional, TypedDict, Union +from typing import Annotated, Callable, Literal, Optional, Union import regex as re import torch @@ -40,6 +40,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) @@ -47,26 +48,34 @@ from .qwen import QWenBaseModel, QWenModel from .utils import flatten_bn, merge_multimodal_embeddings -class QwenImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - data: torch.Tensor +class QwenImagePixelInputs(TensorSchema): """ - Shape: `(batch_size * num_images, 3, image_size, image_size)` - + Dimensions: + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height + - w: Width + Note that image_size is the value in the vision config to which we resize the image to in the normalization transform. Currently multi-image support can only be leveraged by passing image embeddings directly. """ + type: Literal["pixel_values"] = "pixel_values" + data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")] -class QwenImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - data: torch.Tensor - """Shape: `(batch_size * num_images, 256, hidden_size)` - +class QwenImageEmbeddingInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - ifs: Image feature size (256) + - hs: Hidden size + `hidden_size` must match the hidden size of the language model backbone and is stored in the visual config of the model if we have one. """ + type: Literal["image_embeds"] = "image_embeds" + data: Annotated[torch.Tensor, TensorShape("bn", 256, "hs")] QwenImageInputs = Union[QwenImagePixelInputs, QwenImageEmbeddingInputs] @@ -697,19 +706,6 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA, self.transformer: QwenVLModel - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - h = w = self.config.visual["image_size"] - expected_dims = (3, h, w) - actual_dims = tuple(data.shape[1:]) - - if actual_dims != expected_dims: - expected_expr = ("batch_size", *map(str, expected_dims)) - raise ValueError( - f"The expected shape of pixel values is {expected_expr}. " - f"You supplied {tuple(data.shape)}.") - - return data - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[QwenImageInputs]: pixel_values = kwargs.pop("pixel_values", None) @@ -720,10 +716,13 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA, raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") + expected_h = expected_w = self.config.visual["image_size"] + resolve_bindings = {"h": expected_h, "w": expected_w} + return QwenImagePixelInputs( type="pixel_values", - data=self._validate_pixel_values( - flatten_bn(pixel_values, concat=True)), + data=flatten_bn(pixel_values, concat=True), + resolve_bindings=resolve_bindings, ) if image_embeds is not None: