Migrate ultravox inputs to TensorSchema (#23503)

Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
Benji Beck 2025-09-03 23:09:11 -07:00 committed by GitHub
parent 712b273f65
commit cb55ad86fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,7 +4,7 @@
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model."""
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Literal, Optional, TypedDict, Union
from typing import Annotated, Any, Literal, Optional, Union
import torch
from torch import nn
@ -31,6 +31,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
@ -43,26 +44,37 @@ _AUDIO_PLACEHOLDER_OVERRIDE = "<|audio|>"
_MAX_ENCODER_BATCH_SIZE = 16
class UltravoxAudioFeatureInputs(TypedDict):
class UltravoxAudioFeatureInputs(TensorSchema):
"""
Dimensions:
- b: batch size
- n: number of chunks
- t: Time frames (M)
- nmb: Number of mel bins
"""
type: Literal["audio_features"]
data: Union[torch.Tensor, list[torch.Tensor], list[list[torch.Tensor]]]
"""Shape: `(batch_size, num_chunks, 80, M)`"""
lens: Union[torch.Tensor, list[torch.Tensor]]
"""
Length of the audio frames. Used for attention mask in WhisperEncoder.
Shape: `(batch_size, num_chunks)`
"""
token_len: Union[torch.Tensor, list[torch.Tensor]]
"""
Length of the audio tokens. Used for flattening the audio features.
Shape: `(batch_size, num_chunks)`
"""
data: Annotated[Union[torch.Tensor, list[torch.Tensor],
list[list[torch.Tensor]]],
TensorShape("b", "n", "nmb", "t", dynamic_dims={"n"})]
lens: Annotated[Union[torch.Tensor, list[torch.Tensor]],
TensorShape("b", "n", dynamic_dims={"n"})]
"""Length of the audio frames. Used for attention mask in WhisperEncoder."""
token_len: Annotated[Union[torch.Tensor, list[torch.Tensor]],
TensorShape("b", "n", dynamic_dims={"n"})]
"""Length of the audio tokens. Used for flattening the audio features."""
class UltravoxAudioEmbeddingInputs(TypedDict):
class UltravoxAudioEmbeddingInputs(TensorSchema):
"""
Dimensions:
- b: batch size
- na: number of audios
- afs: audio feature size
- hs: hidden size
"""
type: Literal["audio_embeds"]
data: NestedTensors
"""Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)`"""
data: Annotated[Union[torch.Tensor, list[torch.Tensor]],
TensorShape("b", "na", "afs", "hs")]
UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs,
@ -484,26 +496,12 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
return None
if audio_features is not None:
if not isinstance(audio_features, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio features. "
f"Got type: {type(audio_features)}")
if not isinstance(audio_lens, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio_lens. "
f"Got type: {type(audio_features)}")
if not isinstance(audio_token_len, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio_token_len. "
f"Got type: {type(audio_features)}")
return UltravoxAudioFeatureInputs(type="audio_features",
data=audio_features,
lens=audio_lens,
token_len=audio_token_len)
if audio_embeds is not None:
if not isinstance(audio_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio embeds. "
f"Got type: {type(audio_embeds)}")
return UltravoxAudioEmbeddingInputs(type="audio_embeds",
data=audio_embeds)