mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 15:54:31 +08:00
[Multimodal][Qwen3 Omni] Make Qwen3 Omni work with audio-in-video inputs in V1 engine. (#27721)
Signed-off-by: Chenheli Hua <huachenheli@outlook.com> Signed-off-by: Roger Wang <hey@rogerw.io> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
parent
8f066146c3
commit
839c6b7b72
170
examples/offline_inference/qwen3_omni/only_thinker.py
Normal file
170
examples/offline_inference/qwen3_omni/only_thinker.py
Normal file
@ -0,0 +1,170 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This example shows how to use vLLM for running offline inference
|
||||
with the correct prompt format on Qwen2.5-Omni (thinker only).
|
||||
"""
|
||||
|
||||
from typing import NamedTuple
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.assets.audio import AudioAsset
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.assets.video import VideoAsset
|
||||
from vllm.multimodal.image import convert_image_mode
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
class QueryResult(NamedTuple):
|
||||
inputs: dict
|
||||
limit_mm_per_prompt: dict[str, int]
|
||||
|
||||
|
||||
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
|
||||
# lower-end GPUs.
|
||||
# Unless specified, these settings have been tested to work on a single L4.
|
||||
|
||||
default_system = (
|
||||
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
|
||||
"Group, capable of perceiving auditory and visual inputs, as well as "
|
||||
"generating text and speech."
|
||||
)
|
||||
|
||||
|
||||
def get_mixed_modalities_query() -> QueryResult:
|
||||
question = (
|
||||
"What is recited in the audio? "
|
||||
"What is the content of this image? Why is this video funny?"
|
||||
)
|
||||
prompt = (
|
||||
f"<|im_start|>system\n{default_system}<|im_end|>\n"
|
||||
"<|im_start|>user\n<|audio_start|><|audio_pad|><|audio_end|>"
|
||||
"<|vision_start|><|image_pad|><|vision_end|>"
|
||||
"<|vision_start|><|video_pad|><|vision_end|>"
|
||||
f"{question}<|im_end|>\n"
|
||||
f"<|im_start|>assistant\n"
|
||||
)
|
||||
return QueryResult(
|
||||
inputs={
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": {
|
||||
"audio": AudioAsset("mary_had_lamb").audio_and_sample_rate,
|
||||
"image": convert_image_mode(
|
||||
ImageAsset("cherry_blossom").pil_image, "RGB"
|
||||
),
|
||||
"video": VideoAsset(name="baby_reading", num_frames=16).np_ndarrays,
|
||||
},
|
||||
},
|
||||
limit_mm_per_prompt={"audio": 1, "image": 1, "video": 1},
|
||||
)
|
||||
|
||||
|
||||
def get_use_audio_in_video_query() -> QueryResult:
|
||||
question = (
|
||||
"Describe the content of the video in details, then convert what the "
|
||||
"baby say into text."
|
||||
)
|
||||
prompt = (
|
||||
f"<|im_start|>system\n{default_system}<|im_end|>\n"
|
||||
"<|im_start|>user\n<|vision_start|><|video_pad|><|vision_end|>"
|
||||
f"{question}<|im_end|>\n"
|
||||
f"<|im_start|>assistant\n"
|
||||
)
|
||||
asset = VideoAsset(name="baby_reading", num_frames=16)
|
||||
audio = asset.get_audio(sampling_rate=16000)
|
||||
return QueryResult(
|
||||
inputs={
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": {
|
||||
"video": asset.np_ndarrays,
|
||||
"audio": audio,
|
||||
},
|
||||
"mm_processor_kwargs": {
|
||||
"use_audio_in_video": True,
|
||||
},
|
||||
},
|
||||
limit_mm_per_prompt={"audio": 1, "video": 1},
|
||||
)
|
||||
|
||||
|
||||
def get_multi_audios_query() -> QueryResult:
|
||||
question = "Are these two audio clips the same?"
|
||||
prompt = (
|
||||
f"<|im_start|>system\n{default_system}<|im_end|>\n"
|
||||
"<|im_start|>user\n<|audio_start|><|audio_pad|><|audio_end|>"
|
||||
"<|audio_start|><|audio_pad|><|audio_end|>"
|
||||
f"{question}<|im_end|>\n"
|
||||
f"<|im_start|>assistant\n"
|
||||
)
|
||||
return QueryResult(
|
||||
inputs={
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": {
|
||||
"audio": [
|
||||
AudioAsset("winning_call").audio_and_sample_rate,
|
||||
AudioAsset("mary_had_lamb").audio_and_sample_rate,
|
||||
],
|
||||
},
|
||||
},
|
||||
limit_mm_per_prompt={
|
||||
"audio": 2,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
query_map = {
|
||||
"mixed_modalities": get_mixed_modalities_query,
|
||||
"use_audio_in_video": get_use_audio_in_video_query,
|
||||
"multi_audios": get_multi_audios_query,
|
||||
}
|
||||
|
||||
|
||||
def main(args):
|
||||
model_name = "Qwen/Qwen3-Omni-30B-A3B-Instruct"
|
||||
query_result = query_map[args.query_type]()
|
||||
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
max_model_len=12800,
|
||||
max_num_seqs=5,
|
||||
limit_mm_per_prompt=query_result.limit_mm_per_prompt,
|
||||
seed=args.seed,
|
||||
)
|
||||
|
||||
# We set temperature to 0.2 so that outputs can be different
|
||||
# even when all prompts are identical when running batch inference.
|
||||
sampling_params = SamplingParams(temperature=0.2, max_tokens=256)
|
||||
|
||||
outputs = llm.generate(query_result.inputs, sampling_params=sampling_params)
|
||||
|
||||
for o in outputs:
|
||||
generated_text = o.outputs[0].text
|
||||
print(generated_text)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Demo on using vLLM for offline inference with "
|
||||
"audio language models"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--query-type",
|
||||
"-q",
|
||||
type=str,
|
||||
default="mixed_modalities",
|
||||
choices=query_map.keys(),
|
||||
help="Query type.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the seed when initializing `vllm.LLM`.",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
221
tests/model_executor/test_qwen3_omni.py
Normal file
221
tests/model_executor/test_qwen3_omni.py
Normal file
@ -0,0 +1,221 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.multimodal.processing import InputProcessingContext
|
||||
|
||||
|
||||
# Helper function to print input IDs with coalesced audio/video tokens.
|
||||
def print_input_ids(input_ids):
|
||||
"""
|
||||
Print input IDs, compressing consecutive special tokens.
|
||||
- 151675: <|audio_pad|>
|
||||
- 151656: <|video_pad|>
|
||||
"""
|
||||
if not input_ids:
|
||||
print("[]")
|
||||
return
|
||||
|
||||
result = []
|
||||
i = 0
|
||||
|
||||
while i < len(input_ids):
|
||||
current_id = input_ids[i]
|
||||
|
||||
# Check if it's a special token that should be compressed
|
||||
if current_id in [151675, 151656]:
|
||||
# Count consecutive occurrences
|
||||
count = 1
|
||||
while i + count < len(input_ids) and input_ids[i + count] == current_id:
|
||||
count += 1
|
||||
|
||||
# Add compressed representation
|
||||
token_name = "<|audio_pad|>" if current_id == 151675 else "<|video_pad|>"
|
||||
result.append(f"{token_name} * {count}")
|
||||
i += count
|
||||
else:
|
||||
# Regular token, just add it
|
||||
result.append(str(current_id))
|
||||
i += 1
|
||||
|
||||
print(", ".join(result))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_qwen3_omni_config():
|
||||
"""Create a mock Qwen3OmniMoeThinker config."""
|
||||
config = Mock(spec=PretrainedConfig)
|
||||
# Token IDs from https://huggingface.co/Qwen/Qwen3-Omni-30B-A3B-Instruct/blob/main/tokenizer_config.json
|
||||
config.audio_token_id = 151675 # <|audio_pad|>
|
||||
config.video_token_id = 151656 # <|video_pad|>
|
||||
config.image_token_id = 151655 # <|image_pad|>
|
||||
config.audio_start_token_id = 151669 # <|audio_start|>
|
||||
config.audio_end_token_id = 151670 # <|audio_end|>
|
||||
config.vision_start_token_id = 151652 # <|vision_start|>
|
||||
config.position_id_per_seconds = 12.5
|
||||
|
||||
# Vision config
|
||||
vision_config = Mock()
|
||||
vision_config.spatial_merge_size = 2
|
||||
config.vision_config = vision_config
|
||||
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_processor():
|
||||
"""Create a mock HF processor."""
|
||||
from transformers.models.whisper import WhisperFeatureExtractor
|
||||
|
||||
processor = Mock()
|
||||
processor.audio_token = "<|audio_pad|>"
|
||||
processor.image_token = "<|image_pad|>"
|
||||
processor.video_token = "<|video_pad|>"
|
||||
|
||||
# Create a real WhisperFeatureExtractor instance for the feature_extractor attribute
|
||||
feature_extractor = WhisperFeatureExtractor()
|
||||
processor.feature_extractor = feature_extractor
|
||||
|
||||
return processor
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tokenizer():
|
||||
"""Create a mock tokenizer."""
|
||||
tokenizer = Mock()
|
||||
# Token IDs from https://huggingface.co/Qwen/Qwen3-Omni-30B-A3B-Instruct/blob/main/tokenizer_config.json
|
||||
tokenizer.get_vocab = Mock(
|
||||
return_value={
|
||||
"<|audio_pad|>": 151675,
|
||||
"<|video_pad|>": 151656,
|
||||
"<|image_pad|>": 151655,
|
||||
"<|audio_start|>": 151669,
|
||||
"<|audio_end|>": 151670,
|
||||
"<|vision_start|>": 151652,
|
||||
"<|vision_end|>": 151653,
|
||||
}
|
||||
)
|
||||
tokenizer.encode = Mock(
|
||||
side_effect=lambda x: {
|
||||
"<|vision_start|>": [151652],
|
||||
"<|vision_end|>": [151653],
|
||||
"<|audio_start|>": [151669],
|
||||
"<|audio_end|>": [151670],
|
||||
"<|audio_pad|>": [151675],
|
||||
"<|image_pad|>": [151655],
|
||||
"<|video_pad|>": [151656],
|
||||
}.get(x, [0])
|
||||
)
|
||||
tokenizer.vision_bos_token = "<|vision_start|>"
|
||||
tokenizer.vision_eos_token = "<|vision_end|>"
|
||||
tokenizer.audio_bos_token = "<|audio_start|>"
|
||||
tokenizer.audio_eos_token = "<|audio_end|>"
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_image_processor():
|
||||
"""Create a mock image processor."""
|
||||
image_processor = Mock()
|
||||
image_processor.merge_size = 2
|
||||
return image_processor
|
||||
|
||||
|
||||
def test_qwen3_omni_get_updates_use_audio_in_video(
|
||||
mock_qwen3_omni_config,
|
||||
mock_processor,
|
||||
mock_tokenizer,
|
||||
mock_image_processor,
|
||||
):
|
||||
"""Test the get_updates_use_audio_in_video method directly."""
|
||||
|
||||
from vllm.model_executor.models.qwen3_omni_moe_thinker import (
|
||||
Qwen3OmniMoeThinkerMultiModalProcessor,
|
||||
Qwen3OmniMoeThinkerProcessingInfo,
|
||||
)
|
||||
|
||||
# Create a mock context
|
||||
mock_ctx = Mock(spec=InputProcessingContext)
|
||||
|
||||
# Create processing info
|
||||
info = Qwen3OmniMoeThinkerProcessingInfo(mock_ctx)
|
||||
info.get_hf_config = Mock(return_value=mock_qwen3_omni_config)
|
||||
info.get_hf_processor = Mock(return_value=mock_processor)
|
||||
info.get_tokenizer = Mock(return_value=mock_tokenizer)
|
||||
info.get_image_processor = Mock(return_value=mock_image_processor)
|
||||
|
||||
# Create a mock dummy_inputs builder
|
||||
mock_dummy_inputs = Mock()
|
||||
|
||||
# Create the processor
|
||||
processor = Qwen3OmniMoeThinkerMultiModalProcessor(info, mock_dummy_inputs)
|
||||
|
||||
# Test parameters from reference video
|
||||
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-Omni/demo/draw.mp4
|
||||
audio_len = 85
|
||||
video_grid_thw = [6, 36, 64]
|
||||
video_second_per_grid_t = 2.0
|
||||
|
||||
# Call the method
|
||||
updates = processor.get_updates_use_audio_in_video(
|
||||
thinker_config=mock_qwen3_omni_config,
|
||||
audio_len=audio_len,
|
||||
video_grid_thw=video_grid_thw,
|
||||
video_second_per_grid_t=video_second_per_grid_t,
|
||||
)
|
||||
|
||||
# Updated input ids should align with HF implementation.
|
||||
# 151669,
|
||||
# <|video_pad|> * 576, <|audio_pad|> * 25,
|
||||
# <|video_pad|> * 576, <|audio_pad|> * 25,
|
||||
# <|video_pad|> * 576, <|audio_pad|> * 25,
|
||||
# <|video_pad|> * 576, <|audio_pad|> * 10,
|
||||
# <|video_pad|> * 1152,
|
||||
# 151670
|
||||
print_input_ids(updates)
|
||||
|
||||
# Verify structure
|
||||
assert isinstance(updates, list)
|
||||
assert len(updates) > 0
|
||||
|
||||
# Verify start and end tokens
|
||||
audio_start_token_id = mock_qwen3_omni_config.audio_start_token_id
|
||||
audio_end_token_id = mock_qwen3_omni_config.audio_end_token_id
|
||||
|
||||
assert updates[0] == audio_start_token_id
|
||||
assert updates[-1] == audio_end_token_id
|
||||
|
||||
# Verify both audio and video tokens are present
|
||||
audio_token_id = mock_qwen3_omni_config.audio_token_id
|
||||
video_token_id = mock_qwen3_omni_config.video_token_id
|
||||
|
||||
audio_count = updates.count(audio_token_id)
|
||||
video_count = updates.count(video_token_id)
|
||||
|
||||
assert audio_count == audio_len, (
|
||||
f"Expected {audio_len} audio tokens, got {audio_count}"
|
||||
)
|
||||
|
||||
# Calculate expected video token count
|
||||
spatial_merge_size = mock_qwen3_omni_config.vision_config.spatial_merge_size
|
||||
height = video_grid_thw[1] // spatial_merge_size
|
||||
width = video_grid_thw[2] // spatial_merge_size
|
||||
expected_video_count = video_grid_thw[0] * height * width
|
||||
|
||||
assert video_count == expected_video_count, (
|
||||
f"Expected {expected_video_count} video tokens, got {video_count}"
|
||||
)
|
||||
|
||||
# Total tokens should be: 1 (start) + audio_len + video_count + 1 (end)
|
||||
expected_total = 1 + audio_len + expected_video_count + 1
|
||||
assert len(updates) == expected_total, (
|
||||
f"Expected {expected_total} total tokens, got {len(updates)}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@ -23,7 +23,6 @@
|
||||
"""Inference-only Qwen2.5-Omni model (thinker part)."""
|
||||
|
||||
from collections.abc import Callable, Iterable, Mapping, Sequence
|
||||
from copy import copy
|
||||
from functools import partial
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
@ -387,15 +386,6 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
|
||||
self._validate_mm_kwargs(mm_kwargs, mm_item_counts)
|
||||
self._validate_mm_updates(mm_prompt_updates, mm_item_counts)
|
||||
|
||||
use_audio_in_video = False
|
||||
if "video" in mm_kwargs:
|
||||
video_items = [item for item in mm_kwargs["video"] if item is not None]
|
||||
# only check video items (if there are any)
|
||||
if video_items:
|
||||
use_audio_in_video = all(
|
||||
item["use_audio_in_video"].data for item in video_items
|
||||
)
|
||||
|
||||
if is_update_applied:
|
||||
mm_placeholders = self._find_mm_placeholders(
|
||||
prompt_ids,
|
||||
@ -404,7 +394,6 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
|
||||
self._validate_mm_placeholders(
|
||||
mm_placeholders,
|
||||
mm_item_counts,
|
||||
use_audio_in_video=use_audio_in_video,
|
||||
)
|
||||
else:
|
||||
prompt_ids, mm_placeholders = self._apply_prompt_updates(
|
||||
@ -414,7 +403,6 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
|
||||
self._validate_mm_placeholders(
|
||||
mm_placeholders,
|
||||
mm_item_counts,
|
||||
use_audio_in_video=use_audio_in_video,
|
||||
)
|
||||
|
||||
return prompt_ids, mm_placeholders
|
||||
@ -640,19 +628,6 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
|
||||
|
||||
return mm_processed_data
|
||||
|
||||
def _validate_mm_placeholders(
|
||||
self,
|
||||
mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
|
||||
mm_item_counts: Mapping[str, int],
|
||||
use_audio_in_video: bool = False,
|
||||
) -> None:
|
||||
if use_audio_in_video:
|
||||
mm_item_counts = copy(mm_item_counts)
|
||||
if "video" in mm_item_counts:
|
||||
assert "audio" in mm_item_counts
|
||||
mm_item_counts["audio"] -= mm_item_counts["video"]
|
||||
super()._validate_mm_placeholders(mm_placeholders, mm_item_counts)
|
||||
|
||||
|
||||
class Qwen2_5OmniConditionalGenerationMixin:
|
||||
def _parse_and_validate_audio_input(
|
||||
|
||||
@ -68,11 +68,11 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItems
|
||||
from vllm.multimodal.parse import AudioProcessorItems, MultiModalDataItems
|
||||
from vllm.multimodal.processing import (
|
||||
BaseMultiModalProcessor,
|
||||
MultiModalPromptUpdates,
|
||||
PlaceholderFeaturesInfo,
|
||||
PromptReplacement,
|
||||
PromptUpdate,
|
||||
PromptUpdateDetails,
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
@ -87,7 +87,6 @@ from .qwen2_5_omni_thinker import (
|
||||
Qwen2_5OmniConditionalGenerationMixin,
|
||||
Qwen2_5OmniThinkerDummyInputsBuilder,
|
||||
Qwen2_5OmniThinkerMultiModalProcessor,
|
||||
Qwen2_5OmniThinkerProcessingInfo,
|
||||
)
|
||||
from .qwen2_5_vl import (
|
||||
Qwen2_5_VisionAttention,
|
||||
@ -807,24 +806,8 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
|
||||
else:
|
||||
use_audio_in_video = False
|
||||
|
||||
if use_audio_in_video and "video" in mm_item_counts:
|
||||
assert "audio" in mm_item_counts
|
||||
mm_item_counts["audio"] -= mm_item_counts["video"]
|
||||
|
||||
# Special case with `use_audio_in_video=True`
|
||||
if use_audio_in_video:
|
||||
if is_update_applied:
|
||||
prompt_ids = self._get_raw_input_ids(prompt_ids, use_audio_in_video)
|
||||
(
|
||||
prompt_ids,
|
||||
mm_placeholders,
|
||||
) = self._apply_prompt_updates(
|
||||
prompt_ids,
|
||||
mm_prompt_updates,
|
||||
)
|
||||
self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
|
||||
# normal case with `use_audio_in_video=False`
|
||||
elif is_update_applied:
|
||||
if is_update_applied:
|
||||
mm_placeholders = self._find_mm_placeholders(
|
||||
prompt_ids,
|
||||
mm_prompt_updates,
|
||||
@ -834,10 +817,24 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
|
||||
mm_item_counts,
|
||||
)
|
||||
else:
|
||||
prompt_ids, mm_placeholders = self._apply_prompt_updates(
|
||||
prompt_ids,
|
||||
mm_prompt_updates,
|
||||
)
|
||||
if use_audio_in_video and "audio" in mm_prompt_updates:
|
||||
filtered_updates = {
|
||||
k: v for k, v in mm_prompt_updates.items() if k != "audio"
|
||||
}
|
||||
prompt_ids, mm_placeholders = self._apply_prompt_updates(
|
||||
prompt_ids,
|
||||
filtered_updates,
|
||||
)
|
||||
# Derive audio placeholders from video placeholders
|
||||
mm_placeholders = self._derive_audio_from_video_placeholders(
|
||||
mm_placeholders, mm_prompt_updates
|
||||
)
|
||||
else:
|
||||
prompt_ids, mm_placeholders = self._apply_prompt_updates(
|
||||
prompt_ids,
|
||||
mm_prompt_updates,
|
||||
)
|
||||
|
||||
self._validate_mm_placeholders(
|
||||
mm_placeholders,
|
||||
mm_item_counts,
|
||||
@ -962,7 +959,9 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
|
||||
|
||||
def get_replacement_qwen2_use_audio_in_video(item_idx: int):
|
||||
nonlocal audio_in_video_item_idx
|
||||
audio_num_features = audio_output_lengths[audio_item_idx + item_idx]
|
||||
audio_num_features = audio_output_lengths[
|
||||
audio_in_video_item_idx + item_idx
|
||||
]
|
||||
video_grid_thw = out_mm_data["video_grid_thw"][item_idx]
|
||||
|
||||
audio_in_video_item_idx += 1
|
||||
@ -971,14 +970,17 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
|
||||
if second_per_grid_ts:
|
||||
video_second_per_grid_t = second_per_grid_ts[item_idx]
|
||||
else:
|
||||
video_second_per_grid_t = 1.0
|
||||
video_second_per_grid_t = 2.0
|
||||
|
||||
return self.get_updates_use_audio_in_video(
|
||||
placeholder = self.get_updates_use_audio_in_video(
|
||||
thinker_config=thinker_config,
|
||||
audio_len=audio_num_features,
|
||||
video_grid_thw=video_grid_thw,
|
||||
video_second_per_grid_t=video_second_per_grid_t,
|
||||
)
|
||||
return PromptUpdateDetails.select_token_id(
|
||||
placeholder, embed_token_id=video_token_id
|
||||
)
|
||||
|
||||
video_replacement_fn = (
|
||||
get_replacement_qwen2_use_audio_in_video
|
||||
@ -1004,14 +1006,50 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
|
||||
),
|
||||
]
|
||||
|
||||
def _validate_mm_placeholders(
|
||||
def _derive_audio_from_video_placeholders(
|
||||
self,
|
||||
mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> None:
|
||||
BaseMultiModalProcessor[
|
||||
Qwen2_5OmniThinkerProcessingInfo
|
||||
]._validate_mm_placeholders(self, mm_placeholders, mm_item_counts)
|
||||
placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
|
||||
mm_prompt_updates: MultiModalPromptUpdates,
|
||||
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
|
||||
"""
|
||||
Helper to derive audio placeholders from video placeholders when
|
||||
use_audio_in_video=True.
|
||||
"""
|
||||
if "video" not in placeholders:
|
||||
return placeholders
|
||||
|
||||
# Validate audio and video counts match
|
||||
num_videos = len(placeholders["video"])
|
||||
num_audios = len(mm_prompt_updates.get("audio", []))
|
||||
if num_audios != num_videos:
|
||||
raise ValueError(
|
||||
f"use_audio_in_video requires equal number of audio and video items, "
|
||||
f"got {num_audios=}, {num_videos=}"
|
||||
)
|
||||
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
processor = self.info.get_hf_processor()
|
||||
audio_token_id = tokenizer.get_vocab()[processor.audio_token]
|
||||
|
||||
result_placeholders = dict(placeholders)
|
||||
audio_placeholders = []
|
||||
|
||||
# Each video is paired with one audio
|
||||
for video_idx, video_placeholder in enumerate(placeholders["video"]):
|
||||
# Create is_embed mask selecting only audio tokens
|
||||
audio_is_embed = torch.tensor(video_placeholder.tokens) == audio_token_id
|
||||
|
||||
audio_placeholder = PlaceholderFeaturesInfo(
|
||||
modality="audio",
|
||||
item_idx=video_idx,
|
||||
start_idx=video_placeholder.start_idx,
|
||||
tokens=video_placeholder.tokens,
|
||||
is_embed=audio_is_embed,
|
||||
)
|
||||
audio_placeholders.append(audio_placeholder)
|
||||
|
||||
result_placeholders["audio"] = audio_placeholders
|
||||
return result_placeholders
|
||||
|
||||
def _get_raw_input_ids(
|
||||
self,
|
||||
@ -1454,7 +1492,11 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
|
||||
)
|
||||
|
||||
if not len(second_per_grid_ts) and len(video_grid_thw):
|
||||
second_per_grids = torch.ones(len(video_grid_thw), dtype=torch.float32)
|
||||
second_per_grid_ts = 2.0
|
||||
second_per_grids = (
|
||||
torch.ones(len(video_grid_thw), dtype=torch.float32)
|
||||
* second_per_grid_ts
|
||||
)
|
||||
else:
|
||||
second_per_grids = torch.tensor(second_per_grid_ts, dtype=torch.float32)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user