speculative tests

Signed-off-by: Xingyu Liu <charlotteliu12x@gmail.com>
This commit is contained in:
Xingyu Liu 2025-12-09 23:18:29 -08:00
parent 65c6d2565d
commit 0cd72dc438

View File

@ -3,7 +3,7 @@
import json
from pathlib import Path
from vllm.config import ModelConfig, SpeculativeConfig, ParallelConfig
from vllm.config import ModelConfig, ParallelConfig, SpeculativeConfig
def test_basic():
@ -86,10 +86,22 @@ def test_basic():
def test_draft_models():
speculative_models = [
("JackFram/llama-68m", "abhigoyal/vllm-medusa-llama-68m-random", False),
("luccafong/deepseek_mtp_main_random", "luccafong/deepseek_mtp_draft_random", True),
(
"luccafong/deepseek_mtp_main_random",
"luccafong/deepseek_mtp_draft_random",
True,
),
("eagle618/deepseek-v3-random", "eagle618/eagle-deepseek-v3-random", True),
("meta-llama/Meta-Llama-3-8B-Instruct", "yuhuili/EAGLE-LLaMA3-Instruct-8B", True),
("meta-llama/Llama-3.1-8B-Instruct", "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", True),
(
"meta-llama/Meta-Llama-3-8B-Instruct",
"yuhuili/EAGLE-LLaMA3-Instruct-8B",
True,
),
(
"meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
True,
),
]
groundtruth_path = Path(__file__).parent / "draft_model_arch_groundtruth.json"
@ -108,7 +120,7 @@ def test_draft_models():
"target_parallel_config": ParallelConfig(),
}
speculative_config = SpeculativeConfig(**speculative_config)
speculative_config = SpeculativeConfig(**speculative_config)
model_config = speculative_config.draft_model_config
model_arch_config = model_config.model_arch_config
@ -145,7 +157,7 @@ def test_draft_models():
)
if isinstance(expected["head_size"], int):
# Before model_arch_config is introduced, get_head_size() for medusa
# Before model_arch_config is introduced, get_head_size() for medusa
# model config will throw out `integer division or modulo by zero` error.
assert model_arch_config.head_size == expected["head_size"]
assert model_config.get_head_size() == expected["head_size"]