mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:15:26 +08:00
[BugFix][Spec Decode] Fix out-of-range index triggered by eagle3; re-enable test for LlamaForCausalLMEagle3 (#24392)
Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
This commit is contained in:
parent
309d7aa401
commit
53b42f4102
@ -602,11 +602,10 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
||||
trust_remote_code=True,
|
||||
speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
|
||||
tokenizer="meta-llama/Llama-3.1-8B-Instruct"),
|
||||
# TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501
|
||||
# "LlamaForCausalLMEagle3": _HfExamplesInfo("AngelSlim/Qwen3-8B_eagle3", # noqa: E501
|
||||
# trust_remote_code=True,
|
||||
# speculative_model="AngelSlim/Qwen3-8B_eagle3", # noqa: E501
|
||||
# tokenizer="Qwen/Qwen3-8B"),
|
||||
"LlamaForCausalLMEagle3": _HfExamplesInfo("AngelSlim/Qwen3-8B_eagle3", # noqa: E501
|
||||
trust_remote_code=True,
|
||||
speculative_model="AngelSlim/Qwen3-8B_eagle3", # noqa: E501
|
||||
tokenizer="Qwen/Qwen3-8B"),
|
||||
"EagleLlama4ForCausalLM": _HfExamplesInfo(
|
||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
|
||||
trust_remote_code=True,
|
||||
|
||||
@ -125,37 +125,30 @@ def test_ngram_correctness(
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["model_setup", "mm_enabled"],
|
||||
[
|
||||
# TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501
|
||||
# (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
|
||||
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
|
||||
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
|
||||
pytest.param(
|
||||
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
|
||||
False,
|
||||
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
|
||||
pytest.param(
|
||||
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
|
||||
True,
|
||||
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
|
||||
(("eagle", "eagle618/deepseek-v3-random",
|
||||
"eagle618/eagle-deepseek-v3-random", 1), False),
|
||||
],
|
||||
ids=[
|
||||
# TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501
|
||||
# "qwen3_eagle3",
|
||||
"llama3_eagle",
|
||||
"llama3_eagle3",
|
||||
"llama4_eagle",
|
||||
"llama4_eagle_mm",
|
||||
"deepseek_eagle"
|
||||
])
|
||||
@pytest.mark.parametrize(["model_setup", "mm_enabled"], [
|
||||
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
|
||||
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
|
||||
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
|
||||
pytest.param(
|
||||
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
|
||||
False,
|
||||
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
|
||||
pytest.param(
|
||||
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
|
||||
True,
|
||||
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
|
||||
(("eagle", "eagle618/deepseek-v3-random",
|
||||
"eagle618/eagle-deepseek-v3-random", 1), False),
|
||||
],
|
||||
ids=[
|
||||
"qwen3_eagle3", "llama3_eagle", "llama3_eagle3",
|
||||
"llama4_eagle", "llama4_eagle_mm",
|
||||
"deepseek_eagle"
|
||||
])
|
||||
@pytest.mark.parametrize("attn_backend",
|
||||
get_attn_backend_list_based_on_platform())
|
||||
def test_eagle_correctness(
|
||||
|
||||
@ -2191,9 +2191,14 @@ class SpeculativeConfig:
|
||||
# Automatically detect the method
|
||||
if self.method in ('eagle', 'eagle3'):
|
||||
pass
|
||||
elif "eagle-" in self.draft_model_config.model.lower() or \
|
||||
"eagle3-" in self.draft_model_config.model.lower():
|
||||
# examples:
|
||||
# yuhuili/EAGLE-LLaMA3-Instruct-8B
|
||||
# yuhuili/EAGLE3-LLaMA3.1-Instruct-8B
|
||||
# AngelSlim/Qwen3-8B_eagle3
|
||||
elif "eagle-" in self.draft_model_config.model.lower():
|
||||
self.method = "eagle"
|
||||
elif "eagle3" in self.draft_model_config.model.lower():
|
||||
self.method = "eagle3"
|
||||
elif self.draft_model_config.hf_config.model_type == "medusa":
|
||||
self.method = "medusa"
|
||||
elif (self.draft_model_config.hf_config.model_type ==
|
||||
|
||||
@ -171,7 +171,22 @@ class LlamaAttention(nn.Module):
|
||||
|
||||
sliding_window = None
|
||||
if layer_types := getattr(config, "layer_types", None):
|
||||
is_sliding = layer_types[layer_idx] == "sliding_attention"
|
||||
# Fix for Eagle3 compatibility:
|
||||
# for draft models, subtract target layer count
|
||||
# to get draft-relative layer index starting from 0
|
||||
if hasattr(config, 'target_layer_count'):
|
||||
# This is a draft model,
|
||||
# adjust layer_idx to be relative to draft layers
|
||||
effective_layer_idx = layer_idx - config.target_layer_count
|
||||
else:
|
||||
# This is a target model, use layer_idx directly
|
||||
effective_layer_idx = layer_idx
|
||||
assert effective_layer_idx < len(layer_types), \
|
||||
f"effective_layer_idx: {effective_layer_idx} \
|
||||
is out of bounds for layer_types: {layer_types}"
|
||||
|
||||
is_sliding = layer_types[
|
||||
effective_layer_idx] == "sliding_attention"
|
||||
if is_sliding:
|
||||
sliding_window = config.sliding_window
|
||||
|
||||
|
||||
@ -199,6 +199,10 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
||||
speculative_config.draft_model_config.hf_config
|
||||
target_layer_num = vllm_config.model_config.get_num_layers(
|
||||
vllm_config.parallel_config)
|
||||
|
||||
# Store target layer count in draft config for
|
||||
# proper layer_types indexing in draft models
|
||||
self.config.target_layer_count = target_layer_num
|
||||
self.model = LlamaModel(vllm_config=vllm_config,
|
||||
prefix="model",
|
||||
start_layer_id=target_layer_num)
|
||||
|
||||
@ -277,8 +277,7 @@ _SPECULATIVE_DECODING_MODELS = {
|
||||
"EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
|
||||
"EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
|
||||
"Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
|
||||
# TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501
|
||||
# "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
|
||||
"LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
|
||||
"EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
|
||||
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
|
||||
"ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
|
||||
|
||||
@ -46,6 +46,7 @@ class EAGLEConfig(PretrainedConfig):
|
||||
# Eagle model name should follow naming convention of
|
||||
# LlamaForCausalLM -> EagleLlamaForCausalLM
|
||||
# LlamaForCausalLM -> Eagle3LlamaForCausalLM
|
||||
# LlamaForCausalLMEagle3 -> LlamaForCausalLMEagle3
|
||||
if method == "eagle":
|
||||
assert self.model is not None, \
|
||||
"model should not be None when method is eagle"
|
||||
@ -53,6 +54,7 @@ class EAGLEConfig(PretrainedConfig):
|
||||
f"Eagle{arch}" if not arch.startswith("Eagle") \
|
||||
else arch for arch in self.model.architectures
|
||||
]
|
||||
|
||||
elif method == "eagle3":
|
||||
assert self.model is not None, \
|
||||
"model should not be None when method is eagle3"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user