mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-22 21:24:32 +08:00
[Bugfix] Fix Ultravox on V1 (#14929)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
b4ad56c1bd
commit
868a8c5b2c
@ -5,7 +5,7 @@
|
|||||||
import math
|
import math
|
||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Any, List, Literal, Optional, Set, Tuple, TypedDict, Union
|
from typing import Any, Literal, Optional, Set, Tuple, TypedDict, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
@ -36,7 +36,7 @@ from vllm.sequence import IntermediateTensors
|
|||||||
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
|
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
|
||||||
|
|
||||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||||
SupportsMultiModal, SupportsPP, SupportsV0Only)
|
SupportsMultiModal, SupportsPP)
|
||||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||||
init_vllm_registered_model, maybe_prefix,
|
init_vllm_registered_model, maybe_prefix,
|
||||||
merge_multimodal_embeddings,
|
merge_multimodal_embeddings,
|
||||||
@ -50,14 +50,14 @@ _MAX_ENCODER_BATCH_SIZE = 16
|
|||||||
|
|
||||||
class UltravoxAudioFeatureInputs(TypedDict):
|
class UltravoxAudioFeatureInputs(TypedDict):
|
||||||
type: Literal["audio_features"]
|
type: Literal["audio_features"]
|
||||||
data: NestedTensors
|
data: Union[torch.Tensor, list[torch.Tensor], list[list[torch.Tensor]]]
|
||||||
"""Shape: `(batch_size, num_chunks, 80, M)`"""
|
"""Shape: `(batch_size, num_chunks, 80, M)`"""
|
||||||
lens: NestedTensors
|
lens: Union[torch.Tensor, list[torch.Tensor]]
|
||||||
"""
|
"""
|
||||||
Length of the audio frames. Used for attention mask in WhisperEncoder.
|
Length of the audio frames. Used for attention mask in WhisperEncoder.
|
||||||
Shape: `(batch_size, num_chunks)`
|
Shape: `(batch_size, num_chunks)`
|
||||||
"""
|
"""
|
||||||
token_len: NestedTensors
|
token_len: Union[torch.Tensor, list[torch.Tensor]]
|
||||||
"""
|
"""
|
||||||
Length of the audio tokens. Used for flattening the audio features.
|
Length of the audio tokens. Used for flattening the audio features.
|
||||||
Shape: `(batch_size, num_chunks)`
|
Shape: `(batch_size, num_chunks)`
|
||||||
@ -405,8 +405,7 @@ class ModifiedWhisperEncoder(WhisperEncoder):
|
|||||||
UltravoxMultiModalProcessor,
|
UltravoxMultiModalProcessor,
|
||||||
info=UltravoxProcessingInfo,
|
info=UltravoxProcessingInfo,
|
||||||
dummy_inputs=UltravoxDummyInputsBuilder)
|
dummy_inputs=UltravoxDummyInputsBuilder)
|
||||||
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
||||||
SupportsV0Only):
|
|
||||||
|
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||||
@ -506,6 +505,12 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
|||||||
if not isinstance(audio_features, (torch.Tensor, list)):
|
if not isinstance(audio_features, (torch.Tensor, list)):
|
||||||
raise ValueError("Incorrect type of audio features. "
|
raise ValueError("Incorrect type of audio features. "
|
||||||
f"Got type: {type(audio_features)}")
|
f"Got type: {type(audio_features)}")
|
||||||
|
if not isinstance(audio_lens, (torch.Tensor, list)):
|
||||||
|
raise ValueError("Incorrect type of audio_lens. "
|
||||||
|
f"Got type: {type(audio_features)}")
|
||||||
|
if not isinstance(audio_token_len, (torch.Tensor, list)):
|
||||||
|
raise ValueError("Incorrect type of audio_token_len. "
|
||||||
|
f"Got type: {type(audio_features)}")
|
||||||
|
|
||||||
return UltravoxAudioFeatureInputs(type="audio_features",
|
return UltravoxAudioFeatureInputs(type="audio_features",
|
||||||
data=audio_features,
|
data=audio_features,
|
||||||
@ -523,7 +528,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
|||||||
raise AssertionError("This line should be unreachable.")
|
raise AssertionError("This line should be unreachable.")
|
||||||
|
|
||||||
def _process_audio_input(
|
def _process_audio_input(
|
||||||
self, audio_input: UltravoxAudioInputs) -> NestedTensors:
|
self,
|
||||||
|
audio_input: UltravoxAudioInputs,
|
||||||
|
) -> Union[NestedTensors, tuple[torch.Tensor, ...]]:
|
||||||
if audio_input["type"] == "audio_embeds":
|
if audio_input["type"] == "audio_embeds":
|
||||||
return audio_input["data"]
|
return audio_input["data"]
|
||||||
|
|
||||||
@ -531,13 +538,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
|||||||
# [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
|
# [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
|
||||||
audio_features = pad_and_concat_to_dim3(audio_input["data"])
|
audio_features = pad_and_concat_to_dim3(audio_input["data"])
|
||||||
|
|
||||||
if isinstance(audio_input['lens'], list):
|
# [B1, B2] -> [B1+B2]
|
||||||
# [B1, B2] -> [B1+B2]
|
audio_lens = flatten_bn(audio_input['lens'], concat=True)
|
||||||
audio_lens = torch.cat(audio_input['lens'])
|
audio_token_len = flatten_bn(audio_input['token_len'], concat=True)
|
||||||
audio_token_len = torch.cat(audio_input['token_len'])
|
|
||||||
else:
|
|
||||||
audio_lens = flatten_bn(audio_input['lens'])
|
|
||||||
audio_token_len = flatten_bn(audio_input['token_len'])
|
|
||||||
|
|
||||||
embeddings = self._audio_features_to_embeddings(
|
embeddings = self._audio_features_to_embeddings(
|
||||||
audio_features, audio_lens)
|
audio_features, audio_lens)
|
||||||
@ -554,7 +557,12 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
|||||||
# Apply mask and flatten
|
# Apply mask and flatten
|
||||||
flattened_embeddings = embeddings[mask]
|
flattened_embeddings = embeddings[mask]
|
||||||
|
|
||||||
return flattened_embeddings
|
# Return one tensor per input audio
|
||||||
|
embed_lens = [
|
||||||
|
token_len_item.sum().item()
|
||||||
|
for token_len_item in audio_input['token_len']
|
||||||
|
]
|
||||||
|
return flattened_embeddings.split(embed_lens)
|
||||||
|
|
||||||
def get_multimodal_embeddings(
|
def get_multimodal_embeddings(
|
||||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||||
@ -646,7 +654,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
|||||||
|
|
||||||
|
|
||||||
def pad_and_concat_to_dim3(
|
def pad_and_concat_to_dim3(
|
||||||
features: Union[torch.Tensor, List[torch.Tensor], List[List[torch.Tensor]]]
|
features: Union[torch.Tensor, list[torch.Tensor], list[list[torch.Tensor]]]
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Pad and concatenate a list of tensors.
|
Pad and concatenate a list of tensors.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user