mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 17:15:01 +08:00
Migrate ultravox inputs to TensorSchema (#23503)
Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
parent
712b273f65
commit
cb55ad86fe
@ -4,7 +4,7 @@
|
|||||||
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
|
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
|
||||||
"""PyTorch Ultravox model."""
|
"""PyTorch Ultravox model."""
|
||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
from typing import Any, Literal, Optional, TypedDict, Union
|
from typing import Annotated, Any, Literal, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -31,6 +31,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
|
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
|
||||||
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|
||||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||||
SupportsMultiModal, SupportsPP)
|
SupportsMultiModal, SupportsPP)
|
||||||
@ -43,26 +44,37 @@ _AUDIO_PLACEHOLDER_OVERRIDE = "<|audio|>"
|
|||||||
_MAX_ENCODER_BATCH_SIZE = 16
|
_MAX_ENCODER_BATCH_SIZE = 16
|
||||||
|
|
||||||
|
|
||||||
class UltravoxAudioFeatureInputs(TypedDict):
|
class UltravoxAudioFeatureInputs(TensorSchema):
|
||||||
|
"""
|
||||||
|
Dimensions:
|
||||||
|
- b: batch size
|
||||||
|
- n: number of chunks
|
||||||
|
- t: Time frames (M)
|
||||||
|
- nmb: Number of mel bins
|
||||||
|
"""
|
||||||
type: Literal["audio_features"]
|
type: Literal["audio_features"]
|
||||||
data: Union[torch.Tensor, list[torch.Tensor], list[list[torch.Tensor]]]
|
data: Annotated[Union[torch.Tensor, list[torch.Tensor],
|
||||||
"""Shape: `(batch_size, num_chunks, 80, M)`"""
|
list[list[torch.Tensor]]],
|
||||||
lens: Union[torch.Tensor, list[torch.Tensor]]
|
TensorShape("b", "n", "nmb", "t", dynamic_dims={"n"})]
|
||||||
"""
|
lens: Annotated[Union[torch.Tensor, list[torch.Tensor]],
|
||||||
Length of the audio frames. Used for attention mask in WhisperEncoder.
|
TensorShape("b", "n", dynamic_dims={"n"})]
|
||||||
Shape: `(batch_size, num_chunks)`
|
"""Length of the audio frames. Used for attention mask in WhisperEncoder."""
|
||||||
"""
|
token_len: Annotated[Union[torch.Tensor, list[torch.Tensor]],
|
||||||
token_len: Union[torch.Tensor, list[torch.Tensor]]
|
TensorShape("b", "n", dynamic_dims={"n"})]
|
||||||
"""
|
"""Length of the audio tokens. Used for flattening the audio features."""
|
||||||
Length of the audio tokens. Used for flattening the audio features.
|
|
||||||
Shape: `(batch_size, num_chunks)`
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class UltravoxAudioEmbeddingInputs(TypedDict):
|
class UltravoxAudioEmbeddingInputs(TensorSchema):
|
||||||
|
"""
|
||||||
|
Dimensions:
|
||||||
|
- b: batch size
|
||||||
|
- na: number of audios
|
||||||
|
- afs: audio feature size
|
||||||
|
- hs: hidden size
|
||||||
|
"""
|
||||||
type: Literal["audio_embeds"]
|
type: Literal["audio_embeds"]
|
||||||
data: NestedTensors
|
data: Annotated[Union[torch.Tensor, list[torch.Tensor]],
|
||||||
"""Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)`"""
|
TensorShape("b", "na", "afs", "hs")]
|
||||||
|
|
||||||
|
|
||||||
UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs,
|
UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs,
|
||||||
@ -484,26 +496,12 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
if audio_features is not None:
|
if audio_features is not None:
|
||||||
if not isinstance(audio_features, (torch.Tensor, list)):
|
|
||||||
raise ValueError("Incorrect type of audio features. "
|
|
||||||
f"Got type: {type(audio_features)}")
|
|
||||||
if not isinstance(audio_lens, (torch.Tensor, list)):
|
|
||||||
raise ValueError("Incorrect type of audio_lens. "
|
|
||||||
f"Got type: {type(audio_features)}")
|
|
||||||
if not isinstance(audio_token_len, (torch.Tensor, list)):
|
|
||||||
raise ValueError("Incorrect type of audio_token_len. "
|
|
||||||
f"Got type: {type(audio_features)}")
|
|
||||||
|
|
||||||
return UltravoxAudioFeatureInputs(type="audio_features",
|
return UltravoxAudioFeatureInputs(type="audio_features",
|
||||||
data=audio_features,
|
data=audio_features,
|
||||||
lens=audio_lens,
|
lens=audio_lens,
|
||||||
token_len=audio_token_len)
|
token_len=audio_token_len)
|
||||||
|
|
||||||
if audio_embeds is not None:
|
if audio_embeds is not None:
|
||||||
if not isinstance(audio_embeds, (torch.Tensor, list)):
|
|
||||||
raise ValueError("Incorrect type of audio embeds. "
|
|
||||||
f"Got type: {type(audio_embeds)}")
|
|
||||||
|
|
||||||
return UltravoxAudioEmbeddingInputs(type="audio_embeds",
|
return UltravoxAudioEmbeddingInputs(type="audio_embeds",
|
||||||
data=audio_embeds)
|
data=audio_embeds)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user