mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 08:45:00 +08:00
Add interleaved RoPE test for Llama4 (Maverick) (#21478)
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
This commit is contained in:
parent
75d29cf4e1
commit
2eddd437ba
@ -22,6 +22,9 @@ from transformers import (AutoConfig, AutoProcessor, AutoTokenizer,
|
||||
GenerationConfig)
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
|
||||
FullAttentionSpec)
|
||||
|
||||
from ....utils import multi_gpu_test
|
||||
|
||||
@ -69,6 +72,26 @@ def run_maverick_serving(model: str):
|
||||
raise
|
||||
|
||||
|
||||
def get_rope_layers_config(model_path: str) -> list[int]:
|
||||
"""
|
||||
Get the interleaved RoPE configuration from HuggingFace config
|
||||
|
||||
Args:
|
||||
model_path: Path to the local directory containing the reduced
|
||||
Maverick model checkpoint
|
||||
|
||||
Returns:
|
||||
List of 0 or 1 indicating whether each layer uses RoPE and local attn
|
||||
0 indicates that RoPE is not used while 1 indicates that RoPE is used.
|
||||
"""
|
||||
config_path = Path(model_path) / "config.json"
|
||||
model_config = json.loads(config_path.read_text())
|
||||
text_config = model_config["text_config"]
|
||||
no_rope_layers = text_config["no_rope_layers"]
|
||||
print(f"Found no_rope_layers: {no_rope_layers}")
|
||||
return no_rope_layers
|
||||
|
||||
|
||||
def create_reduced_maverick_model(
|
||||
original_model_name:
|
||||
str = "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
|
||||
@ -113,7 +136,6 @@ def create_reduced_maverick_model(
|
||||
print("Loading original model configuration...")
|
||||
original_config = AutoConfig.from_pretrained(original_model_name,
|
||||
trust_remote_code=True)
|
||||
|
||||
print("Creating reduced configuration...")
|
||||
reduced_config = create_reduced_config(original_config, text_layers,
|
||||
num_experts, vision_layers)
|
||||
@ -510,21 +532,32 @@ def save_weights_to_safetensors(weights: dict[str, torch.Tensor],
|
||||
f"{index_data['metadata']['total_size'] / (1024**3):.2f} GB")
|
||||
|
||||
|
||||
def run_reduced_model(model_path: str,
|
||||
should_profile: bool = False,
|
||||
**kwargs) -> None:
|
||||
"""Test the created reduced model with vLLM."""
|
||||
|
||||
print(f"\nTesting reduced model at {model_path}...")
|
||||
|
||||
llm = LLM(
|
||||
model=model_path,
|
||||
trust_remote_code=True,
|
||||
max_model_len=512, # Small context for testing
|
||||
gpu_memory_utilization=0.3, # Conservative memory usage
|
||||
**kwargs,
|
||||
def check_attention_spec_interleaved_rope(
|
||||
llm: LLM,
|
||||
num_attention_layers: int,
|
||||
num_ranks: int,
|
||||
rope_layers: list[int],
|
||||
):
|
||||
"""Check that the attention spec is correct."""
|
||||
assert isinstance(llm.llm_engine.model_executor, Executor)
|
||||
kv_cache_specs_per_rank = llm.llm_engine.model_executor.get_kv_cache_specs(
|
||||
)
|
||||
for rank in range(num_ranks):
|
||||
kv_cache_specs = kv_cache_specs_per_rank[rank]
|
||||
assert len(kv_cache_specs.keys()) == num_attention_layers
|
||||
for i in range(num_attention_layers):
|
||||
if rope_layers[i] == 0:
|
||||
expected_spec = FullAttentionSpec
|
||||
else:
|
||||
expected_spec = ChunkedLocalAttentionSpec
|
||||
assert isinstance(
|
||||
kv_cache_specs[
|
||||
f"language_model.model.layers.{i}.self_attn.attn"],
|
||||
expected_spec)
|
||||
|
||||
|
||||
def run_reduced_model(llm: LLM, should_profile: bool = False) -> None:
|
||||
"""Test the created reduced model with vLLM."""
|
||||
sampling_params = SamplingParams(temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=50)
|
||||
@ -551,6 +584,7 @@ def run_reduced_model(model_path: str,
|
||||
@pytest.mark.parametrize("tp,ep", [(2, True)])
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_dummy_maverick(
|
||||
monkeypatch,
|
||||
original_model_name: str,
|
||||
text_layers: int,
|
||||
num_experts: int,
|
||||
@ -562,6 +596,10 @@ def test_dummy_maverick(
|
||||
force_recreate: bool = True,
|
||||
profile: bool = False,
|
||||
) -> None:
|
||||
# Disable multiprocessing allows us to access model executor from LLM engine
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||
|
||||
model_path = create_reduced_maverick_model(
|
||||
original_model_name=original_model_name,
|
||||
output_dir=output_dir,
|
||||
@ -573,11 +611,27 @@ def test_dummy_maverick(
|
||||
|
||||
print(f"\nReduced model created successfully at: {model_path}")
|
||||
|
||||
run_reduced_model(model_path=model_path,
|
||||
should_profile=profile,
|
||||
enforce_eager=enforce_eager,
|
||||
tensor_parallel_size=tp,
|
||||
enable_expert_parallel=ep)
|
||||
rope_layers = get_rope_layers_config(model_path)
|
||||
|
||||
llm = LLM(
|
||||
model=model_path,
|
||||
trust_remote_code=True,
|
||||
max_model_len=512, # Small context for testing
|
||||
gpu_memory_utilization=0.3, # Conservative memory usage
|
||||
enforce_eager=enforce_eager,
|
||||
tensor_parallel_size=tp,
|
||||
enable_expert_parallel=ep,
|
||||
)
|
||||
|
||||
check_attention_spec_interleaved_rope(
|
||||
llm,
|
||||
text_layers,
|
||||
tp,
|
||||
rope_layers,
|
||||
)
|
||||
|
||||
print(f"\nTesting reduced model at {model_path}...")
|
||||
run_reduced_model(llm=llm, should_profile=profile)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user