adapt voxtral (#31095)

Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Patrick von Platen 2025-12-23 14:31:55 +01:00 committed by GitHub
parent b10d47e0e0
commit 3faa8bee57
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 739 additions and 98 deletions

View File

@ -111,4 +111,5 @@ async def test_online_serving(client, audio_assets: AudioTestAssets):
assert len(chat_completion.choices) == 1
choice = chat_completion.choices[0]
assert choice.message.content == "In the first audio clip, you hear a brief"
assert choice.finish_reason == "length"

View File

@ -860,6 +860,11 @@ _MULTIMODAL_EXAMPLE_MODELS = {
# disable this temporarily until we support HF format
is_available_online=False,
),
"VoxtralStreamingGeneration": _HfExamplesInfo(
"<place-holder>",
# disable this temporarily until we support HF format
is_available_online=False,
),
# [Encoder-decoder]
"WhisperForConditionalGeneration": _HfExamplesInfo(
"openai/whisper-large-v3-turbo",

View File

@ -1542,6 +1542,10 @@ class ModelConfig:
def is_multimodal_raw_input_only_model(self) -> bool:
return self._model_info.supports_multimodal_raw_input_only
@property
def requires_raw_input_tokens(self) -> bool:
return self._model_info.requires_raw_input_tokens
@property
def is_cross_encoder(self) -> bool:
return (

View File

@ -94,6 +94,12 @@ class SupportsMultiModal(Protocol):
`multimodal_config.mm_encoder_tp_mode="data"`.
"""
requires_raw_input_tokens: ClassVar[bool] = False
"""
A flag that indicates this model processes input id tokens
in their raw form and not input embeddings.
"""
merge_by_field_config: ClassVar[bool | None] = None
"""
[DEPRECATED] A flag that indicates which implementation of
@ -306,6 +312,10 @@ def supports_multimodal_raw_input_only(model: type[object] | object) -> bool:
return getattr(model, "supports_multimodal_raw_input_only", False)
def requires_raw_input_tokens(model: type[object] | object) -> bool:
return getattr(model, "requires_raw_input_tokens", False)
def supports_multimodal_encoder_tp_data(model: type[object] | object) -> bool:
return getattr(model, "supports_encoder_tp_data", False)

View File

@ -46,6 +46,7 @@ from .interfaces import (
has_noops,
is_attention_free,
is_hybrid,
requires_raw_input_tokens,
supports_cross_encoding,
supports_mamba_prefix_caching,
supports_multimodal,
@ -422,6 +423,7 @@ _MULTIMODAL_MODELS = {
),
"UltravoxModel": ("ultravox", "UltravoxModel"),
"VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501
"VoxtralStreamingGeneration": ("voxtral_streaming", "VoxtralStreamingGeneration"), # noqa: E501
# [Encoder-decoder]
"WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), # noqa: E501
}
@ -539,6 +541,7 @@ class _ModelInfo:
supports_cross_encoding: bool
supports_multimodal: bool
supports_multimodal_raw_input_only: bool
requires_raw_input_tokens: bool
supports_multimodal_encoder_tp_data: bool
supports_pp: bool
has_inner_state: bool
@ -562,6 +565,7 @@ class _ModelInfo:
supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
model
),
requires_raw_input_tokens=requires_raw_input_tokens(model),
supports_multimodal_encoder_tp_data=supports_multimodal_encoder_tp_data(
model
),

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import inspect
import math
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
@ -116,10 +117,7 @@ class VoxtralProcessorAdapter:
self,
audio_length: int,
) -> int:
pad_audio_length = self._audio_processor.next_multiple_of_chunk_frames(
audio_length, self.sampling_rate
)
return ceil(pad_audio_length / (self.sampling_rate // self.frame_rate))
return ceil(audio_length / (self.sampling_rate // self.frame_rate))
def __call__(
self,
@ -158,7 +156,14 @@ class VoxtralProcessorAdapter:
assert audio.ndim == 1
# pad if necessary
audio = self._audio_processor.pad(audio, self.sampling_rate)
# TODO(Patrick) - remove once mistral-common is bumped
sig = inspect.signature(self._audio_processor.pad)
if "is_online_streaming" in sig.parameters:
audio = self._audio_processor.pad(
audio, self.sampling_rate, is_online_streaming=False
)
else:
audio = self._audio_processor.pad(audio, self.sampling_rate)
audio_tokens = [self.begin_audio_token_id] + [
self.audio_token_id
@ -510,6 +515,7 @@ class VoxtralForConditionalGeneration(
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
remapping_rules = [
(r"mm_streams_embeddings.embedding_module\.(.*)", r"\1"),
(r"mm_whisper_embeddings\.(.*)", r"\1"),
(r"audio_language_projection\.(.*)", r"audio_language_adapter.\1"),
(
@ -535,13 +541,16 @@ class VoxtralForConditionalGeneration(
def llm_weights_generator():
nonlocal loaded_weights
for name, w in weights:
is_encoder = (
name.startswith("mm_whisper_embeddings")
and not name.startswith("mm_whisper_embeddings.tok_embeddings")
and not name.startswith(
"mm_whisper_embeddings.audio_language_projection"
is_encoder = False
for k in [
"mm_whisper_embeddings",
"mm_streams_embeddings.embedding_module",
]:
is_encoder |= (
name.startswith(k)
and not name.startswith(f"{k}.tok_embeddings")
and not name.startswith(f"{k}.audio_language_projection")
)
)
for pattern, repl in remapping_rules:
if re.fullmatch(pattern, name):
@ -676,6 +685,7 @@ class VoxtralEncoderModel(nn.Module):
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
mistral_remapping = [
(r"mm_streams_embeddings.embedding_module\.(.*)", r"\1"),
(
r"whisper_encoder\.conv_layers\.0\.(weight|bias)",
r"whisper_encoder.conv1.\1",
@ -684,6 +694,14 @@ class VoxtralEncoderModel(nn.Module):
r"whisper_encoder\.conv_layers\.1\.(weight|bias)",
r"whisper_encoder.conv2.\1",
),
(
r"whisper_encoder\.conv_layers\.0\.conv\.(weight|bias)",
r"whisper_encoder.conv1.\1",
), # noqa: E501
(
r"whisper_encoder\.conv_layers\.1\.conv\.(weight|bias)",
r"whisper_encoder.conv2.\1",
), # noqa: E501
(
r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.w([qkv])\.(weight|bias)", # noqa: E501
r"whisper_encoder.layers.\1.self_attn.\2_proj.\3",

View File

@ -0,0 +1,243 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Mapping
import torch
from vllm.config.vllm import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.voxtral import (
VoxtralDummyInputsBuilder,
VoxtralForConditionalGeneration,
VoxtralMultiModalProcessor,
VoxtralProcessingInfo,
)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import _I, BaseMultiModalProcessorCache
from vllm.multimodal.inputs import (
MultiModalKwargsOptionalItems,
)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (
MultiModalPromptUpdates,
PlaceholderFeaturesInfo,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from .utils import (
_flatten_embeddings,
)
logger = init_logger(__name__)
class VoxtralStreamingMultiModalProcessor(VoxtralMultiModalProcessor):
def __init__(
self,
info: _I,
dummy_inputs: BaseDummyInputsBuilder[_I],
*,
cache: BaseMultiModalProcessorCache | None = None,
) -> None:
# streaming can't make use of a cache yet
super().__init__(info, dummy_inputs, cache=None)
def _maybe_apply_prompt_updates(
self,
mm_items: MultiModalDataItems,
prompt_ids: list[int],
mm_kwargs: MultiModalKwargsOptionalItems,
mm_prompt_updates: MultiModalPromptUpdates,
is_update_applied: bool,
) -> tuple[list[int], Mapping[str, list[PlaceholderFeaturesInfo]]]:
# there are no placeholder audio tokens for streaming
# so we need to build the place placeholder positions manually
# in streaming there is always only one audio input
audios = mm_kwargs.get("audio", [])
assert len(audios) == 1, (
f"Expected only one audio input for streaming, got {mm_kwargs=}"
)
tokenizer = self.info.get_tokenizer()
audio_config = tokenizer.instruct.audio_encoder.audio_config
num_audio_samples = audios[0]["audio_arrays"].data.shape[0]
length = audio_config.num_audio_tokens(num_audio_samples)
features_info = PlaceholderFeaturesInfo(
modality="audio",
item_idx=0,
start_idx=0,
tokens=length
* [0], # only used for length computation, so we can take dummy inputs
is_embed=None,
)
return prompt_ids, {"audio": [features_info]}
class TimeEmbedding(torch.nn.Module):
"""Sinusoidal Embedding for encoding time"""
def __init__(self, dim: int, theta: float = 10000.0) -> None:
super().__init__()
self.dim = dim
self.theta = theta
inv_freq = torch.exp(
-math.log(self.theta)
* torch.arange(self.dim // 2).float()
/ (self.dim // 2)
)
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, t: torch.Tensor) -> torch.Tensor:
t = t[..., None] # (B,) -> (B, 1) or (B, T) -> (B, T, 1)
inv_freq = self.inv_freq.to(device=t.device, dtype=t.dtype)
emb = (
t * inv_freq
) # (B, 1) x (D/2,) -> (B, D/2) or (B, T, 1) x (D/2,) -> (B, T, D/2)
return torch.cat((emb.cos(), emb.sin()), dim=-1) # (B, D) or (B, T, D)
@MULTIMODAL_REGISTRY.register_processor(
VoxtralStreamingMultiModalProcessor,
info=VoxtralProcessingInfo,
dummy_inputs=VoxtralDummyInputsBuilder,
)
class VoxtralStreamingGeneration(VoxtralForConditionalGeneration):
requires_raw_input_tokens = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
self.time_embedding: TimeEmbedding = TimeEmbedding(
dim=self.config.text_config.hidden_size
)
audio_config = self.tokenizer.instruct.audio_encoder.audio_config
_n_delay_tokens = (
audio_config.frame_rate * audio_config.transcription_delay_ms / 1000
)
assert _n_delay_tokens.is_integer(), (
f"n_delay_tokens must be integer, got {_n_delay_tokens}"
)
self.n_delay_tokens = int(_n_delay_tokens)
@property
def audio_config(self):
return self.tokenizer.instruct.audio_encoder.audio_config
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
# Multi-modal token ID may exceed vocab size
handle_oov_mm_token: bool = True,
) -> torch.Tensor:
"""Pass post-conv embeddings directly as input"""
# for streaming we simply flatten the multimodal embeddings
# to be in tensor format, we treat the input ids later
assert multimodal_embeddings is not None
assert len(multimodal_embeddings) > 0, (
"For streaming you must provide a multimodal_embedding at every step."
)
mm_embeds_flat = _flatten_embeddings(multimodal_embeddings)
return mm_embeds_flat
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> torch.Tensor | IntermediateTensors:
assert inputs_embeds is not None
assert input_ids is not None
pool_size = self.config.audio_config.block_pool_size
inputs_embeds = inputs_embeds.view(
inputs_embeds.shape[0] * pool_size, inputs_embeds.shape[1] // pool_size
)
audio_hidden_states = self.whisper_encoder.whisper_encoder.forward_layers(
inputs_embeds
)
num_tokens, audio_hidden_size = audio_hidden_states.shape
assert num_tokens % self.downsample_factor == 0
audio_hidden_states = audio_hidden_states.reshape(
num_tokens // self.downsample_factor,
audio_hidden_size * self.downsample_factor,
)
audio_text_embeds = self.audio_language_adapter(audio_hidden_states)
text_embeds = self.language_model.embed_input_ids(input_ids)
# sum pool text and audio embeddings
inputs_embeds = audio_text_embeds + text_embeds
time_tensor = torch.tensor(
[self.n_delay_tokens],
device=inputs_embeds.device,
dtype=inputs_embeds.dtype,
)
inputs_embeds = inputs_embeds + self.time_embedding(time_tensor)
hidden_states = self.language_model.model(
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
)
return hidden_states
def embed_multimodal(
self, **kwargs
) -> list[torch.Tensor] | torch.Tensor | tuple[torch.Tensor, ...] | None:
"""Transform audio waveforms -> initial whisper post-conv embeddings"""
audio_inputs = self._parse_and_validate_audio_arrays(**kwargs)
assert audio_inputs is not None, (
"For streaming you must provide an audio input at every step."
)
multiple_of = self.audio_config.raw_audio_length_per_tok
assert all(
(this_audio := audio.shape[0]) % multiple_of == 0 for audio in audio_inputs
), (
f"Every input audio waveform has to be a multiple of {multiple_of}, but"
f" one is {this_audio} with {(this_audio / multiple_of)=}."
)
mel_features = [
self.whisper_encoder.compute_whisper_melspec(audio).to(
self.whisper_encoder.dtype
)
for audio in audio_inputs
]
seq_lens = [mel.shape[1] for mel in mel_features]
# [total_num_20ms_frames, hidden_size]
audio_embeddings = self.whisper_encoder.whisper_encoder.forward_conv(
mel_features
)[0]
conv_stride = self.whisper_encoder.whisper_encoder.total_stride
audio_embeddings_per_sample = audio_embeddings.split(
[s // conv_stride for s in seq_lens], dim=0
)
# audio_embeddings per sample need to be divisible by 4
pool_size = self.config.audio_config.block_pool_size
assert all(
(this_shape := sample.shape[0]) % pool_size == 0
for sample in audio_embeddings_per_sample
), f"Every audio embedding has to be a multiple of 4, but one is {this_shape}."
audio_embeddings_per_sample = [
e.view(e.shape[0] // pool_size, e.shape[1] * pool_size)
for e in audio_embeddings_per_sample
]
return audio_embeddings_per_sample

View File

@ -1,9 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
import math
from collections.abc import Iterable, Mapping, Sequence
from contextlib import nullcontext
from functools import partial
from typing import Annotated, Literal, cast
import numpy as np
@ -16,7 +18,10 @@ from transformers import (
)
from transformers.models.whisper.modeling_whisper import sinusoids
from vllm.attention.layer import Attention, AttentionType
from vllm.attention.backends.abstract import (
AttentionType,
)
from vllm.attention.layer import Attention
from vllm.attention.layers.cross_attention import CrossAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
@ -34,6 +39,11 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.whisper_utils import (
ISO639_1_SUPPORTED_LANGS,
WhisperAttentionWithBlockPooling,
WhisperCausalConv1d,
)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalDataDict,
@ -64,67 +74,11 @@ from .utils import (
logger = init_logger(__name__)
# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages
ISO639_1_SUPPORTED_LANGS = {
"af": "Afrikaans",
"ar": "Arabic",
"hy": "Armenian",
"az": "Azerbaijani",
"be": "Belarusian",
"bs": "Bosnian",
"bg": "Bulgarian",
"ca": "Catalan",
"zh": "Chinese",
"hr": "Croatian",
"cs": "Czech",
"da": "Danish",
"nl": "Dutch",
"en": "English",
"et": "Estonian",
"fi": "Finnish",
"fr": "French",
"gl": "Galician",
"de": "German",
"el": "Greek",
"he": "Hebrew",
"hi": "Hindi",
"hu": "Hungarian",
"is": "Icelandic",
"id": "Indonesian",
"it": "Italian",
"ja": "Japanese",
"kn": "Kannada",
"kk": "Kazakh",
"ko": "Korean",
"lv": "Latvian",
"lt": "Lithuanian",
"mk": "Macedonian",
"ms": "Malay",
"mr": "Marathi",
"mi": "Maori",
"ne": "Nepali",
"no": "Norwegian",
"fa": "Persian",
"pl": "Polish",
"pt": "Portuguese",
"ro": "Romanian",
"ru": "Russian",
"sr": "Serbian",
"sk": "Slovak",
"sl": "Slovenian",
"es": "Spanish",
"sw": "Swahili",
"sv": "Swedish",
"tl": "Tagalog",
"ta": "Tamil",
"th": "Thai",
"tr": "Turkish",
"uk": "Ukrainian",
"ur": "Urdu",
"vi": "Vietnamese",
"cy": "Welsh",
}
class WhisperPosEmbedType(enum.Enum):
SINUSOIDAL = "sinusoidal"
NOPE = "nope"
LEARNED = "learned"
class WhisperAudioInputs(TensorSchema):
@ -184,6 +138,8 @@ class WhisperAttention(nn.Module):
num_heads: int,
bias: bool = True,
attn_type: AttentionType = AttentionType.DECODER,
per_layer_sliding_window: int | None = None,
block_pool_size: int = 1,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
@ -242,7 +198,14 @@ class WhisperAttention(nn.Module):
attn_type=self.attn_type,
)
else: # AttentionType.DECODER (regular decoder self-attention)
self.attn = Attention(
if block_pool_size > 1:
attn_cls = partial(
WhisperAttentionWithBlockPooling, block_pool_size=block_pool_size
)
else:
attn_cls = Attention
self.attn = attn_cls(
self.num_heads,
self.head_dim,
self.scaling,
@ -251,6 +214,7 @@ class WhisperAttention(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.attn",
attn_type=self.attn_type,
per_layer_sliding_window=per_layer_sliding_window,
)
def _init_qkv(
@ -386,6 +350,9 @@ class WhisperEncoderLayer(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
is_causal = getattr(config, "is_causal", False)
sliding_window = getattr(config, "sliding_window", None)
block_pool_size = getattr(config, "block_pool_size", 1)
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
@ -393,7 +360,9 @@ class WhisperEncoderLayer(nn.Module):
self.self_attn = WhisperAttention(
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
attn_type=AttentionType.ENCODER,
attn_type=AttentionType.DECODER if is_causal else AttentionType.ENCODER,
block_pool_size=block_pool_size,
per_layer_sliding_window=sliding_window,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
@ -492,12 +461,21 @@ class WhisperEncoder(nn.Module):
super().__init__()
config = vllm_config.model_config.hf_config
embed_dim = config.d_model
self.pos_embed_type = WhisperPosEmbedType(
getattr(config, "pos_embed", "sinusoidal")
)
self.num_mel_bins = config.num_mel_bins
self.max_source_positions = config.max_source_positions
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
is_causal = getattr(config, "is_causal", False)
Conv1d = WhisperCausalConv1d if is_causal else partial(nn.Conv1d, padding=1)
self.conv1 = Conv1d(self.num_mel_bins, embed_dim, kernel_size=3)
self.conv2 = Conv1d(embed_dim, embed_dim, stride=2, kernel_size=3)
self.total_stride = self.conv1.stride[0] * self.conv2.stride[0]
self.start_layer, self.end_layer, self.layers = make_layers(
config.encoder_layers,
lambda prefix: WhisperEncoderLayer(
@ -507,29 +485,54 @@ class WhisperEncoder(nn.Module):
)
self.layer_norm = nn.LayerNorm(config.d_model)
maybe_fp32_init_ctx = (
set_default_torch_dtype(torch.float32) if init_in_fp32 else nullcontext()
)
with (
torch.no_grad(),
maybe_fp32_init_ctx,
if is_causal and self.pos_embed_type != WhisperPosEmbedType.NOPE:
raise ValueError(
"Only NOPE position embeddings are supported "
f"for causal models, but got {self.pos_embed_type}"
)
elif self.pos_embed_type in (
WhisperPosEmbedType.SINUSOIDAL,
WhisperPosEmbedType.LEARNED,
):
self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)
self.embed_positions.weight.copy_(
sinusoids(*self.embed_positions.weight.shape)
maybe_fp32_init_ctx = (
set_default_torch_dtype(torch.float32)
if init_in_fp32
else nullcontext()
)
def forward(self, input_features: torch.Tensor | list[torch.Tensor]):
with (
torch.no_grad(),
maybe_fp32_init_ctx,
):
self.embed_positions = nn.Embedding(
self.max_source_positions, embed_dim
)
self.embed_positions.weight.copy_(
sinusoids(*self.embed_positions.weight.shape)
)
def forward_conv(
self, input_features: torch.Tensor | list[torch.Tensor]
) -> torch.Tensor:
hidden_states = []
input_is_batched = False
for features in input_features:
embeds = nn.functional.gelu(self.conv1(features))
embeds = nn.functional.gelu(self.conv2(embeds))
embeds = embeds.transpose(-1, -2)
embeds = (embeds + self.embed_positions.weight[: embeds.size(-2), :]).to(
embeds.dtype
)
if self.pos_embed_type in (
WhisperPosEmbedType.SINUSOIDAL,
WhisperPosEmbedType.LEARNED,
):
embeds = embeds.transpose(-1, -2)
embeds = (
embeds + self.embed_positions.weight[: embeds.size(-2), :]
).to(embeds.dtype)
elif self.pos_embed_type == WhisperPosEmbedType.NOPE:
embeds = embeds.transpose(-1, -2).to(embeds.dtype)
else:
raise ValueError(f"Unknown pos_embed_type: {self.pos_embed_type}")
hidden_states.append(embeds)
input_is_batched = embeds.ndim > 2
# Input to MHA must be B x T x D
@ -539,12 +542,19 @@ class WhisperEncoder(nn.Module):
else:
hidden_states = torch.stack(hidden_states, dim=0)
return hidden_states
def forward_layers(self, hidden_states: torch.Tensor) -> torch.Tensor:
for encoder_layer in self.layers:
hidden_states = encoder_layer(hidden_states)
hidden_states = self.layer_norm(hidden_states)
return hidden_states
def forward(self, input_features: torch.Tensor | list[torch.Tensor]):
hidden_states = self.forward_conv(input_features)
return self.forward_layers(hidden_states)
class WhisperDecoder(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

View File

@ -0,0 +1,299 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import functools
import math
from dataclasses import replace
import torch
import torch.nn.functional as F
from torch import nn
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionMetadata,
AttentionType,
)
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig, VllmConfig
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
subclass_attention_backend_with_overrides,
)
from vllm.v1.kv_cache_interface import AttentionSpec
# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages
ISO639_1_SUPPORTED_LANGS = {
"af": "Afrikaans",
"ar": "Arabic",
"hy": "Armenian",
"az": "Azerbaijani",
"be": "Belarusian",
"bs": "Bosnian",
"bg": "Bulgarian",
"ca": "Catalan",
"zh": "Chinese",
"hr": "Croatian",
"cs": "Czech",
"da": "Danish",
"nl": "Dutch",
"en": "English",
"et": "Estonian",
"fi": "Finnish",
"fr": "French",
"gl": "Galician",
"de": "German",
"el": "Greek",
"he": "Hebrew",
"hi": "Hindi",
"hu": "Hungarian",
"is": "Icelandic",
"id": "Indonesian",
"it": "Italian",
"ja": "Japanese",
"kn": "Kannada",
"kk": "Kazakh",
"ko": "Korean",
"lv": "Latvian",
"lt": "Lithuanian",
"mk": "Macedonian",
"ms": "Malay",
"mr": "Marathi",
"mi": "Maori",
"ne": "Nepali",
"no": "Norwegian",
"fa": "Persian",
"pl": "Polish",
"pt": "Portuguese",
"ro": "Romanian",
"ru": "Russian",
"sr": "Serbian",
"sk": "Slovak",
"sl": "Slovenian",
"es": "Spanish",
"sw": "Swahili",
"sv": "Swedish",
"tl": "Tagalog",
"ta": "Tamil",
"th": "Thai",
"tr": "Turkish",
"uk": "Ukrainian",
"ur": "Urdu",
"vi": "Vietnamese",
"cy": "Welsh",
}
def _pad1d(
x: torch.Tensor,
paddings: tuple[int, int],
mode: str = "constant",
value: float = 0.0,
) -> torch.Tensor:
"""Tiny wrapper around F.pad, just to allow for
reflect padding on small input.
If this is the case, we insert extra 0 padding
to the right before the reflection happen.
"""
length = x.shape[-1]
padding_left, padding_right = paddings
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
if mode == "reflect":
max_pad = max(padding_left, padding_right)
extra_pad = 0
if length <= max_pad:
extra_pad = max_pad - length + 1
x = F.pad(x, (0, extra_pad))
padded = F.pad(x, paddings, mode, value)
end = padded.shape[-1] - extra_pad
return padded[..., :end]
else:
return F.pad(x, paddings, mode, value)
class WhisperCausalConv1d(nn.Conv1d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
bias: bool = True,
) -> None:
super().__init__(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
bias=bias,
)
self._stride = self.stride[0]
self._effective_kernel_size = (kernel_size - 1) * self.dilation[0] + 1
self._padding_total = self._effective_kernel_size - self._stride
def forward(self, x: torch.Tensor) -> torch.Tensor:
n_frames = (
x.shape[-1] - self._effective_kernel_size + self._padding_total
) / self._stride + 1
target_length = (math.ceil(n_frames) - 1) * self._stride + (
self._effective_kernel_size - self._padding_total
)
extra_padding = target_length - x.shape[-1]
x = _pad1d(x, (self._padding_total, extra_padding), mode="constant")
return super().forward(x)
@functools.lru_cache
def create_whisper_attention_backend_with_block_pooling(
underlying_attn_backend: AttentionBackend, block_pool_size: int
) -> type[AttentionBackend]:
prefix = "WhisperAttentionWithBlockPooling_"
underlying_builder = underlying_attn_backend.get_builder_cls()
class WhisperAttentionWithBlockPoolingBuilder(underlying_builder): # type: ignore
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
assert kv_cache_spec.num_kv_heads % block_pool_size == 0
kv_cache_spec = replace(
kv_cache_spec,
block_size=kv_cache_spec.block_size * block_pool_size,
num_kv_heads=kv_cache_spec.num_kv_heads // block_pool_size,
)
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> AttentionMetadata:
new_common_attn_metadata = copy.deepcopy(common_attn_metadata)
new_common_attn_metadata.query_start_loc *= block_pool_size
new_common_attn_metadata.query_start_loc_cpu *= block_pool_size
new_common_attn_metadata.seq_lens *= block_pool_size
new_common_attn_metadata._seq_lens_cpu *= block_pool_size
new_common_attn_metadata._num_computed_tokens_cpu *= block_pool_size
new_common_attn_metadata.num_actual_tokens *= block_pool_size
new_common_attn_metadata.max_query_len *= block_pool_size
new_common_attn_metadata.max_seq_len *= block_pool_size
original_slot_mapping = common_attn_metadata.slot_mapping
common_prefix_len *= block_pool_size
new_common_attn_metadata.slot_mapping = (
(
original_slot_mapping.unsqueeze(1) * block_pool_size
+ torch.arange(block_pool_size, device=original_slot_mapping.device)
)
.flatten()
.clamp(min=-1)
)
return super().build(
common_prefix_len, new_common_attn_metadata, fast_build
)
if not issubclass(underlying_attn_backend, FlashAttentionBackend):
raise NotImplementedError(
f"{underlying_attn_backend} is not yet supported."
"Contributions to support more backends are much "
"appreciated."
)
attn_backend = subclass_attention_backend_with_overrides(
name_prefix=prefix,
attention_backend_cls=underlying_attn_backend,
overrides={
"get_builder_cls": lambda: WhisperAttentionWithBlockPoolingBuilder,
"get_kv_cache_shape": lambda num_blocks,
block_size,
num_kv_heads,
head_size,
cache_dtype_str: (
2,
num_blocks,
# we stretch each block by `block_pool_size`
block_size * block_pool_size,
num_kv_heads // block_pool_size,
head_size,
), # TODO: generalize to other backends
},
)
return attn_backend
class WhisperAttentionWithBlockPooling(Attention):
"""Attention layer with block pooling."""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int | None = None,
alibi_slopes: list[float] | None = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
logits_soft_cap: float | None = None,
per_layer_sliding_window: int | None = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: str | None = None,
block_pool_size: int = 1,
attn_backend: type[AttentionBackend] | None = None,
**extra_impl_args,
) -> None:
self.block_pool_size = block_pool_size
dtype = torch.get_default_dtype()
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
else:
kv_cache_dtype = "auto"
block_size = 16
underlying_attn_backend = get_attn_backend(
head_size,
dtype,
kv_cache_dtype,
block_size,
attn_type=attn_type,
)
attn_backend = create_whisper_attention_backend_with_block_pooling(
underlying_attn_backend, block_pool_size
)
super().__init__(
num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=num_kv_heads,
alibi_slopes=alibi_slopes,
cache_config=cache_config,
quant_config=quant_config,
logits_soft_cap=logits_soft_cap,
per_layer_sliding_window=per_layer_sliding_window,
prefix=prefix,
attn_type=attn_type,
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
attn_backend=attn_backend,
**extra_impl_args,
)
def get_kv_cache_spec(self, vllm_config: VllmConfig):
kv_cache_spec = super().get_kv_cache_spec(vllm_config)
assert isinstance(kv_cache_spec, AttentionSpec)
kv_cache_spec = replace(
kv_cache_spec,
num_kv_heads=self.block_pool_size * kv_cache_spec.num_kv_heads,
)
return kv_cache_spec

View File

@ -184,18 +184,42 @@ def _remap_mistral_audio_args(config: dict) -> dict:
whisper_args = config["multimodal"].pop("whisper_model_args")
encoder_args = whisper_args["encoder_args"]
downsample_args = whisper_args["downsample_args"]
downsample_factor = downsample_args["downsample_factor"]
# make sure that k/v blocks can be allocated with
# unified k/v cache class and pool whisper k/v cache blocks
# with downsample_factor:1 ratio
if encoder_args.get("causal"):
block_pool_size = downsample_factor
config["projection_size"] = downsample_factor * encoder_args["dim"]
else:
block_pool_size = 1
_maybe_sliding_window = encoder_args.get("ragged_attention", None)
if _maybe_sliding_window is None:
sliding_window = None
elif _maybe_sliding_window.isdigit():
sliding_window = int(_maybe_sliding_window)
else:
raise NotImplementedError(f"Unsupported: {_maybe_sliding_window=}")
architecture = (
"VoxtralStreamingGeneration"
if encoder_args.get("causal")
else "VoxtralForConditionalGeneration"
)
quant_config = config.get("quantization_config")
config = {
"model_type": "whixtral",
"architectures": ["VoxtralForConditionalGeneration"],
"model_type": "voxtral",
"architectures": [architecture],
"text_config": PretrainedConfig.from_dict(config),
"audio_config": WhisperConfig(
num_mel_bins=encoder_args["audio_encoding_args"]["num_mel_bins"],
window_size=encoder_args["audio_encoding_args"]["window_size"],
sampling_rate=encoder_args["audio_encoding_args"]["sampling_rate"],
hop_length=encoder_args["audio_encoding_args"]["hop_length"],
downsample_factor=downsample_args["downsample_factor"],
downsample_factor=downsample_factor,
d_model=encoder_args["dim"],
encoder_layers=encoder_args["n_layers"],
encoder_ffn_dim=encoder_args["hidden_dim"],
@ -203,6 +227,10 @@ def _remap_mistral_audio_args(config: dict) -> dict:
vocab_size=encoder_args["vocab_size"],
max_source_positions=encoder_args["max_source_positions"],
is_encoder_decoder=False, # Override WhisperConfig default
is_causal=encoder_args.get("causal", False),
sliding_window=sliding_window,
block_pool_size=block_pool_size,
pos_embed=encoder_args.get("pos_embed", "sinusoidal"),
),
}
if quant_config:

View File

@ -835,6 +835,15 @@ def subclass_attention_backend(
)
def subclass_attention_backend_with_overrides(
name_prefix: str,
attention_backend_cls: type[AttentionBackend],
overrides: dict[str, Any],
) -> type[AttentionBackend]:
name: str = name_prefix + attention_backend_cls.__name__ # type: ignore
return type(name, (attention_backend_cls,), overrides)
def split_decodes_prefills_and_extends(
common_attn_metadata: CommonAttentionMetadata,
decode_threshold: int = 1,

View File

@ -2457,6 +2457,17 @@ class GPUModelRunner(
return round_up(num_scheduled_tokens, tp_size)
return num_scheduled_tokens
def _prepare_mm_inputs(
self, num_tokens: int
) -> tuple[torch.Tensor | None, torch.Tensor]:
if self.model.requires_raw_input_tokens:
input_ids = self.input_ids.gpu[:num_tokens]
else:
input_ids = None
inputs_embeds = self.inputs_embeds.gpu[:num_tokens]
return input_ids, inputs_embeds
def _preprocess(
self,
scheduler_output: "SchedulerOutput",
@ -2499,8 +2510,7 @@ class GPUModelRunner(
# TODO(woosuk): Avoid the copy. Optimize.
self.inputs_embeds.gpu[:num_scheduled_tokens].copy_(inputs_embeds_scheduled)
input_ids = None
inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens]
input_ids, inputs_embeds = self._prepare_mm_inputs(num_input_tokens)
model_kwargs = {
**self._init_model_kwargs(num_scheduled_tokens),
**self._extract_mm_kwargs(scheduler_output),
@ -4220,8 +4230,8 @@ class GPUModelRunner(
assert num_tokens_padded <= self.max_num_tokens
model_kwargs = self._init_model_kwargs(num_tokens_padded)
if self.supports_mm_inputs and not self.model_config.is_encoder_decoder:
input_ids = None
inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded]
input_ids, inputs_embeds = self._prepare_mm_inputs(num_tokens_padded)
model_kwargs = {
**model_kwargs,
**self._dummy_mm_kwargs(num_reqs),