[BugFix][Spec Decode] Fix out-of-range index triggered by eagle3; re-enable test for LlamaForCausalLMEagle3 (#24392)

Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
This commit is contained in:
Wenlong Wang 2025-09-09 21:24:23 -07:00 committed by GitHub
parent 309d7aa401
commit 53b42f4102
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 58 additions and 41 deletions

View File

@ -602,11 +602,10 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
trust_remote_code=True,
speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
tokenizer="meta-llama/Llama-3.1-8B-Instruct"),
# TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501
# "LlamaForCausalLMEagle3": _HfExamplesInfo("AngelSlim/Qwen3-8B_eagle3", # noqa: E501
# trust_remote_code=True,
# speculative_model="AngelSlim/Qwen3-8B_eagle3", # noqa: E501
# tokenizer="Qwen/Qwen3-8B"),
"LlamaForCausalLMEagle3": _HfExamplesInfo("AngelSlim/Qwen3-8B_eagle3", # noqa: E501
trust_remote_code=True,
speculative_model="AngelSlim/Qwen3-8B_eagle3", # noqa: E501
tokenizer="Qwen/Qwen3-8B"),
"EagleLlama4ForCausalLM": _HfExamplesInfo(
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
trust_remote_code=True,

View File

@ -125,11 +125,8 @@ def test_ngram_correctness(
cleanup_dist_env_and_memory()
@pytest.mark.parametrize(
["model_setup", "mm_enabled"],
[
# TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501
# (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
@pytest.mark.parametrize(["model_setup", "mm_enabled"], [
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
@ -148,12 +145,8 @@ def test_ngram_correctness(
"eagle618/eagle-deepseek-v3-random", 1), False),
],
ids=[
# TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501
# "qwen3_eagle3",
"llama3_eagle",
"llama3_eagle3",
"llama4_eagle",
"llama4_eagle_mm",
"qwen3_eagle3", "llama3_eagle", "llama3_eagle3",
"llama4_eagle", "llama4_eagle_mm",
"deepseek_eagle"
])
@pytest.mark.parametrize("attn_backend",

View File

@ -2191,9 +2191,14 @@ class SpeculativeConfig:
# Automatically detect the method
if self.method in ('eagle', 'eagle3'):
pass
elif "eagle-" in self.draft_model_config.model.lower() or \
"eagle3-" in self.draft_model_config.model.lower():
# examples:
# yuhuili/EAGLE-LLaMA3-Instruct-8B
# yuhuili/EAGLE3-LLaMA3.1-Instruct-8B
# AngelSlim/Qwen3-8B_eagle3
elif "eagle-" in self.draft_model_config.model.lower():
self.method = "eagle"
elif "eagle3" in self.draft_model_config.model.lower():
self.method = "eagle3"
elif self.draft_model_config.hf_config.model_type == "medusa":
self.method = "medusa"
elif (self.draft_model_config.hf_config.model_type ==

View File

@ -171,7 +171,22 @@ class LlamaAttention(nn.Module):
sliding_window = None
if layer_types := getattr(config, "layer_types", None):
is_sliding = layer_types[layer_idx] == "sliding_attention"
# Fix for Eagle3 compatibility:
# for draft models, subtract target layer count
# to get draft-relative layer index starting from 0
if hasattr(config, 'target_layer_count'):
# This is a draft model,
# adjust layer_idx to be relative to draft layers
effective_layer_idx = layer_idx - config.target_layer_count
else:
# This is a target model, use layer_idx directly
effective_layer_idx = layer_idx
assert effective_layer_idx < len(layer_types), \
f"effective_layer_idx: {effective_layer_idx} \
is out of bounds for layer_types: {layer_types}"
is_sliding = layer_types[
effective_layer_idx] == "sliding_attention"
if is_sliding:
sliding_window = config.sliding_window

View File

@ -199,6 +199,10 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
speculative_config.draft_model_config.hf_config
target_layer_num = vllm_config.model_config.get_num_layers(
vllm_config.parallel_config)
# Store target layer count in draft config for
# proper layer_types indexing in draft models
self.config.target_layer_count = target_layer_num
self.model = LlamaModel(vllm_config=vllm_config,
prefix="model",
start_layer_id=target_layer_num)

View File

@ -277,8 +277,7 @@ _SPECULATIVE_DECODING_MODELS = {
"EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
"EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
"Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
# TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501
# "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
"LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
"EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
"ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),

View File

@ -46,6 +46,7 @@ class EAGLEConfig(PretrainedConfig):
# Eagle model name should follow naming convention of
# LlamaForCausalLM -> EagleLlamaForCausalLM
# LlamaForCausalLM -> Eagle3LlamaForCausalLM
# LlamaForCausalLMEagle3 -> LlamaForCausalLMEagle3
if method == "eagle":
assert self.model is not None, \
"model should not be None when method is eagle"
@ -53,6 +54,7 @@ class EAGLEConfig(PretrainedConfig):
f"Eagle{arch}" if not arch.startswith("Eagle") \
else arch for arch in self.model.architectures
]
elif method == "eagle3":
assert self.model is not None, \
"model should not be None when method is eagle3"