mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:55:40 +08:00
[Model] Add support for transformer-based Ultravox v0.7 projector (#30089)
Signed-off-by: Peter Salas <peter@fixie.ai>
This commit is contained in:
parent
e3fbb6f152
commit
e858bc4d14
@ -4,15 +4,21 @@
|
||||
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
|
||||
"""PyTorch Ultravox model."""
|
||||
|
||||
import copy
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from types import SimpleNamespace
|
||||
from typing import Annotated, Any, Literal, TypeAlias
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from transformers import BatchFeature, ProcessorMixin
|
||||
from transformers.modeling_utils import ModuleUtilsMixin
|
||||
from transformers.models.whisper import WhisperFeatureExtractor
|
||||
from transformers.models.whisper.modeling_whisper import WhisperEncoder
|
||||
from transformers.models.whisper.modeling_whisper import (
|
||||
WhisperEncoder,
|
||||
WhisperEncoderLayer,
|
||||
)
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
@ -282,7 +288,7 @@ class StackAudioFrames(nn.Module):
|
||||
return audio_embeds
|
||||
|
||||
|
||||
class UltravoxProjector(nn.Module):
|
||||
class UltravoxFeedForwardProjector(nn.Module):
|
||||
def __init__(self, config: UltravoxConfig):
|
||||
super().__init__()
|
||||
self.hidden_dim = config.hidden_size
|
||||
@ -310,7 +316,9 @@ class UltravoxProjector(nn.Module):
|
||||
self.ln_mid = nn.Identity()
|
||||
self.ln_post = RMSNorm(dim_out)
|
||||
|
||||
def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
|
||||
def forward(
|
||||
self, audio_features: torch.Tensor, audio_token_len: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
audio_features = self._pad_and_stack(audio_features)
|
||||
audio_features = self.ln_pre(audio_features)
|
||||
hidden_states = self.linear_1(audio_features)
|
||||
@ -321,6 +329,70 @@ class UltravoxProjector(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class UltravoxTransformerProjector(nn.Module, ModuleUtilsMixin):
|
||||
def __init__(self, config: UltravoxConfig):
|
||||
super().__init__()
|
||||
self.config = SimpleNamespace(is_decoder=False)
|
||||
|
||||
self._pad_and_stack = StackAudioFrames(config.stack_factor)
|
||||
dim_in = config.audio_config.hidden_size * config.stack_factor
|
||||
|
||||
projector_audio_config = copy.deepcopy(config.audio_config)
|
||||
|
||||
self.ln_pre = RMSNorm(dim_in)
|
||||
self.linear_in = nn.Linear(dim_in, projector_audio_config.d_model)
|
||||
|
||||
self.embed_positions = nn.Embedding(
|
||||
projector_audio_config.max_source_positions,
|
||||
projector_audio_config.d_model,
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
WhisperEncoderLayer(projector_audio_config)
|
||||
for _ in range(config.num_projector_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.ln_post = RMSNorm(projector_audio_config.d_model)
|
||||
self.linear_out = nn.Linear(
|
||||
projector_audio_config.d_model, config.text_config.hidden_size
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, audio_features: torch.Tensor, audio_token_len: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
audio_features = self._pad_and_stack(audio_features)
|
||||
|
||||
max_len_stacked = audio_features.shape[1]
|
||||
attention_mask = torch.arange(max_len_stacked, device=audio_features.device)[
|
||||
None, :
|
||||
].lt(audio_token_len[:, None])
|
||||
extended_attention_mask = self.get_extended_attention_mask(
|
||||
attention_mask, attention_mask.shape, audio_features.dtype
|
||||
)
|
||||
|
||||
hidden_states = self.ln_pre(audio_features)
|
||||
hidden_states = self.linear_in(hidden_states)
|
||||
|
||||
positions = self.embed_positions(
|
||||
torch.arange(hidden_states.size(1), device=hidden_states.device)
|
||||
)
|
||||
hidden_states = hidden_states + positions
|
||||
|
||||
for layer in self.layers:
|
||||
layer_outputs = layer(
|
||||
hidden_states,
|
||||
attention_mask=extended_attention_mask,
|
||||
layer_head_mask=None,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
hidden_states = self.ln_post(hidden_states)
|
||||
hidden_states = self.linear_out(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ModifiedWhisperEncoder(WhisperEncoder):
|
||||
"""
|
||||
Encoder portion of OpenAI's Whisper model.
|
||||
@ -464,7 +536,10 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
||||
prefix="audio_tower.",
|
||||
)
|
||||
)
|
||||
self.multi_modal_projector = UltravoxProjector(config)
|
||||
if config.num_projector_layers > 0:
|
||||
self.multi_modal_projector = UltravoxTransformerProjector(config)
|
||||
else:
|
||||
self.multi_modal_projector = UltravoxFeedForwardProjector(config)
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
hf_config=config.wrapped_model_config,
|
||||
@ -496,7 +571,10 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
||||
)
|
||||
|
||||
def _audio_features_to_embeddings(
|
||||
self, input_features: torch.Tensor, audio_lens: torch.Tensor
|
||||
self,
|
||||
input_features: torch.Tensor,
|
||||
audio_lens: torch.Tensor,
|
||||
audio_token_len: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
audio_features = input_features.to(self.audio_tower.dtype)
|
||||
batch_size = audio_features.size(0)
|
||||
@ -512,7 +590,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
||||
batch_features = batch_features.to(self.audio_tower.dtype)
|
||||
|
||||
# Process through projector
|
||||
batch_embeddings = self.multi_modal_projector(batch_features)
|
||||
batch_embeddings = self.multi_modal_projector(
|
||||
batch_features, audio_token_len[start:end]
|
||||
)
|
||||
audio_embeddings.append(batch_embeddings)
|
||||
|
||||
# Concatenate results
|
||||
@ -559,7 +639,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
||||
audio_lens = audio_input["lens"]
|
||||
audio_token_len = audio_input["token_len"]
|
||||
|
||||
embeddings = self._audio_features_to_embeddings(audio_features, audio_lens)
|
||||
embeddings = self._audio_features_to_embeddings(
|
||||
audio_features, audio_lens, audio_token_len
|
||||
)
|
||||
|
||||
# We should flatten and concatenate embeddings based on token lengths
|
||||
# For example, with token_len = [4, 2, 3], flattened_embeddings will be
|
||||
|
||||
@ -61,6 +61,7 @@ class UltravoxConfig(transformers.PretrainedConfig):
|
||||
norm_init: float = 0.4,
|
||||
projector_act: str = "swiglu",
|
||||
projector_ln_mid: bool = False,
|
||||
num_projector_layers: int = 0,
|
||||
**kwargs,
|
||||
):
|
||||
self.ignore_index = ignore_index
|
||||
@ -71,6 +72,7 @@ class UltravoxConfig(transformers.PretrainedConfig):
|
||||
self.norm_init = norm_init
|
||||
self.projector_act = projector_act
|
||||
self.projector_ln_mid = projector_ln_mid
|
||||
self.num_projector_layers = num_projector_layers
|
||||
|
||||
# N.B. May set the wrapped_model_config below.
|
||||
self.text_model_id = text_model_id
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user