From 53b42f4102dadd77e16a439b282addab19e2a625 Mon Sep 17 00:00:00 2001 From: Wenlong Wang Date: Tue, 9 Sep 2025 21:24:23 -0700 Subject: [PATCH] [BugFix][Spec Decode] Fix out-of-range index triggered by eagle3; re-enable test for LlamaForCausalLMEagle3 (#24392) Signed-off-by: wwl2755 --- tests/models/registry.py | 9 ++-- tests/v1/e2e/test_spec_decode.py | 55 ++++++++++------------ vllm/config/__init__.py | 9 +++- vllm/model_executor/models/llama.py | 17 ++++++- vllm/model_executor/models/llama_eagle3.py | 4 ++ vllm/model_executor/models/registry.py | 3 +- vllm/transformers_utils/configs/eagle.py | 2 + 7 files changed, 58 insertions(+), 41 deletions(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index 755a37b109d7..a5e83bc11f14 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -602,11 +602,10 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { trust_remote_code=True, speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", tokenizer="meta-llama/Llama-3.1-8B-Instruct"), - # TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501 - # "LlamaForCausalLMEagle3": _HfExamplesInfo("AngelSlim/Qwen3-8B_eagle3", # noqa: E501 - # trust_remote_code=True, - # speculative_model="AngelSlim/Qwen3-8B_eagle3", # noqa: E501 - # tokenizer="Qwen/Qwen3-8B"), + "LlamaForCausalLMEagle3": _HfExamplesInfo("AngelSlim/Qwen3-8B_eagle3", # noqa: E501 + trust_remote_code=True, + speculative_model="AngelSlim/Qwen3-8B_eagle3", # noqa: E501 + tokenizer="Qwen/Qwen3-8B"), "EagleLlama4ForCausalLM": _HfExamplesInfo( "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", trust_remote_code=True, diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 6848f204358c..469464f777fb 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -125,37 +125,30 @@ def test_ngram_correctness( cleanup_dist_env_and_memory() -@pytest.mark.parametrize( - ["model_setup", "mm_enabled"], - [ - # TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501 - # (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False), - (("eagle", "meta-llama/Llama-3.1-8B-Instruct", - "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False), - (("eagle3", "meta-llama/Llama-3.1-8B-Instruct", - "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False), - pytest.param( - ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", - "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), - False, - marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), - pytest.param( - ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", - "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), - True, - marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), - (("eagle", "eagle618/deepseek-v3-random", - "eagle618/eagle-deepseek-v3-random", 1), False), - ], - ids=[ - # TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501 - # "qwen3_eagle3", - "llama3_eagle", - "llama3_eagle3", - "llama4_eagle", - "llama4_eagle_mm", - "deepseek_eagle" - ]) +@pytest.mark.parametrize(["model_setup", "mm_enabled"], [ + (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False), + (("eagle", "meta-llama/Llama-3.1-8B-Instruct", + "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False), + (("eagle3", "meta-llama/Llama-3.1-8B-Instruct", + "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False), + pytest.param( + ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), + False, + marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), + pytest.param( + ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), + True, + marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), + (("eagle", "eagle618/deepseek-v3-random", + "eagle618/eagle-deepseek-v3-random", 1), False), +], + ids=[ + "qwen3_eagle3", "llama3_eagle", "llama3_eagle3", + "llama4_eagle", "llama4_eagle_mm", + "deepseek_eagle" + ]) @pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform()) def test_eagle_correctness( diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 4f4673ac6e67..651cd7339101 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -2191,9 +2191,14 @@ class SpeculativeConfig: # Automatically detect the method if self.method in ('eagle', 'eagle3'): pass - elif "eagle-" in self.draft_model_config.model.lower() or \ - "eagle3-" in self.draft_model_config.model.lower(): + # examples: + # yuhuili/EAGLE-LLaMA3-Instruct-8B + # yuhuili/EAGLE3-LLaMA3.1-Instruct-8B + # AngelSlim/Qwen3-8B_eagle3 + elif "eagle-" in self.draft_model_config.model.lower(): self.method = "eagle" + elif "eagle3" in self.draft_model_config.model.lower(): + self.method = "eagle3" elif self.draft_model_config.hf_config.model_type == "medusa": self.method = "medusa" elif (self.draft_model_config.hf_config.model_type == diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index a22bde194f5d..96530562b072 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -171,7 +171,22 @@ class LlamaAttention(nn.Module): sliding_window = None if layer_types := getattr(config, "layer_types", None): - is_sliding = layer_types[layer_idx] == "sliding_attention" + # Fix for Eagle3 compatibility: + # for draft models, subtract target layer count + # to get draft-relative layer index starting from 0 + if hasattr(config, 'target_layer_count'): + # This is a draft model, + # adjust layer_idx to be relative to draft layers + effective_layer_idx = layer_idx - config.target_layer_count + else: + # This is a target model, use layer_idx directly + effective_layer_idx = layer_idx + assert effective_layer_idx < len(layer_types), \ + f"effective_layer_idx: {effective_layer_idx} \ + is out of bounds for layer_types: {layer_types}" + + is_sliding = layer_types[ + effective_layer_idx] == "sliding_attention" if is_sliding: sliding_window = config.sliding_window diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 572930c39a84..bceb6cc42768 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -199,6 +199,10 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): speculative_config.draft_model_config.hf_config target_layer_num = vllm_config.model_config.get_num_layers( vllm_config.parallel_config) + + # Store target layer count in draft config for + # proper layer_types indexing in draft models + self.config.target_layer_count = target_layer_num self.model = LlamaModel(vllm_config=vllm_config, prefix="model", start_layer_id=target_layer_num) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 43075956b450..ed7797ce9d4f 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -277,8 +277,7 @@ _SPECULATIVE_DECODING_MODELS = { "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"), "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"), "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), - # TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501 - # "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"), + "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"), "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"), "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"), diff --git a/vllm/transformers_utils/configs/eagle.py b/vllm/transformers_utils/configs/eagle.py index 6aabf9e5262e..444ed70de3d0 100644 --- a/vllm/transformers_utils/configs/eagle.py +++ b/vllm/transformers_utils/configs/eagle.py @@ -46,6 +46,7 @@ class EAGLEConfig(PretrainedConfig): # Eagle model name should follow naming convention of # LlamaForCausalLM -> EagleLlamaForCausalLM # LlamaForCausalLM -> Eagle3LlamaForCausalLM + # LlamaForCausalLMEagle3 -> LlamaForCausalLMEagle3 if method == "eagle": assert self.model is not None, \ "model should not be None when method is eagle" @@ -53,6 +54,7 @@ class EAGLEConfig(PretrainedConfig): f"Eagle{arch}" if not arch.startswith("Eagle") \ else arch for arch in self.model.architectures ] + elif method == "eagle3": assert self.model is not None, \ "model should not be None when method is eagle3"