diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index a61b8ca8f7ae7..5c64c81547e65 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -47,7 +47,7 @@ from vllm.model_executor.models.qwen2_5_vl import ( Qwen2_5_VLProcessingInfo, Qwen2_5_VLVideoEmbeddingInputs, Qwen2_5_VLVideoInputs, Qwen2_5_VLVideoPixelInputs) from vllm.model_executor.models.qwen2_audio import ( - Qwen2AudioInputs, Qwen2AudioProcessingInfo, + Qwen2AudioFeatureInputs, Qwen2AudioProcessingInfo, _get_feat_extract_output_lengths) from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalDataParser from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -534,7 +534,7 @@ class Qwen2_5OmniConditionalGenerationMixin: return torch.concat(mm_input, dim=dim) def _parse_and_validate_audio_input( - self, **kwargs: object) -> Optional[Qwen2AudioInputs]: + self, **kwargs: object) -> Optional[Qwen2AudioFeatureInputs]: input_audio_features = kwargs.pop('input_audio_features', None) audio_feature_lengths = kwargs.pop('audio_feature_lengths', None) feature_attention_mask = kwargs.pop('feature_attention_mask', None) @@ -548,9 +548,10 @@ class Qwen2_5OmniConditionalGenerationMixin: if not isinstance(input_audio_features, (torch.Tensor, list)): raise ValueError("Incorrect type of audio input features. " f"Got type: {type(input_audio_features)}") - return Qwen2AudioInputs(input_features=input_audio_features, - audio_feature_lengths=audio_feature_lengths, - feature_attention_mask=feature_attention_mask) + return Qwen2AudioFeatureInputs( + input_features=input_audio_features, + audio_feature_lengths=audio_feature_lengths, + feature_attention_mask=feature_attention_mask) def _parse_and_validate_image_input( self, @@ -630,7 +631,7 @@ class Qwen2_5OmniConditionalGenerationMixin: def _process_audio_input( self, - audio_input: Qwen2AudioInputs, + audio_input: Qwen2AudioFeatureInputs, audio_hashes: list[str] = None, cached_audio_features: torch.Tensor = None, ) -> torch.Tensor: diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 86c567ca36174..86b4a9a018c76 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -23,7 +23,7 @@ # limitations under the License. """Inference-only Qwen2-Audio model compatible with HuggingFace weights.""" from collections.abc import Iterable, Mapping, Sequence -from typing import Any, Optional, TypedDict, Union +from typing import Any, Literal, Optional, TypedDict, Union import torch import torch.nn as nn @@ -36,9 +36,11 @@ from transformers.models.whisper import WhisperFeatureExtractor from vllm.config import VllmConfig from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, +from vllm.multimodal.inputs import (AudioItem, ModalityData, + MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems) -from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems, +from vllm.multimodal.parse import (AudioProcessorItems, DictEmbeddingItems, + ModalityDataItems, MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, @@ -52,7 +54,8 @@ from .utils import (AutoWeightsLoader, init_vllm_registered_model, # # === Audio Inputs === # -class Qwen2AudioInputs(TypedDict): +class Qwen2AudioFeatureInputs(TypedDict): + type: Literal["audio_features"] input_features: torch.Tensor """Shape: `(num_audios, num_mel_bins, 3000)`""" @@ -60,6 +63,16 @@ class Qwen2AudioInputs(TypedDict): """Shape: `(num_audios, 3000)`""" +class Qwen2AudioEmbeddingInputs(TypedDict): + type: Literal["audio_embeds"] + audio_embeds: list[torch.Tensor] + """Shape: `(num_audio_features, hidden_size)` + `hidden_size` must match the hidden size of language model backbone. + """ + + +Qwen2AudioInputs = Union[Qwen2AudioFeatureInputs, Qwen2AudioEmbeddingInputs] + # === Audio Encoder === # @@ -128,12 +141,38 @@ class Qwen2AudioDummyInputsBuilder( } +def _qwen2audio_field_config(hf_inputs: Mapping[str, torch.Tensor]): + return dict( + audio_embeds=MultiModalFieldConfig.batched("audio"), + input_features=MultiModalFieldConfig.batched("audio"), + feature_attention_mask=MultiModalFieldConfig.batched("audio"), + ) + + +class Qwen2AudioMultiModalDataParser(MultiModalDataParser): + + def _parse_audio_data( + self, + data: Union[dict[str, torch.Tensor], ModalityData[AudioItem]], + ) -> Optional[ModalityDataItems[Any, Any]]: + if isinstance(data, dict): + return DictEmbeddingItems( + data, + modality="audio", + required_fields={"audio_embeds"}, + fields_factory=_qwen2audio_field_config, + ) + + return super()._parse_audio_data(data) + + class Qwen2AudioMultiModalProcessor( BaseMultiModalProcessor[Qwen2AudioProcessingInfo]): def _get_data_parser(self) -> MultiModalDataParser: feature_extractor = self.info.get_feature_extractor() - return MultiModalDataParser(target_sr=feature_extractor.sampling_rate) + return Qwen2AudioMultiModalDataParser( + target_sr=feature_extractor.sampling_rate) def _call_hf_processor( self, @@ -173,10 +212,7 @@ class Qwen2AudioMultiModalProcessor( hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - return dict( - input_features=MultiModalFieldConfig.batched("audio"), - feature_attention_mask=MultiModalFieldConfig.batched("audio"), - ) + return _qwen2audio_field_config(hf_inputs) def _get_prompt_updates( self, @@ -184,6 +220,7 @@ class Qwen2AudioMultiModalProcessor( hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: + processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) tokenizer = self.info.get_tokenizer() vocab = tokenizer.get_vocab() @@ -211,7 +248,15 @@ class Qwen2AudioMultiModalProcessor( audio_output_lengths = audio_output_lens.tolist() def get_replacement_qwen2_audio(item_idx: int): - num_features = audio_output_lengths[item_idx] + + if audio_output_lengths: + num_features = audio_output_lengths[item_idx] + else: + audio_embeds = out_mm_data["audio_embeds"][item_idx] + assert len(audio_embeds.shape + ) == 2, "audio_embeds must be a 2D tensor" + num_features = audio_embeds.shape[0] + if num_features == 0: audios = mm_items.get_items("audio", AudioProcessorItems) audio_len = audios.get_audio_length(item_idx) @@ -286,21 +331,39 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, def _parse_and_validate_audio_input( self, **kwargs: object) -> Optional[Qwen2AudioInputs]: input_features = kwargs.pop('input_features', None) + audio_embeds = kwargs.pop('audio_embeds', None) feature_attention_mask = kwargs.pop('feature_attention_mask', None) - if input_features is None: - return None - input_features = self._validate_and_reshape_mm_tensor( - input_features, 'input_features') - feature_attention_mask = self._validate_and_reshape_mm_tensor( - feature_attention_mask, 'feature_attention_mask') - if not isinstance(input_features, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio input features. " - f"Got type: {type(input_features)}") - return Qwen2AudioInputs(input_features=input_features, - feature_attention_mask=feature_attention_mask) - def _process_audio_input(self, - audio_input: Qwen2AudioInputs) -> torch.Tensor: + if input_features is None and audio_embeds is None: + return None + + if audio_embeds is not None: + if not isinstance(audio_embeds, (torch.Tensor, list)): + raise ValueError("Incorrect type of audio embeds. " + f"Got type: {type(audio_embeds)}") + audio_embeds = self._validate_and_reshape_mm_tensor( + audio_embeds, "audio_embeds") + return Qwen2AudioEmbeddingInputs(type="audio_embeds", + audio_embeds=audio_embeds) + + if input_features is not None: + input_features = self._validate_and_reshape_mm_tensor( + input_features, 'input_features') + feature_attention_mask = self._validate_and_reshape_mm_tensor( + feature_attention_mask, 'feature_attention_mask') + return Qwen2AudioFeatureInputs( + type="audio_features", + input_features=input_features, + feature_attention_mask=feature_attention_mask) + + raise AssertionError("This line should be unreachable.") + + def _process_audio_input( + self, audio_input: Qwen2AudioInputs + ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: + if audio_input["type"] == "audio_embeds": + audio_embeds = audio_input["audio_embeds"] + return tuple(audio_embeds) input_features = audio_input["input_features"] feature_attention_mask = audio_input["feature_attention_mask"]