diff --git a/tests/config/test_model_arch_config.py b/tests/config/test_model_arch_config.py index 6419461b930ce..d7f12e1d5a6f8 100644 --- a/tests/config/test_model_arch_config.py +++ b/tests/config/test_model_arch_config.py @@ -3,7 +3,7 @@ import json from pathlib import Path -from vllm.config import ModelConfig, SpeculativeConfig, ParallelConfig +from vllm.config import ModelConfig, ParallelConfig, SpeculativeConfig def test_basic(): @@ -86,10 +86,22 @@ def test_basic(): def test_draft_models(): speculative_models = [ ("JackFram/llama-68m", "abhigoyal/vllm-medusa-llama-68m-random", False), - ("luccafong/deepseek_mtp_main_random", "luccafong/deepseek_mtp_draft_random", True), + ( + "luccafong/deepseek_mtp_main_random", + "luccafong/deepseek_mtp_draft_random", + True, + ), ("eagle618/deepseek-v3-random", "eagle618/eagle-deepseek-v3-random", True), - ("meta-llama/Meta-Llama-3-8B-Instruct", "yuhuili/EAGLE-LLaMA3-Instruct-8B", True), - ("meta-llama/Llama-3.1-8B-Instruct", "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", True), + ( + "meta-llama/Meta-Llama-3-8B-Instruct", + "yuhuili/EAGLE-LLaMA3-Instruct-8B", + True, + ), + ( + "meta-llama/Llama-3.1-8B-Instruct", + "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", + True, + ), ] groundtruth_path = Path(__file__).parent / "draft_model_arch_groundtruth.json" @@ -108,7 +120,7 @@ def test_draft_models(): "target_parallel_config": ParallelConfig(), } - speculative_config = SpeculativeConfig(**speculative_config) + speculative_config = SpeculativeConfig(**speculative_config) model_config = speculative_config.draft_model_config model_arch_config = model_config.model_arch_config @@ -145,7 +157,7 @@ def test_draft_models(): ) if isinstance(expected["head_size"], int): - # Before model_arch_config is introduced, get_head_size() for medusa + # Before model_arch_config is introduced, get_head_size() for medusa # model config will throw out `integer division or modulo by zero` error. assert model_arch_config.head_size == expected["head_size"] assert model_config.get_head_size() == expected["head_size"]