[Multimodal][Qwen3 Omni] Make Qwen3 Omni work with audio-in-video inputs in V1 engine. (#27721)

Signed-off-by: Chenheli Hua <huachenheli@outlook.com>
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Chenheli Hua 2025-11-24 11:24:37 -08:00 committed by GitHub
parent 8f066146c3
commit 839c6b7b72
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 467 additions and 59 deletions

View File

@ -0,0 +1,170 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This example shows how to use vLLM for running offline inference
with the correct prompt format on Qwen2.5-Omni (thinker only).
"""
from typing import NamedTuple
from vllm import LLM, SamplingParams
from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.multimodal.image import convert_image_mode
from vllm.utils.argparse_utils import FlexibleArgumentParser
class QueryResult(NamedTuple):
inputs: dict
limit_mm_per_prompt: dict[str, int]
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
# lower-end GPUs.
# Unless specified, these settings have been tested to work on a single L4.
default_system = (
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
"Group, capable of perceiving auditory and visual inputs, as well as "
"generating text and speech."
)
def get_mixed_modalities_query() -> QueryResult:
question = (
"What is recited in the audio? "
"What is the content of this image? Why is this video funny?"
)
prompt = (
f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n<|audio_start|><|audio_pad|><|audio_end|>"
"<|vision_start|><|image_pad|><|vision_end|>"
"<|vision_start|><|video_pad|><|vision_end|>"
f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
return QueryResult(
inputs={
"prompt": prompt,
"multi_modal_data": {
"audio": AudioAsset("mary_had_lamb").audio_and_sample_rate,
"image": convert_image_mode(
ImageAsset("cherry_blossom").pil_image, "RGB"
),
"video": VideoAsset(name="baby_reading", num_frames=16).np_ndarrays,
},
},
limit_mm_per_prompt={"audio": 1, "image": 1, "video": 1},
)
def get_use_audio_in_video_query() -> QueryResult:
question = (
"Describe the content of the video in details, then convert what the "
"baby say into text."
)
prompt = (
f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n<|vision_start|><|video_pad|><|vision_end|>"
f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
asset = VideoAsset(name="baby_reading", num_frames=16)
audio = asset.get_audio(sampling_rate=16000)
return QueryResult(
inputs={
"prompt": prompt,
"multi_modal_data": {
"video": asset.np_ndarrays,
"audio": audio,
},
"mm_processor_kwargs": {
"use_audio_in_video": True,
},
},
limit_mm_per_prompt={"audio": 1, "video": 1},
)
def get_multi_audios_query() -> QueryResult:
question = "Are these two audio clips the same?"
prompt = (
f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n<|audio_start|><|audio_pad|><|audio_end|>"
"<|audio_start|><|audio_pad|><|audio_end|>"
f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
return QueryResult(
inputs={
"prompt": prompt,
"multi_modal_data": {
"audio": [
AudioAsset("winning_call").audio_and_sample_rate,
AudioAsset("mary_had_lamb").audio_and_sample_rate,
],
},
},
limit_mm_per_prompt={
"audio": 2,
},
)
query_map = {
"mixed_modalities": get_mixed_modalities_query,
"use_audio_in_video": get_use_audio_in_video_query,
"multi_audios": get_multi_audios_query,
}
def main(args):
model_name = "Qwen/Qwen3-Omni-30B-A3B-Instruct"
query_result = query_map[args.query_type]()
llm = LLM(
model=model_name,
max_model_len=12800,
max_num_seqs=5,
limit_mm_per_prompt=query_result.limit_mm_per_prompt,
seed=args.seed,
)
# We set temperature to 0.2 so that outputs can be different
# even when all prompts are identical when running batch inference.
sampling_params = SamplingParams(temperature=0.2, max_tokens=256)
outputs = llm.generate(query_result.inputs, sampling_params=sampling_params)
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
def parse_args():
parser = FlexibleArgumentParser(
description="Demo on using vLLM for offline inference with "
"audio language models"
)
parser.add_argument(
"--query-type",
"-q",
type=str,
default="mixed_modalities",
choices=query_map.keys(),
help="Query type.",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="Set the seed when initializing `vllm.LLM`.",
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
main(args)

View File

@ -0,0 +1,221 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import Mock
import pytest
from transformers import PretrainedConfig
from vllm.multimodal.processing import InputProcessingContext
# Helper function to print input IDs with coalesced audio/video tokens.
def print_input_ids(input_ids):
"""
Print input IDs, compressing consecutive special tokens.
- 151675: <|audio_pad|>
- 151656: <|video_pad|>
"""
if not input_ids:
print("[]")
return
result = []
i = 0
while i < len(input_ids):
current_id = input_ids[i]
# Check if it's a special token that should be compressed
if current_id in [151675, 151656]:
# Count consecutive occurrences
count = 1
while i + count < len(input_ids) and input_ids[i + count] == current_id:
count += 1
# Add compressed representation
token_name = "<|audio_pad|>" if current_id == 151675 else "<|video_pad|>"
result.append(f"{token_name} * {count}")
i += count
else:
# Regular token, just add it
result.append(str(current_id))
i += 1
print(", ".join(result))
@pytest.fixture
def mock_qwen3_omni_config():
"""Create a mock Qwen3OmniMoeThinker config."""
config = Mock(spec=PretrainedConfig)
# Token IDs from https://huggingface.co/Qwen/Qwen3-Omni-30B-A3B-Instruct/blob/main/tokenizer_config.json
config.audio_token_id = 151675 # <|audio_pad|>
config.video_token_id = 151656 # <|video_pad|>
config.image_token_id = 151655 # <|image_pad|>
config.audio_start_token_id = 151669 # <|audio_start|>
config.audio_end_token_id = 151670 # <|audio_end|>
config.vision_start_token_id = 151652 # <|vision_start|>
config.position_id_per_seconds = 12.5
# Vision config
vision_config = Mock()
vision_config.spatial_merge_size = 2
config.vision_config = vision_config
return config
@pytest.fixture
def mock_processor():
"""Create a mock HF processor."""
from transformers.models.whisper import WhisperFeatureExtractor
processor = Mock()
processor.audio_token = "<|audio_pad|>"
processor.image_token = "<|image_pad|>"
processor.video_token = "<|video_pad|>"
# Create a real WhisperFeatureExtractor instance for the feature_extractor attribute
feature_extractor = WhisperFeatureExtractor()
processor.feature_extractor = feature_extractor
return processor
@pytest.fixture
def mock_tokenizer():
"""Create a mock tokenizer."""
tokenizer = Mock()
# Token IDs from https://huggingface.co/Qwen/Qwen3-Omni-30B-A3B-Instruct/blob/main/tokenizer_config.json
tokenizer.get_vocab = Mock(
return_value={
"<|audio_pad|>": 151675,
"<|video_pad|>": 151656,
"<|image_pad|>": 151655,
"<|audio_start|>": 151669,
"<|audio_end|>": 151670,
"<|vision_start|>": 151652,
"<|vision_end|>": 151653,
}
)
tokenizer.encode = Mock(
side_effect=lambda x: {
"<|vision_start|>": [151652],
"<|vision_end|>": [151653],
"<|audio_start|>": [151669],
"<|audio_end|>": [151670],
"<|audio_pad|>": [151675],
"<|image_pad|>": [151655],
"<|video_pad|>": [151656],
}.get(x, [0])
)
tokenizer.vision_bos_token = "<|vision_start|>"
tokenizer.vision_eos_token = "<|vision_end|>"
tokenizer.audio_bos_token = "<|audio_start|>"
tokenizer.audio_eos_token = "<|audio_end|>"
return tokenizer
@pytest.fixture
def mock_image_processor():
"""Create a mock image processor."""
image_processor = Mock()
image_processor.merge_size = 2
return image_processor
def test_qwen3_omni_get_updates_use_audio_in_video(
mock_qwen3_omni_config,
mock_processor,
mock_tokenizer,
mock_image_processor,
):
"""Test the get_updates_use_audio_in_video method directly."""
from vllm.model_executor.models.qwen3_omni_moe_thinker import (
Qwen3OmniMoeThinkerMultiModalProcessor,
Qwen3OmniMoeThinkerProcessingInfo,
)
# Create a mock context
mock_ctx = Mock(spec=InputProcessingContext)
# Create processing info
info = Qwen3OmniMoeThinkerProcessingInfo(mock_ctx)
info.get_hf_config = Mock(return_value=mock_qwen3_omni_config)
info.get_hf_processor = Mock(return_value=mock_processor)
info.get_tokenizer = Mock(return_value=mock_tokenizer)
info.get_image_processor = Mock(return_value=mock_image_processor)
# Create a mock dummy_inputs builder
mock_dummy_inputs = Mock()
# Create the processor
processor = Qwen3OmniMoeThinkerMultiModalProcessor(info, mock_dummy_inputs)
# Test parameters from reference video
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-Omni/demo/draw.mp4
audio_len = 85
video_grid_thw = [6, 36, 64]
video_second_per_grid_t = 2.0
# Call the method
updates = processor.get_updates_use_audio_in_video(
thinker_config=mock_qwen3_omni_config,
audio_len=audio_len,
video_grid_thw=video_grid_thw,
video_second_per_grid_t=video_second_per_grid_t,
)
# Updated input ids should align with HF implementation.
# 151669,
# <|video_pad|> * 576, <|audio_pad|> * 25,
# <|video_pad|> * 576, <|audio_pad|> * 25,
# <|video_pad|> * 576, <|audio_pad|> * 25,
# <|video_pad|> * 576, <|audio_pad|> * 10,
# <|video_pad|> * 1152,
# 151670
print_input_ids(updates)
# Verify structure
assert isinstance(updates, list)
assert len(updates) > 0
# Verify start and end tokens
audio_start_token_id = mock_qwen3_omni_config.audio_start_token_id
audio_end_token_id = mock_qwen3_omni_config.audio_end_token_id
assert updates[0] == audio_start_token_id
assert updates[-1] == audio_end_token_id
# Verify both audio and video tokens are present
audio_token_id = mock_qwen3_omni_config.audio_token_id
video_token_id = mock_qwen3_omni_config.video_token_id
audio_count = updates.count(audio_token_id)
video_count = updates.count(video_token_id)
assert audio_count == audio_len, (
f"Expected {audio_len} audio tokens, got {audio_count}"
)
# Calculate expected video token count
spatial_merge_size = mock_qwen3_omni_config.vision_config.spatial_merge_size
height = video_grid_thw[1] // spatial_merge_size
width = video_grid_thw[2] // spatial_merge_size
expected_video_count = video_grid_thw[0] * height * width
assert video_count == expected_video_count, (
f"Expected {expected_video_count} video tokens, got {video_count}"
)
# Total tokens should be: 1 (start) + audio_len + video_count + 1 (end)
expected_total = 1 + audio_len + expected_video_count + 1
assert len(updates) == expected_total, (
f"Expected {expected_total} total tokens, got {len(updates)}"
)
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@ -23,7 +23,6 @@
"""Inference-only Qwen2.5-Omni model (thinker part)."""
from collections.abc import Callable, Iterable, Mapping, Sequence
from copy import copy
from functools import partial
from typing import Annotated, Any, Literal
@ -387,15 +386,6 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
self._validate_mm_kwargs(mm_kwargs, mm_item_counts)
self._validate_mm_updates(mm_prompt_updates, mm_item_counts)
use_audio_in_video = False
if "video" in mm_kwargs:
video_items = [item for item in mm_kwargs["video"] if item is not None]
# only check video items (if there are any)
if video_items:
use_audio_in_video = all(
item["use_audio_in_video"].data for item in video_items
)
if is_update_applied:
mm_placeholders = self._find_mm_placeholders(
prompt_ids,
@ -404,7 +394,6 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
self._validate_mm_placeholders(
mm_placeholders,
mm_item_counts,
use_audio_in_video=use_audio_in_video,
)
else:
prompt_ids, mm_placeholders = self._apply_prompt_updates(
@ -414,7 +403,6 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
self._validate_mm_placeholders(
mm_placeholders,
mm_item_counts,
use_audio_in_video=use_audio_in_video,
)
return prompt_ids, mm_placeholders
@ -640,19 +628,6 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
return mm_processed_data
def _validate_mm_placeholders(
self,
mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
mm_item_counts: Mapping[str, int],
use_audio_in_video: bool = False,
) -> None:
if use_audio_in_video:
mm_item_counts = copy(mm_item_counts)
if "video" in mm_item_counts:
assert "audio" in mm_item_counts
mm_item_counts["audio"] -= mm_item_counts["video"]
super()._validate_mm_placeholders(mm_placeholders, mm_item_counts)
class Qwen2_5OmniConditionalGenerationMixin:
def _parse_and_validate_audio_input(

View File

@ -68,11 +68,11 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItems
from vllm.multimodal.parse import AudioProcessorItems, MultiModalDataItems
from vllm.multimodal.processing import (
BaseMultiModalProcessor,
MultiModalPromptUpdates,
PlaceholderFeaturesInfo,
PromptReplacement,
PromptUpdate,
PromptUpdateDetails,
)
from vllm.sequence import IntermediateTensors
@ -87,7 +87,6 @@ from .qwen2_5_omni_thinker import (
Qwen2_5OmniConditionalGenerationMixin,
Qwen2_5OmniThinkerDummyInputsBuilder,
Qwen2_5OmniThinkerMultiModalProcessor,
Qwen2_5OmniThinkerProcessingInfo,
)
from .qwen2_5_vl import (
Qwen2_5_VisionAttention,
@ -807,24 +806,8 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
else:
use_audio_in_video = False
if use_audio_in_video and "video" in mm_item_counts:
assert "audio" in mm_item_counts
mm_item_counts["audio"] -= mm_item_counts["video"]
# Special case with `use_audio_in_video=True`
if use_audio_in_video:
if is_update_applied:
prompt_ids = self._get_raw_input_ids(prompt_ids, use_audio_in_video)
(
prompt_ids,
mm_placeholders,
) = self._apply_prompt_updates(
prompt_ids,
mm_prompt_updates,
)
self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
# normal case with `use_audio_in_video=False`
elif is_update_applied:
if is_update_applied:
mm_placeholders = self._find_mm_placeholders(
prompt_ids,
mm_prompt_updates,
@ -834,10 +817,24 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
mm_item_counts,
)
else:
prompt_ids, mm_placeholders = self._apply_prompt_updates(
prompt_ids,
mm_prompt_updates,
)
if use_audio_in_video and "audio" in mm_prompt_updates:
filtered_updates = {
k: v for k, v in mm_prompt_updates.items() if k != "audio"
}
prompt_ids, mm_placeholders = self._apply_prompt_updates(
prompt_ids,
filtered_updates,
)
# Derive audio placeholders from video placeholders
mm_placeholders = self._derive_audio_from_video_placeholders(
mm_placeholders, mm_prompt_updates
)
else:
prompt_ids, mm_placeholders = self._apply_prompt_updates(
prompt_ids,
mm_prompt_updates,
)
self._validate_mm_placeholders(
mm_placeholders,
mm_item_counts,
@ -962,7 +959,9 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
def get_replacement_qwen2_use_audio_in_video(item_idx: int):
nonlocal audio_in_video_item_idx
audio_num_features = audio_output_lengths[audio_item_idx + item_idx]
audio_num_features = audio_output_lengths[
audio_in_video_item_idx + item_idx
]
video_grid_thw = out_mm_data["video_grid_thw"][item_idx]
audio_in_video_item_idx += 1
@ -971,14 +970,17 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
if second_per_grid_ts:
video_second_per_grid_t = second_per_grid_ts[item_idx]
else:
video_second_per_grid_t = 1.0
video_second_per_grid_t = 2.0
return self.get_updates_use_audio_in_video(
placeholder = self.get_updates_use_audio_in_video(
thinker_config=thinker_config,
audio_len=audio_num_features,
video_grid_thw=video_grid_thw,
video_second_per_grid_t=video_second_per_grid_t,
)
return PromptUpdateDetails.select_token_id(
placeholder, embed_token_id=video_token_id
)
video_replacement_fn = (
get_replacement_qwen2_use_audio_in_video
@ -1004,14 +1006,50 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
),
]
def _validate_mm_placeholders(
def _derive_audio_from_video_placeholders(
self,
mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
mm_item_counts: Mapping[str, int],
) -> None:
BaseMultiModalProcessor[
Qwen2_5OmniThinkerProcessingInfo
]._validate_mm_placeholders(self, mm_placeholders, mm_item_counts)
placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
mm_prompt_updates: MultiModalPromptUpdates,
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
"""
Helper to derive audio placeholders from video placeholders when
use_audio_in_video=True.
"""
if "video" not in placeholders:
return placeholders
# Validate audio and video counts match
num_videos = len(placeholders["video"])
num_audios = len(mm_prompt_updates.get("audio", []))
if num_audios != num_videos:
raise ValueError(
f"use_audio_in_video requires equal number of audio and video items, "
f"got {num_audios=}, {num_videos=}"
)
tokenizer = self.info.get_tokenizer()
processor = self.info.get_hf_processor()
audio_token_id = tokenizer.get_vocab()[processor.audio_token]
result_placeholders = dict(placeholders)
audio_placeholders = []
# Each video is paired with one audio
for video_idx, video_placeholder in enumerate(placeholders["video"]):
# Create is_embed mask selecting only audio tokens
audio_is_embed = torch.tensor(video_placeholder.tokens) == audio_token_id
audio_placeholder = PlaceholderFeaturesInfo(
modality="audio",
item_idx=video_idx,
start_idx=video_placeholder.start_idx,
tokens=video_placeholder.tokens,
is_embed=audio_is_embed,
)
audio_placeholders.append(audio_placeholder)
result_placeholders["audio"] = audio_placeholders
return result_placeholders
def _get_raw_input_ids(
self,
@ -1454,7 +1492,11 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
)
if not len(second_per_grid_ts) and len(video_grid_thw):
second_per_grids = torch.ones(len(video_grid_thw), dtype=torch.float32)
second_per_grid_ts = 2.0
second_per_grids = (
torch.ones(len(video_grid_thw), dtype=torch.float32)
* second_per_grid_ts
)
else:
second_per_grids = torch.tensor(second_per_grid_ts, dtype=torch.float32)