diff --git a/tests/config/model_arch_groundtruth.json b/tests/config/base_model_arch_groundtruth.json similarity index 100% rename from tests/config/model_arch_groundtruth.json rename to tests/config/base_model_arch_groundtruth.json diff --git a/tests/config/draft_model_arch_groundtruth.json b/tests/config/draft_model_arch_groundtruth.json new file mode 100644 index 0000000000000..dfe6f3d39e93b --- /dev/null +++ b/tests/config/draft_model_arch_groundtruth.json @@ -0,0 +1,87 @@ +{ + "abhigoyal/vllm-medusa-llama-68m-random": { + "architectures": [ + "MedusaModel" + ], + "model_type": "medusa", + "text_model_type": "medusa", + "hidden_size": 768, + "total_num_hidden_layers": 1, + "total_num_attention_heads": 0, + "head_size": "Error: integer division or modulo by zero", + "vocab_size": 32000, + "total_num_kv_heads": 0, + "num_experts": 0, + "is_deepseek_mla": false, + "is_multimodal_model": false, + "dtype": "torch.float32" + }, + "luccafong/deepseek_mtp_draft_random": { + "architectures": [ + "DeepSeekMTPModel" + ], + "model_type": "deepseek_mtp", + "text_model_type": "deepseek_mtp", + "hidden_size": 2560, + "total_num_hidden_layers": 1, + "total_num_attention_heads": 32, + "head_size": 576, + "vocab_size": 129280, + "total_num_kv_heads": 32, + "num_experts": 72, + "is_deepseek_mla": true, + "is_multimodal_model": false, + "dtype": "torch.bfloat16" + }, + "eagle618/eagle-deepseek-v3-random": { + "architectures": [ + "EagleDeepSeekMTPModel" + ], + "model_type": "eagle", + "text_model_type": "deepseek_mtp", + "hidden_size": 2560, + "total_num_hidden_layers": 1, + "total_num_attention_heads": 32, + "head_size": 576, + "vocab_size": 129280, + "total_num_kv_heads": 32, + "num_experts": 72, + "is_deepseek_mla": true, + "is_multimodal_model": false, + "dtype": "bfloat16" + }, + "yuhuili/EAGLE-LLaMA3-Instruct-8B": { + "architectures": [ + "EagleLlamaForCausalLM" + ], + "model_type": "eagle", + "text_model_type": "llama", + "hidden_size": 4096, + "total_num_hidden_layers": 1, + "total_num_attention_heads": 32, + "head_size": 128, + "vocab_size": 128256, + "total_num_kv_heads": 8, + "num_experts": 0, + "is_deepseek_mla": false, + "is_multimodal_model": false, + "dtype": "float16" + }, + "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B": { + "architectures": [ + "Eagle3LlamaForCausalLM" + ], + "model_type": "eagle", + "text_model_type": "llama", + "hidden_size": 4096, + "total_num_hidden_layers": 1, + "total_num_attention_heads": 32, + "head_size": 128, + "vocab_size": 128256, + "total_num_kv_heads": 8, + "num_experts": 0, + "is_deepseek_mla": false, + "is_multimodal_model": false, + "dtype": "float16" + } +} diff --git a/tests/config/test_model_arch_config.py b/tests/config/test_model_arch_config.py index b024b5ebec83e..6419461b930ce 100644 --- a/tests/config/test_model_arch_config.py +++ b/tests/config/test_model_arch_config.py @@ -3,10 +3,10 @@ import json from pathlib import Path -from vllm.config import ModelConfig +from vllm.config import ModelConfig, SpeculativeConfig, ParallelConfig -def test_model_arch_config(): +def test_basic(): trust_remote_code_models = [ "nvidia/Llama-3_3-Nemotron-Super-49B-v1", "XiaomiMiMo/MiMo-7B-RL", @@ -38,7 +38,7 @@ def test_model_arch_config(): "meta-llama/Llama-4-Scout-17B-16E-Instruct", ] + trust_remote_code_models - groundtruth_path = Path(__file__).parent / "model_arch_groundtruth.json" + groundtruth_path = Path(__file__).parent / "base_model_arch_groundtruth.json" with open(groundtruth_path) as f: model_arch_groundtruth = json.load(f) @@ -81,3 +81,71 @@ def test_model_arch_config(): model_config.get_total_num_hidden_layers() == expected["total_num_hidden_layers"] ) + + +def test_draft_models(): + speculative_models = [ + ("JackFram/llama-68m", "abhigoyal/vllm-medusa-llama-68m-random", False), + ("luccafong/deepseek_mtp_main_random", "luccafong/deepseek_mtp_draft_random", True), + ("eagle618/deepseek-v3-random", "eagle618/eagle-deepseek-v3-random", True), + ("meta-llama/Meta-Llama-3-8B-Instruct", "yuhuili/EAGLE-LLaMA3-Instruct-8B", True), + ("meta-llama/Llama-3.1-8B-Instruct", "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", True), + ] + + groundtruth_path = Path(__file__).parent / "draft_model_arch_groundtruth.json" + with open(groundtruth_path) as f: + model_arch_groundtruth = json.load(f) + + for target_model, draft_model, trust_remote_code in speculative_models: + print(f"testing {target_model=} {draft_model=}") + target_model_config = ModelConfig( + target_model, trust_remote_code=trust_remote_code + ) + speculative_config = { + "model": draft_model, + "num_speculative_tokens": 1, + "target_model_config": target_model_config, + "target_parallel_config": ParallelConfig(), + } + + speculative_config = SpeculativeConfig(**speculative_config) + model_config = speculative_config.draft_model_config + + model_arch_config = model_config.model_arch_config + expected = model_arch_groundtruth[draft_model] + assert model_arch_config.architectures == expected["architectures"] + assert model_arch_config.model_type == expected["model_type"] + assert model_arch_config.text_model_type == expected["text_model_type"] + assert model_arch_config.hidden_size == expected["hidden_size"] + assert ( + model_arch_config.total_num_hidden_layers + == expected["total_num_hidden_layers"] + ) + assert ( + model_arch_config.total_num_attention_heads + == expected["total_num_attention_heads"] + ) + + assert model_arch_config.vocab_size == expected["vocab_size"] + assert model_arch_config.total_num_kv_heads == expected["total_num_kv_heads"] + assert model_arch_config.num_experts == expected["num_experts"] + assert model_arch_config.is_deepseek_mla == expected["is_deepseek_mla"] + dtype = model_arch_config.torch_dtype + assert str(dtype) == expected["dtype"] + + # Ensure model_config methods return expected values + assert model_config.architectures == expected["architectures"] + assert model_config.get_vocab_size() == expected["vocab_size"] + assert model_config.get_hidden_size() == expected["hidden_size"] + assert model_config.get_total_num_kv_heads() == expected["total_num_kv_heads"] + assert model_config.get_num_experts() == expected["num_experts"] + assert ( + model_config.get_total_num_hidden_layers() + == expected["total_num_hidden_layers"] + ) + + if isinstance(expected["head_size"], int): + # Before model_arch_config is introduced, get_head_size() for medusa + # model config will throw out `integer division or modulo by zero` error. + assert model_arch_config.head_size == expected["head_size"] + assert model_config.get_head_size() == expected["head_size"] diff --git a/vllm/config/model.py b/vllm/config/model.py index 370f5c9b11935..5ed08bdbe949f 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -716,7 +716,7 @@ class ModelConfig: convertor_cls = MODEL_ARCH_CONFIG_CONVERTORS.get( self.hf_config.model_type, ModelArchConfigConvertorBase ) - convertor = convertor_cls(self.hf_config) + convertor = convertor_cls(self.hf_config, self.hf_text_config) return convertor.convert(self.model, self.revision) @field_validator("tokenizer_mode", mode="after") diff --git a/vllm/transformers_utils/model_arch_config_convertor.py b/vllm/transformers_utils/model_arch_config_convertor.py index ed6ba0adb5e20..d785dce3d32e5 100644 --- a/vllm/transformers_utils/model_arch_config_convertor.py +++ b/vllm/transformers_utils/model_arch_config_convertor.py @@ -14,7 +14,6 @@ from vllm.config.model_arch import ( from vllm.config.utils import getattr_iter from vllm.logger import init_logger from vllm.transformers_utils.config import ( - get_hf_text_config, try_get_safetensors_metadata, ) from vllm.utils.torch_utils import common_broadcastable_dtype @@ -23,9 +22,9 @@ logger = init_logger(__name__) class ModelArchConfigConvertorBase: - def __init__(self, hf_config: PretrainedConfig): + def __init__(self, hf_config: PretrainedConfig, hf_text_config: PretrainedConfig): self.hf_config = hf_config - self.hf_text_config = get_hf_text_config(hf_config) + self.hf_text_config = hf_text_config def get_num_hidden_layers(self) -> int: return getattr(self.hf_text_config, "num_hidden_layers", 0)