mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 21:55:38 +08:00
Migrate MiniCPMOAudioInputs to TensorSchema (#21847)
Signed-off-by: Benji Beck <benjibeck@meta.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
0ba1b54ac6
commit
998720859c
@ -24,7 +24,7 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only MiniCPM-O model compatible with HuggingFace weights."""
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Any, Callable, Literal, Optional, TypedDict, Union
|
||||
from typing import Annotated, Any, Callable, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -49,6 +49,7 @@ from vllm.multimodal.parse import (AudioItem, AudioProcessorItems,
|
||||
MultiModalDataParser)
|
||||
from vllm.multimodal.processing import (PromptReplacement, PromptUpdate,
|
||||
PromptUpdateDetails)
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .minicpmv import (_MAX_FRAMES_PER_VIDEO, MiniCPMV2_6,
|
||||
MiniCPMVDummyInputsBuilder,
|
||||
@ -61,35 +62,52 @@ from .utils import (AutoWeightsLoader, cast_overflow_tensors, flatten_bn,
|
||||
CPU_DEVICE = torch.device("cpu")
|
||||
|
||||
|
||||
class MiniCPMOAudioFeatureInputs(TypedDict):
|
||||
type: Literal["audio_features"]
|
||||
audio_features: Union[torch.Tensor, list[torch.Tensor]]
|
||||
class MiniCPMOAudioFeatureInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- bns: Batch size * number of audios * number of slices
|
||||
- bn: Batch size * number of audios
|
||||
- c: Number of channels
|
||||
- l: Length
|
||||
- s: Number of slices
|
||||
"""
|
||||
type: Literal["audio_features"] = "audio_features"
|
||||
|
||||
audio_features: Annotated[
|
||||
Union[torch.Tensor, list[torch.Tensor]],
|
||||
TensorShape("bns", "c", "l", dynamic_dims={"l"}),
|
||||
]
|
||||
"""
|
||||
Shape: `(batch_size * num_audios * num_slices, num_channels, length)`
|
||||
Slice here means chunk. Audio that is too long will be split into slices,
|
||||
which is the same as image.
|
||||
Padding is used therefore `audio_features` is `torch.Tensor`.
|
||||
which is the same as image. Padding is used therefore `audio_features` is
|
||||
`torch.Tensor`.
|
||||
"""
|
||||
|
||||
audio_feature_lens: Union[torch.Tensor, list[torch.Tensor]]
|
||||
audio_feature_lens: Annotated[
|
||||
Union[torch.Tensor, list[torch.Tensor]],
|
||||
TensorShape("bn", "s"),
|
||||
]
|
||||
"""
|
||||
Shape: `(batch_size * num_audios, num_slices)`
|
||||
|
||||
This should be feature length of each audio slice,
|
||||
which equals to `audio_features.shape[-1]`
|
||||
"""
|
||||
|
||||
|
||||
class MiniCPMOAudioEmbeddingInputs(TypedDict):
|
||||
type: Literal["audio_embeds"]
|
||||
audio_embeds: Union[torch.Tensor, list[torch.Tensor]]
|
||||
class MiniCPMOAudioEmbeddingInputs(TensorSchema):
|
||||
"""
|
||||
Shape: `(batch_size * num_audios, num_slices, hidden_size)`
|
||||
|
||||
`hidden_size` must match the hidden size of language model backbone.
|
||||
instead of a batched tensor.
|
||||
Dimensions:
|
||||
- bn: Batch size * number of audios
|
||||
- s: Number of slices
|
||||
- h: Hidden size (must match language model backbone)
|
||||
|
||||
Length of each slice may vary, so pass it as a list.
|
||||
"""
|
||||
type: Literal["audio_embeds"] = "audio_embeds"
|
||||
|
||||
audio_embeds: Annotated[
|
||||
Union[torch.Tensor, list[torch.Tensor]],
|
||||
TensorShape("bn", "s", "h", dynamic_dims={"s"}),
|
||||
]
|
||||
|
||||
|
||||
MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user