[BugFix][Spec Decode] Fix out-of-range index triggered by eagle3; re-enable test for LlamaForCausalLMEagle3 (#24392)

Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
This commit is contained in:
Wenlong Wang 2025-09-09 21:24:23 -07:00 committed by GitHub
parent 309d7aa401
commit 53b42f4102
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 58 additions and 41 deletions

View File

@ -602,11 +602,10 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
trust_remote_code=True, trust_remote_code=True,
speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
tokenizer="meta-llama/Llama-3.1-8B-Instruct"), tokenizer="meta-llama/Llama-3.1-8B-Instruct"),
# TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501 "LlamaForCausalLMEagle3": _HfExamplesInfo("AngelSlim/Qwen3-8B_eagle3", # noqa: E501
# "LlamaForCausalLMEagle3": _HfExamplesInfo("AngelSlim/Qwen3-8B_eagle3", # noqa: E501 trust_remote_code=True,
# trust_remote_code=True, speculative_model="AngelSlim/Qwen3-8B_eagle3", # noqa: E501
# speculative_model="AngelSlim/Qwen3-8B_eagle3", # noqa: E501 tokenizer="Qwen/Qwen3-8B"),
# tokenizer="Qwen/Qwen3-8B"),
"EagleLlama4ForCausalLM": _HfExamplesInfo( "EagleLlama4ForCausalLM": _HfExamplesInfo(
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
trust_remote_code=True, trust_remote_code=True,

View File

@ -125,37 +125,30 @@ def test_ngram_correctness(
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
@pytest.mark.parametrize( @pytest.mark.parametrize(["model_setup", "mm_enabled"], [
["model_setup", "mm_enabled"], (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
[ (("eagle", "meta-llama/Llama-3.1-8B-Instruct",
# TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501 "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
# (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False), (("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
(("eagle", "meta-llama/Llama-3.1-8B-Instruct", "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False), pytest.param(
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct", ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False), "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
pytest.param( False,
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), pytest.param(
False, ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
pytest.param( True,
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), (("eagle", "eagle618/deepseek-v3-random",
True, "eagle618/eagle-deepseek-v3-random", 1), False),
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), ],
(("eagle", "eagle618/deepseek-v3-random", ids=[
"eagle618/eagle-deepseek-v3-random", 1), False), "qwen3_eagle3", "llama3_eagle", "llama3_eagle3",
], "llama4_eagle", "llama4_eagle_mm",
ids=[ "deepseek_eagle"
# TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501 ])
# "qwen3_eagle3",
"llama3_eagle",
"llama3_eagle3",
"llama4_eagle",
"llama4_eagle_mm",
"deepseek_eagle"
])
@pytest.mark.parametrize("attn_backend", @pytest.mark.parametrize("attn_backend",
get_attn_backend_list_based_on_platform()) get_attn_backend_list_based_on_platform())
def test_eagle_correctness( def test_eagle_correctness(

View File

@ -2191,9 +2191,14 @@ class SpeculativeConfig:
# Automatically detect the method # Automatically detect the method
if self.method in ('eagle', 'eagle3'): if self.method in ('eagle', 'eagle3'):
pass pass
elif "eagle-" in self.draft_model_config.model.lower() or \ # examples:
"eagle3-" in self.draft_model_config.model.lower(): # yuhuili/EAGLE-LLaMA3-Instruct-8B
# yuhuili/EAGLE3-LLaMA3.1-Instruct-8B
# AngelSlim/Qwen3-8B_eagle3
elif "eagle-" in self.draft_model_config.model.lower():
self.method = "eagle" self.method = "eagle"
elif "eagle3" in self.draft_model_config.model.lower():
self.method = "eagle3"
elif self.draft_model_config.hf_config.model_type == "medusa": elif self.draft_model_config.hf_config.model_type == "medusa":
self.method = "medusa" self.method = "medusa"
elif (self.draft_model_config.hf_config.model_type == elif (self.draft_model_config.hf_config.model_type ==

View File

@ -171,7 +171,22 @@ class LlamaAttention(nn.Module):
sliding_window = None sliding_window = None
if layer_types := getattr(config, "layer_types", None): if layer_types := getattr(config, "layer_types", None):
is_sliding = layer_types[layer_idx] == "sliding_attention" # Fix for Eagle3 compatibility:
# for draft models, subtract target layer count
# to get draft-relative layer index starting from 0
if hasattr(config, 'target_layer_count'):
# This is a draft model,
# adjust layer_idx to be relative to draft layers
effective_layer_idx = layer_idx - config.target_layer_count
else:
# This is a target model, use layer_idx directly
effective_layer_idx = layer_idx
assert effective_layer_idx < len(layer_types), \
f"effective_layer_idx: {effective_layer_idx} \
is out of bounds for layer_types: {layer_types}"
is_sliding = layer_types[
effective_layer_idx] == "sliding_attention"
if is_sliding: if is_sliding:
sliding_window = config.sliding_window sliding_window = config.sliding_window

View File

@ -199,6 +199,10 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
speculative_config.draft_model_config.hf_config speculative_config.draft_model_config.hf_config
target_layer_num = vllm_config.model_config.get_num_layers( target_layer_num = vllm_config.model_config.get_num_layers(
vllm_config.parallel_config) vllm_config.parallel_config)
# Store target layer count in draft config for
# proper layer_types indexing in draft models
self.config.target_layer_count = target_layer_num
self.model = LlamaModel(vllm_config=vllm_config, self.model = LlamaModel(vllm_config=vllm_config,
prefix="model", prefix="model",
start_layer_id=target_layer_num) start_layer_id=target_layer_num)

View File

@ -277,8 +277,7 @@ _SPECULATIVE_DECODING_MODELS = {
"EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"), "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
"EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"), "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
"Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
# TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501 "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
# "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
"EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"), "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
"ErnieMTPModel": ("ernie_mtp", "ErnieMTP"), "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),

View File

@ -46,6 +46,7 @@ class EAGLEConfig(PretrainedConfig):
# Eagle model name should follow naming convention of # Eagle model name should follow naming convention of
# LlamaForCausalLM -> EagleLlamaForCausalLM # LlamaForCausalLM -> EagleLlamaForCausalLM
# LlamaForCausalLM -> Eagle3LlamaForCausalLM # LlamaForCausalLM -> Eagle3LlamaForCausalLM
# LlamaForCausalLMEagle3 -> LlamaForCausalLMEagle3
if method == "eagle": if method == "eagle":
assert self.model is not None, \ assert self.model is not None, \
"model should not be None when method is eagle" "model should not be None when method is eagle"
@ -53,6 +54,7 @@ class EAGLEConfig(PretrainedConfig):
f"Eagle{arch}" if not arch.startswith("Eagle") \ f"Eagle{arch}" if not arch.startswith("Eagle") \
else arch for arch in self.model.architectures else arch for arch in self.model.architectures
] ]
elif method == "eagle3": elif method == "eagle3":
assert self.model is not None, \ assert self.model is not None, \
"model should not be None when method is eagle3" "model should not be None when method is eagle3"