mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:06:10 +08:00
Signed-off-by: Chenheli Hua <huachenheli@outlook.com> Signed-off-by: Roger Wang <hey@rogerw.io> Co-authored-by: Roger Wang <hey@rogerw.io>
222 lines
7.0 KiB
Python
222 lines
7.0 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from unittest.mock import Mock
|
|
|
|
import pytest
|
|
from transformers import PretrainedConfig
|
|
|
|
from vllm.multimodal.processing import InputProcessingContext
|
|
|
|
|
|
# Helper function to print input IDs with coalesced audio/video tokens.
|
|
def print_input_ids(input_ids):
|
|
"""
|
|
Print input IDs, compressing consecutive special tokens.
|
|
- 151675: <|audio_pad|>
|
|
- 151656: <|video_pad|>
|
|
"""
|
|
if not input_ids:
|
|
print("[]")
|
|
return
|
|
|
|
result = []
|
|
i = 0
|
|
|
|
while i < len(input_ids):
|
|
current_id = input_ids[i]
|
|
|
|
# Check if it's a special token that should be compressed
|
|
if current_id in [151675, 151656]:
|
|
# Count consecutive occurrences
|
|
count = 1
|
|
while i + count < len(input_ids) and input_ids[i + count] == current_id:
|
|
count += 1
|
|
|
|
# Add compressed representation
|
|
token_name = "<|audio_pad|>" if current_id == 151675 else "<|video_pad|>"
|
|
result.append(f"{token_name} * {count}")
|
|
i += count
|
|
else:
|
|
# Regular token, just add it
|
|
result.append(str(current_id))
|
|
i += 1
|
|
|
|
print(", ".join(result))
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_qwen3_omni_config():
|
|
"""Create a mock Qwen3OmniMoeThinker config."""
|
|
config = Mock(spec=PretrainedConfig)
|
|
# Token IDs from https://huggingface.co/Qwen/Qwen3-Omni-30B-A3B-Instruct/blob/main/tokenizer_config.json
|
|
config.audio_token_id = 151675 # <|audio_pad|>
|
|
config.video_token_id = 151656 # <|video_pad|>
|
|
config.image_token_id = 151655 # <|image_pad|>
|
|
config.audio_start_token_id = 151669 # <|audio_start|>
|
|
config.audio_end_token_id = 151670 # <|audio_end|>
|
|
config.vision_start_token_id = 151652 # <|vision_start|>
|
|
config.position_id_per_seconds = 12.5
|
|
|
|
# Vision config
|
|
vision_config = Mock()
|
|
vision_config.spatial_merge_size = 2
|
|
config.vision_config = vision_config
|
|
|
|
return config
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_processor():
|
|
"""Create a mock HF processor."""
|
|
from transformers.models.whisper import WhisperFeatureExtractor
|
|
|
|
processor = Mock()
|
|
processor.audio_token = "<|audio_pad|>"
|
|
processor.image_token = "<|image_pad|>"
|
|
processor.video_token = "<|video_pad|>"
|
|
|
|
# Create a real WhisperFeatureExtractor instance for the feature_extractor attribute
|
|
feature_extractor = WhisperFeatureExtractor()
|
|
processor.feature_extractor = feature_extractor
|
|
|
|
return processor
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_tokenizer():
|
|
"""Create a mock tokenizer."""
|
|
tokenizer = Mock()
|
|
# Token IDs from https://huggingface.co/Qwen/Qwen3-Omni-30B-A3B-Instruct/blob/main/tokenizer_config.json
|
|
tokenizer.get_vocab = Mock(
|
|
return_value={
|
|
"<|audio_pad|>": 151675,
|
|
"<|video_pad|>": 151656,
|
|
"<|image_pad|>": 151655,
|
|
"<|audio_start|>": 151669,
|
|
"<|audio_end|>": 151670,
|
|
"<|vision_start|>": 151652,
|
|
"<|vision_end|>": 151653,
|
|
}
|
|
)
|
|
tokenizer.encode = Mock(
|
|
side_effect=lambda x: {
|
|
"<|vision_start|>": [151652],
|
|
"<|vision_end|>": [151653],
|
|
"<|audio_start|>": [151669],
|
|
"<|audio_end|>": [151670],
|
|
"<|audio_pad|>": [151675],
|
|
"<|image_pad|>": [151655],
|
|
"<|video_pad|>": [151656],
|
|
}.get(x, [0])
|
|
)
|
|
tokenizer.vision_bos_token = "<|vision_start|>"
|
|
tokenizer.vision_eos_token = "<|vision_end|>"
|
|
tokenizer.audio_bos_token = "<|audio_start|>"
|
|
tokenizer.audio_eos_token = "<|audio_end|>"
|
|
return tokenizer
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_image_processor():
|
|
"""Create a mock image processor."""
|
|
image_processor = Mock()
|
|
image_processor.merge_size = 2
|
|
return image_processor
|
|
|
|
|
|
def test_qwen3_omni_get_updates_use_audio_in_video(
|
|
mock_qwen3_omni_config,
|
|
mock_processor,
|
|
mock_tokenizer,
|
|
mock_image_processor,
|
|
):
|
|
"""Test the get_updates_use_audio_in_video method directly."""
|
|
|
|
from vllm.model_executor.models.qwen3_omni_moe_thinker import (
|
|
Qwen3OmniMoeThinkerMultiModalProcessor,
|
|
Qwen3OmniMoeThinkerProcessingInfo,
|
|
)
|
|
|
|
# Create a mock context
|
|
mock_ctx = Mock(spec=InputProcessingContext)
|
|
|
|
# Create processing info
|
|
info = Qwen3OmniMoeThinkerProcessingInfo(mock_ctx)
|
|
info.get_hf_config = Mock(return_value=mock_qwen3_omni_config)
|
|
info.get_hf_processor = Mock(return_value=mock_processor)
|
|
info.get_tokenizer = Mock(return_value=mock_tokenizer)
|
|
info.get_image_processor = Mock(return_value=mock_image_processor)
|
|
|
|
# Create a mock dummy_inputs builder
|
|
mock_dummy_inputs = Mock()
|
|
|
|
# Create the processor
|
|
processor = Qwen3OmniMoeThinkerMultiModalProcessor(info, mock_dummy_inputs)
|
|
|
|
# Test parameters from reference video
|
|
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-Omni/demo/draw.mp4
|
|
audio_len = 85
|
|
video_grid_thw = [6, 36, 64]
|
|
video_second_per_grid_t = 2.0
|
|
|
|
# Call the method
|
|
updates = processor.get_updates_use_audio_in_video(
|
|
thinker_config=mock_qwen3_omni_config,
|
|
audio_len=audio_len,
|
|
video_grid_thw=video_grid_thw,
|
|
video_second_per_grid_t=video_second_per_grid_t,
|
|
)
|
|
|
|
# Updated input ids should align with HF implementation.
|
|
# 151669,
|
|
# <|video_pad|> * 576, <|audio_pad|> * 25,
|
|
# <|video_pad|> * 576, <|audio_pad|> * 25,
|
|
# <|video_pad|> * 576, <|audio_pad|> * 25,
|
|
# <|video_pad|> * 576, <|audio_pad|> * 10,
|
|
# <|video_pad|> * 1152,
|
|
# 151670
|
|
print_input_ids(updates)
|
|
|
|
# Verify structure
|
|
assert isinstance(updates, list)
|
|
assert len(updates) > 0
|
|
|
|
# Verify start and end tokens
|
|
audio_start_token_id = mock_qwen3_omni_config.audio_start_token_id
|
|
audio_end_token_id = mock_qwen3_omni_config.audio_end_token_id
|
|
|
|
assert updates[0] == audio_start_token_id
|
|
assert updates[-1] == audio_end_token_id
|
|
|
|
# Verify both audio and video tokens are present
|
|
audio_token_id = mock_qwen3_omni_config.audio_token_id
|
|
video_token_id = mock_qwen3_omni_config.video_token_id
|
|
|
|
audio_count = updates.count(audio_token_id)
|
|
video_count = updates.count(video_token_id)
|
|
|
|
assert audio_count == audio_len, (
|
|
f"Expected {audio_len} audio tokens, got {audio_count}"
|
|
)
|
|
|
|
# Calculate expected video token count
|
|
spatial_merge_size = mock_qwen3_omni_config.vision_config.spatial_merge_size
|
|
height = video_grid_thw[1] // spatial_merge_size
|
|
width = video_grid_thw[2] // spatial_merge_size
|
|
expected_video_count = video_grid_thw[0] * height * width
|
|
|
|
assert video_count == expected_video_count, (
|
|
f"Expected {expected_video_count} video tokens, got {video_count}"
|
|
)
|
|
|
|
# Total tokens should be: 1 (start) + audio_len + video_count + 1 (end)
|
|
expected_total = 1 + audio_len + expected_video_count + 1
|
|
assert len(updates) == expected_total, (
|
|
f"Expected {expected_total} total tokens, got {len(updates)}"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|