diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index f91c4ddb6e83..c88306580527 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -4,7 +4,7 @@ # Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py """PyTorch Ultravox model.""" from collections.abc import Iterable, Mapping, Sequence -from typing import Any, Literal, Optional, TypedDict, Union +from typing import Annotated, Any, Literal, Optional, Union import torch from torch import nn @@ -31,6 +31,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.ultravox import UltravoxConfig +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) @@ -43,26 +44,37 @@ _AUDIO_PLACEHOLDER_OVERRIDE = "<|audio|>" _MAX_ENCODER_BATCH_SIZE = 16 -class UltravoxAudioFeatureInputs(TypedDict): +class UltravoxAudioFeatureInputs(TensorSchema): + """ + Dimensions: + - b: batch size + - n: number of chunks + - t: Time frames (M) + - nmb: Number of mel bins + """ type: Literal["audio_features"] - data: Union[torch.Tensor, list[torch.Tensor], list[list[torch.Tensor]]] - """Shape: `(batch_size, num_chunks, 80, M)`""" - lens: Union[torch.Tensor, list[torch.Tensor]] - """ - Length of the audio frames. Used for attention mask in WhisperEncoder. - Shape: `(batch_size, num_chunks)` - """ - token_len: Union[torch.Tensor, list[torch.Tensor]] - """ - Length of the audio tokens. Used for flattening the audio features. - Shape: `(batch_size, num_chunks)` - """ + data: Annotated[Union[torch.Tensor, list[torch.Tensor], + list[list[torch.Tensor]]], + TensorShape("b", "n", "nmb", "t", dynamic_dims={"n"})] + lens: Annotated[Union[torch.Tensor, list[torch.Tensor]], + TensorShape("b", "n", dynamic_dims={"n"})] + """Length of the audio frames. Used for attention mask in WhisperEncoder.""" + token_len: Annotated[Union[torch.Tensor, list[torch.Tensor]], + TensorShape("b", "n", dynamic_dims={"n"})] + """Length of the audio tokens. Used for flattening the audio features.""" -class UltravoxAudioEmbeddingInputs(TypedDict): +class UltravoxAudioEmbeddingInputs(TensorSchema): + """ + Dimensions: + - b: batch size + - na: number of audios + - afs: audio feature size + - hs: hidden size + """ type: Literal["audio_embeds"] - data: NestedTensors - """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)`""" + data: Annotated[Union[torch.Tensor, list[torch.Tensor]], + TensorShape("b", "na", "afs", "hs")] UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs, @@ -484,26 +496,12 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): return None if audio_features is not None: - if not isinstance(audio_features, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio features. " - f"Got type: {type(audio_features)}") - if not isinstance(audio_lens, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio_lens. " - f"Got type: {type(audio_features)}") - if not isinstance(audio_token_len, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio_token_len. " - f"Got type: {type(audio_features)}") - return UltravoxAudioFeatureInputs(type="audio_features", data=audio_features, lens=audio_lens, token_len=audio_token_len) if audio_embeds is not None: - if not isinstance(audio_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio embeds. " - f"Got type: {type(audio_embeds)}") - return UltravoxAudioEmbeddingInputs(type="audio_embeds", data=audio_embeds)