Add interleaved RoPE test for Llama4 (Maverick) (#21478)

Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
This commit is contained in:
Yong Hoon Shin 2025-07-25 17:07:26 -07:00 committed by GitHub
parent 75d29cf4e1
commit 2eddd437ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -22,6 +22,9 @@ from transformers import (AutoConfig, AutoProcessor, AutoTokenizer,
GenerationConfig)
from vllm import LLM, SamplingParams
from vllm.v1.executor.abstract import Executor
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
FullAttentionSpec)
from ....utils import multi_gpu_test
@ -69,6 +72,26 @@ def run_maverick_serving(model: str):
raise
def get_rope_layers_config(model_path: str) -> list[int]:
"""
Get the interleaved RoPE configuration from HuggingFace config
Args:
model_path: Path to the local directory containing the reduced
Maverick model checkpoint
Returns:
List of 0 or 1 indicating whether each layer uses RoPE and local attn
0 indicates that RoPE is not used while 1 indicates that RoPE is used.
"""
config_path = Path(model_path) / "config.json"
model_config = json.loads(config_path.read_text())
text_config = model_config["text_config"]
no_rope_layers = text_config["no_rope_layers"]
print(f"Found no_rope_layers: {no_rope_layers}")
return no_rope_layers
def create_reduced_maverick_model(
original_model_name:
str = "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
@ -113,7 +136,6 @@ def create_reduced_maverick_model(
print("Loading original model configuration...")
original_config = AutoConfig.from_pretrained(original_model_name,
trust_remote_code=True)
print("Creating reduced configuration...")
reduced_config = create_reduced_config(original_config, text_layers,
num_experts, vision_layers)
@ -510,21 +532,32 @@ def save_weights_to_safetensors(weights: dict[str, torch.Tensor],
f"{index_data['metadata']['total_size'] / (1024**3):.2f} GB")
def run_reduced_model(model_path: str,
should_profile: bool = False,
**kwargs) -> None:
"""Test the created reduced model with vLLM."""
print(f"\nTesting reduced model at {model_path}...")
llm = LLM(
model=model_path,
trust_remote_code=True,
max_model_len=512, # Small context for testing
gpu_memory_utilization=0.3, # Conservative memory usage
**kwargs,
def check_attention_spec_interleaved_rope(
llm: LLM,
num_attention_layers: int,
num_ranks: int,
rope_layers: list[int],
):
"""Check that the attention spec is correct."""
assert isinstance(llm.llm_engine.model_executor, Executor)
kv_cache_specs_per_rank = llm.llm_engine.model_executor.get_kv_cache_specs(
)
for rank in range(num_ranks):
kv_cache_specs = kv_cache_specs_per_rank[rank]
assert len(kv_cache_specs.keys()) == num_attention_layers
for i in range(num_attention_layers):
if rope_layers[i] == 0:
expected_spec = FullAttentionSpec
else:
expected_spec = ChunkedLocalAttentionSpec
assert isinstance(
kv_cache_specs[
f"language_model.model.layers.{i}.self_attn.attn"],
expected_spec)
def run_reduced_model(llm: LLM, should_profile: bool = False) -> None:
"""Test the created reduced model with vLLM."""
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
max_tokens=50)
@ -551,6 +584,7 @@ def run_reduced_model(model_path: str,
@pytest.mark.parametrize("tp,ep", [(2, True)])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_dummy_maverick(
monkeypatch,
original_model_name: str,
text_layers: int,
num_experts: int,
@ -562,6 +596,10 @@ def test_dummy_maverick(
force_recreate: bool = True,
profile: bool = False,
) -> None:
# Disable multiprocessing allows us to access model executor from LLM engine
monkeypatch.setenv("VLLM_USE_V1", "1")
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
model_path = create_reduced_maverick_model(
original_model_name=original_model_name,
output_dir=output_dir,
@ -573,11 +611,27 @@ def test_dummy_maverick(
print(f"\nReduced model created successfully at: {model_path}")
run_reduced_model(model_path=model_path,
should_profile=profile,
enforce_eager=enforce_eager,
tensor_parallel_size=tp,
enable_expert_parallel=ep)
rope_layers = get_rope_layers_config(model_path)
llm = LLM(
model=model_path,
trust_remote_code=True,
max_model_len=512, # Small context for testing
gpu_memory_utilization=0.3, # Conservative memory usage
enforce_eager=enforce_eager,
tensor_parallel_size=tp,
enable_expert_parallel=ep,
)
check_attention_spec_interleaved_rope(
llm,
text_layers,
tp,
rope_layers,
)
print(f"\nTesting reduced model at {model_path}...")
run_reduced_model(llm=llm, should_profile=profile)
def main():