mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 20:35:26 +08:00
[Model] Add Ultravox support for multiple audio chunks (#7963)
This commit is contained in:
parent
e16fa99a6a
commit
2be8ec6e71
@ -11,25 +11,33 @@ from vllm import LLM, SamplingParams
|
|||||||
from vllm.assets.audio import AudioAsset
|
from vllm.assets.audio import AudioAsset
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
# Input audio and question
|
audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
|
||||||
audio_and_sample_rate = AudioAsset("mary_had_lamb").audio_and_sample_rate
|
question_per_audio_count = [
|
||||||
question = "What is recited in the audio?"
|
"What is recited in the audio?",
|
||||||
|
"What sport and what nursery rhyme are referenced?"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
# Ultravox 0.3
|
# Ultravox 0.3
|
||||||
def run_ultravox(question):
|
def run_ultravox(question, audio_count):
|
||||||
model_name = "fixie-ai/ultravox-v0_3"
|
model_name = "fixie-ai/ultravox-v0_3"
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
messages = [{
|
messages = [{
|
||||||
'role': 'user',
|
'role':
|
||||||
'content': f"<|reserved_special_token_0|>\n{question}"
|
'user',
|
||||||
|
'content':
|
||||||
|
"<|reserved_special_token_0|>\n" * audio_count + question
|
||||||
}]
|
}]
|
||||||
prompt = tokenizer.apply_chat_template(messages,
|
prompt = tokenizer.apply_chat_template(messages,
|
||||||
tokenize=False,
|
tokenize=False,
|
||||||
add_generation_prompt=True)
|
add_generation_prompt=True)
|
||||||
|
|
||||||
llm = LLM(model=model_name)
|
llm = LLM(model=model_name,
|
||||||
|
enforce_eager=True,
|
||||||
|
enable_chunked_prefill=False,
|
||||||
|
max_model_len=8192,
|
||||||
|
limit_mm_per_prompt={"audio": audio_count})
|
||||||
stop_token_ids = None
|
stop_token_ids = None
|
||||||
return llm, prompt, stop_token_ids
|
return llm, prompt, stop_token_ids
|
||||||
|
|
||||||
@ -44,7 +52,9 @@ def main(args):
|
|||||||
if model not in model_example_map:
|
if model not in model_example_map:
|
||||||
raise ValueError(f"Model type {model} is not supported.")
|
raise ValueError(f"Model type {model} is not supported.")
|
||||||
|
|
||||||
llm, prompt, stop_token_ids = model_example_map[model](question)
|
audio_count = args.num_audios
|
||||||
|
llm, prompt, stop_token_ids = model_example_map[model](
|
||||||
|
question_per_audio_count[audio_count - 1], audio_count)
|
||||||
|
|
||||||
# We set temperature to 0.2 so that outputs can be different
|
# We set temperature to 0.2 so that outputs can be different
|
||||||
# even when all prompts are identical when running batch inference.
|
# even when all prompts are identical when running batch inference.
|
||||||
@ -53,23 +63,18 @@ def main(args):
|
|||||||
stop_token_ids=stop_token_ids)
|
stop_token_ids=stop_token_ids)
|
||||||
|
|
||||||
assert args.num_prompts > 0
|
assert args.num_prompts > 0
|
||||||
if args.num_prompts == 1:
|
inputs = {
|
||||||
# Single inference
|
"prompt": prompt,
|
||||||
inputs = {
|
"multi_modal_data": {
|
||||||
"prompt": prompt,
|
"audio": [
|
||||||
"multi_modal_data": {
|
asset.audio_and_sample_rate
|
||||||
"audio": audio_and_sample_rate
|
for asset in audio_assets[:audio_count]
|
||||||
},
|
]
|
||||||
}
|
},
|
||||||
|
}
|
||||||
else:
|
if args.num_prompts > 1:
|
||||||
# Batch inference
|
# Batch inference
|
||||||
inputs = [{
|
inputs = [inputs] * args.num_prompts
|
||||||
"prompt": prompt,
|
|
||||||
"multi_modal_data": {
|
|
||||||
"audio": audio_and_sample_rate
|
|
||||||
},
|
|
||||||
} for _ in range(args.num_prompts)]
|
|
||||||
|
|
||||||
outputs = llm.generate(inputs, sampling_params=sampling_params)
|
outputs = llm.generate(inputs, sampling_params=sampling_params)
|
||||||
|
|
||||||
@ -92,6 +97,11 @@ if __name__ == "__main__":
|
|||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=1,
|
||||||
help='Number of prompts to run.')
|
help='Number of prompts to run.')
|
||||||
|
parser.add_argument("--num-audios",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
choices=[1, 2],
|
||||||
|
help="Number of audio items per prompt.")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
|||||||
@ -16,37 +16,32 @@ MODEL_NAME = "fixie-ai/ultravox-v0_3"
|
|||||||
|
|
||||||
AudioTuple = Tuple[np.ndarray, int]
|
AudioTuple = Tuple[np.ndarray, int]
|
||||||
|
|
||||||
|
VLLM_PLACEHOLDER = "<|reserved_special_token_0|>"
|
||||||
|
HF_PLACEHOLDER = "<|audio|>"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def audio_and_sample_rate():
|
def audio_assets():
|
||||||
from vllm.assets.audio import AudioAsset
|
from vllm.assets.audio import AudioAsset
|
||||||
return AudioAsset("mary_had_lamb").audio_and_sample_rate
|
return [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture(scope="module", params=("mary_had_lamb", "winning_call"))
|
||||||
def prompts_and_audios(audio_and_sample_rate):
|
def audio(request):
|
||||||
|
from vllm.assets.audio import AudioAsset
|
||||||
|
return AudioAsset(request.param)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_prompt(audio_count, question, placeholder):
|
||||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||||
|
placeholder = f"{placeholder}\n" * audio_count
|
||||||
|
|
||||||
vllm_placeholder = "<|reserved_special_token_0|>"
|
return tokenizer.apply_chat_template([{
|
||||||
hf_placeholder = "<|audio|>"
|
'role': 'user',
|
||||||
|
'content': f"{placeholder}{question}"
|
||||||
question = "What's in the audio?"
|
}],
|
||||||
vllm_prompt = tokenizer.apply_chat_template(
|
tokenize=False,
|
||||||
[{
|
add_generation_prompt=True)
|
||||||
'role': 'user',
|
|
||||||
'content': f"{vllm_placeholder}\n{question}"
|
|
||||||
}],
|
|
||||||
tokenize=False,
|
|
||||||
add_generation_prompt=True)
|
|
||||||
hf_prompt = tokenizer.apply_chat_template(
|
|
||||||
[{
|
|
||||||
'role': 'user',
|
|
||||||
'content': f"{hf_placeholder}\n{question}"
|
|
||||||
}],
|
|
||||||
tokenize=False,
|
|
||||||
add_generation_prompt=True)
|
|
||||||
|
|
||||||
return [(vllm_prompt, hf_prompt, audio_and_sample_rate)]
|
|
||||||
|
|
||||||
|
|
||||||
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
|
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
|
||||||
@ -134,15 +129,71 @@ def run_test(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run_multi_audio_test(
|
||||||
|
vllm_runner: Type[VllmRunner],
|
||||||
|
prompts_and_audios: List[Tuple[str, List[AudioTuple]]],
|
||||||
|
model: str,
|
||||||
|
*,
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
num_logprobs: int,
|
||||||
|
tensor_parallel_size: int,
|
||||||
|
distributed_executor_backend: Optional[str] = None,
|
||||||
|
):
|
||||||
|
with vllm_runner(model,
|
||||||
|
dtype=dtype,
|
||||||
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
|
distributed_executor_backend=distributed_executor_backend,
|
||||||
|
enforce_eager=True,
|
||||||
|
limit_mm_per_prompt={
|
||||||
|
"audio":
|
||||||
|
max((len(audio) for _, audio in prompts_and_audios))
|
||||||
|
}) as vllm_model:
|
||||||
|
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
||||||
|
[prompt for prompt, _ in prompts_and_audios],
|
||||||
|
max_tokens,
|
||||||
|
num_logprobs=num_logprobs,
|
||||||
|
audios=[audios for _, audios in prompts_and_audios])
|
||||||
|
|
||||||
|
# The HuggingFace model doesn't support multiple audios yet, so
|
||||||
|
# just assert that some tokens were generated.
|
||||||
|
assert all(tokens for tokens, *_ in vllm_outputs)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dtype", ["half"])
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
@pytest.mark.parametrize("max_tokens", [128])
|
@pytest.mark.parametrize("max_tokens", [128])
|
||||||
@pytest.mark.parametrize("num_logprobs", [5])
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
def test_models(hf_runner, vllm_runner, prompts_and_audios, dtype: str,
|
def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int,
|
||||||
max_tokens: int, num_logprobs: int) -> None:
|
num_logprobs: int) -> None:
|
||||||
|
|
||||||
|
vllm_prompt = _get_prompt(1, "Describe the audio above.", VLLM_PLACEHOLDER)
|
||||||
|
hf_prompt = _get_prompt(1, "Describe the audio above.", HF_PLACEHOLDER)
|
||||||
run_test(
|
run_test(
|
||||||
hf_runner,
|
hf_runner,
|
||||||
vllm_runner,
|
vllm_runner,
|
||||||
prompts_and_audios,
|
[(vllm_prompt, hf_prompt, audio.audio_and_sample_rate)],
|
||||||
|
MODEL_NAME,
|
||||||
|
dtype=dtype,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
num_logprobs=num_logprobs,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [128])
|
||||||
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
|
def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
num_logprobs: int) -> None:
|
||||||
|
|
||||||
|
vllm_prompt = _get_prompt(len(audio_assets),
|
||||||
|
"Describe each of the audios above.",
|
||||||
|
VLLM_PLACEHOLDER)
|
||||||
|
run_multi_audio_test(
|
||||||
|
vllm_runner,
|
||||||
|
[(vllm_prompt, [audio.audio_and_sample_rate
|
||||||
|
for audio in audio_assets])],
|
||||||
MODEL_NAME,
|
MODEL_NAME,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
|
|||||||
@ -29,12 +29,12 @@ from vllm.model_executor.layers.quantization.base_config import (
|
|||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||||
from vllm.model_executor.models.utils import (filter_weights,
|
from vllm.model_executor.models.utils import (filter_weights, flatten_bn,
|
||||||
init_vllm_registered_model,
|
init_vllm_registered_model,
|
||||||
merge_multimodal_embeddings)
|
merge_multimodal_embeddings)
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.base import MultiModalInputs
|
from vllm.multimodal.base import MultiModalInputs, NestedTensors
|
||||||
from vllm.multimodal.utils import (cached_get_tokenizer,
|
from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||||
repeat_and_pad_placeholder_tokens)
|
repeat_and_pad_placeholder_tokens)
|
||||||
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
|
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
|
||||||
@ -48,13 +48,14 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
class UltravoxAudioFeatureInputs(TypedDict):
|
class UltravoxAudioFeatureInputs(TypedDict):
|
||||||
type: Literal["audio_features"]
|
type: Literal["audio_features"]
|
||||||
data: Union[torch.Tensor, List[torch.Tensor]]
|
data: NestedTensors
|
||||||
"""Shape: `(batch_size * num_audios, 80, M)"""
|
"""Shape: `(batch_size, num_audios, 80, M)"""
|
||||||
|
|
||||||
|
|
||||||
class UltravoxAudioEmbeddingInputs(TypedDict):
|
class UltravoxAudioEmbeddingInputs(TypedDict):
|
||||||
type: Literal["audio_embeds"]
|
type: Literal["audio_embeds"]
|
||||||
data: torch.Tensor
|
data: NestedTensors
|
||||||
|
"""Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)"""
|
||||||
|
|
||||||
|
|
||||||
UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs,
|
UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs,
|
||||||
@ -85,24 +86,33 @@ def dummy_data_for_ultravox(
|
|||||||
|
|
||||||
audio_count = mm_counts["audio"]
|
audio_count = mm_counts["audio"]
|
||||||
|
|
||||||
audio_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [
|
audio_placeholder = array(
|
||||||
_AUDIO_PLACEHOLDER_TOKEN
|
VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||||
]) * get_ultravox_max_audio_tokens(ctx) * audio_count
|
[_AUDIO_PLACEHOLDER_TOKEN]) * get_ultravox_max_audio_tokens(ctx)
|
||||||
|
|
||||||
|
# Add a separator between each chunk.
|
||||||
|
audio_token_ids = (audio_placeholder +
|
||||||
|
array(VLLM_TOKEN_ID_ARRAY_TYPE, [0])) * audio_count
|
||||||
other_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
other_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||||
[0]) * (seq_len - len(audio_token_ids))
|
[0]) * (seq_len - len(audio_token_ids))
|
||||||
|
|
||||||
audio_and_sr = (np.array([0.0] * feature_extractor.chunk_length), 1)
|
audio_and_sr = (np.array([0.0] * feature_extractor.chunk_length), 1)
|
||||||
mm_dict = {
|
mm_dict = {"audio": [audio_and_sr] * audio_count}
|
||||||
"audio":
|
|
||||||
audio_and_sr if audio_count == 1 else [audio_and_sr] * audio_count
|
|
||||||
}
|
|
||||||
|
|
||||||
return (SequenceData(audio_token_ids + other_token_ids), mm_dict)
|
return (SequenceData(audio_token_ids + other_token_ids), mm_dict)
|
||||||
|
|
||||||
|
|
||||||
def input_mapper_for_ultravox(ctx: InputContext, data: object):
|
def input_mapper_for_ultravox(ctx: InputContext, data: object):
|
||||||
if isinstance(data, tuple):
|
if not isinstance(data, list):
|
||||||
(audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], data)
|
data = [data]
|
||||||
|
|
||||||
|
audio_features = []
|
||||||
|
for audio_input in data:
|
||||||
|
if not isinstance(audio_input, tuple):
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Unsupported data type: {type(audio_input)}")
|
||||||
|
|
||||||
|
(audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], audio_input)
|
||||||
feature_extractor = whisper_feature_extractor(ctx)
|
feature_extractor = whisper_feature_extractor(ctx)
|
||||||
|
|
||||||
if sr != feature_extractor.sampling_rate:
|
if sr != feature_extractor.sampling_rate:
|
||||||
@ -121,15 +131,14 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object):
|
|||||||
# Not enough audio; pad it.
|
# Not enough audio; pad it.
|
||||||
audio = np.pad(audio, (0, minimum_audio_length - len(audio)))
|
audio = np.pad(audio, (0, minimum_audio_length - len(audio)))
|
||||||
|
|
||||||
return MultiModalInputs({
|
single_audio_features = feature_extractor(
|
||||||
"audio_features":
|
audio, sampling_rate=sr, padding="longest",
|
||||||
feature_extractor(audio,
|
return_tensors="pt")["input_features"]
|
||||||
sampling_rate=sr,
|
|
||||||
padding="longest",
|
|
||||||
return_tensors="pt")["input_features"]
|
|
||||||
})
|
|
||||||
|
|
||||||
raise NotImplementedError(f"Unsupported data type: {type(data)}")
|
# Remove the batch dimension because we're wrapping it in a list.
|
||||||
|
audio_features.append(single_audio_features.squeeze(0))
|
||||||
|
|
||||||
|
return MultiModalInputs({"audio_features": audio_features})
|
||||||
|
|
||||||
|
|
||||||
def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs):
|
def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs):
|
||||||
@ -138,25 +147,31 @@ def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs):
|
|||||||
return llm_inputs
|
return llm_inputs
|
||||||
|
|
||||||
feature_extractor = whisper_feature_extractor(ctx)
|
feature_extractor = whisper_feature_extractor(ctx)
|
||||||
audio_data, sample_rate = multi_modal_data["audio"]
|
audios = multi_modal_data["audio"]
|
||||||
|
if not isinstance(audios, list):
|
||||||
|
audios = [audios]
|
||||||
|
|
||||||
audio_length = audio_data.shape[0]
|
audio_token_counts = []
|
||||||
if sample_rate != feature_extractor.sampling_rate:
|
for audio_data, sample_rate in audios:
|
||||||
# Account for resampling.
|
audio_length = audio_data.shape[0]
|
||||||
adjustment = feature_extractor.sampling_rate / sample_rate
|
if sample_rate != feature_extractor.sampling_rate:
|
||||||
audio_length = math.ceil(adjustment * audio_length)
|
# Account for resampling.
|
||||||
|
adjustment = feature_extractor.sampling_rate / sample_rate
|
||||||
|
audio_length = math.ceil(adjustment * audio_length)
|
||||||
|
|
||||||
feature_extractor_output_length = math.ceil(
|
feature_extractor_output_length = math.ceil(
|
||||||
(audio_length -
|
(audio_length - (feature_extractor.hop_length - 1)) /
|
||||||
(feature_extractor.hop_length - 1)) / feature_extractor.hop_length)
|
feature_extractor.hop_length)
|
||||||
|
|
||||||
|
uv_config = ctx.get_hf_config(UltravoxConfig)
|
||||||
|
audio_num_tokens = min(
|
||||||
|
max(
|
||||||
|
1,
|
||||||
|
math.ceil(feature_extractor_output_length /
|
||||||
|
(uv_config.stack_factor * 2))),
|
||||||
|
get_ultravox_max_audio_tokens(ctx))
|
||||||
|
audio_token_counts.append(audio_num_tokens)
|
||||||
|
|
||||||
uv_config = ctx.get_hf_config(UltravoxConfig)
|
|
||||||
audio_num_tokens = min(
|
|
||||||
max(
|
|
||||||
1,
|
|
||||||
math.ceil(feature_extractor_output_length /
|
|
||||||
(uv_config.stack_factor * 2))),
|
|
||||||
get_ultravox_max_audio_tokens(ctx))
|
|
||||||
tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
|
tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
|
||||||
|
|
||||||
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
|
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
|
||||||
@ -164,7 +179,7 @@ def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs):
|
|||||||
llm_inputs.get("prompt"),
|
llm_inputs.get("prompt"),
|
||||||
llm_inputs["prompt_token_ids"],
|
llm_inputs["prompt_token_ids"],
|
||||||
placeholder_token_id=_AUDIO_PLACEHOLDER_TOKEN,
|
placeholder_token_id=_AUDIO_PLACEHOLDER_TOKEN,
|
||||||
repeat_count=audio_num_tokens,
|
repeat_count=audio_token_counts,
|
||||||
)
|
)
|
||||||
|
|
||||||
# NOTE: Create a defensive copy of the original inputs
|
# NOTE: Create a defensive copy of the original inputs
|
||||||
@ -338,45 +353,52 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
|
|||||||
raise ValueError("Incorrect type of audio features. "
|
raise ValueError("Incorrect type of audio features. "
|
||||||
f"Got type: {type(audio_features)}")
|
f"Got type: {type(audio_features)}")
|
||||||
|
|
||||||
# Remove the N dimension until multiple audios are supported.
|
|
||||||
if isinstance(audio_features, torch.Tensor):
|
|
||||||
audio_features = audio_features.squeeze(1)
|
|
||||||
else:
|
|
||||||
audio_features = [t.squeeze(0) for t in audio_features]
|
|
||||||
|
|
||||||
return UltravoxAudioFeatureInputs(type="audio_features",
|
return UltravoxAudioFeatureInputs(type="audio_features",
|
||||||
data=audio_features)
|
data=audio_features)
|
||||||
|
|
||||||
if audio_embeds is not None:
|
if audio_embeds is not None:
|
||||||
if not isinstance(audio_embeds, torch.Tensor):
|
if not isinstance(audio_embeds, (torch.Tensor, list)):
|
||||||
raise ValueError("Incorrect type of audio embeds. "
|
raise ValueError("Incorrect type of audio embeds. "
|
||||||
f"Got type: {type(audio_embeds)}")
|
f"Got type: {type(audio_embeds)}")
|
||||||
|
|
||||||
# Remove the N dimension until multiple audios are supported.
|
|
||||||
audio_embeds = audio_embeds.squeeze(1)
|
|
||||||
|
|
||||||
return UltravoxAudioEmbeddingInputs(type="audio_embeds",
|
return UltravoxAudioEmbeddingInputs(type="audio_embeds",
|
||||||
data=audio_embeds)
|
data=audio_embeds)
|
||||||
|
|
||||||
raise AssertionError("This line should be unreachable.")
|
raise AssertionError("This line should be unreachable.")
|
||||||
|
|
||||||
def _process_audio_input(
|
def _process_audio_input(
|
||||||
self, audio_input: UltravoxAudioInputs
|
self, audio_input: UltravoxAudioInputs) -> NestedTensors:
|
||||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
|
||||||
if audio_input["type"] == "audio_embeds":
|
if audio_input["type"] == "audio_embeds":
|
||||||
return audio_input["data"]
|
return audio_input["data"]
|
||||||
|
|
||||||
audio_features = audio_input["data"]
|
audio_features = audio_input["data"]
|
||||||
if isinstance(audio_features, list):
|
if isinstance(audio_features, torch.Tensor):
|
||||||
# TODO: Batch these through the encoder/projector instead of
|
# Combine the B and N dimensions for the encoder/projector
|
||||||
# serializing them.
|
flattened = flatten_bn(audio_features)
|
||||||
return [
|
flattened_embeddings = self._audio_features_to_embeddings(
|
||||||
self._audio_features_to_embeddings(
|
flattened)
|
||||||
features.unsqueeze(0)).squeeze(0)
|
|
||||||
for features in audio_features
|
# Restore the original dimensions
|
||||||
]
|
embeddings = flattened_embeddings.unflatten(
|
||||||
else:
|
0, audio_features.shape[:2])
|
||||||
return self._audio_features_to_embeddings(audio_features)
|
return embeddings
|
||||||
|
|
||||||
|
result = []
|
||||||
|
# TODO: Batch heterogeneous tensors through the encoder/projector
|
||||||
|
for audio_features_item in audio_features:
|
||||||
|
if isinstance(audio_features_item, torch.Tensor):
|
||||||
|
result.append(
|
||||||
|
self._audio_features_to_embeddings(audio_features_item))
|
||||||
|
else:
|
||||||
|
embeddings = [
|
||||||
|
# Add a batch dimension to embed it, then remove it.
|
||||||
|
self._audio_features_to_embeddings(tensor.unsqueeze(0)
|
||||||
|
).squeeze(0)
|
||||||
|
for tensor in audio_features_item
|
||||||
|
]
|
||||||
|
result.append(embeddings)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
|
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
kv_caches: List[torch.Tensor],
|
||||||
@ -393,7 +415,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
|
|||||||
with the `input_ids`.
|
with the `input_ids`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_features: A batch of audio inputs, [1, 80, M].
|
audio_features: A batch of audio inputs [B, N, 80, M].
|
||||||
"""
|
"""
|
||||||
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
||||||
if audio_input is not None:
|
if audio_input is not None:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user