diff --git a/vllm/model_executor/models/minicpmo.py b/vllm/model_executor/models/minicpmo.py index 98ea366d3a6e4..225668d87facb 100644 --- a/vllm/model_executor/models/minicpmo.py +++ b/vllm/model_executor/models/minicpmo.py @@ -24,7 +24,7 @@ # limitations under the License. """Inference-only MiniCPM-O model compatible with HuggingFace weights.""" from collections.abc import Iterable, Mapping, Sequence -from typing import Any, Callable, Literal, Optional, TypedDict, Union +from typing import Annotated, Any, Callable, Literal, Optional, Union import torch from torch import nn @@ -49,6 +49,7 @@ from vllm.multimodal.parse import (AudioItem, AudioProcessorItems, MultiModalDataParser) from vllm.multimodal.processing import (PromptReplacement, PromptUpdate, PromptUpdateDetails) +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .minicpmv import (_MAX_FRAMES_PER_VIDEO, MiniCPMV2_6, MiniCPMVDummyInputsBuilder, @@ -61,35 +62,52 @@ from .utils import (AutoWeightsLoader, cast_overflow_tensors, flatten_bn, CPU_DEVICE = torch.device("cpu") -class MiniCPMOAudioFeatureInputs(TypedDict): - type: Literal["audio_features"] - audio_features: Union[torch.Tensor, list[torch.Tensor]] +class MiniCPMOAudioFeatureInputs(TensorSchema): + """ + Dimensions: + - bns: Batch size * number of audios * number of slices + - bn: Batch size * number of audios + - c: Number of channels + - l: Length + - s: Number of slices + """ + type: Literal["audio_features"] = "audio_features" + + audio_features: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bns", "c", "l", dynamic_dims={"l"}), + ] """ - Shape: `(batch_size * num_audios * num_slices, num_channels, length)` Slice here means chunk. Audio that is too long will be split into slices, - which is the same as image. - Padding is used therefore `audio_features` is `torch.Tensor`. + which is the same as image. Padding is used therefore `audio_features` is + `torch.Tensor`. """ - audio_feature_lens: Union[torch.Tensor, list[torch.Tensor]] + audio_feature_lens: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "s"), + ] """ - Shape: `(batch_size * num_audios, num_slices)` - This should be feature length of each audio slice, which equals to `audio_features.shape[-1]` """ -class MiniCPMOAudioEmbeddingInputs(TypedDict): - type: Literal["audio_embeds"] - audio_embeds: Union[torch.Tensor, list[torch.Tensor]] +class MiniCPMOAudioEmbeddingInputs(TensorSchema): """ - Shape: `(batch_size * num_audios, num_slices, hidden_size)` - - `hidden_size` must match the hidden size of language model backbone. - instead of a batched tensor. + Dimensions: + - bn: Batch size * number of audios + - s: Number of slices + - h: Hidden size (must match language model backbone) + Length of each slice may vary, so pass it as a list. """ + type: Literal["audio_embeds"] = "audio_embeds" + + audio_embeds: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "s", "h", dynamic_dims={"s"}), + ] MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs,