mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 09:45:34 +08:00
[Bugfix] Fix qwen3-omni audio truncation issue (#26815)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
7cfa420f49
commit
8c851f6d04
@ -30,7 +30,9 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from packaging.version import Version
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
from transformers import __version__ as TRANSFORMERS_VERSION
|
||||||
from transformers.feature_extraction_utils import BatchFeature
|
from transformers.feature_extraction_utils import BatchFeature
|
||||||
from transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe import (
|
from transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe import (
|
||||||
Qwen3OmniMoeConfig,
|
Qwen3OmniMoeConfig,
|
||||||
@ -711,11 +713,12 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
# NOTE: WhisperFeatureExtractor cannot handle empty list of audios
|
# NOTE: WhisperFeatureExtractor cannot handle empty list of audios
|
||||||
|
feature_extractor = self.info.get_feature_extractor()
|
||||||
|
hop_length = feature_extractor.hop_length
|
||||||
if audios:
|
if audios:
|
||||||
# NOTE: Qwen3-Omni processor accept "audio"
|
# NOTE: Qwen3-Omni processor accept "audio"
|
||||||
# To make sure the cache works with padding=True, we pre-padded
|
# To make sure the cache works with padding=True, we pre-padded
|
||||||
# the audio to multiple of hop_length.
|
# the audio to multiple of hop_length.
|
||||||
hop_length = self.info.get_feature_extractor().hop_length
|
|
||||||
mm_data["audio"] = [
|
mm_data["audio"] = [
|
||||||
pad_to_hop_length(audio, hop_length)
|
pad_to_hop_length(audio, hop_length)
|
||||||
if isinstance(audio, np.ndarray)
|
if isinstance(audio, np.ndarray)
|
||||||
@ -725,6 +728,14 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
|
|||||||
mm_kwargs = dict(
|
mm_kwargs = dict(
|
||||||
**mm_kwargs,
|
**mm_kwargs,
|
||||||
)
|
)
|
||||||
|
# TODO(Isotr0py): Remove this patch after upstream fix PR
|
||||||
|
# released and Transformers version update:
|
||||||
|
# https://github.com/huggingface/transformers/pull/41473
|
||||||
|
if (
|
||||||
|
Version(TRANSFORMERS_VERSION) < Version("4.58.0")
|
||||||
|
and "truncation" not in mm_kwargs
|
||||||
|
):
|
||||||
|
mm_kwargs["truncation"] = False
|
||||||
|
|
||||||
hf_inputs = super()._call_hf_processor(
|
hf_inputs = super()._call_hf_processor(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@ -738,7 +749,6 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
|
|||||||
and "feature_attention_mask" in hf_inputs
|
and "feature_attention_mask" in hf_inputs
|
||||||
and (audios := mm_data.get("audio", []))
|
and (audios := mm_data.get("audio", []))
|
||||||
):
|
):
|
||||||
hop_length = self.info.get_feature_extractor().hop_length
|
|
||||||
audio_num_frames = []
|
audio_num_frames = []
|
||||||
for _, audio in enumerate(audios):
|
for _, audio in enumerate(audios):
|
||||||
audio_length = len(audio[0]) if isinstance(audio, tuple) else len(audio)
|
audio_length = len(audio[0]) if isinstance(audio, tuple) else len(audio)
|
||||||
@ -747,6 +757,10 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
|
|||||||
if audio_length % hop_length == 0
|
if audio_length % hop_length == 0
|
||||||
else (audio_length // hop_length - 1)
|
else (audio_length // hop_length - 1)
|
||||||
)
|
)
|
||||||
|
if mm_kwargs.get("truncation", False):
|
||||||
|
num_frame = min(
|
||||||
|
num_frame, feature_extractor.n_samples // hop_length
|
||||||
|
)
|
||||||
audio_num_frames.append(num_frame)
|
audio_num_frames.append(num_frame)
|
||||||
hf_inputs["feature_attention_mask"] = [
|
hf_inputs["feature_attention_mask"] = [
|
||||||
torch.ones(num_frame) for num_frame in audio_num_frames
|
torch.ones(num_frame) for num_frame in audio_num_frames
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user