diff --git a/vllm/model_executor/models/phi4_multimodal.py b/vllm/model_executor/models/phi4_multimodal.py index 492d4bfb7d3e6..6d973a964de04 100644 --- a/vllm/model_executor/models/phi4_multimodal.py +++ b/vllm/model_executor/models/phi4_multimodal.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math from collections.abc import Iterable, Mapping, Sequence -from typing import Any, Literal, Optional, TypedDict, Union +from typing import Annotated, Any, Literal, Optional, Union import numpy as np import torch @@ -40,6 +40,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .idefics2_vision_model import Idefics2VisionTransformer from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal @@ -615,50 +616,90 @@ class Phi4MMAudioEmbedding(nn.Module): return loaded_params -class Phi4MMImagePixelInputs(TypedDict): +class Phi4MMImagePixelInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - p: Number of patches (1 + num_patches) + - c: Number of channels (3) + - h: Height of each image patch + - w: Width of each image patch + - nc: Number of crops + - H_mask: Height of attention mask + - W_mask: Width of attention mask + """ + type: Literal["pixel_values"] - data: Union[torch.Tensor, list[torch.Tensor]] - """ - Shape: - `(batch_size * num_images, 1 + num_patches, num_channels, height, width)` - Note that `num_patches` may be different per batch and image, - in which case the data is passed as a list instead of a batched tensor. + data: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "p", 3, "h", "w", dynamic_dims={"p"} + ), # may be different per batch and image + ] + + image_sizes: Annotated[ + torch.Tensor, + TensorShape("bn", 2), # (height, width) + ] + + num_img_tokens: Annotated[ + list[int], + TensorShape("bn"), + ] + + image_attention_mask: Annotated[ + torch.Tensor, + TensorShape("bn", "nc", 32, 32), # H_mask, W_mask + ] + + +class Phi4MMImageEmbeddingInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - f: Image feature size + - h: Hidden size (must match language model backbone) """ - image_sizes: torch.Tensor - """ - Shape: `(batch_size * num_images, 2)` - - This should be in `(height, width)` format. - """ - - num_img_tokens: list[int] - """Shape: `(batch_size * num_images)`""" - - image_attention_mask: torch.Tensor - """Shape: `(batch_size * num_images, H_mask, W_mask)`""" - - -class Phi4MMImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] - data: Union[torch.Tensor, list[torch.Tensor]] - """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` - `hidden_size` must match the hidden size of language model backbone. + data: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "f", "h"), + ] + + +class Phi4MMAudioFeatureInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of audios + - f: Number of Mel filterbank bins (80) + - t: Time frames (M) """ - -class Phi4MMAudioFeatureInputs(TypedDict): type: Literal["audio_features"] - data: Union[torch.Tensor, list[torch.Tensor]] - """Shape: `(batch_size * num_audios, 80, M)""" + + data: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "t", 80, dynamic_dims={"t"}), + ] -class Phi4MMAudioEmbeddingInputs(TypedDict): +class Phi4MMAudioEmbeddingInputs(TensorSchema): + """ + Dimensions: + - b: Batch size + - n: Number of audios + - f: Audio feature size + - h: Hidden size (must match language model backbone) + """ + type: Literal["audio_embeds"] - data: NestedTensors - """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)""" + + data: Annotated[ + NestedTensors, + TensorShape("b", "n", "f", "h"), + ] Phi4MMImageInput = Union[Phi4MMImagePixelInputs, Phi4MMImageEmbeddingInputs] @@ -1170,18 +1211,10 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): return None if audio_features is not None: - if not isinstance(audio_features, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio features. " - f"Got type: {type(audio_features)}") - return Phi4MMAudioFeatureInputs(type="audio_features", data=flatten_bn(audio_features)) if audio_embeds is not None: - if not isinstance(audio_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio embeds. " - f"Got type: {type(audio_embeds)}") - return Phi4MMAudioEmbeddingInputs(type="audio_embeds", data=audio_embeds) @@ -1259,7 +1292,7 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): elif isinstance(image_sizes, torch.Tensor): image_sizes = image_sizes.flatten(0, 1) else: - raise ValueError("Incorrect image_attention_mask inputs") + raise ValueError("Incorrect image_sizes inputs") if isinstance(num_img_tokens, list): num_img_tokens = [ @@ -1269,7 +1302,7 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): elif isinstance(num_img_tokens, torch.Tensor): num_img_tokens = num_img_tokens.flatten(0, 1).tolist() else: - raise ValueError("Incorrect image_attention_mask inputs") + raise ValueError("Incorrect num_img_tokens inputs") return Phi4MMImagePixelInputs( type="pixel_values", diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py index ac0efc2771752..352ae4064cc61 100644 --- a/vllm/model_executor/models/phi4mm.py +++ b/vllm/model_executor/models/phi4mm.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math from collections.abc import Iterable, Mapping, Sequence -from typing import Any, Literal, Optional, TypedDict, Union +from typing import Annotated, Any, Literal, Optional, Union import numpy as np import torch @@ -31,6 +31,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .idefics2_vision_model import Idefics2VisionTransformer from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal @@ -391,41 +392,71 @@ class Phi4MMImageEncoder(nn.Module): return img_set_tensor -class Phi4MMImagePixelInputs(TypedDict): +class Phi4MMImagePixelInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - p: Number of patches (1 + num_patches) + - c: Number of channels (3) + - h: Height of each image patch + - w: Width of each image patch + - nc: Number of crops + - H_mask: Height of attention mask + - W_mask: Width of attention mask + """ + type: Literal["pixel_values"] - data: Union[torch.Tensor, list[torch.Tensor]] - """ - Shape: - `(batch_size * num_images, 1 + num_patches, num_channels, height, width)` - Note that `num_patches` may be different per batch and image, - in which case the data is passed as a list instead of a batched tensor. + data: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "p", 3, "h", "w", dynamic_dims={"p"} + ), # may be different per batch and image + ] + + image_sizes: Annotated[ + torch.Tensor, + TensorShape("bn", 2), # (height, width) + ] + + num_img_tokens: Annotated[ + list[int], + TensorShape("bn"), + ] + + image_attention_mask: Annotated[ + torch.Tensor, + TensorShape("bn", "nc", 32, 32), # H_mask, W_mask + ] + + +class Phi4MMAudioFeatureInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of audios + - t: Time frames (M) """ - image_sizes: torch.Tensor - """ - Shape: `(batch_size * num_images, 2)` - - This should be in `(height, width)` format. - """ - - num_img_tokens: list[int] - """Shape: `(batch_size * num_images)`""" - - image_attention_mask: torch.Tensor - """Shape: `(batch_size * num_images, H_mask, W_mask)`""" - - -class Phi4MMAudioFeatureInputs(TypedDict): type: Literal["audio_features"] - data: Union[torch.Tensor, list[torch.Tensor]] - """Shape: `(batch_size * num_audios, 80, M)""" + + data: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "t", 80, dynamic_dims={"t"}), + ] -class Phi4MMAudioEmbeddingInputs(TypedDict): +class Phi4MMAudioEmbeddingInputs(TensorSchema): + """ + Dimensions: + - b: Batch size + - n: Number of audios + - f: Audio feature size + - h: Hidden size (must match language model backbone) + """ type: Literal["audio_embeds"] - data: NestedTensors - """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)""" + data: Annotated[ + NestedTensors, + TensorShape("b", "n", "f", "h"), + ] Phi4MMAudioInputs = Union[Phi4MMAudioFeatureInputs, Phi4MMAudioEmbeddingInputs] @@ -985,18 +1016,10 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): return None if audio_features is not None: - if not isinstance(audio_features, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio features. " - f"Got type: {type(audio_features)}") - return Phi4MMAudioFeatureInputs(type="audio_features", data=flatten_bn(audio_features)) if audio_embeds is not None: - if not isinstance(audio_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio embeds. " - f"Got type: {type(audio_embeds)}") - return Phi4MMAudioEmbeddingInputs(type="audio_embeds", data=audio_embeds) @@ -1074,7 +1097,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): elif isinstance(image_sizes, torch.Tensor): image_sizes = image_sizes.flatten(0, 1) else: - raise ValueError("Incorrect image_attention_mask inputs") + raise ValueError("Incorrect image_sizes inputs") if isinstance(num_img_tokens, list): num_img_tokens = [ @@ -1084,7 +1107,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): elif isinstance(num_img_tokens, torch.Tensor): num_img_tokens = num_img_tokens.flatten(0, 1).tolist() else: - raise ValueError("Incorrect image_attention_mask inputs") + raise ValueError("Incorrect num_img_tokens inputs") return Phi4MMImagePixelInputs( type="pixel_values",