[model] support qwen2audio embedding input (#23625)

Signed-off-by: Yuekai Zhang <zhangyuekai@foxmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Yuekai Zhang 2025-08-26 23:48:08 +08:00 committed by GitHub
parent 513298f1b4
commit 9d4183dd2e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 93 additions and 29 deletions

View File

@ -47,7 +47,7 @@ from vllm.model_executor.models.qwen2_5_vl import (
Qwen2_5_VLProcessingInfo, Qwen2_5_VLVideoEmbeddingInputs, Qwen2_5_VLProcessingInfo, Qwen2_5_VLVideoEmbeddingInputs,
Qwen2_5_VLVideoInputs, Qwen2_5_VLVideoPixelInputs) Qwen2_5_VLVideoInputs, Qwen2_5_VLVideoPixelInputs)
from vllm.model_executor.models.qwen2_audio import ( from vllm.model_executor.models.qwen2_audio import (
Qwen2AudioInputs, Qwen2AudioProcessingInfo, Qwen2AudioFeatureInputs, Qwen2AudioProcessingInfo,
_get_feat_extract_output_lengths) _get_feat_extract_output_lengths)
from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalDataParser from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalDataParser
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
@ -534,7 +534,7 @@ class Qwen2_5OmniConditionalGenerationMixin:
return torch.concat(mm_input, dim=dim) return torch.concat(mm_input, dim=dim)
def _parse_and_validate_audio_input( def _parse_and_validate_audio_input(
self, **kwargs: object) -> Optional[Qwen2AudioInputs]: self, **kwargs: object) -> Optional[Qwen2AudioFeatureInputs]:
input_audio_features = kwargs.pop('input_audio_features', None) input_audio_features = kwargs.pop('input_audio_features', None)
audio_feature_lengths = kwargs.pop('audio_feature_lengths', None) audio_feature_lengths = kwargs.pop('audio_feature_lengths', None)
feature_attention_mask = kwargs.pop('feature_attention_mask', None) feature_attention_mask = kwargs.pop('feature_attention_mask', None)
@ -548,9 +548,10 @@ class Qwen2_5OmniConditionalGenerationMixin:
if not isinstance(input_audio_features, (torch.Tensor, list)): if not isinstance(input_audio_features, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio input features. " raise ValueError("Incorrect type of audio input features. "
f"Got type: {type(input_audio_features)}") f"Got type: {type(input_audio_features)}")
return Qwen2AudioInputs(input_features=input_audio_features, return Qwen2AudioFeatureInputs(
audio_feature_lengths=audio_feature_lengths, input_features=input_audio_features,
feature_attention_mask=feature_attention_mask) audio_feature_lengths=audio_feature_lengths,
feature_attention_mask=feature_attention_mask)
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, self,
@ -630,7 +631,7 @@ class Qwen2_5OmniConditionalGenerationMixin:
def _process_audio_input( def _process_audio_input(
self, self,
audio_input: Qwen2AudioInputs, audio_input: Qwen2AudioFeatureInputs,
audio_hashes: list[str] = None, audio_hashes: list[str] = None,
cached_audio_features: torch.Tensor = None, cached_audio_features: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:

View File

@ -23,7 +23,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights.""" """Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Optional, TypedDict, Union from typing import Any, Literal, Optional, TypedDict, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -36,9 +36,11 @@ from transformers.models.whisper import WhisperFeatureExtractor
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (AudioItem, ModalityData,
MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems) MultiModalKwargsItems)
from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems, from vllm.multimodal.parse import (AudioProcessorItems, DictEmbeddingItems,
ModalityDataItems, MultiModalDataItems,
MultiModalDataParser) MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement, BaseProcessingInfo, PromptReplacement,
@ -52,7 +54,8 @@ from .utils import (AutoWeightsLoader, init_vllm_registered_model,
# # === Audio Inputs === # # # === Audio Inputs === #
class Qwen2AudioInputs(TypedDict): class Qwen2AudioFeatureInputs(TypedDict):
type: Literal["audio_features"]
input_features: torch.Tensor input_features: torch.Tensor
"""Shape: `(num_audios, num_mel_bins, 3000)`""" """Shape: `(num_audios, num_mel_bins, 3000)`"""
@ -60,6 +63,16 @@ class Qwen2AudioInputs(TypedDict):
"""Shape: `(num_audios, 3000)`""" """Shape: `(num_audios, 3000)`"""
class Qwen2AudioEmbeddingInputs(TypedDict):
type: Literal["audio_embeds"]
audio_embeds: list[torch.Tensor]
"""Shape: `(num_audio_features, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
Qwen2AudioInputs = Union[Qwen2AudioFeatureInputs, Qwen2AudioEmbeddingInputs]
# === Audio Encoder === # # === Audio Encoder === #
@ -128,12 +141,38 @@ class Qwen2AudioDummyInputsBuilder(
} }
def _qwen2audio_field_config(hf_inputs: Mapping[str, torch.Tensor]):
return dict(
audio_embeds=MultiModalFieldConfig.batched("audio"),
input_features=MultiModalFieldConfig.batched("audio"),
feature_attention_mask=MultiModalFieldConfig.batched("audio"),
)
class Qwen2AudioMultiModalDataParser(MultiModalDataParser):
def _parse_audio_data(
self,
data: Union[dict[str, torch.Tensor], ModalityData[AudioItem]],
) -> Optional[ModalityDataItems[Any, Any]]:
if isinstance(data, dict):
return DictEmbeddingItems(
data,
modality="audio",
required_fields={"audio_embeds"},
fields_factory=_qwen2audio_field_config,
)
return super()._parse_audio_data(data)
class Qwen2AudioMultiModalProcessor( class Qwen2AudioMultiModalProcessor(
BaseMultiModalProcessor[Qwen2AudioProcessingInfo]): BaseMultiModalProcessor[Qwen2AudioProcessingInfo]):
def _get_data_parser(self) -> MultiModalDataParser: def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self.info.get_feature_extractor() feature_extractor = self.info.get_feature_extractor()
return MultiModalDataParser(target_sr=feature_extractor.sampling_rate) return Qwen2AudioMultiModalDataParser(
target_sr=feature_extractor.sampling_rate)
def _call_hf_processor( def _call_hf_processor(
self, self,
@ -173,10 +212,7 @@ class Qwen2AudioMultiModalProcessor(
hf_inputs: BatchFeature, hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
return dict( return _qwen2audio_field_config(hf_inputs)
input_features=MultiModalFieldConfig.batched("audio"),
feature_attention_mask=MultiModalFieldConfig.batched("audio"),
)
def _get_prompt_updates( def _get_prompt_updates(
self, self,
@ -184,6 +220,7 @@ class Qwen2AudioMultiModalProcessor(
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems, out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]: ) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab() vocab = tokenizer.get_vocab()
@ -211,7 +248,15 @@ class Qwen2AudioMultiModalProcessor(
audio_output_lengths = audio_output_lens.tolist() audio_output_lengths = audio_output_lens.tolist()
def get_replacement_qwen2_audio(item_idx: int): def get_replacement_qwen2_audio(item_idx: int):
num_features = audio_output_lengths[item_idx]
if audio_output_lengths:
num_features = audio_output_lengths[item_idx]
else:
audio_embeds = out_mm_data["audio_embeds"][item_idx]
assert len(audio_embeds.shape
) == 2, "audio_embeds must be a 2D tensor"
num_features = audio_embeds.shape[0]
if num_features == 0: if num_features == 0:
audios = mm_items.get_items("audio", AudioProcessorItems) audios = mm_items.get_items("audio", AudioProcessorItems)
audio_len = audios.get_audio_length(item_idx) audio_len = audios.get_audio_length(item_idx)
@ -286,21 +331,39 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
def _parse_and_validate_audio_input( def _parse_and_validate_audio_input(
self, **kwargs: object) -> Optional[Qwen2AudioInputs]: self, **kwargs: object) -> Optional[Qwen2AudioInputs]:
input_features = kwargs.pop('input_features', None) input_features = kwargs.pop('input_features', None)
audio_embeds = kwargs.pop('audio_embeds', None)
feature_attention_mask = kwargs.pop('feature_attention_mask', None) feature_attention_mask = kwargs.pop('feature_attention_mask', None)
if input_features is None:
return None
input_features = self._validate_and_reshape_mm_tensor(
input_features, 'input_features')
feature_attention_mask = self._validate_and_reshape_mm_tensor(
feature_attention_mask, 'feature_attention_mask')
if not isinstance(input_features, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio input features. "
f"Got type: {type(input_features)}")
return Qwen2AudioInputs(input_features=input_features,
feature_attention_mask=feature_attention_mask)
def _process_audio_input(self, if input_features is None and audio_embeds is None:
audio_input: Qwen2AudioInputs) -> torch.Tensor: return None
if audio_embeds is not None:
if not isinstance(audio_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio embeds. "
f"Got type: {type(audio_embeds)}")
audio_embeds = self._validate_and_reshape_mm_tensor(
audio_embeds, "audio_embeds")
return Qwen2AudioEmbeddingInputs(type="audio_embeds",
audio_embeds=audio_embeds)
if input_features is not None:
input_features = self._validate_and_reshape_mm_tensor(
input_features, 'input_features')
feature_attention_mask = self._validate_and_reshape_mm_tensor(
feature_attention_mask, 'feature_attention_mask')
return Qwen2AudioFeatureInputs(
type="audio_features",
input_features=input_features,
feature_attention_mask=feature_attention_mask)
raise AssertionError("This line should be unreachable.")
def _process_audio_input(
self, audio_input: Qwen2AudioInputs
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
if audio_input["type"] == "audio_embeds":
audio_embeds = audio_input["audio_embeds"]
return tuple(audio_embeds)
input_features = audio_input["input_features"] input_features = audio_input["input_features"]
feature_attention_mask = audio_input["feature_attention_mask"] feature_attention_mask = audio_input["feature_attention_mask"]