[Model] Add support for transformer-based Ultravox v0.7 projector (#30089)

Signed-off-by: Peter Salas <peter@fixie.ai>
This commit is contained in:
Peter Salas 2025-12-05 20:55:43 -08:00 committed by GitHub
parent e3fbb6f152
commit e858bc4d14
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 91 additions and 7 deletions

View File

@ -4,15 +4,21 @@
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py # Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model.""" """PyTorch Ultravox model."""
import copy
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from types import SimpleNamespace
from typing import Annotated, Any, Literal, TypeAlias from typing import Annotated, Any, Literal, TypeAlias
import torch import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from transformers import BatchFeature, ProcessorMixin from transformers import BatchFeature, ProcessorMixin
from transformers.modeling_utils import ModuleUtilsMixin
from transformers.models.whisper import WhisperFeatureExtractor from transformers.models.whisper import WhisperFeatureExtractor
from transformers.models.whisper.modeling_whisper import WhisperEncoder from transformers.models.whisper.modeling_whisper import (
WhisperEncoder,
WhisperEncoderLayer,
)
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
@ -282,7 +288,7 @@ class StackAudioFrames(nn.Module):
return audio_embeds return audio_embeds
class UltravoxProjector(nn.Module): class UltravoxFeedForwardProjector(nn.Module):
def __init__(self, config: UltravoxConfig): def __init__(self, config: UltravoxConfig):
super().__init__() super().__init__()
self.hidden_dim = config.hidden_size self.hidden_dim = config.hidden_size
@ -310,7 +316,9 @@ class UltravoxProjector(nn.Module):
self.ln_mid = nn.Identity() self.ln_mid = nn.Identity()
self.ln_post = RMSNorm(dim_out) self.ln_post = RMSNorm(dim_out)
def forward(self, audio_features: torch.Tensor) -> torch.Tensor: def forward(
self, audio_features: torch.Tensor, audio_token_len: torch.Tensor
) -> torch.Tensor:
audio_features = self._pad_and_stack(audio_features) audio_features = self._pad_and_stack(audio_features)
audio_features = self.ln_pre(audio_features) audio_features = self.ln_pre(audio_features)
hidden_states = self.linear_1(audio_features) hidden_states = self.linear_1(audio_features)
@ -321,6 +329,70 @@ class UltravoxProjector(nn.Module):
return hidden_states return hidden_states
class UltravoxTransformerProjector(nn.Module, ModuleUtilsMixin):
def __init__(self, config: UltravoxConfig):
super().__init__()
self.config = SimpleNamespace(is_decoder=False)
self._pad_and_stack = StackAudioFrames(config.stack_factor)
dim_in = config.audio_config.hidden_size * config.stack_factor
projector_audio_config = copy.deepcopy(config.audio_config)
self.ln_pre = RMSNorm(dim_in)
self.linear_in = nn.Linear(dim_in, projector_audio_config.d_model)
self.embed_positions = nn.Embedding(
projector_audio_config.max_source_positions,
projector_audio_config.d_model,
)
self.layers = nn.ModuleList(
[
WhisperEncoderLayer(projector_audio_config)
for _ in range(config.num_projector_layers)
]
)
self.ln_post = RMSNorm(projector_audio_config.d_model)
self.linear_out = nn.Linear(
projector_audio_config.d_model, config.text_config.hidden_size
)
def forward(
self, audio_features: torch.Tensor, audio_token_len: torch.Tensor
) -> torch.Tensor:
audio_features = self._pad_and_stack(audio_features)
max_len_stacked = audio_features.shape[1]
attention_mask = torch.arange(max_len_stacked, device=audio_features.device)[
None, :
].lt(audio_token_len[:, None])
extended_attention_mask = self.get_extended_attention_mask(
attention_mask, attention_mask.shape, audio_features.dtype
)
hidden_states = self.ln_pre(audio_features)
hidden_states = self.linear_in(hidden_states)
positions = self.embed_positions(
torch.arange(hidden_states.size(1), device=hidden_states.device)
)
hidden_states = hidden_states + positions
for layer in self.layers:
layer_outputs = layer(
hidden_states,
attention_mask=extended_attention_mask,
layer_head_mask=None,
)
hidden_states = layer_outputs[0]
hidden_states = self.ln_post(hidden_states)
hidden_states = self.linear_out(hidden_states)
return hidden_states
class ModifiedWhisperEncoder(WhisperEncoder): class ModifiedWhisperEncoder(WhisperEncoder):
""" """
Encoder portion of OpenAI's Whisper model. Encoder portion of OpenAI's Whisper model.
@ -464,7 +536,10 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
prefix="audio_tower.", prefix="audio_tower.",
) )
) )
self.multi_modal_projector = UltravoxProjector(config) if config.num_projector_layers > 0:
self.multi_modal_projector = UltravoxTransformerProjector(config)
else:
self.multi_modal_projector = UltravoxFeedForwardProjector(config)
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
vllm_config=vllm_config, vllm_config=vllm_config,
hf_config=config.wrapped_model_config, hf_config=config.wrapped_model_config,
@ -496,7 +571,10 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
) )
def _audio_features_to_embeddings( def _audio_features_to_embeddings(
self, input_features: torch.Tensor, audio_lens: torch.Tensor self,
input_features: torch.Tensor,
audio_lens: torch.Tensor,
audio_token_len: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
audio_features = input_features.to(self.audio_tower.dtype) audio_features = input_features.to(self.audio_tower.dtype)
batch_size = audio_features.size(0) batch_size = audio_features.size(0)
@ -512,7 +590,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
batch_features = batch_features.to(self.audio_tower.dtype) batch_features = batch_features.to(self.audio_tower.dtype)
# Process through projector # Process through projector
batch_embeddings = self.multi_modal_projector(batch_features) batch_embeddings = self.multi_modal_projector(
batch_features, audio_token_len[start:end]
)
audio_embeddings.append(batch_embeddings) audio_embeddings.append(batch_embeddings)
# Concatenate results # Concatenate results
@ -559,7 +639,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
audio_lens = audio_input["lens"] audio_lens = audio_input["lens"]
audio_token_len = audio_input["token_len"] audio_token_len = audio_input["token_len"]
embeddings = self._audio_features_to_embeddings(audio_features, audio_lens) embeddings = self._audio_features_to_embeddings(
audio_features, audio_lens, audio_token_len
)
# We should flatten and concatenate embeddings based on token lengths # We should flatten and concatenate embeddings based on token lengths
# For example, with token_len = [4, 2, 3], flattened_embeddings will be # For example, with token_len = [4, 2, 3], flattened_embeddings will be

View File

@ -61,6 +61,7 @@ class UltravoxConfig(transformers.PretrainedConfig):
norm_init: float = 0.4, norm_init: float = 0.4,
projector_act: str = "swiglu", projector_act: str = "swiglu",
projector_ln_mid: bool = False, projector_ln_mid: bool = False,
num_projector_layers: int = 0,
**kwargs, **kwargs,
): ):
self.ignore_index = ignore_index self.ignore_index = ignore_index
@ -71,6 +72,7 @@ class UltravoxConfig(transformers.PretrainedConfig):
self.norm_init = norm_init self.norm_init = norm_init
self.projector_act = projector_act self.projector_act = projector_act
self.projector_ln_mid = projector_ln_mid self.projector_ln_mid = projector_ln_mid
self.num_projector_layers = num_projector_layers
# N.B. May set the wrapped_model_config below. # N.B. May set the wrapped_model_config below.
self.text_model_id = text_model_id self.text_model_id = text_model_id