Migrate Phi4 inputs to TensorSchema (#23471)

Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
Benji Beck 2025-08-31 23:05:59 -07:00 committed by GitHub
parent 499b074bfd
commit 437c3ce026
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 137 additions and 81 deletions

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Literal, Optional, TypedDict, Union
from typing import Annotated, Any, Literal, Optional, Union
import numpy as np
import torch
@ -40,6 +40,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .idefics2_vision_model import Idefics2VisionTransformer
from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
@ -615,50 +616,90 @@ class Phi4MMAudioEmbedding(nn.Module):
return loaded_params
class Phi4MMImagePixelInputs(TypedDict):
class Phi4MMImagePixelInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- p: Number of patches (1 + num_patches)
- c: Number of channels (3)
- h: Height of each image patch
- w: Width of each image patch
- nc: Number of crops
- H_mask: Height of attention mask
- W_mask: Width of attention mask
"""
type: Literal["pixel_values"]
data: Union[torch.Tensor, list[torch.Tensor]]
"""
Shape:
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
Note that `num_patches` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor.
data: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "p", 3, "h", "w", dynamic_dims={"p"}
), # may be different per batch and image
]
image_sizes: Annotated[
torch.Tensor,
TensorShape("bn", 2), # (height, width)
]
num_img_tokens: Annotated[
list[int],
TensorShape("bn"),
]
image_attention_mask: Annotated[
torch.Tensor,
TensorShape("bn", "nc", 32, 32), # H_mask, W_mask
]
class Phi4MMImageEmbeddingInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- f: Image feature size
- h: Hidden size (must match language model backbone)
"""
image_sizes: torch.Tensor
"""
Shape: `(batch_size * num_images, 2)`
This should be in `(height, width)` format.
"""
num_img_tokens: list[int]
"""Shape: `(batch_size * num_images)`"""
image_attention_mask: torch.Tensor
"""Shape: `(batch_size * num_images, H_mask, W_mask)`"""
class Phi4MMImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
data: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "f", "h"),
]
class Phi4MMAudioFeatureInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of audios
- f: Number of Mel filterbank bins (80)
- t: Time frames (M)
"""
class Phi4MMAudioFeatureInputs(TypedDict):
type: Literal["audio_features"]
data: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size * num_audios, 80, M)"""
data: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "t", 80, dynamic_dims={"t"}),
]
class Phi4MMAudioEmbeddingInputs(TypedDict):
class Phi4MMAudioEmbeddingInputs(TensorSchema):
"""
Dimensions:
- b: Batch size
- n: Number of audios
- f: Audio feature size
- h: Hidden size (must match language model backbone)
"""
type: Literal["audio_embeds"]
data: NestedTensors
"""Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)"""
data: Annotated[
NestedTensors,
TensorShape("b", "n", "f", "h"),
]
Phi4MMImageInput = Union[Phi4MMImagePixelInputs, Phi4MMImageEmbeddingInputs]
@ -1170,18 +1211,10 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
return None
if audio_features is not None:
if not isinstance(audio_features, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio features. "
f"Got type: {type(audio_features)}")
return Phi4MMAudioFeatureInputs(type="audio_features",
data=flatten_bn(audio_features))
if audio_embeds is not None:
if not isinstance(audio_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio embeds. "
f"Got type: {type(audio_embeds)}")
return Phi4MMAudioEmbeddingInputs(type="audio_embeds",
data=audio_embeds)
@ -1259,7 +1292,7 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
elif isinstance(image_sizes, torch.Tensor):
image_sizes = image_sizes.flatten(0, 1)
else:
raise ValueError("Incorrect image_attention_mask inputs")
raise ValueError("Incorrect image_sizes inputs")
if isinstance(num_img_tokens, list):
num_img_tokens = [
@ -1269,7 +1302,7 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
elif isinstance(num_img_tokens, torch.Tensor):
num_img_tokens = num_img_tokens.flatten(0, 1).tolist()
else:
raise ValueError("Incorrect image_attention_mask inputs")
raise ValueError("Incorrect num_img_tokens inputs")
return Phi4MMImagePixelInputs(
type="pixel_values",

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Literal, Optional, TypedDict, Union
from typing import Annotated, Any, Literal, Optional, Union
import numpy as np
import torch
@ -31,6 +31,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .idefics2_vision_model import Idefics2VisionTransformer
from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
@ -391,41 +392,71 @@ class Phi4MMImageEncoder(nn.Module):
return img_set_tensor
class Phi4MMImagePixelInputs(TypedDict):
class Phi4MMImagePixelInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- p: Number of patches (1 + num_patches)
- c: Number of channels (3)
- h: Height of each image patch
- w: Width of each image patch
- nc: Number of crops
- H_mask: Height of attention mask
- W_mask: Width of attention mask
"""
type: Literal["pixel_values"]
data: Union[torch.Tensor, list[torch.Tensor]]
"""
Shape:
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
Note that `num_patches` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor.
data: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "p", 3, "h", "w", dynamic_dims={"p"}
), # may be different per batch and image
]
image_sizes: Annotated[
torch.Tensor,
TensorShape("bn", 2), # (height, width)
]
num_img_tokens: Annotated[
list[int],
TensorShape("bn"),
]
image_attention_mask: Annotated[
torch.Tensor,
TensorShape("bn", "nc", 32, 32), # H_mask, W_mask
]
class Phi4MMAudioFeatureInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of audios
- t: Time frames (M)
"""
image_sizes: torch.Tensor
"""
Shape: `(batch_size * num_images, 2)`
This should be in `(height, width)` format.
"""
num_img_tokens: list[int]
"""Shape: `(batch_size * num_images)`"""
image_attention_mask: torch.Tensor
"""Shape: `(batch_size * num_images, H_mask, W_mask)`"""
class Phi4MMAudioFeatureInputs(TypedDict):
type: Literal["audio_features"]
data: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size * num_audios, 80, M)"""
data: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "t", 80, dynamic_dims={"t"}),
]
class Phi4MMAudioEmbeddingInputs(TypedDict):
class Phi4MMAudioEmbeddingInputs(TensorSchema):
"""
Dimensions:
- b: Batch size
- n: Number of audios
- f: Audio feature size
- h: Hidden size (must match language model backbone)
"""
type: Literal["audio_embeds"]
data: NestedTensors
"""Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)"""
data: Annotated[
NestedTensors,
TensorShape("b", "n", "f", "h"),
]
Phi4MMAudioInputs = Union[Phi4MMAudioFeatureInputs, Phi4MMAudioEmbeddingInputs]
@ -985,18 +1016,10 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
return None
if audio_features is not None:
if not isinstance(audio_features, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio features. "
f"Got type: {type(audio_features)}")
return Phi4MMAudioFeatureInputs(type="audio_features",
data=flatten_bn(audio_features))
if audio_embeds is not None:
if not isinstance(audio_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio embeds. "
f"Got type: {type(audio_embeds)}")
return Phi4MMAudioEmbeddingInputs(type="audio_embeds",
data=audio_embeds)
@ -1074,7 +1097,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
elif isinstance(image_sizes, torch.Tensor):
image_sizes = image_sizes.flatten(0, 1)
else:
raise ValueError("Incorrect image_attention_mask inputs")
raise ValueError("Incorrect image_sizes inputs")
if isinstance(num_img_tokens, list):
num_img_tokens = [
@ -1084,7 +1107,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
elif isinstance(num_img_tokens, torch.Tensor):
num_img_tokens = num_img_tokens.flatten(0, 1).tolist()
else:
raise ValueError("Incorrect image_attention_mask inputs")
raise ValueError("Incorrect num_img_tokens inputs")
return Phi4MMImagePixelInputs(
type="pixel_values",