# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from unittest.mock import Mock import pytest from transformers import PretrainedConfig from vllm.multimodal.processing import InputProcessingContext # Helper function to print input IDs with coalesced audio/video tokens. def print_input_ids(input_ids): """ Print input IDs, compressing consecutive special tokens. - 151675: <|audio_pad|> - 151656: <|video_pad|> """ if not input_ids: print("[]") return result = [] i = 0 while i < len(input_ids): current_id = input_ids[i] # Check if it's a special token that should be compressed if current_id in [151675, 151656]: # Count consecutive occurrences count = 1 while i + count < len(input_ids) and input_ids[i + count] == current_id: count += 1 # Add compressed representation token_name = "<|audio_pad|>" if current_id == 151675 else "<|video_pad|>" result.append(f"{token_name} * {count}") i += count else: # Regular token, just add it result.append(str(current_id)) i += 1 print(", ".join(result)) @pytest.fixture def mock_qwen3_omni_config(): """Create a mock Qwen3OmniMoeThinker config.""" config = Mock(spec=PretrainedConfig) # Token IDs from https://huggingface.co/Qwen/Qwen3-Omni-30B-A3B-Instruct/blob/main/tokenizer_config.json config.audio_token_id = 151675 # <|audio_pad|> config.video_token_id = 151656 # <|video_pad|> config.image_token_id = 151655 # <|image_pad|> config.audio_start_token_id = 151669 # <|audio_start|> config.audio_end_token_id = 151670 # <|audio_end|> config.vision_start_token_id = 151652 # <|vision_start|> config.position_id_per_seconds = 12.5 # Vision config vision_config = Mock() vision_config.spatial_merge_size = 2 config.vision_config = vision_config return config @pytest.fixture def mock_processor(): """Create a mock HF processor.""" from transformers.models.whisper import WhisperFeatureExtractor processor = Mock() processor.audio_token = "<|audio_pad|>" processor.image_token = "<|image_pad|>" processor.video_token = "<|video_pad|>" # Create a real WhisperFeatureExtractor instance for the feature_extractor attribute feature_extractor = WhisperFeatureExtractor() processor.feature_extractor = feature_extractor return processor @pytest.fixture def mock_tokenizer(): """Create a mock tokenizer.""" tokenizer = Mock() # Token IDs from https://huggingface.co/Qwen/Qwen3-Omni-30B-A3B-Instruct/blob/main/tokenizer_config.json tokenizer.get_vocab = Mock( return_value={ "<|audio_pad|>": 151675, "<|video_pad|>": 151656, "<|image_pad|>": 151655, "<|audio_start|>": 151669, "<|audio_end|>": 151670, "<|vision_start|>": 151652, "<|vision_end|>": 151653, } ) tokenizer.encode = Mock( side_effect=lambda x: { "<|vision_start|>": [151652], "<|vision_end|>": [151653], "<|audio_start|>": [151669], "<|audio_end|>": [151670], "<|audio_pad|>": [151675], "<|image_pad|>": [151655], "<|video_pad|>": [151656], }.get(x, [0]) ) tokenizer.vision_bos_token = "<|vision_start|>" tokenizer.vision_eos_token = "<|vision_end|>" tokenizer.audio_bos_token = "<|audio_start|>" tokenizer.audio_eos_token = "<|audio_end|>" return tokenizer @pytest.fixture def mock_image_processor(): """Create a mock image processor.""" image_processor = Mock() image_processor.merge_size = 2 return image_processor def test_qwen3_omni_get_updates_use_audio_in_video( mock_qwen3_omni_config, mock_processor, mock_tokenizer, mock_image_processor, ): """Test the get_updates_use_audio_in_video method directly.""" from vllm.model_executor.models.qwen3_omni_moe_thinker import ( Qwen3OmniMoeThinkerMultiModalProcessor, Qwen3OmniMoeThinkerProcessingInfo, ) # Create a mock context mock_ctx = Mock(spec=InputProcessingContext) # Create processing info info = Qwen3OmniMoeThinkerProcessingInfo(mock_ctx) info.get_hf_config = Mock(return_value=mock_qwen3_omni_config) info.get_hf_processor = Mock(return_value=mock_processor) info.get_tokenizer = Mock(return_value=mock_tokenizer) info.get_image_processor = Mock(return_value=mock_image_processor) # Create a mock dummy_inputs builder mock_dummy_inputs = Mock() # Create the processor processor = Qwen3OmniMoeThinkerMultiModalProcessor(info, mock_dummy_inputs) # Test parameters from reference video # https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-Omni/demo/draw.mp4 audio_len = 85 video_grid_thw = [6, 36, 64] video_second_per_grid_t = 2.0 # Call the method updates = processor.get_updates_use_audio_in_video( thinker_config=mock_qwen3_omni_config, audio_len=audio_len, video_grid_thw=video_grid_thw, video_second_per_grid_t=video_second_per_grid_t, ) # Updated input ids should align with HF implementation. # 151669, # <|video_pad|> * 576, <|audio_pad|> * 25, # <|video_pad|> * 576, <|audio_pad|> * 25, # <|video_pad|> * 576, <|audio_pad|> * 25, # <|video_pad|> * 576, <|audio_pad|> * 10, # <|video_pad|> * 1152, # 151670 print_input_ids(updates) # Verify structure assert isinstance(updates, list) assert len(updates) > 0 # Verify start and end tokens audio_start_token_id = mock_qwen3_omni_config.audio_start_token_id audio_end_token_id = mock_qwen3_omni_config.audio_end_token_id assert updates[0] == audio_start_token_id assert updates[-1] == audio_end_token_id # Verify both audio and video tokens are present audio_token_id = mock_qwen3_omni_config.audio_token_id video_token_id = mock_qwen3_omni_config.video_token_id audio_count = updates.count(audio_token_id) video_count = updates.count(video_token_id) assert audio_count == audio_len, ( f"Expected {audio_len} audio tokens, got {audio_count}" ) # Calculate expected video token count spatial_merge_size = mock_qwen3_omni_config.vision_config.spatial_merge_size height = video_grid_thw[1] // spatial_merge_size width = video_grid_thw[2] // spatial_merge_size expected_video_count = video_grid_thw[0] * height * width assert video_count == expected_video_count, ( f"Expected {expected_video_count} video tokens, got {video_count}" ) # Total tokens should be: 1 (start) + audio_len + video_count + 1 (end) expected_total = 1 + audio_len + expected_video_count + 1 assert len(updates) == expected_total, ( f"Expected {expected_total} total tokens, got {len(updates)}" ) if __name__ == "__main__": pytest.main([__file__, "-v"])