Migrate MiniCPMOAudioInputs to TensorSchema (#21847)

Signed-off-by: Benji Beck <benjibeck@meta.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Benji Beck 2025-08-22 01:43:29 -07:00 committed by GitHub
parent 0ba1b54ac6
commit 998720859c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -24,7 +24,7 @@
# limitations under the License.
"""Inference-only MiniCPM-O model compatible with HuggingFace weights."""
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Callable, Literal, Optional, TypedDict, Union
from typing import Annotated, Any, Callable, Literal, Optional, Union
import torch
from torch import nn
@ -49,6 +49,7 @@ from vllm.multimodal.parse import (AudioItem, AudioProcessorItems,
MultiModalDataParser)
from vllm.multimodal.processing import (PromptReplacement, PromptUpdate,
PromptUpdateDetails)
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .minicpmv import (_MAX_FRAMES_PER_VIDEO, MiniCPMV2_6,
MiniCPMVDummyInputsBuilder,
@ -61,35 +62,52 @@ from .utils import (AutoWeightsLoader, cast_overflow_tensors, flatten_bn,
CPU_DEVICE = torch.device("cpu")
class MiniCPMOAudioFeatureInputs(TypedDict):
type: Literal["audio_features"]
audio_features: Union[torch.Tensor, list[torch.Tensor]]
class MiniCPMOAudioFeatureInputs(TensorSchema):
"""
Dimensions:
- bns: Batch size * number of audios * number of slices
- bn: Batch size * number of audios
- c: Number of channels
- l: Length
- s: Number of slices
"""
type: Literal["audio_features"] = "audio_features"
audio_features: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bns", "c", "l", dynamic_dims={"l"}),
]
"""
Shape: `(batch_size * num_audios * num_slices, num_channels, length)`
Slice here means chunk. Audio that is too long will be split into slices,
which is the same as image.
Padding is used therefore `audio_features` is `torch.Tensor`.
which is the same as image. Padding is used therefore `audio_features` is
`torch.Tensor`.
"""
audio_feature_lens: Union[torch.Tensor, list[torch.Tensor]]
audio_feature_lens: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "s"),
]
"""
Shape: `(batch_size * num_audios, num_slices)`
This should be feature length of each audio slice,
which equals to `audio_features.shape[-1]`
"""
class MiniCPMOAudioEmbeddingInputs(TypedDict):
type: Literal["audio_embeds"]
audio_embeds: Union[torch.Tensor, list[torch.Tensor]]
class MiniCPMOAudioEmbeddingInputs(TensorSchema):
"""
Shape: `(batch_size * num_audios, num_slices, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
instead of a batched tensor.
Dimensions:
- bn: Batch size * number of audios
- s: Number of slices
- h: Hidden size (must match language model backbone)
Length of each slice may vary, so pass it as a list.
"""
type: Literal["audio_embeds"] = "audio_embeds"
audio_embeds: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "s", "h", dynamic_dims={"s"}),
]
MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs,