[Model] Add support for transformer-based Ultravox v0.7 projector (#30089)

Signed-off-by: Peter Salas <peter@fixie.ai>
This commit is contained in:
Peter Salas 2025-12-05 20:55:43 -08:00 committed by GitHub
parent e3fbb6f152
commit e858bc4d14
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 91 additions and 7 deletions

View File

@ -4,15 +4,21 @@
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model."""
import copy
from collections.abc import Iterable, Mapping, Sequence
from types import SimpleNamespace
from typing import Annotated, Any, Literal, TypeAlias
import torch
from torch import nn
from torch.nn import functional as F
from transformers import BatchFeature, ProcessorMixin
from transformers.modeling_utils import ModuleUtilsMixin
from transformers.models.whisper import WhisperFeatureExtractor
from transformers.models.whisper.modeling_whisper import WhisperEncoder
from transformers.models.whisper.modeling_whisper import (
WhisperEncoder,
WhisperEncoderLayer,
)
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
@ -282,7 +288,7 @@ class StackAudioFrames(nn.Module):
return audio_embeds
class UltravoxProjector(nn.Module):
class UltravoxFeedForwardProjector(nn.Module):
def __init__(self, config: UltravoxConfig):
super().__init__()
self.hidden_dim = config.hidden_size
@ -310,7 +316,9 @@ class UltravoxProjector(nn.Module):
self.ln_mid = nn.Identity()
self.ln_post = RMSNorm(dim_out)
def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
def forward(
self, audio_features: torch.Tensor, audio_token_len: torch.Tensor
) -> torch.Tensor:
audio_features = self._pad_and_stack(audio_features)
audio_features = self.ln_pre(audio_features)
hidden_states = self.linear_1(audio_features)
@ -321,6 +329,70 @@ class UltravoxProjector(nn.Module):
return hidden_states
class UltravoxTransformerProjector(nn.Module, ModuleUtilsMixin):
def __init__(self, config: UltravoxConfig):
super().__init__()
self.config = SimpleNamespace(is_decoder=False)
self._pad_and_stack = StackAudioFrames(config.stack_factor)
dim_in = config.audio_config.hidden_size * config.stack_factor
projector_audio_config = copy.deepcopy(config.audio_config)
self.ln_pre = RMSNorm(dim_in)
self.linear_in = nn.Linear(dim_in, projector_audio_config.d_model)
self.embed_positions = nn.Embedding(
projector_audio_config.max_source_positions,
projector_audio_config.d_model,
)
self.layers = nn.ModuleList(
[
WhisperEncoderLayer(projector_audio_config)
for _ in range(config.num_projector_layers)
]
)
self.ln_post = RMSNorm(projector_audio_config.d_model)
self.linear_out = nn.Linear(
projector_audio_config.d_model, config.text_config.hidden_size
)
def forward(
self, audio_features: torch.Tensor, audio_token_len: torch.Tensor
) -> torch.Tensor:
audio_features = self._pad_and_stack(audio_features)
max_len_stacked = audio_features.shape[1]
attention_mask = torch.arange(max_len_stacked, device=audio_features.device)[
None, :
].lt(audio_token_len[:, None])
extended_attention_mask = self.get_extended_attention_mask(
attention_mask, attention_mask.shape, audio_features.dtype
)
hidden_states = self.ln_pre(audio_features)
hidden_states = self.linear_in(hidden_states)
positions = self.embed_positions(
torch.arange(hidden_states.size(1), device=hidden_states.device)
)
hidden_states = hidden_states + positions
for layer in self.layers:
layer_outputs = layer(
hidden_states,
attention_mask=extended_attention_mask,
layer_head_mask=None,
)
hidden_states = layer_outputs[0]
hidden_states = self.ln_post(hidden_states)
hidden_states = self.linear_out(hidden_states)
return hidden_states
class ModifiedWhisperEncoder(WhisperEncoder):
"""
Encoder portion of OpenAI's Whisper model.
@ -464,7 +536,10 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
prefix="audio_tower.",
)
)
self.multi_modal_projector = UltravoxProjector(config)
if config.num_projector_layers > 0:
self.multi_modal_projector = UltravoxTransformerProjector(config)
else:
self.multi_modal_projector = UltravoxFeedForwardProjector(config)
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.wrapped_model_config,
@ -496,7 +571,10 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
)
def _audio_features_to_embeddings(
self, input_features: torch.Tensor, audio_lens: torch.Tensor
self,
input_features: torch.Tensor,
audio_lens: torch.Tensor,
audio_token_len: torch.Tensor,
) -> torch.Tensor:
audio_features = input_features.to(self.audio_tower.dtype)
batch_size = audio_features.size(0)
@ -512,7 +590,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
batch_features = batch_features.to(self.audio_tower.dtype)
# Process through projector
batch_embeddings = self.multi_modal_projector(batch_features)
batch_embeddings = self.multi_modal_projector(
batch_features, audio_token_len[start:end]
)
audio_embeddings.append(batch_embeddings)
# Concatenate results
@ -559,7 +639,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
audio_lens = audio_input["lens"]
audio_token_len = audio_input["token_len"]
embeddings = self._audio_features_to_embeddings(audio_features, audio_lens)
embeddings = self._audio_features_to_embeddings(
audio_features, audio_lens, audio_token_len
)
# We should flatten and concatenate embeddings based on token lengths
# For example, with token_len = [4, 2, 3], flattened_embeddings will be

View File

@ -61,6 +61,7 @@ class UltravoxConfig(transformers.PretrainedConfig):
norm_init: float = 0.4,
projector_act: str = "swiglu",
projector_ln_mid: bool = False,
num_projector_layers: int = 0,
**kwargs,
):
self.ignore_index = ignore_index
@ -71,6 +72,7 @@ class UltravoxConfig(transformers.PretrainedConfig):
self.norm_init = norm_init
self.projector_act = projector_act
self.projector_ln_mid = projector_ln_mid
self.num_projector_layers = num_projector_layers
# N.B. May set the wrapped_model_config below.
self.text_model_id = text_model_id