diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index d368c145d55f9..cb1e143838496 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -5,7 +5,7 @@ import math from collections.abc import Iterable, Mapping, Sequence from functools import cached_property -from typing import Any, List, Literal, Optional, Set, Tuple, TypedDict, Union +from typing import Any, Literal, Optional, Set, Tuple, TypedDict, Union import torch import torch.utils.checkpoint @@ -36,7 +36,7 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.ultravox import UltravoxConfig from .interfaces import (MultiModalEmbeddings, SupportsLoRA, - SupportsMultiModal, SupportsPP, SupportsV0Only) + SupportsMultiModal, SupportsPP) from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings, @@ -50,14 +50,14 @@ _MAX_ENCODER_BATCH_SIZE = 16 class UltravoxAudioFeatureInputs(TypedDict): type: Literal["audio_features"] - data: NestedTensors + data: Union[torch.Tensor, list[torch.Tensor], list[list[torch.Tensor]]] """Shape: `(batch_size, num_chunks, 80, M)`""" - lens: NestedTensors + lens: Union[torch.Tensor, list[torch.Tensor]] """ Length of the audio frames. Used for attention mask in WhisperEncoder. Shape: `(batch_size, num_chunks)` """ - token_len: NestedTensors + token_len: Union[torch.Tensor, list[torch.Tensor]] """ Length of the audio tokens. Used for flattening the audio features. Shape: `(batch_size, num_chunks)` @@ -405,8 +405,7 @@ class ModifiedWhisperEncoder(WhisperEncoder): UltravoxMultiModalProcessor, info=UltravoxProcessingInfo, dummy_inputs=UltravoxDummyInputsBuilder) -class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, - SupportsV0Only): +class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], @@ -506,6 +505,12 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, if not isinstance(audio_features, (torch.Tensor, list)): raise ValueError("Incorrect type of audio features. " f"Got type: {type(audio_features)}") + if not isinstance(audio_lens, (torch.Tensor, list)): + raise ValueError("Incorrect type of audio_lens. " + f"Got type: {type(audio_features)}") + if not isinstance(audio_token_len, (torch.Tensor, list)): + raise ValueError("Incorrect type of audio_token_len. " + f"Got type: {type(audio_features)}") return UltravoxAudioFeatureInputs(type="audio_features", data=audio_features, @@ -523,7 +528,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, raise AssertionError("This line should be unreachable.") def _process_audio_input( - self, audio_input: UltravoxAudioInputs) -> NestedTensors: + self, + audio_input: UltravoxAudioInputs, + ) -> Union[NestedTensors, tuple[torch.Tensor, ...]]: if audio_input["type"] == "audio_embeds": return audio_input["data"] @@ -531,13 +538,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, # [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)] audio_features = pad_and_concat_to_dim3(audio_input["data"]) - if isinstance(audio_input['lens'], list): - # [B1, B2] -> [B1+B2] - audio_lens = torch.cat(audio_input['lens']) - audio_token_len = torch.cat(audio_input['token_len']) - else: - audio_lens = flatten_bn(audio_input['lens']) - audio_token_len = flatten_bn(audio_input['token_len']) + # [B1, B2] -> [B1+B2] + audio_lens = flatten_bn(audio_input['lens'], concat=True) + audio_token_len = flatten_bn(audio_input['token_len'], concat=True) embeddings = self._audio_features_to_embeddings( audio_features, audio_lens) @@ -554,7 +557,12 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, # Apply mask and flatten flattened_embeddings = embeddings[mask] - return flattened_embeddings + # Return one tensor per input audio + embed_lens = [ + token_len_item.sum().item() + for token_len_item in audio_input['token_len'] + ] + return flattened_embeddings.split(embed_lens) def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: @@ -646,7 +654,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, def pad_and_concat_to_dim3( - features: Union[torch.Tensor, List[torch.Tensor], List[List[torch.Tensor]]] + features: Union[torch.Tensor, list[torch.Tensor], list[list[torch.Tensor]]] ) -> torch.Tensor: """ Pad and concatenate a list of tensors.