mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 18:38:38 +08:00
Add interleaved RoPE test for Llama4 (Maverick) (#21478)
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
This commit is contained in:
parent
75d29cf4e1
commit
2eddd437ba
@ -22,6 +22,9 @@ from transformers import (AutoConfig, AutoProcessor, AutoTokenizer,
|
|||||||
GenerationConfig)
|
GenerationConfig)
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.v1.executor.abstract import Executor
|
||||||
|
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
|
||||||
|
FullAttentionSpec)
|
||||||
|
|
||||||
from ....utils import multi_gpu_test
|
from ....utils import multi_gpu_test
|
||||||
|
|
||||||
@ -69,6 +72,26 @@ def run_maverick_serving(model: str):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def get_rope_layers_config(model_path: str) -> list[int]:
|
||||||
|
"""
|
||||||
|
Get the interleaved RoPE configuration from HuggingFace config
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to the local directory containing the reduced
|
||||||
|
Maverick model checkpoint
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of 0 or 1 indicating whether each layer uses RoPE and local attn
|
||||||
|
0 indicates that RoPE is not used while 1 indicates that RoPE is used.
|
||||||
|
"""
|
||||||
|
config_path = Path(model_path) / "config.json"
|
||||||
|
model_config = json.loads(config_path.read_text())
|
||||||
|
text_config = model_config["text_config"]
|
||||||
|
no_rope_layers = text_config["no_rope_layers"]
|
||||||
|
print(f"Found no_rope_layers: {no_rope_layers}")
|
||||||
|
return no_rope_layers
|
||||||
|
|
||||||
|
|
||||||
def create_reduced_maverick_model(
|
def create_reduced_maverick_model(
|
||||||
original_model_name:
|
original_model_name:
|
||||||
str = "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
|
str = "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
|
||||||
@ -113,7 +136,6 @@ def create_reduced_maverick_model(
|
|||||||
print("Loading original model configuration...")
|
print("Loading original model configuration...")
|
||||||
original_config = AutoConfig.from_pretrained(original_model_name,
|
original_config = AutoConfig.from_pretrained(original_model_name,
|
||||||
trust_remote_code=True)
|
trust_remote_code=True)
|
||||||
|
|
||||||
print("Creating reduced configuration...")
|
print("Creating reduced configuration...")
|
||||||
reduced_config = create_reduced_config(original_config, text_layers,
|
reduced_config = create_reduced_config(original_config, text_layers,
|
||||||
num_experts, vision_layers)
|
num_experts, vision_layers)
|
||||||
@ -510,21 +532,32 @@ def save_weights_to_safetensors(weights: dict[str, torch.Tensor],
|
|||||||
f"{index_data['metadata']['total_size'] / (1024**3):.2f} GB")
|
f"{index_data['metadata']['total_size'] / (1024**3):.2f} GB")
|
||||||
|
|
||||||
|
|
||||||
def run_reduced_model(model_path: str,
|
def check_attention_spec_interleaved_rope(
|
||||||
should_profile: bool = False,
|
llm: LLM,
|
||||||
**kwargs) -> None:
|
num_attention_layers: int,
|
||||||
"""Test the created reduced model with vLLM."""
|
num_ranks: int,
|
||||||
|
rope_layers: list[int],
|
||||||
print(f"\nTesting reduced model at {model_path}...")
|
):
|
||||||
|
"""Check that the attention spec is correct."""
|
||||||
llm = LLM(
|
assert isinstance(llm.llm_engine.model_executor, Executor)
|
||||||
model=model_path,
|
kv_cache_specs_per_rank = llm.llm_engine.model_executor.get_kv_cache_specs(
|
||||||
trust_remote_code=True,
|
|
||||||
max_model_len=512, # Small context for testing
|
|
||||||
gpu_memory_utilization=0.3, # Conservative memory usage
|
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
|
for rank in range(num_ranks):
|
||||||
|
kv_cache_specs = kv_cache_specs_per_rank[rank]
|
||||||
|
assert len(kv_cache_specs.keys()) == num_attention_layers
|
||||||
|
for i in range(num_attention_layers):
|
||||||
|
if rope_layers[i] == 0:
|
||||||
|
expected_spec = FullAttentionSpec
|
||||||
|
else:
|
||||||
|
expected_spec = ChunkedLocalAttentionSpec
|
||||||
|
assert isinstance(
|
||||||
|
kv_cache_specs[
|
||||||
|
f"language_model.model.layers.{i}.self_attn.attn"],
|
||||||
|
expected_spec)
|
||||||
|
|
||||||
|
|
||||||
|
def run_reduced_model(llm: LLM, should_profile: bool = False) -> None:
|
||||||
|
"""Test the created reduced model with vLLM."""
|
||||||
sampling_params = SamplingParams(temperature=0.8,
|
sampling_params = SamplingParams(temperature=0.8,
|
||||||
top_p=0.95,
|
top_p=0.95,
|
||||||
max_tokens=50)
|
max_tokens=50)
|
||||||
@ -551,6 +584,7 @@ def run_reduced_model(model_path: str,
|
|||||||
@pytest.mark.parametrize("tp,ep", [(2, True)])
|
@pytest.mark.parametrize("tp,ep", [(2, True)])
|
||||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||||
def test_dummy_maverick(
|
def test_dummy_maverick(
|
||||||
|
monkeypatch,
|
||||||
original_model_name: str,
|
original_model_name: str,
|
||||||
text_layers: int,
|
text_layers: int,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
@ -562,6 +596,10 @@ def test_dummy_maverick(
|
|||||||
force_recreate: bool = True,
|
force_recreate: bool = True,
|
||||||
profile: bool = False,
|
profile: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
# Disable multiprocessing allows us to access model executor from LLM engine
|
||||||
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||||
|
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||||
|
|
||||||
model_path = create_reduced_maverick_model(
|
model_path = create_reduced_maverick_model(
|
||||||
original_model_name=original_model_name,
|
original_model_name=original_model_name,
|
||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
@ -573,11 +611,27 @@ def test_dummy_maverick(
|
|||||||
|
|
||||||
print(f"\nReduced model created successfully at: {model_path}")
|
print(f"\nReduced model created successfully at: {model_path}")
|
||||||
|
|
||||||
run_reduced_model(model_path=model_path,
|
rope_layers = get_rope_layers_config(model_path)
|
||||||
should_profile=profile,
|
|
||||||
enforce_eager=enforce_eager,
|
llm = LLM(
|
||||||
tensor_parallel_size=tp,
|
model=model_path,
|
||||||
enable_expert_parallel=ep)
|
trust_remote_code=True,
|
||||||
|
max_model_len=512, # Small context for testing
|
||||||
|
gpu_memory_utilization=0.3, # Conservative memory usage
|
||||||
|
enforce_eager=enforce_eager,
|
||||||
|
tensor_parallel_size=tp,
|
||||||
|
enable_expert_parallel=ep,
|
||||||
|
)
|
||||||
|
|
||||||
|
check_attention_spec_interleaved_rope(
|
||||||
|
llm,
|
||||||
|
text_layers,
|
||||||
|
tp,
|
||||||
|
rope_layers,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\nTesting reduced model at {model_path}...")
|
||||||
|
run_reduced_model(llm=llm, should_profile=profile)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user