[model] support qwen2audio embedding input (#23625)

Signed-off-by: Yuekai Zhang <zhangyuekai@foxmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Yuekai Zhang 2025-08-26 23:48:08 +08:00 committed by GitHub
parent 513298f1b4
commit 9d4183dd2e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 93 additions and 29 deletions

View File

@ -47,7 +47,7 @@ from vllm.model_executor.models.qwen2_5_vl import (
Qwen2_5_VLProcessingInfo, Qwen2_5_VLVideoEmbeddingInputs,
Qwen2_5_VLVideoInputs, Qwen2_5_VLVideoPixelInputs)
from vllm.model_executor.models.qwen2_audio import (
Qwen2AudioInputs, Qwen2AudioProcessingInfo,
Qwen2AudioFeatureInputs, Qwen2AudioProcessingInfo,
_get_feat_extract_output_lengths)
from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalDataParser
from vllm.model_executor.sampling_metadata import SamplingMetadata
@ -534,7 +534,7 @@ class Qwen2_5OmniConditionalGenerationMixin:
return torch.concat(mm_input, dim=dim)
def _parse_and_validate_audio_input(
self, **kwargs: object) -> Optional[Qwen2AudioInputs]:
self, **kwargs: object) -> Optional[Qwen2AudioFeatureInputs]:
input_audio_features = kwargs.pop('input_audio_features', None)
audio_feature_lengths = kwargs.pop('audio_feature_lengths', None)
feature_attention_mask = kwargs.pop('feature_attention_mask', None)
@ -548,9 +548,10 @@ class Qwen2_5OmniConditionalGenerationMixin:
if not isinstance(input_audio_features, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio input features. "
f"Got type: {type(input_audio_features)}")
return Qwen2AudioInputs(input_features=input_audio_features,
audio_feature_lengths=audio_feature_lengths,
feature_attention_mask=feature_attention_mask)
return Qwen2AudioFeatureInputs(
input_features=input_audio_features,
audio_feature_lengths=audio_feature_lengths,
feature_attention_mask=feature_attention_mask)
def _parse_and_validate_image_input(
self,
@ -630,7 +631,7 @@ class Qwen2_5OmniConditionalGenerationMixin:
def _process_audio_input(
self,
audio_input: Qwen2AudioInputs,
audio_input: Qwen2AudioFeatureInputs,
audio_hashes: list[str] = None,
cached_audio_features: torch.Tensor = None,
) -> torch.Tensor:

View File

@ -23,7 +23,7 @@
# limitations under the License.
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Optional, TypedDict, Union
from typing import Any, Literal, Optional, TypedDict, Union
import torch
import torch.nn as nn
@ -36,9 +36,11 @@ from transformers.models.whisper import WhisperFeatureExtractor
from vllm.config import VllmConfig
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
from vllm.multimodal.inputs import (AudioItem, ModalityData,
MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems)
from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
from vllm.multimodal.parse import (AudioProcessorItems, DictEmbeddingItems,
ModalityDataItems, MultiModalDataItems,
MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
@ -52,7 +54,8 @@ from .utils import (AutoWeightsLoader, init_vllm_registered_model,
# # === Audio Inputs === #
class Qwen2AudioInputs(TypedDict):
class Qwen2AudioFeatureInputs(TypedDict):
type: Literal["audio_features"]
input_features: torch.Tensor
"""Shape: `(num_audios, num_mel_bins, 3000)`"""
@ -60,6 +63,16 @@ class Qwen2AudioInputs(TypedDict):
"""Shape: `(num_audios, 3000)`"""
class Qwen2AudioEmbeddingInputs(TypedDict):
type: Literal["audio_embeds"]
audio_embeds: list[torch.Tensor]
"""Shape: `(num_audio_features, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
Qwen2AudioInputs = Union[Qwen2AudioFeatureInputs, Qwen2AudioEmbeddingInputs]
# === Audio Encoder === #
@ -128,12 +141,38 @@ class Qwen2AudioDummyInputsBuilder(
}
def _qwen2audio_field_config(hf_inputs: Mapping[str, torch.Tensor]):
return dict(
audio_embeds=MultiModalFieldConfig.batched("audio"),
input_features=MultiModalFieldConfig.batched("audio"),
feature_attention_mask=MultiModalFieldConfig.batched("audio"),
)
class Qwen2AudioMultiModalDataParser(MultiModalDataParser):
def _parse_audio_data(
self,
data: Union[dict[str, torch.Tensor], ModalityData[AudioItem]],
) -> Optional[ModalityDataItems[Any, Any]]:
if isinstance(data, dict):
return DictEmbeddingItems(
data,
modality="audio",
required_fields={"audio_embeds"},
fields_factory=_qwen2audio_field_config,
)
return super()._parse_audio_data(data)
class Qwen2AudioMultiModalProcessor(
BaseMultiModalProcessor[Qwen2AudioProcessingInfo]):
def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self.info.get_feature_extractor()
return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
return Qwen2AudioMultiModalDataParser(
target_sr=feature_extractor.sampling_rate)
def _call_hf_processor(
self,
@ -173,10 +212,7 @@ class Qwen2AudioMultiModalProcessor(
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
input_features=MultiModalFieldConfig.batched("audio"),
feature_attention_mask=MultiModalFieldConfig.batched("audio"),
)
return _qwen2audio_field_config(hf_inputs)
def _get_prompt_updates(
self,
@ -184,6 +220,7 @@ class Qwen2AudioMultiModalProcessor(
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
@ -211,7 +248,15 @@ class Qwen2AudioMultiModalProcessor(
audio_output_lengths = audio_output_lens.tolist()
def get_replacement_qwen2_audio(item_idx: int):
num_features = audio_output_lengths[item_idx]
if audio_output_lengths:
num_features = audio_output_lengths[item_idx]
else:
audio_embeds = out_mm_data["audio_embeds"][item_idx]
assert len(audio_embeds.shape
) == 2, "audio_embeds must be a 2D tensor"
num_features = audio_embeds.shape[0]
if num_features == 0:
audios = mm_items.get_items("audio", AudioProcessorItems)
audio_len = audios.get_audio_length(item_idx)
@ -286,21 +331,39 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
def _parse_and_validate_audio_input(
self, **kwargs: object) -> Optional[Qwen2AudioInputs]:
input_features = kwargs.pop('input_features', None)
audio_embeds = kwargs.pop('audio_embeds', None)
feature_attention_mask = kwargs.pop('feature_attention_mask', None)
if input_features is None:
return None
input_features = self._validate_and_reshape_mm_tensor(
input_features, 'input_features')
feature_attention_mask = self._validate_and_reshape_mm_tensor(
feature_attention_mask, 'feature_attention_mask')
if not isinstance(input_features, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio input features. "
f"Got type: {type(input_features)}")
return Qwen2AudioInputs(input_features=input_features,
feature_attention_mask=feature_attention_mask)
def _process_audio_input(self,
audio_input: Qwen2AudioInputs) -> torch.Tensor:
if input_features is None and audio_embeds is None:
return None
if audio_embeds is not None:
if not isinstance(audio_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio embeds. "
f"Got type: {type(audio_embeds)}")
audio_embeds = self._validate_and_reshape_mm_tensor(
audio_embeds, "audio_embeds")
return Qwen2AudioEmbeddingInputs(type="audio_embeds",
audio_embeds=audio_embeds)
if input_features is not None:
input_features = self._validate_and_reshape_mm_tensor(
input_features, 'input_features')
feature_attention_mask = self._validate_and_reshape_mm_tensor(
feature_attention_mask, 'feature_attention_mask')
return Qwen2AudioFeatureInputs(
type="audio_features",
input_features=input_features,
feature_attention_mask=feature_attention_mask)
raise AssertionError("This line should be unreachable.")
def _process_audio_input(
self, audio_input: Qwen2AudioInputs
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
if audio_input["type"] == "audio_embeds":
audio_embeds = audio_input["audio_embeds"]
return tuple(audio_embeds)
input_features = audio_input["input_features"]
feature_attention_mask = audio_input["feature_attention_mask"]