diff --git a/tests/models/multimodal/generation/test_voxtral.py b/tests/models/multimodal/generation/test_voxtral.py index 0eaef49e2395c..9f8415c0c390c 100644 --- a/tests/models/multimodal/generation/test_voxtral.py +++ b/tests/models/multimodal/generation/test_voxtral.py @@ -111,4 +111,5 @@ async def test_online_serving(client, audio_assets: AudioTestAssets): assert len(chat_completion.choices) == 1 choice = chat_completion.choices[0] + assert choice.message.content == "In the first audio clip, you hear a brief" assert choice.finish_reason == "length" diff --git a/tests/models/registry.py b/tests/models/registry.py index 01f7fe64aa850..2922414cdaa6a 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -860,6 +860,11 @@ _MULTIMODAL_EXAMPLE_MODELS = { # disable this temporarily until we support HF format is_available_online=False, ), + "VoxtralStreamingGeneration": _HfExamplesInfo( + "", + # disable this temporarily until we support HF format + is_available_online=False, + ), # [Encoder-decoder] "WhisperForConditionalGeneration": _HfExamplesInfo( "openai/whisper-large-v3-turbo", diff --git a/vllm/config/model.py b/vllm/config/model.py index f2f39ac6af022..6e199adbf3ee6 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -1542,6 +1542,10 @@ class ModelConfig: def is_multimodal_raw_input_only_model(self) -> bool: return self._model_info.supports_multimodal_raw_input_only + @property + def requires_raw_input_tokens(self) -> bool: + return self._model_info.requires_raw_input_tokens + @property def is_cross_encoder(self) -> bool: return ( diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 67c65a44dcf7f..f8288b92ebfae 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -94,6 +94,12 @@ class SupportsMultiModal(Protocol): `multimodal_config.mm_encoder_tp_mode="data"`. """ + requires_raw_input_tokens: ClassVar[bool] = False + """ + A flag that indicates this model processes input id tokens + in their raw form and not input embeddings. + """ + merge_by_field_config: ClassVar[bool | None] = None """ [DEPRECATED] A flag that indicates which implementation of @@ -306,6 +312,10 @@ def supports_multimodal_raw_input_only(model: type[object] | object) -> bool: return getattr(model, "supports_multimodal_raw_input_only", False) +def requires_raw_input_tokens(model: type[object] | object) -> bool: + return getattr(model, "requires_raw_input_tokens", False) + + def supports_multimodal_encoder_tp_data(model: type[object] | object) -> bool: return getattr(model, "supports_encoder_tp_data", False) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 89af625b685f5..fd39afe259ae3 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -46,6 +46,7 @@ from .interfaces import ( has_noops, is_attention_free, is_hybrid, + requires_raw_input_tokens, supports_cross_encoding, supports_mamba_prefix_caching, supports_multimodal, @@ -422,6 +423,7 @@ _MULTIMODAL_MODELS = { ), "UltravoxModel": ("ultravox", "UltravoxModel"), "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501 + "VoxtralStreamingGeneration": ("voxtral_streaming", "VoxtralStreamingGeneration"), # noqa: E501 # [Encoder-decoder] "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), # noqa: E501 } @@ -539,6 +541,7 @@ class _ModelInfo: supports_cross_encoding: bool supports_multimodal: bool supports_multimodal_raw_input_only: bool + requires_raw_input_tokens: bool supports_multimodal_encoder_tp_data: bool supports_pp: bool has_inner_state: bool @@ -562,6 +565,7 @@ class _ModelInfo: supports_multimodal_raw_input_only=supports_multimodal_raw_input_only( model ), + requires_raw_input_tokens=requires_raw_input_tokens(model), supports_multimodal_encoder_tp_data=supports_multimodal_encoder_tp_data( model ), diff --git a/vllm/model_executor/models/voxtral.py b/vllm/model_executor/models/voxtral.py index 331f0c54ecfbc..cbba1af89190c 100644 --- a/vllm/model_executor/models/voxtral.py +++ b/vllm/model_executor/models/voxtral.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import inspect import math from collections.abc import Iterable, Mapping, Sequence from functools import cached_property @@ -116,10 +117,7 @@ class VoxtralProcessorAdapter: self, audio_length: int, ) -> int: - pad_audio_length = self._audio_processor.next_multiple_of_chunk_frames( - audio_length, self.sampling_rate - ) - return ceil(pad_audio_length / (self.sampling_rate // self.frame_rate)) + return ceil(audio_length / (self.sampling_rate // self.frame_rate)) def __call__( self, @@ -158,7 +156,14 @@ class VoxtralProcessorAdapter: assert audio.ndim == 1 # pad if necessary - audio = self._audio_processor.pad(audio, self.sampling_rate) + # TODO(Patrick) - remove once mistral-common is bumped + sig = inspect.signature(self._audio_processor.pad) + if "is_online_streaming" in sig.parameters: + audio = self._audio_processor.pad( + audio, self.sampling_rate, is_online_streaming=False + ) + else: + audio = self._audio_processor.pad(audio, self.sampling_rate) audio_tokens = [self.begin_audio_token_id] + [ self.audio_token_id @@ -510,6 +515,7 @@ class VoxtralForConditionalGeneration( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: remapping_rules = [ + (r"mm_streams_embeddings.embedding_module\.(.*)", r"\1"), (r"mm_whisper_embeddings\.(.*)", r"\1"), (r"audio_language_projection\.(.*)", r"audio_language_adapter.\1"), ( @@ -535,13 +541,16 @@ class VoxtralForConditionalGeneration( def llm_weights_generator(): nonlocal loaded_weights for name, w in weights: - is_encoder = ( - name.startswith("mm_whisper_embeddings") - and not name.startswith("mm_whisper_embeddings.tok_embeddings") - and not name.startswith( - "mm_whisper_embeddings.audio_language_projection" + is_encoder = False + for k in [ + "mm_whisper_embeddings", + "mm_streams_embeddings.embedding_module", + ]: + is_encoder |= ( + name.startswith(k) + and not name.startswith(f"{k}.tok_embeddings") + and not name.startswith(f"{k}.audio_language_projection") ) - ) for pattern, repl in remapping_rules: if re.fullmatch(pattern, name): @@ -676,6 +685,7 @@ class VoxtralEncoderModel(nn.Module): packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} mistral_remapping = [ + (r"mm_streams_embeddings.embedding_module\.(.*)", r"\1"), ( r"whisper_encoder\.conv_layers\.0\.(weight|bias)", r"whisper_encoder.conv1.\1", @@ -684,6 +694,14 @@ class VoxtralEncoderModel(nn.Module): r"whisper_encoder\.conv_layers\.1\.(weight|bias)", r"whisper_encoder.conv2.\1", ), + ( + r"whisper_encoder\.conv_layers\.0\.conv\.(weight|bias)", + r"whisper_encoder.conv1.\1", + ), # noqa: E501 + ( + r"whisper_encoder\.conv_layers\.1\.conv\.(weight|bias)", + r"whisper_encoder.conv2.\1", + ), # noqa: E501 ( r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.w([qkv])\.(weight|bias)", # noqa: E501 r"whisper_encoder.layers.\1.self_attn.\2_proj.\3", diff --git a/vllm/model_executor/models/voxtral_streaming.py b/vllm/model_executor/models/voxtral_streaming.py new file mode 100644 index 0000000000000..2e79e24e6f194 --- /dev/null +++ b/vllm/model_executor/models/voxtral_streaming.py @@ -0,0 +1,243 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import math +from collections.abc import Mapping + +import torch + +from vllm.config.vllm import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.models.interfaces import MultiModalEmbeddings +from vllm.model_executor.models.voxtral import ( + VoxtralDummyInputsBuilder, + VoxtralForConditionalGeneration, + VoxtralMultiModalProcessor, + VoxtralProcessingInfo, +) +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.cache import _I, BaseMultiModalProcessorCache +from vllm.multimodal.inputs import ( + MultiModalKwargsOptionalItems, +) +from vllm.multimodal.parse import MultiModalDataItems +from vllm.multimodal.processing import ( + MultiModalPromptUpdates, + PlaceholderFeaturesInfo, +) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors + +from .utils import ( + _flatten_embeddings, +) + +logger = init_logger(__name__) + + +class VoxtralStreamingMultiModalProcessor(VoxtralMultiModalProcessor): + def __init__( + self, + info: _I, + dummy_inputs: BaseDummyInputsBuilder[_I], + *, + cache: BaseMultiModalProcessorCache | None = None, + ) -> None: + # streaming can't make use of a cache yet + super().__init__(info, dummy_inputs, cache=None) + + def _maybe_apply_prompt_updates( + self, + mm_items: MultiModalDataItems, + prompt_ids: list[int], + mm_kwargs: MultiModalKwargsOptionalItems, + mm_prompt_updates: MultiModalPromptUpdates, + is_update_applied: bool, + ) -> tuple[list[int], Mapping[str, list[PlaceholderFeaturesInfo]]]: + # there are no placeholder audio tokens for streaming + # so we need to build the place placeholder positions manually + + # in streaming there is always only one audio input + audios = mm_kwargs.get("audio", []) + assert len(audios) == 1, ( + f"Expected only one audio input for streaming, got {mm_kwargs=}" + ) + tokenizer = self.info.get_tokenizer() + audio_config = tokenizer.instruct.audio_encoder.audio_config + + num_audio_samples = audios[0]["audio_arrays"].data.shape[0] + length = audio_config.num_audio_tokens(num_audio_samples) + + features_info = PlaceholderFeaturesInfo( + modality="audio", + item_idx=0, + start_idx=0, + tokens=length + * [0], # only used for length computation, so we can take dummy inputs + is_embed=None, + ) + return prompt_ids, {"audio": [features_info]} + + +class TimeEmbedding(torch.nn.Module): + """Sinusoidal Embedding for encoding time""" + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + self.dim = dim + self.theta = theta + inv_freq = torch.exp( + -math.log(self.theta) + * torch.arange(self.dim // 2).float() + / (self.dim // 2) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, t: torch.Tensor) -> torch.Tensor: + t = t[..., None] # (B,) -> (B, 1) or (B, T) -> (B, T, 1) + inv_freq = self.inv_freq.to(device=t.device, dtype=t.dtype) + emb = ( + t * inv_freq + ) # (B, 1) x (D/2,) -> (B, D/2) or (B, T, 1) x (D/2,) -> (B, T, D/2) + return torch.cat((emb.cos(), emb.sin()), dim=-1) # (B, D) or (B, T, D) + + +@MULTIMODAL_REGISTRY.register_processor( + VoxtralStreamingMultiModalProcessor, + info=VoxtralProcessingInfo, + dummy_inputs=VoxtralDummyInputsBuilder, +) +class VoxtralStreamingGeneration(VoxtralForConditionalGeneration): + requires_raw_input_tokens = True + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + self.time_embedding: TimeEmbedding = TimeEmbedding( + dim=self.config.text_config.hidden_size + ) + + audio_config = self.tokenizer.instruct.audio_encoder.audio_config + _n_delay_tokens = ( + audio_config.frame_rate * audio_config.transcription_delay_ms / 1000 + ) + assert _n_delay_tokens.is_integer(), ( + f"n_delay_tokens must be integer, got {_n_delay_tokens}" + ) + + self.n_delay_tokens = int(_n_delay_tokens) + + @property + def audio_config(self): + return self.tokenizer.instruct.audio_encoder.audio_config + + def embed_input_ids( + self, + input_ids: torch.Tensor, + multimodal_embeddings: MultiModalEmbeddings | None = None, + *, + is_multimodal: torch.Tensor | None = None, + # Multi-modal token ID may exceed vocab size + handle_oov_mm_token: bool = True, + ) -> torch.Tensor: + """Pass post-conv embeddings directly as input""" + # for streaming we simply flatten the multimodal embeddings + # to be in tensor format, we treat the input ids later + assert multimodal_embeddings is not None + assert len(multimodal_embeddings) > 0, ( + "For streaming you must provide a multimodal_embedding at every step." + ) + mm_embeds_flat = _flatten_embeddings(multimodal_embeddings) + return mm_embeds_flat + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: object, + ) -> torch.Tensor | IntermediateTensors: + assert inputs_embeds is not None + assert input_ids is not None + + pool_size = self.config.audio_config.block_pool_size + inputs_embeds = inputs_embeds.view( + inputs_embeds.shape[0] * pool_size, inputs_embeds.shape[1] // pool_size + ) + + audio_hidden_states = self.whisper_encoder.whisper_encoder.forward_layers( + inputs_embeds + ) + + num_tokens, audio_hidden_size = audio_hidden_states.shape + assert num_tokens % self.downsample_factor == 0 + audio_hidden_states = audio_hidden_states.reshape( + num_tokens // self.downsample_factor, + audio_hidden_size * self.downsample_factor, + ) + audio_text_embeds = self.audio_language_adapter(audio_hidden_states) + + text_embeds = self.language_model.embed_input_ids(input_ids) + + # sum pool text and audio embeddings + inputs_embeds = audio_text_embeds + text_embeds + + time_tensor = torch.tensor( + [self.n_delay_tokens], + device=inputs_embeds.device, + dtype=inputs_embeds.dtype, + ) + inputs_embeds = inputs_embeds + self.time_embedding(time_tensor) + + hidden_states = self.language_model.model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) + + return hidden_states + + def embed_multimodal( + self, **kwargs + ) -> list[torch.Tensor] | torch.Tensor | tuple[torch.Tensor, ...] | None: + """Transform audio waveforms -> initial whisper post-conv embeddings""" + audio_inputs = self._parse_and_validate_audio_arrays(**kwargs) + + assert audio_inputs is not None, ( + "For streaming you must provide an audio input at every step." + ) + + multiple_of = self.audio_config.raw_audio_length_per_tok + assert all( + (this_audio := audio.shape[0]) % multiple_of == 0 for audio in audio_inputs + ), ( + f"Every input audio waveform has to be a multiple of {multiple_of}, but" + f" one is {this_audio} with {(this_audio / multiple_of)=}." + ) + + mel_features = [ + self.whisper_encoder.compute_whisper_melspec(audio).to( + self.whisper_encoder.dtype + ) + for audio in audio_inputs + ] + seq_lens = [mel.shape[1] for mel in mel_features] + # [total_num_20ms_frames, hidden_size] + audio_embeddings = self.whisper_encoder.whisper_encoder.forward_conv( + mel_features + )[0] + conv_stride = self.whisper_encoder.whisper_encoder.total_stride + audio_embeddings_per_sample = audio_embeddings.split( + [s // conv_stride for s in seq_lens], dim=0 + ) + + # audio_embeddings per sample need to be divisible by 4 + pool_size = self.config.audio_config.block_pool_size + assert all( + (this_shape := sample.shape[0]) % pool_size == 0 + for sample in audio_embeddings_per_sample + ), f"Every audio embedding has to be a multiple of 4, but one is {this_shape}." + + audio_embeddings_per_sample = [ + e.view(e.shape[0] // pool_size, e.shape[1] * pool_size) + for e in audio_embeddings_per_sample + ] + return audio_embeddings_per_sample diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index f5a1e75d99617..f1bae28debad2 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -1,9 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import enum import math from collections.abc import Iterable, Mapping, Sequence from contextlib import nullcontext +from functools import partial from typing import Annotated, Literal, cast import numpy as np @@ -16,7 +18,10 @@ from transformers import ( ) from transformers.models.whisper.modeling_whisper import sinusoids -from vllm.attention.layer import Attention, AttentionType +from vllm.attention.backends.abstract import ( + AttentionType, +) +from vllm.attention.layer import Attention from vllm.attention.layers.cross_attention import CrossAttention from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig @@ -34,6 +39,11 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.whisper_utils import ( + ISO639_1_SUPPORTED_LANGS, + WhisperAttentionWithBlockPooling, + WhisperCausalConv1d, +) from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( MultiModalDataDict, @@ -64,67 +74,11 @@ from .utils import ( logger = init_logger(__name__) -# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages -ISO639_1_SUPPORTED_LANGS = { - "af": "Afrikaans", - "ar": "Arabic", - "hy": "Armenian", - "az": "Azerbaijani", - "be": "Belarusian", - "bs": "Bosnian", - "bg": "Bulgarian", - "ca": "Catalan", - "zh": "Chinese", - "hr": "Croatian", - "cs": "Czech", - "da": "Danish", - "nl": "Dutch", - "en": "English", - "et": "Estonian", - "fi": "Finnish", - "fr": "French", - "gl": "Galician", - "de": "German", - "el": "Greek", - "he": "Hebrew", - "hi": "Hindi", - "hu": "Hungarian", - "is": "Icelandic", - "id": "Indonesian", - "it": "Italian", - "ja": "Japanese", - "kn": "Kannada", - "kk": "Kazakh", - "ko": "Korean", - "lv": "Latvian", - "lt": "Lithuanian", - "mk": "Macedonian", - "ms": "Malay", - "mr": "Marathi", - "mi": "Maori", - "ne": "Nepali", - "no": "Norwegian", - "fa": "Persian", - "pl": "Polish", - "pt": "Portuguese", - "ro": "Romanian", - "ru": "Russian", - "sr": "Serbian", - "sk": "Slovak", - "sl": "Slovenian", - "es": "Spanish", - "sw": "Swahili", - "sv": "Swedish", - "tl": "Tagalog", - "ta": "Tamil", - "th": "Thai", - "tr": "Turkish", - "uk": "Ukrainian", - "ur": "Urdu", - "vi": "Vietnamese", - "cy": "Welsh", -} +class WhisperPosEmbedType(enum.Enum): + SINUSOIDAL = "sinusoidal" + NOPE = "nope" + LEARNED = "learned" class WhisperAudioInputs(TensorSchema): @@ -184,6 +138,8 @@ class WhisperAttention(nn.Module): num_heads: int, bias: bool = True, attn_type: AttentionType = AttentionType.DECODER, + per_layer_sliding_window: int | None = None, + block_pool_size: int = 1, cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, prefix: str = "", @@ -242,7 +198,14 @@ class WhisperAttention(nn.Module): attn_type=self.attn_type, ) else: # AttentionType.DECODER (regular decoder self-attention) - self.attn = Attention( + if block_pool_size > 1: + attn_cls = partial( + WhisperAttentionWithBlockPooling, block_pool_size=block_pool_size + ) + else: + attn_cls = Attention + + self.attn = attn_cls( self.num_heads, self.head_dim, self.scaling, @@ -251,6 +214,7 @@ class WhisperAttention(nn.Module): quant_config=quant_config, prefix=f"{prefix}.attn", attn_type=self.attn_type, + per_layer_sliding_window=per_layer_sliding_window, ) def _init_qkv( @@ -386,6 +350,9 @@ class WhisperEncoderLayer(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config + is_causal = getattr(config, "is_causal", False) + sliding_window = getattr(config, "sliding_window", None) + block_pool_size = getattr(config, "block_pool_size", 1) cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config @@ -393,7 +360,9 @@ class WhisperEncoderLayer(nn.Module): self.self_attn = WhisperAttention( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, - attn_type=AttentionType.ENCODER, + attn_type=AttentionType.DECODER if is_causal else AttentionType.ENCODER, + block_pool_size=block_pool_size, + per_layer_sliding_window=sliding_window, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", @@ -492,12 +461,21 @@ class WhisperEncoder(nn.Module): super().__init__() config = vllm_config.model_config.hf_config embed_dim = config.d_model + + self.pos_embed_type = WhisperPosEmbedType( + getattr(config, "pos_embed", "sinusoidal") + ) self.num_mel_bins = config.num_mel_bins self.max_source_positions = config.max_source_positions self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 - self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1) - self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1) + is_causal = getattr(config, "is_causal", False) + Conv1d = WhisperCausalConv1d if is_causal else partial(nn.Conv1d, padding=1) + + self.conv1 = Conv1d(self.num_mel_bins, embed_dim, kernel_size=3) + self.conv2 = Conv1d(embed_dim, embed_dim, stride=2, kernel_size=3) + + self.total_stride = self.conv1.stride[0] * self.conv2.stride[0] self.start_layer, self.end_layer, self.layers = make_layers( config.encoder_layers, lambda prefix: WhisperEncoderLayer( @@ -507,29 +485,54 @@ class WhisperEncoder(nn.Module): ) self.layer_norm = nn.LayerNorm(config.d_model) - maybe_fp32_init_ctx = ( - set_default_torch_dtype(torch.float32) if init_in_fp32 else nullcontext() - ) - - with ( - torch.no_grad(), - maybe_fp32_init_ctx, + if is_causal and self.pos_embed_type != WhisperPosEmbedType.NOPE: + raise ValueError( + "Only NOPE position embeddings are supported " + f"for causal models, but got {self.pos_embed_type}" + ) + elif self.pos_embed_type in ( + WhisperPosEmbedType.SINUSOIDAL, + WhisperPosEmbedType.LEARNED, ): - self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim) - self.embed_positions.weight.copy_( - sinusoids(*self.embed_positions.weight.shape) + maybe_fp32_init_ctx = ( + set_default_torch_dtype(torch.float32) + if init_in_fp32 + else nullcontext() ) - def forward(self, input_features: torch.Tensor | list[torch.Tensor]): + with ( + torch.no_grad(), + maybe_fp32_init_ctx, + ): + self.embed_positions = nn.Embedding( + self.max_source_positions, embed_dim + ) + self.embed_positions.weight.copy_( + sinusoids(*self.embed_positions.weight.shape) + ) + + def forward_conv( + self, input_features: torch.Tensor | list[torch.Tensor] + ) -> torch.Tensor: hidden_states = [] input_is_batched = False for features in input_features: embeds = nn.functional.gelu(self.conv1(features)) embeds = nn.functional.gelu(self.conv2(embeds)) - embeds = embeds.transpose(-1, -2) - embeds = (embeds + self.embed_positions.weight[: embeds.size(-2), :]).to( - embeds.dtype - ) + + if self.pos_embed_type in ( + WhisperPosEmbedType.SINUSOIDAL, + WhisperPosEmbedType.LEARNED, + ): + embeds = embeds.transpose(-1, -2) + embeds = ( + embeds + self.embed_positions.weight[: embeds.size(-2), :] + ).to(embeds.dtype) + elif self.pos_embed_type == WhisperPosEmbedType.NOPE: + embeds = embeds.transpose(-1, -2).to(embeds.dtype) + else: + raise ValueError(f"Unknown pos_embed_type: {self.pos_embed_type}") + hidden_states.append(embeds) input_is_batched = embeds.ndim > 2 # Input to MHA must be B x T x D @@ -539,12 +542,19 @@ class WhisperEncoder(nn.Module): else: hidden_states = torch.stack(hidden_states, dim=0) + return hidden_states + + def forward_layers(self, hidden_states: torch.Tensor) -> torch.Tensor: for encoder_layer in self.layers: hidden_states = encoder_layer(hidden_states) hidden_states = self.layer_norm(hidden_states) return hidden_states + def forward(self, input_features: torch.Tensor | list[torch.Tensor]): + hidden_states = self.forward_conv(input_features) + return self.forward_layers(hidden_states) + class WhisperDecoder(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): diff --git a/vllm/model_executor/models/whisper_utils.py b/vllm/model_executor/models/whisper_utils.py new file mode 100644 index 0000000000000..077b4aff6fec9 --- /dev/null +++ b/vllm/model_executor/models/whisper_utils.py @@ -0,0 +1,299 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy +import functools +import math +from dataclasses import replace + +import torch +import torch.nn.functional as F +from torch import nn + +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionMetadata, + AttentionType, +) +from vllm.attention.layer import Attention +from vllm.attention.selector import get_attn_backend +from vllm.config import CacheConfig, VllmConfig +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend +from vllm.v1.attention.backends.utils import ( + CommonAttentionMetadata, + subclass_attention_backend_with_overrides, +) +from vllm.v1.kv_cache_interface import AttentionSpec + +# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages +ISO639_1_SUPPORTED_LANGS = { + "af": "Afrikaans", + "ar": "Arabic", + "hy": "Armenian", + "az": "Azerbaijani", + "be": "Belarusian", + "bs": "Bosnian", + "bg": "Bulgarian", + "ca": "Catalan", + "zh": "Chinese", + "hr": "Croatian", + "cs": "Czech", + "da": "Danish", + "nl": "Dutch", + "en": "English", + "et": "Estonian", + "fi": "Finnish", + "fr": "French", + "gl": "Galician", + "de": "German", + "el": "Greek", + "he": "Hebrew", + "hi": "Hindi", + "hu": "Hungarian", + "is": "Icelandic", + "id": "Indonesian", + "it": "Italian", + "ja": "Japanese", + "kn": "Kannada", + "kk": "Kazakh", + "ko": "Korean", + "lv": "Latvian", + "lt": "Lithuanian", + "mk": "Macedonian", + "ms": "Malay", + "mr": "Marathi", + "mi": "Maori", + "ne": "Nepali", + "no": "Norwegian", + "fa": "Persian", + "pl": "Polish", + "pt": "Portuguese", + "ro": "Romanian", + "ru": "Russian", + "sr": "Serbian", + "sk": "Slovak", + "sl": "Slovenian", + "es": "Spanish", + "sw": "Swahili", + "sv": "Swedish", + "tl": "Tagalog", + "ta": "Tamil", + "th": "Thai", + "tr": "Turkish", + "uk": "Ukrainian", + "ur": "Urdu", + "vi": "Vietnamese", + "cy": "Welsh", +} + + +def _pad1d( + x: torch.Tensor, + paddings: tuple[int, int], + mode: str = "constant", + value: float = 0.0, +) -> torch.Tensor: + """Tiny wrapper around F.pad, just to allow for + reflect padding on small input. + If this is the case, we insert extra 0 padding + to the right before the reflection happen. + """ + length = x.shape[-1] + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + if mode == "reflect": + max_pad = max(padding_left, padding_right) + extra_pad = 0 + if length <= max_pad: + extra_pad = max_pad - length + 1 + x = F.pad(x, (0, extra_pad)) + padded = F.pad(x, paddings, mode, value) + end = padded.shape[-1] - extra_pad + return padded[..., :end] + else: + return F.pad(x, paddings, mode, value) + + +class WhisperCausalConv1d(nn.Conv1d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + bias: bool = True, + ) -> None: + super().__init__( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ) + self._stride = self.stride[0] + self._effective_kernel_size = (kernel_size - 1) * self.dilation[0] + 1 + self._padding_total = self._effective_kernel_size - self._stride + + def forward(self, x: torch.Tensor) -> torch.Tensor: + n_frames = ( + x.shape[-1] - self._effective_kernel_size + self._padding_total + ) / self._stride + 1 + target_length = (math.ceil(n_frames) - 1) * self._stride + ( + self._effective_kernel_size - self._padding_total + ) + extra_padding = target_length - x.shape[-1] + x = _pad1d(x, (self._padding_total, extra_padding), mode="constant") + return super().forward(x) + + +@functools.lru_cache +def create_whisper_attention_backend_with_block_pooling( + underlying_attn_backend: AttentionBackend, block_pool_size: int +) -> type[AttentionBackend]: + prefix = "WhisperAttentionWithBlockPooling_" + underlying_builder = underlying_attn_backend.get_builder_cls() + + class WhisperAttentionWithBlockPoolingBuilder(underlying_builder): # type: ignore + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + assert kv_cache_spec.num_kv_heads % block_pool_size == 0 + kv_cache_spec = replace( + kv_cache_spec, + block_size=kv_cache_spec.block_size * block_pool_size, + num_kv_heads=kv_cache_spec.num_kv_heads // block_pool_size, + ) + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> AttentionMetadata: + new_common_attn_metadata = copy.deepcopy(common_attn_metadata) + new_common_attn_metadata.query_start_loc *= block_pool_size + new_common_attn_metadata.query_start_loc_cpu *= block_pool_size + new_common_attn_metadata.seq_lens *= block_pool_size + new_common_attn_metadata._seq_lens_cpu *= block_pool_size + new_common_attn_metadata._num_computed_tokens_cpu *= block_pool_size + new_common_attn_metadata.num_actual_tokens *= block_pool_size + new_common_attn_metadata.max_query_len *= block_pool_size + new_common_attn_metadata.max_seq_len *= block_pool_size + original_slot_mapping = common_attn_metadata.slot_mapping + common_prefix_len *= block_pool_size + new_common_attn_metadata.slot_mapping = ( + ( + original_slot_mapping.unsqueeze(1) * block_pool_size + + torch.arange(block_pool_size, device=original_slot_mapping.device) + ) + .flatten() + .clamp(min=-1) + ) + return super().build( + common_prefix_len, new_common_attn_metadata, fast_build + ) + + if not issubclass(underlying_attn_backend, FlashAttentionBackend): + raise NotImplementedError( + f"{underlying_attn_backend} is not yet supported." + "Contributions to support more backends are much " + "appreciated." + ) + + attn_backend = subclass_attention_backend_with_overrides( + name_prefix=prefix, + attention_backend_cls=underlying_attn_backend, + overrides={ + "get_builder_cls": lambda: WhisperAttentionWithBlockPoolingBuilder, + "get_kv_cache_shape": lambda num_blocks, + block_size, + num_kv_heads, + head_size, + cache_dtype_str: ( + 2, + num_blocks, + # we stretch each block by `block_pool_size` + block_size * block_pool_size, + num_kv_heads // block_pool_size, + head_size, + ), # TODO: generalize to other backends + }, + ) + + return attn_backend + + +class WhisperAttentionWithBlockPooling(Attention): + """Attention layer with block pooling.""" + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int | None = None, + alibi_slopes: list[float] | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + logits_soft_cap: float | None = None, + per_layer_sliding_window: int | None = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: str | None = None, + block_pool_size: int = 1, + attn_backend: type[AttentionBackend] | None = None, + **extra_impl_args, + ) -> None: + self.block_pool_size = block_pool_size + dtype = torch.get_default_dtype() + + if cache_config is not None: + kv_cache_dtype = cache_config.cache_dtype + block_size = cache_config.block_size + else: + kv_cache_dtype = "auto" + block_size = 16 + + underlying_attn_backend = get_attn_backend( + head_size, + dtype, + kv_cache_dtype, + block_size, + attn_type=attn_type, + ) + attn_backend = create_whisper_attention_backend_with_block_pooling( + underlying_attn_backend, block_pool_size + ) + + super().__init__( + num_heads=num_heads, + head_size=head_size, + scale=scale, + num_kv_heads=num_kv_heads, + alibi_slopes=alibi_slopes, + cache_config=cache_config, + quant_config=quant_config, + logits_soft_cap=logits_soft_cap, + per_layer_sliding_window=per_layer_sliding_window, + prefix=prefix, + attn_type=attn_type, + kv_sharing_target_layer_name=kv_sharing_target_layer_name, + attn_backend=attn_backend, + **extra_impl_args, + ) + + def get_kv_cache_spec(self, vllm_config: VllmConfig): + kv_cache_spec = super().get_kv_cache_spec(vllm_config) + assert isinstance(kv_cache_spec, AttentionSpec) + kv_cache_spec = replace( + kv_cache_spec, + num_kv_heads=self.block_pool_size * kv_cache_spec.num_kv_heads, + ) + return kv_cache_spec diff --git a/vllm/transformers_utils/configs/mistral.py b/vllm/transformers_utils/configs/mistral.py index d59169d95f0c9..4776c892eb722 100644 --- a/vllm/transformers_utils/configs/mistral.py +++ b/vllm/transformers_utils/configs/mistral.py @@ -184,18 +184,42 @@ def _remap_mistral_audio_args(config: dict) -> dict: whisper_args = config["multimodal"].pop("whisper_model_args") encoder_args = whisper_args["encoder_args"] downsample_args = whisper_args["downsample_args"] + downsample_factor = downsample_args["downsample_factor"] + + # make sure that k/v blocks can be allocated with + # unified k/v cache class and pool whisper k/v cache blocks + # with downsample_factor:1 ratio + if encoder_args.get("causal"): + block_pool_size = downsample_factor + config["projection_size"] = downsample_factor * encoder_args["dim"] + else: + block_pool_size = 1 + + _maybe_sliding_window = encoder_args.get("ragged_attention", None) + if _maybe_sliding_window is None: + sliding_window = None + elif _maybe_sliding_window.isdigit(): + sliding_window = int(_maybe_sliding_window) + else: + raise NotImplementedError(f"Unsupported: {_maybe_sliding_window=}") + + architecture = ( + "VoxtralStreamingGeneration" + if encoder_args.get("causal") + else "VoxtralForConditionalGeneration" + ) quant_config = config.get("quantization_config") config = { - "model_type": "whixtral", - "architectures": ["VoxtralForConditionalGeneration"], + "model_type": "voxtral", + "architectures": [architecture], "text_config": PretrainedConfig.from_dict(config), "audio_config": WhisperConfig( num_mel_bins=encoder_args["audio_encoding_args"]["num_mel_bins"], window_size=encoder_args["audio_encoding_args"]["window_size"], sampling_rate=encoder_args["audio_encoding_args"]["sampling_rate"], hop_length=encoder_args["audio_encoding_args"]["hop_length"], - downsample_factor=downsample_args["downsample_factor"], + downsample_factor=downsample_factor, d_model=encoder_args["dim"], encoder_layers=encoder_args["n_layers"], encoder_ffn_dim=encoder_args["hidden_dim"], @@ -203,6 +227,10 @@ def _remap_mistral_audio_args(config: dict) -> dict: vocab_size=encoder_args["vocab_size"], max_source_positions=encoder_args["max_source_positions"], is_encoder_decoder=False, # Override WhisperConfig default + is_causal=encoder_args.get("causal", False), + sliding_window=sliding_window, + block_pool_size=block_pool_size, + pos_embed=encoder_args.get("pos_embed", "sinusoidal"), ), } if quant_config: diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 56763f4b52539..6b94f786a26b2 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -835,6 +835,15 @@ def subclass_attention_backend( ) +def subclass_attention_backend_with_overrides( + name_prefix: str, + attention_backend_cls: type[AttentionBackend], + overrides: dict[str, Any], +) -> type[AttentionBackend]: + name: str = name_prefix + attention_backend_cls.__name__ # type: ignore + return type(name, (attention_backend_cls,), overrides) + + def split_decodes_prefills_and_extends( common_attn_metadata: CommonAttentionMetadata, decode_threshold: int = 1, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 455406394d3ec..00c585aaaacbb 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2457,6 +2457,17 @@ class GPUModelRunner( return round_up(num_scheduled_tokens, tp_size) return num_scheduled_tokens + def _prepare_mm_inputs( + self, num_tokens: int + ) -> tuple[torch.Tensor | None, torch.Tensor]: + if self.model.requires_raw_input_tokens: + input_ids = self.input_ids.gpu[:num_tokens] + else: + input_ids = None + + inputs_embeds = self.inputs_embeds.gpu[:num_tokens] + return input_ids, inputs_embeds + def _preprocess( self, scheduler_output: "SchedulerOutput", @@ -2499,8 +2510,7 @@ class GPUModelRunner( # TODO(woosuk): Avoid the copy. Optimize. self.inputs_embeds.gpu[:num_scheduled_tokens].copy_(inputs_embeds_scheduled) - input_ids = None - inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens] + input_ids, inputs_embeds = self._prepare_mm_inputs(num_input_tokens) model_kwargs = { **self._init_model_kwargs(num_scheduled_tokens), **self._extract_mm_kwargs(scheduler_output), @@ -4220,8 +4230,8 @@ class GPUModelRunner( assert num_tokens_padded <= self.max_num_tokens model_kwargs = self._init_model_kwargs(num_tokens_padded) if self.supports_mm_inputs and not self.model_config.is_encoder_decoder: - input_ids = None - inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded] + input_ids, inputs_embeds = self._prepare_mm_inputs(num_tokens_padded) + model_kwargs = { **model_kwargs, **self._dummy_mm_kwargs(num_reqs),