mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 13:05:01 +08:00
[V1] Support audio language models on V1 (#11733)
Signed-off-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
parent
869e829b85
commit
2de197bdd4
@ -710,7 +710,7 @@ See [this page](#generative-models) for more information on how to use generativ
|
|||||||
- `Qwen/Qwen2-Audio-7B-Instruct`
|
- `Qwen/Qwen2-Audio-7B-Instruct`
|
||||||
-
|
-
|
||||||
- ✅︎
|
- ✅︎
|
||||||
-
|
- ✅︎
|
||||||
* - `Qwen2VLForConditionalGeneration`
|
* - `Qwen2VLForConditionalGeneration`
|
||||||
- Qwen2-VL
|
- Qwen2-VL
|
||||||
- T + I<sup>E+</sup> + V<sup>E+</sup>
|
- T + I<sup>E+</sup> + V<sup>E+</sup>
|
||||||
@ -724,7 +724,7 @@ See [this page](#generative-models) for more information on how to use generativ
|
|||||||
- `fixie-ai/ultravox-v0_3`
|
- `fixie-ai/ultravox-v0_3`
|
||||||
-
|
-
|
||||||
- ✅︎
|
- ✅︎
|
||||||
-
|
- ✅︎
|
||||||
```
|
```
|
||||||
|
|
||||||
<sup>E</sup> Pre-computed embeddings can be inputted for this modality.
|
<sup>E</sup> Pre-computed embeddings can be inputted for this modality.
|
||||||
|
|||||||
@ -335,13 +335,16 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
selected_audio_feature = audio_outputs.last_hidden_state
|
selected_audio_feature = audio_outputs.last_hidden_state
|
||||||
audio_features = self.multi_modal_projector(selected_audio_feature)
|
audio_features = self.multi_modal_projector(selected_audio_feature)
|
||||||
num_audios, max_audio_tokens, embed_dim = audio_features.shape
|
num_audios, max_audio_tokens, embed_dim = audio_features.shape
|
||||||
|
audio_output_lengths = audio_output_lengths.unsqueeze(1)
|
||||||
audio_features_mask = torch.arange(max_audio_tokens).expand(
|
audio_features_mask = torch.arange(max_audio_tokens).expand(
|
||||||
num_audios, max_audio_tokens
|
num_audios, max_audio_tokens).to(
|
||||||
).to(audio_output_lengths.device) < audio_output_lengths.unsqueeze(1)
|
audio_output_lengths.device) < audio_output_lengths
|
||||||
masked_audio_features = audio_features[audio_features_mask].view(
|
masked_audio_features = audio_features[audio_features_mask].view(
|
||||||
-1, embed_dim)
|
-1, embed_dim)
|
||||||
|
|
||||||
return masked_audio_features
|
# Split to tuple of embeddings for individual audio input.
|
||||||
|
return torch.split(masked_audio_features,
|
||||||
|
audio_output_lengths.flatten().tolist())
|
||||||
|
|
||||||
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
||||||
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
|
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
|
||||||
"""PyTorch Ultravox model."""
|
"""PyTorch Ultravox model."""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import (Any, Iterable, List, Literal, Mapping, Optional, Set,
|
from typing import (Any, Iterable, List, Literal, Mapping, Optional, Set,
|
||||||
@ -14,6 +13,7 @@ from transformers import BatchFeature, ProcessorMixin
|
|||||||
from transformers.models.whisper import WhisperFeatureExtractor
|
from transformers.models.whisper import WhisperFeatureExtractor
|
||||||
from transformers.models.whisper.modeling_whisper import WhisperEncoder
|
from transformers.models.whisper.modeling_whisper import WhisperEncoder
|
||||||
|
|
||||||
|
from vllm import envs
|
||||||
from vllm.attention import AttentionMetadata
|
from vllm.attention import AttentionMetadata
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
|
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
|
||||||
@ -35,8 +35,11 @@ from vllm.transformers_utils.configs.ultravox import UltravoxConfig
|
|||||||
from .interfaces import SupportsMultiModal, SupportsPP
|
from .interfaces import SupportsMultiModal, SupportsPP
|
||||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||||
init_vllm_registered_model, maybe_prefix,
|
init_vllm_registered_model, maybe_prefix,
|
||||||
|
merge_multimodal_embeddings,
|
||||||
merge_multimodal_embeddings_from_map)
|
merge_multimodal_embeddings_from_map)
|
||||||
|
|
||||||
|
_AUDIO_PLACEHOLDER_OVERRIDE = "<|reserved_special_token_0|>"
|
||||||
|
_AUDIO_PLACEHOLDER_TOKEN = 128002
|
||||||
_AUDIO_TOKENS_PER_SECOND = 6.25
|
_AUDIO_TOKENS_PER_SECOND = 6.25
|
||||||
|
|
||||||
|
|
||||||
@ -64,7 +67,14 @@ class UltravoxProcessingMixin(ProcessingMixin):
|
|||||||
# Ignored in initialization
|
# Ignored in initialization
|
||||||
sampling_rate: Optional[int] = None,
|
sampling_rate: Optional[int] = None,
|
||||||
) -> ProcessorMixin:
|
) -> ProcessorMixin:
|
||||||
return self.ctx.get_hf_processor()
|
hf_processor = self.ctx.get_hf_processor()
|
||||||
|
|
||||||
|
# NOTE: Ultravox processing definition uses '<|eot_id|>' as the
|
||||||
|
# placeholder that will cause confusion with the actual end of turn
|
||||||
|
# token, thus we override placeholder with a reserved special
|
||||||
|
# token.
|
||||||
|
hf_processor.audio_token_replacement = _AUDIO_PLACEHOLDER_OVERRIDE
|
||||||
|
return hf_processor
|
||||||
|
|
||||||
def _get_feature_extractor(
|
def _get_feature_extractor(
|
||||||
self,
|
self,
|
||||||
@ -465,11 +475,15 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||||
if multimodal_embeddings is not None:
|
if multimodal_embeddings is not None:
|
||||||
|
|
||||||
# TODO(ywang96): use merge_multimodal_embeddings after
|
# TODO(ywang96): remove this block after v0 is deprecated.
|
||||||
# v0 is deprecated
|
if not envs.VLLM_USE_V1:
|
||||||
merge_multimodal_embeddings_from_map(
|
merge_multimodal_embeddings_from_map(
|
||||||
inputs_embeds, multimodal_embeddings,
|
inputs_embeds, multimodal_embeddings,
|
||||||
attn_metadata.multi_modal_placeholder_index_maps["audio"])
|
attn_metadata.multi_modal_placeholder_index_maps["audio"])
|
||||||
|
else:
|
||||||
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
|
input_ids, inputs_embeds, multimodal_embeddings,
|
||||||
|
_AUDIO_PLACEHOLDER_TOKEN)
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user