diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 26a8355cd22b..2444159b2ad6 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -4,15 +4,21 @@ # Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py """PyTorch Ultravox model.""" +import copy from collections.abc import Iterable, Mapping, Sequence +from types import SimpleNamespace from typing import Annotated, Any, Literal, TypeAlias import torch from torch import nn from torch.nn import functional as F from transformers import BatchFeature, ProcessorMixin +from transformers.modeling_utils import ModuleUtilsMixin from transformers.models.whisper import WhisperFeatureExtractor -from transformers.models.whisper.modeling_whisper import WhisperEncoder +from transformers.models.whisper.modeling_whisper import ( + WhisperEncoder, + WhisperEncoderLayer, +) from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions @@ -282,7 +288,7 @@ class StackAudioFrames(nn.Module): return audio_embeds -class UltravoxProjector(nn.Module): +class UltravoxFeedForwardProjector(nn.Module): def __init__(self, config: UltravoxConfig): super().__init__() self.hidden_dim = config.hidden_size @@ -310,7 +316,9 @@ class UltravoxProjector(nn.Module): self.ln_mid = nn.Identity() self.ln_post = RMSNorm(dim_out) - def forward(self, audio_features: torch.Tensor) -> torch.Tensor: + def forward( + self, audio_features: torch.Tensor, audio_token_len: torch.Tensor + ) -> torch.Tensor: audio_features = self._pad_and_stack(audio_features) audio_features = self.ln_pre(audio_features) hidden_states = self.linear_1(audio_features) @@ -321,6 +329,70 @@ class UltravoxProjector(nn.Module): return hidden_states +class UltravoxTransformerProjector(nn.Module, ModuleUtilsMixin): + def __init__(self, config: UltravoxConfig): + super().__init__() + self.config = SimpleNamespace(is_decoder=False) + + self._pad_and_stack = StackAudioFrames(config.stack_factor) + dim_in = config.audio_config.hidden_size * config.stack_factor + + projector_audio_config = copy.deepcopy(config.audio_config) + + self.ln_pre = RMSNorm(dim_in) + self.linear_in = nn.Linear(dim_in, projector_audio_config.d_model) + + self.embed_positions = nn.Embedding( + projector_audio_config.max_source_positions, + projector_audio_config.d_model, + ) + + self.layers = nn.ModuleList( + [ + WhisperEncoderLayer(projector_audio_config) + for _ in range(config.num_projector_layers) + ] + ) + + self.ln_post = RMSNorm(projector_audio_config.d_model) + self.linear_out = nn.Linear( + projector_audio_config.d_model, config.text_config.hidden_size + ) + + def forward( + self, audio_features: torch.Tensor, audio_token_len: torch.Tensor + ) -> torch.Tensor: + audio_features = self._pad_and_stack(audio_features) + + max_len_stacked = audio_features.shape[1] + attention_mask = torch.arange(max_len_stacked, device=audio_features.device)[ + None, : + ].lt(audio_token_len[:, None]) + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, attention_mask.shape, audio_features.dtype + ) + + hidden_states = self.ln_pre(audio_features) + hidden_states = self.linear_in(hidden_states) + + positions = self.embed_positions( + torch.arange(hidden_states.size(1), device=hidden_states.device) + ) + hidden_states = hidden_states + positions + + for layer in self.layers: + layer_outputs = layer( + hidden_states, + attention_mask=extended_attention_mask, + layer_head_mask=None, + ) + hidden_states = layer_outputs[0] + + hidden_states = self.ln_post(hidden_states) + hidden_states = self.linear_out(hidden_states) + return hidden_states + + class ModifiedWhisperEncoder(WhisperEncoder): """ Encoder portion of OpenAI's Whisper model. @@ -464,7 +536,10 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): prefix="audio_tower.", ) ) - self.multi_modal_projector = UltravoxProjector(config) + if config.num_projector_layers > 0: + self.multi_modal_projector = UltravoxTransformerProjector(config) + else: + self.multi_modal_projector = UltravoxFeedForwardProjector(config) self.language_model = init_vllm_registered_model( vllm_config=vllm_config, hf_config=config.wrapped_model_config, @@ -496,7 +571,10 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): ) def _audio_features_to_embeddings( - self, input_features: torch.Tensor, audio_lens: torch.Tensor + self, + input_features: torch.Tensor, + audio_lens: torch.Tensor, + audio_token_len: torch.Tensor, ) -> torch.Tensor: audio_features = input_features.to(self.audio_tower.dtype) batch_size = audio_features.size(0) @@ -512,7 +590,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): batch_features = batch_features.to(self.audio_tower.dtype) # Process through projector - batch_embeddings = self.multi_modal_projector(batch_features) + batch_embeddings = self.multi_modal_projector( + batch_features, audio_token_len[start:end] + ) audio_embeddings.append(batch_embeddings) # Concatenate results @@ -559,7 +639,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): audio_lens = audio_input["lens"] audio_token_len = audio_input["token_len"] - embeddings = self._audio_features_to_embeddings(audio_features, audio_lens) + embeddings = self._audio_features_to_embeddings( + audio_features, audio_lens, audio_token_len + ) # We should flatten and concatenate embeddings based on token lengths # For example, with token_len = [4, 2, 3], flattened_embeddings will be diff --git a/vllm/transformers_utils/configs/ultravox.py b/vllm/transformers_utils/configs/ultravox.py index fc0360a9ecb4..395b3130d40a 100644 --- a/vllm/transformers_utils/configs/ultravox.py +++ b/vllm/transformers_utils/configs/ultravox.py @@ -61,6 +61,7 @@ class UltravoxConfig(transformers.PretrainedConfig): norm_init: float = 0.4, projector_act: str = "swiglu", projector_ln_mid: bool = False, + num_projector_layers: int = 0, **kwargs, ): self.ignore_index = ignore_index @@ -71,6 +72,7 @@ class UltravoxConfig(transformers.PretrainedConfig): self.norm_init = norm_init self.projector_act = projector_act self.projector_ln_mid = projector_ln_mid + self.num_projector_layers = num_projector_layers # N.B. May set the wrapped_model_config below. self.text_model_id = text_model_id