mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-23 11:22:25 +08:00
speculative tests
Signed-off-by: Xingyu Liu <charlotteliu12x@gmail.com>
This commit is contained in:
parent
65c6d2565d
commit
0cd72dc438
@ -3,7 +3,7 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from vllm.config import ModelConfig, SpeculativeConfig, ParallelConfig
|
||||
from vllm.config import ModelConfig, ParallelConfig, SpeculativeConfig
|
||||
|
||||
|
||||
def test_basic():
|
||||
@ -86,10 +86,22 @@ def test_basic():
|
||||
def test_draft_models():
|
||||
speculative_models = [
|
||||
("JackFram/llama-68m", "abhigoyal/vllm-medusa-llama-68m-random", False),
|
||||
("luccafong/deepseek_mtp_main_random", "luccafong/deepseek_mtp_draft_random", True),
|
||||
(
|
||||
"luccafong/deepseek_mtp_main_random",
|
||||
"luccafong/deepseek_mtp_draft_random",
|
||||
True,
|
||||
),
|
||||
("eagle618/deepseek-v3-random", "eagle618/eagle-deepseek-v3-random", True),
|
||||
("meta-llama/Meta-Llama-3-8B-Instruct", "yuhuili/EAGLE-LLaMA3-Instruct-8B", True),
|
||||
("meta-llama/Llama-3.1-8B-Instruct", "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", True),
|
||||
(
|
||||
"meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
"yuhuili/EAGLE-LLaMA3-Instruct-8B",
|
||||
True,
|
||||
),
|
||||
(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
|
||||
True,
|
||||
),
|
||||
]
|
||||
|
||||
groundtruth_path = Path(__file__).parent / "draft_model_arch_groundtruth.json"
|
||||
@ -108,7 +120,7 @@ def test_draft_models():
|
||||
"target_parallel_config": ParallelConfig(),
|
||||
}
|
||||
|
||||
speculative_config = SpeculativeConfig(**speculative_config)
|
||||
speculative_config = SpeculativeConfig(**speculative_config)
|
||||
model_config = speculative_config.draft_model_config
|
||||
|
||||
model_arch_config = model_config.model_arch_config
|
||||
@ -145,7 +157,7 @@ def test_draft_models():
|
||||
)
|
||||
|
||||
if isinstance(expected["head_size"], int):
|
||||
# Before model_arch_config is introduced, get_head_size() for medusa
|
||||
# Before model_arch_config is introduced, get_head_size() for medusa
|
||||
# model config will throw out `integer division or modulo by zero` error.
|
||||
assert model_arch_config.head_size == expected["head_size"]
|
||||
assert model_config.get_head_size() == expected["head_size"]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user