From 80e78d02aceaf55528926c800e89568d1358e1dc Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Tue, 11 Mar 2025 19:27:10 -0700 Subject: [PATCH] [Model] Extend Ultravox to accept audio longer than 30s (#13631) Signed-off-by: Farzad Abdolhosseini --- .../audio_language/test_ultravox.py | 2 +- .../multimodal/processing/test_common.py | 57 +++- tests/models/registry.py | 3 +- vllm/model_executor/models/ultravox.py | 248 +++++++++++++----- 4 files changed, 231 insertions(+), 79 deletions(-) diff --git a/tests/models/decoder_only/audio_language/test_ultravox.py b/tests/models/decoder_only/audio_language/test_ultravox.py index 13433b042258c..f8770bca4e913 100644 --- a/tests/models/decoder_only/audio_language/test_ultravox.py +++ b/tests/models/decoder_only/audio_language/test_ultravox.py @@ -15,7 +15,7 @@ from ....conftest import HfRunner, VllmRunner from ....utils import RemoteOpenAIServer from ...utils import check_logprobs_close -MODEL_NAME = "fixie-ai/ultravox-v0_4" +MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b" AudioTuple = tuple[np.ndarray, int] diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 629d1012d18e0..e64b703cc5201 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 +import copy from functools import partial +from typing import Optional import numpy as np import pytest @@ -21,6 +23,7 @@ def _test_processing_correctness( hit_rate: float, num_batches: int, simplify_rate: float, + ignore_mm_keys: Optional[list[str]] = None, ): model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) model_info.check_available_online(on_fail="skip") @@ -123,8 +126,10 @@ def _test_processing_correctness( hf_processor_mm_kwargs={}, ) - assert baseline_result == cached_result, ( - f"Failed ({batch_idx=}, {prompt=}, {mm_data=})") + assert _drop_mm_kwargs_keys( + baseline_result, ignore_mm_keys) == _drop_mm_kwargs_keys( + cached_result, ignore_mm_keys), ( + f"Failed ({batch_idx=}, {prompt=}, {mm_data=})") baseline_tokenized_result = baseline_processor.apply( tokenizer.encode(prompt, **tokenizer_encode_kwargs), @@ -132,8 +137,10 @@ def _test_processing_correctness( hf_processor_mm_kwargs={}, ) - assert baseline_result == baseline_tokenized_result, ( - f"Failed ({batch_idx=}, {prompt=}, {mm_data=})") + assert _drop_mm_kwargs_keys( + baseline_result, ignore_mm_keys) == _drop_mm_kwargs_keys( + baseline_tokenized_result, ignore_mm_keys), ( + f"Failed ({batch_idx=}, {prompt=}, {mm_data=})") cached_tokenized_result = cached_processor.apply( tokenizer.encode(prompt, **tokenizer_encode_kwargs), @@ -141,8 +148,10 @@ def _test_processing_correctness( hf_processor_mm_kwargs={}, ) - assert cached_result == cached_tokenized_result, ( - f"Failed ({batch_idx=}, {prompt=}, {mm_data=})") + assert _drop_mm_kwargs_keys( + cached_result, ignore_mm_keys) == _drop_mm_kwargs_keys( + cached_tokenized_result, ignore_mm_keys), ( + f"Failed ({batch_idx=}, {prompt=}, {mm_data=})") # yapf: disable @@ -173,7 +182,7 @@ def _test_processing_correctness( "Qwen/Qwen2-VL-2B-Instruct", "Qwen/Qwen2.5-VL-3B-Instruct", "Qwen/Qwen2-Audio-7B-Instruct", - "fixie-ai/ultravox-v0_4", + "fixie-ai/ultravox-v0_5-llama-3_2-1b", "openai/whisper-large-v3", "google/paligemma-3b-mix-224", "google/paligemma2-3b-ft-docci-448", @@ -188,11 +197,19 @@ def test_processing_correctness( num_batches: int, simplify_rate: float, ): + ignore_mm_keys = None + if 'ultravox' in model_id: + # In Ultravox, the audio_features can be different depending on padding + # The slight difference should not be a problem though, since + # attention_mask lets us ignore the difference. + ignore_mm_keys = ['audio_features'] + _test_processing_correctness( model_id, hit_rate=hit_rate, num_batches=num_batches, simplify_rate=simplify_rate, + ignore_mm_keys=ignore_mm_keys, ) @@ -221,3 +238,29 @@ def test_processing_correctness_phi3v( num_batches=num_batches, simplify_rate=simplify_rate, ) + + +def _drop_mm_kwargs_keys(result: dict, + ignore_mm_keys: Optional[list[str]] = None) -> dict: + """Drop specified keys from result['mm_kwargs']. + + This is mainly to avoid doing exact match of audio_features in ultravox. + + Args: + result: Result to drop keys from + ignore_mm_keys: List of keys to ignore, e.g. ['audio_features'] + """ + if not ignore_mm_keys: + return result + + if 'mm_kwargs' in result: + result = copy.deepcopy(result) + mm_kwargs = result['mm_kwargs'] + for key in ignore_mm_keys: + mm_kwargs.pop(key, None) + for items in mm_kwargs._items_by_modality.values(): + for item in items: + for key in ignore_mm_keys: + item.pop(key, None) + + return result diff --git a/tests/models/registry.py b/tests/models/registry.py index 3c3247eaf3e99..a7a88d1990479 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -284,8 +284,7 @@ _MULTIMODAL_EXAMPLE_MODELS = { "Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501 "Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501 min_transformers_version="4.49"), # noqa: E501 - "UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_4", - extras={"v0.5": "fixie-ai/ultravox-v0_5-llama-3_2-1b"}, # noqa: E501 + "UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501 trust_remote_code=True), # [Encoder-decoder] # Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 90a833a83b66f..f639b8d8f9bed 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -5,7 +5,7 @@ import math from collections.abc import Iterable, Mapping, Sequence from functools import cached_property -from typing import Any, Literal, Optional, Set, Tuple, TypedDict, Union +from typing import Any, List, Literal, Optional, Set, Tuple, TypedDict, Union import torch import torch.utils.checkpoint @@ -44,12 +44,23 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, _AUDIO_PLACEHOLDER_OVERRIDE = "<|reserved_special_token_0|>" _AUDIO_PLACEHOLDER_TOKEN = 128002 _AUDIO_TOKENS_PER_SECOND = 6.25 +_MAX_ENCODER_BATCH_SIZE = 16 class UltravoxAudioFeatureInputs(TypedDict): type: Literal["audio_features"] data: NestedTensors - """Shape: `(batch_size, num_audios, 80, M)`""" + """Shape: `(batch_size, num_chunks, 80, M)`""" + lens: NestedTensors + """ + Length of the audio frames. Used for attention mask in WhisperEncoder. + Shape: `(batch_size, num_chunks)` + """ + token_len: NestedTensors + """ + Length of the audio tokens. Used for flattening the audio features. + Shape: `(batch_size, num_chunks)` + """ class UltravoxAudioEmbeddingInputs(TypedDict): @@ -78,6 +89,7 @@ class UltravoxProcessingInfo(BaseProcessingInfo): # token, thus we override placeholder with a reserved special # token. hf_processor.audio_token_replacement = _AUDIO_PLACEHOLDER_OVERRIDE + hf_processor.audio_replacement_token_id = _AUDIO_PLACEHOLDER_TOKEN return hf_processor def get_feature_extractor( @@ -104,7 +116,7 @@ class UltravoxProcessingInfo(BaseProcessingInfo): max_audio_tokens = math.ceil(feature_extractor.chunk_length * _AUDIO_TOKENS_PER_SECOND) - return {"audio": max_audio_tokens} + return {"audio": max_audio_tokens * _MAX_ENCODER_BATCH_SIZE} class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo] @@ -118,7 +130,8 @@ class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo] feature_extractor = self.info.get_feature_extractor() sampling_rate = feature_extractor.sampling_rate - audio_len = feature_extractor.chunk_length * sampling_rate + audio_len = (feature_extractor.chunk_length * sampling_rate * + _MAX_ENCODER_BATCH_SIZE) num_audios = mm_counts.get("audio", 0) mm_data = { @@ -160,41 +173,38 @@ class UltravoxMultiModalProcessor( mm_kwargs = dict( **mm_kwargs, sampling_rate=feature_extractor.sampling_rate, + include_audio_num_chunks=True, ) - # Ultravox processor doesn't support multiple inputs, - # therefore we need to input text and audio one by one - audio_features, audio_token_len = [], [] - shared_outputs = {} - for audio in audios: - # NOTE: Ultravox processor accepts "audio" instead of "audios" - item_processor_data = dict(**mm_data, audio=audio) + item_processor_data = dict(**mm_data, audios=audios) - item_outputs = super()._call_hf_processor( - prompt=prompt, - mm_data=item_processor_data, - mm_kwargs=mm_kwargs, - ) - - audio_features.append(item_outputs.pop("audio_values")[0]) - audio_token_len.append(item_outputs.pop("audio_token_len").item()) - shared_outputs = item_outputs - - combined_outputs = dict( - **shared_outputs, - audio_features=audio_features, - audio_token_len=audio_token_len, + output = super()._call_hf_processor( + prompt=prompt, + mm_data=item_processor_data, + mm_kwargs=mm_kwargs, ) - return BatchFeature(combined_outputs) + output['audio_features'] = output.pop('audio_values') + + return output def _get_mm_fields_config( self, hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: + num_chunks = hf_inputs.get('audio_num_chunks', torch.zeros(0)) return dict( - audio_features=MultiModalFieldConfig.batched("audio"), - audio_token_len=MultiModalFieldConfig.batched("audio"), + # to handle longer than 30s audio, each audio might be split + # into multiple chunks as such, their batch dimension can be + # higher than the number of audio samples + audio_features=MultiModalFieldConfig.flat_from_sizes( + "audio", num_chunks), + audio_token_len=MultiModalFieldConfig.flat_from_sizes( + "audio", num_chunks), + audio_lens=MultiModalFieldConfig.flat_from_sizes( + "audio", num_chunks), + # num_chunks can convert audio_chunked to audio batch dimension + audio_num_chunks=MultiModalFieldConfig.batched("audio"), audio_embeds=MultiModalFieldConfig.batched("audio"), ) @@ -205,14 +215,23 @@ class UltravoxMultiModalProcessor( out_mm_kwargs: MultiModalKwargs, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - tokenizer = self.info.get_tokenizer() - vocab = tokenizer.get_vocab() - replacement_id = vocab[ - hf_processor.audio_token_replacement] # type: ignore + replacement_id = hf_processor.audio_replacement_token_id # type: ignore + + # Each audio can be split into multiple chunks. + # chunks_start_idx[i] indicates the start index of the chunks + # belonging to the i-th audio. + num_chunks = out_mm_kwargs.get("audio_num_chunks", torch.zeros(0)) + chunks_start_idx: torch.Tensor = torch.cumsum(num_chunks, + dim=0, + dtype=torch.int32) + chunks_start_idx = torch.cat( + [torch.tensor([0], dtype=torch.int32), chunks_start_idx]) def get_replacement_ultravox(item_idx: int): - audio_token_len = out_mm_kwargs["audio_token_len"][item_idx] + start = chunks_start_idx[item_idx] + end = chunks_start_idx[item_idx + 1] + audio_token_len = out_mm_kwargs["audio_token_len"][start:end].sum() return [replacement_id] * int(audio_token_len) # type: ignore return [ @@ -304,12 +323,49 @@ class ModifiedWhisperEncoder(WhisperEncoder): base_model_prefix = "model.encoder" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.config.is_decoder = False + + @property + def max_context_length(self): + return (self.config.max_source_positions * self.conv1.stride[0] * + self.conv2.stride[0]) + + def get_attention_mask_by_audio_len(self, + audio_lens: Optional[torch.Tensor], + hidden_states: torch.Tensor): + """ + Create attention mask based on audio lengths to mask out padding tokens + For each sample in batch: + - Convert raw audio length to feature length after convolutions + - Create bool mask: True for valid positions and False for padding + - Convert to attention mask format expected by transformer layers + (1.0 for positions to attend to, large negative for positions to ignore) + This masking ensures consistent behavior between training and inference + by preventing the model from attending to padding tokens in both cases + """ + if audio_lens is None: + return None + + audio_feature_len = self._get_feat_extract_output_lengths(audio_lens) + max_seq_len = hidden_states.shape[1] + attention_mask = torch.arange(max_seq_len, + device=hidden_states.device)[None, :].lt( + audio_feature_len.view(-1, 1)) + attention_mask = self.get_extended_attention_mask( + attention_mask, + None, + dtype=hidden_states.dtype, + ) + return attention_mask + def forward( self, - input_features, + input_features: torch.Tensor, + audio_lens: Optional[torch.Tensor] = None, ): - expected_seq_length = (self.config.max_source_positions * - self.conv1.stride[0] * self.conv2.stride[0]) + expected_seq_length = self.max_context_length if input_features.shape[-1] > expected_seq_length: raise ValueError( f"Whisper expects the mel input features to be of length " @@ -328,10 +384,13 @@ class ModifiedWhisperEncoder(WhisperEncoder): p=self.dropout, training=self.training) + attention_mask = self.get_attention_mask_by_audio_len( + audio_lens, hidden_states) + for encoder_layer in self.layers: layer_outputs = encoder_layer( hidden_states, - None, + attention_mask, layer_head_mask=None, ) @@ -409,17 +468,34 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): ) def _audio_features_to_embeddings( - self, input_features: torch.Tensor) -> torch.Tensor: - audio_input = input_features.to(self.audio_tower.dtype) - audio_features = self.audio_tower(audio_input) - audio_features = audio_features.to(self.audio_tower.dtype) - audio_embeddings = self.multi_modal_projector(audio_features) + self, input_features: torch.Tensor, + audio_lens: torch.Tensor) -> torch.Tensor: + audio_features = input_features.to(self.audio_tower.dtype) + batch_size = audio_features.size(0) + audio_embeddings = [] + + # Process audio features in batches to keep memory usage predictable + for start in range(0, batch_size, _MAX_ENCODER_BATCH_SIZE): + end = min(start + _MAX_ENCODER_BATCH_SIZE, batch_size) + # Process through audio tower + batch_features = self.audio_tower(audio_features[start:end], + audio_lens[start:end]) + batch_features = batch_features.to(self.audio_tower.dtype) + + # Process through projector + batch_embeddings = self.multi_modal_projector(batch_features) + audio_embeddings.append(batch_embeddings) + + # Concatenate results + audio_embeddings = torch.cat(audio_embeddings, dim=0) return audio_embeddings def _parse_and_validate_audio_input( self, **kwargs: object) -> Optional[UltravoxAudioInputs]: audio_features = kwargs.pop("audio_features", None) audio_embeds = kwargs.pop("audio_embeds", None) + audio_lens = kwargs.pop("audio_lens", None) + audio_token_len = kwargs.pop("audio_token_len", None) if audio_features is None and audio_embeds is None: return None @@ -430,7 +506,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): f"Got type: {type(audio_features)}") return UltravoxAudioFeatureInputs(type="audio_features", - data=audio_features) + data=audio_features, + lens=audio_lens, + token_len=audio_token_len) if audio_embeds is not None: if not isinstance(audio_embeds, (torch.Tensor, list)): @@ -447,34 +525,34 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): if audio_input["type"] == "audio_embeds": return audio_input["data"] - audio_features = audio_input["data"] - if isinstance(audio_features, torch.Tensor): - # Combine the B and N dimensions for the encoder/projector - flattened = flatten_bn(audio_features) - flattened_embeddings = self._audio_features_to_embeddings( - flattened) + # Pad and concatenate audio features + # [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)] + audio_features = pad_and_concat_to_dim3(audio_input["data"]) - # Restore the original dimensions - embeddings = flattened_embeddings.unflatten( - 0, audio_features.shape[:2]) - return embeddings + if isinstance(audio_input['lens'], list): + # [B1, B2] -> [B1+B2] + audio_lens = torch.cat(audio_input['lens']) + audio_token_len = torch.cat(audio_input['token_len']) + else: + audio_lens = flatten_bn(audio_input['lens']) + audio_token_len = flatten_bn(audio_input['token_len']) - result = [] - # TODO: Batch heterogeneous tensors through the encoder/projector - for audio_features_item in audio_features: - if isinstance(audio_features_item, torch.Tensor): - result.append( - self._audio_features_to_embeddings(audio_features_item)) - else: - embeddings = [ - # Add a batch dimension to embed it, then remove it. - self._audio_features_to_embeddings(tensor.unsqueeze(0) - ).squeeze(0) - for tensor in audio_features_item - ] - result.append(embeddings) + embeddings = self._audio_features_to_embeddings( + audio_features, audio_lens) - return result + # We should flatten and concatenate embeddings based on token lengths + # For example, with token_len = [4, 2, 3], flattened_embeddings will be + # concat(embeddings[0][:4], embeddings[1][:2], embeddings[2][:3]) + + # Create a mask of valid indices based on token lengths + max_len = embeddings.shape[1] + indices = torch.arange(max_len, device=embeddings.device).expand( + embeddings.shape[0], -1) + mask = indices < audio_token_len[:, None] + # Apply mask and flatten + flattened_embeddings = embeddings[mask] + + return flattened_embeddings def get_multimodal_embeddings( self, **kwargs @@ -521,7 +599,11 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): with the `input_ids`. Args: - audio_features: A batch of audio inputs [B, N, 80, M]. + audio_features: A batch of audio input chunks [B, N, 80, M]. + audio_lens: Length of audio frames for each audio chunk [B]. + audio_token_len: Length of audio tokens for each audio chunk [B']. + Note: batch dim is different from batch dim in audio chunks. + """ if intermediate_tensors is not None: @@ -560,3 +642,31 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): loader = AutoWeightsLoader(self, ignore_unexpected_prefixes=["audio_tower."]) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + +def pad_and_concat_to_dim3( + features: Union[torch.Tensor, List[torch.Tensor], List[List[torch.Tensor]]] +) -> torch.Tensor: + """ + Pad and concatenate a list of tensors. + + output: + Tensor of shape [B, C, M] where M is the maximum length of the input + tensors, B is the sum of the batch sizes of the input tensors. + C must be the same for all input tensors. + """ + if isinstance(features, torch.Tensor): + if features.ndim > 3: + # Flatten [B, N, 80, M] -> [B * N, 80, M] + features = flatten_bn(features) + return features + + features = [pad_and_concat_to_dim3(f) for f in features] + + max_len = max(f.shape[-1] for f in features) + # Ensure all features have dim=3 + features = [f.view(-1, *f.shape[-2:]) for f in features] + # Pad and oncatenate: + # [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)] + features = [F.pad(f, (0, max_len - f.shape[-1])) for f in features] + return torch.cat(features)