From 2eddd437ba5e7ce80d7341bf87a3078802b01ba7 Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin <48474650+sarckk@users.noreply.github.com> Date: Fri, 25 Jul 2025 17:07:26 -0700 Subject: [PATCH] Add interleaved RoPE test for Llama4 (Maverick) (#21478) Signed-off-by: Yong Hoon Shin --- .../multimodal/generation/test_maverick.py | 92 +++++++++++++++---- 1 file changed, 73 insertions(+), 19 deletions(-) diff --git a/tests/models/multimodal/generation/test_maverick.py b/tests/models/multimodal/generation/test_maverick.py index 306cf39002df2..bacc9ef94f49d 100644 --- a/tests/models/multimodal/generation/test_maverick.py +++ b/tests/models/multimodal/generation/test_maverick.py @@ -22,6 +22,9 @@ from transformers import (AutoConfig, AutoProcessor, AutoTokenizer, GenerationConfig) from vllm import LLM, SamplingParams +from vllm.v1.executor.abstract import Executor +from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, + FullAttentionSpec) from ....utils import multi_gpu_test @@ -69,6 +72,26 @@ def run_maverick_serving(model: str): raise +def get_rope_layers_config(model_path: str) -> list[int]: + """ + Get the interleaved RoPE configuration from HuggingFace config + + Args: + model_path: Path to the local directory containing the reduced + Maverick model checkpoint + + Returns: + List of 0 or 1 indicating whether each layer uses RoPE and local attn + 0 indicates that RoPE is not used while 1 indicates that RoPE is used. + """ + config_path = Path(model_path) / "config.json" + model_config = json.loads(config_path.read_text()) + text_config = model_config["text_config"] + no_rope_layers = text_config["no_rope_layers"] + print(f"Found no_rope_layers: {no_rope_layers}") + return no_rope_layers + + def create_reduced_maverick_model( original_model_name: str = "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", @@ -113,7 +136,6 @@ def create_reduced_maverick_model( print("Loading original model configuration...") original_config = AutoConfig.from_pretrained(original_model_name, trust_remote_code=True) - print("Creating reduced configuration...") reduced_config = create_reduced_config(original_config, text_layers, num_experts, vision_layers) @@ -510,21 +532,32 @@ def save_weights_to_safetensors(weights: dict[str, torch.Tensor], f"{index_data['metadata']['total_size'] / (1024**3):.2f} GB") -def run_reduced_model(model_path: str, - should_profile: bool = False, - **kwargs) -> None: - """Test the created reduced model with vLLM.""" - - print(f"\nTesting reduced model at {model_path}...") - - llm = LLM( - model=model_path, - trust_remote_code=True, - max_model_len=512, # Small context for testing - gpu_memory_utilization=0.3, # Conservative memory usage - **kwargs, +def check_attention_spec_interleaved_rope( + llm: LLM, + num_attention_layers: int, + num_ranks: int, + rope_layers: list[int], +): + """Check that the attention spec is correct.""" + assert isinstance(llm.llm_engine.model_executor, Executor) + kv_cache_specs_per_rank = llm.llm_engine.model_executor.get_kv_cache_specs( ) + for rank in range(num_ranks): + kv_cache_specs = kv_cache_specs_per_rank[rank] + assert len(kv_cache_specs.keys()) == num_attention_layers + for i in range(num_attention_layers): + if rope_layers[i] == 0: + expected_spec = FullAttentionSpec + else: + expected_spec = ChunkedLocalAttentionSpec + assert isinstance( + kv_cache_specs[ + f"language_model.model.layers.{i}.self_attn.attn"], + expected_spec) + +def run_reduced_model(llm: LLM, should_profile: bool = False) -> None: + """Test the created reduced model with vLLM.""" sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=50) @@ -551,6 +584,7 @@ def run_reduced_model(model_path: str, @pytest.mark.parametrize("tp,ep", [(2, True)]) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_dummy_maverick( + monkeypatch, original_model_name: str, text_layers: int, num_experts: int, @@ -562,6 +596,10 @@ def test_dummy_maverick( force_recreate: bool = True, profile: bool = False, ) -> None: + # Disable multiprocessing allows us to access model executor from LLM engine + monkeypatch.setenv("VLLM_USE_V1", "1") + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + model_path = create_reduced_maverick_model( original_model_name=original_model_name, output_dir=output_dir, @@ -573,11 +611,27 @@ def test_dummy_maverick( print(f"\nReduced model created successfully at: {model_path}") - run_reduced_model(model_path=model_path, - should_profile=profile, - enforce_eager=enforce_eager, - tensor_parallel_size=tp, - enable_expert_parallel=ep) + rope_layers = get_rope_layers_config(model_path) + + llm = LLM( + model=model_path, + trust_remote_code=True, + max_model_len=512, # Small context for testing + gpu_memory_utilization=0.3, # Conservative memory usage + enforce_eager=enforce_eager, + tensor_parallel_size=tp, + enable_expert_parallel=ep, + ) + + check_attention_spec_interleaved_rope( + llm, + text_layers, + tp, + rope_layers, + ) + + print(f"\nTesting reduced model at {model_path}...") + run_reduced_model(llm=llm, should_profile=profile) def main():