[Bug] [Spec Decode] Fix model_initialization test and mismatch in aux_hidden_layers (#24613)

Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
Signed-off-by: Roger Wang <hey@rogerw.io>
Signed-off-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
Wenlong Wang 2025-09-10 21:23:18 -07:00 committed by GitHub
parent 55b823ba0f
commit 6c8deacd72
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 32 additions and 11 deletions

View File

@ -97,6 +97,12 @@ class _HfExamplesInfo:
max_num_seqs: Optional[int] = None
"""Maximum number of sequences to be processed in a single iteration."""
use_original_num_layers: bool = False
"""
If True, use the original number of layers from the model config
instead of minimal layers for testing.
"""
def check_transformers_version(
self,
*,
@ -597,18 +603,21 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
"EagleDeepSeekMTPModel": _HfExamplesInfo("eagle618/deepseek-v3-random",
speculative_model="eagle618/eagle-deepseek-v3-random", # noqa: E501
trust_remote_code=True),
"EagleLlamaForCausalLM": _HfExamplesInfo("yuhuili/EAGLE-LLaMA3-Instruct-8B",
"EagleLlamaForCausalLM": _HfExamplesInfo("meta-llama/Meta-Llama-3-8B-Instruct", # noqa: E501
trust_remote_code=True,
speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B",
tokenizer="meta-llama/Meta-Llama-3-8B-Instruct"), # noqa: E501
"Eagle3LlamaForCausalLM": _HfExamplesInfo("yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", # noqa: E501
tokenizer="meta-llama/Meta-Llama-3-8B-Instruct"), # noqa: E501
"Eagle3LlamaForCausalLM": _HfExamplesInfo("meta-llama/Llama-3.1-8B-Instruct", # noqa: E501
trust_remote_code=True,
speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
tokenizer="meta-llama/Llama-3.1-8B-Instruct"),
"LlamaForCausalLMEagle3": _HfExamplesInfo("AngelSlim/Qwen3-8B_eagle3", # noqa: E501
speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", # noqa: E501
tokenizer="meta-llama/Llama-3.1-8B-Instruct",
use_original_num_layers=True,
max_model_len=10240),
"LlamaForCausalLMEagle3": _HfExamplesInfo("Qwen/Qwen3-8B", # noqa: E501
trust_remote_code=True,
speculative_model="AngelSlim/Qwen3-8B_eagle3", # noqa: E501
tokenizer="Qwen/Qwen3-8B"),
tokenizer="Qwen/Qwen3-8B",
use_original_num_layers=True),
"EagleLlama4ForCausalLM": _HfExamplesInfo(
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
trust_remote_code=True,

View File

@ -36,7 +36,10 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
hf_overrides_fn = partial(dummy_hf_overrides,
model_arch=model_arch,
exist_overrides=model_info.hf_overrides)
exist_overrides=model_info.hf_overrides,
use_original_num_layers=getattr(
model_info, 'use_original_num_layers',
False))
# Avoid calling model.forward()
def _initialize_kv_caches_v0(self) -> None:

View File

@ -396,6 +396,7 @@ def dummy_hf_overrides(
*,
model_arch: str = "",
exist_overrides: Optional[dict[str, Any]] = None,
use_original_num_layers: bool = False,
) -> PretrainedConfig:
"""
Dummy HF overrides function used to create dummy model
@ -412,10 +413,18 @@ def dummy_hf_overrides(
# we use three layers for Gemma-3n to check
# both normal layer and kv_shared_layer
num_hidden_layers = (3 if model_arch == "Gemma3nForConditionalGeneration"
else 1)
if use_original_num_layers:
# Use the original number of layers from the config
num_layers = getattr(text_config, 'num_layers', 1)
num_hidden_layers = getattr(text_config, 'num_hidden_layers', 1)
else:
# Use minimal layers for testing
num_layers = 1
num_hidden_layers = (3 if model_arch
== "Gemma3nForConditionalGeneration" else 1)
text_config.update({
"num_layers": 1,
"num_layers": num_layers,
"num_hidden_layers": num_hidden_layers,
"num_experts": num_experts,
"num_experts_per_tok": 2,