mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 11:06:08 +08:00
[BugFix][Spec Decode] Fix out-of-range index triggered by eagle3; re-enable test for LlamaForCausalLMEagle3 (#24392)
Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
This commit is contained in:
parent
309d7aa401
commit
53b42f4102
@ -602,11 +602,10 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
|||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
|
speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
|
||||||
tokenizer="meta-llama/Llama-3.1-8B-Instruct"),
|
tokenizer="meta-llama/Llama-3.1-8B-Instruct"),
|
||||||
# TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501
|
"LlamaForCausalLMEagle3": _HfExamplesInfo("AngelSlim/Qwen3-8B_eagle3", # noqa: E501
|
||||||
# "LlamaForCausalLMEagle3": _HfExamplesInfo("AngelSlim/Qwen3-8B_eagle3", # noqa: E501
|
trust_remote_code=True,
|
||||||
# trust_remote_code=True,
|
speculative_model="AngelSlim/Qwen3-8B_eagle3", # noqa: E501
|
||||||
# speculative_model="AngelSlim/Qwen3-8B_eagle3", # noqa: E501
|
tokenizer="Qwen/Qwen3-8B"),
|
||||||
# tokenizer="Qwen/Qwen3-8B"),
|
|
||||||
"EagleLlama4ForCausalLM": _HfExamplesInfo(
|
"EagleLlama4ForCausalLM": _HfExamplesInfo(
|
||||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
|
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
|
|||||||
@ -125,37 +125,30 @@ def test_ngram_correctness(
|
|||||||
cleanup_dist_env_and_memory()
|
cleanup_dist_env_and_memory()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(["model_setup", "mm_enabled"], [
|
||||||
["model_setup", "mm_enabled"],
|
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
|
||||||
[
|
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
# TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501
|
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
|
||||||
# (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
|
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
|
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
|
||||||
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
|
pytest.param(
|
||||||
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
|
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||||
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
|
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
|
||||||
pytest.param(
|
False,
|
||||||
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
|
||||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
|
pytest.param(
|
||||||
False,
|
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||||
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
|
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
|
||||||
pytest.param(
|
True,
|
||||||
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
|
||||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
|
(("eagle", "eagle618/deepseek-v3-random",
|
||||||
True,
|
"eagle618/eagle-deepseek-v3-random", 1), False),
|
||||||
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
|
],
|
||||||
(("eagle", "eagle618/deepseek-v3-random",
|
ids=[
|
||||||
"eagle618/eagle-deepseek-v3-random", 1), False),
|
"qwen3_eagle3", "llama3_eagle", "llama3_eagle3",
|
||||||
],
|
"llama4_eagle", "llama4_eagle_mm",
|
||||||
ids=[
|
"deepseek_eagle"
|
||||||
# TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501
|
])
|
||||||
# "qwen3_eagle3",
|
|
||||||
"llama3_eagle",
|
|
||||||
"llama3_eagle3",
|
|
||||||
"llama4_eagle",
|
|
||||||
"llama4_eagle_mm",
|
|
||||||
"deepseek_eagle"
|
|
||||||
])
|
|
||||||
@pytest.mark.parametrize("attn_backend",
|
@pytest.mark.parametrize("attn_backend",
|
||||||
get_attn_backend_list_based_on_platform())
|
get_attn_backend_list_based_on_platform())
|
||||||
def test_eagle_correctness(
|
def test_eagle_correctness(
|
||||||
|
|||||||
@ -2191,9 +2191,14 @@ class SpeculativeConfig:
|
|||||||
# Automatically detect the method
|
# Automatically detect the method
|
||||||
if self.method in ('eagle', 'eagle3'):
|
if self.method in ('eagle', 'eagle3'):
|
||||||
pass
|
pass
|
||||||
elif "eagle-" in self.draft_model_config.model.lower() or \
|
# examples:
|
||||||
"eagle3-" in self.draft_model_config.model.lower():
|
# yuhuili/EAGLE-LLaMA3-Instruct-8B
|
||||||
|
# yuhuili/EAGLE3-LLaMA3.1-Instruct-8B
|
||||||
|
# AngelSlim/Qwen3-8B_eagle3
|
||||||
|
elif "eagle-" in self.draft_model_config.model.lower():
|
||||||
self.method = "eagle"
|
self.method = "eagle"
|
||||||
|
elif "eagle3" in self.draft_model_config.model.lower():
|
||||||
|
self.method = "eagle3"
|
||||||
elif self.draft_model_config.hf_config.model_type == "medusa":
|
elif self.draft_model_config.hf_config.model_type == "medusa":
|
||||||
self.method = "medusa"
|
self.method = "medusa"
|
||||||
elif (self.draft_model_config.hf_config.model_type ==
|
elif (self.draft_model_config.hf_config.model_type ==
|
||||||
|
|||||||
@ -171,7 +171,22 @@ class LlamaAttention(nn.Module):
|
|||||||
|
|
||||||
sliding_window = None
|
sliding_window = None
|
||||||
if layer_types := getattr(config, "layer_types", None):
|
if layer_types := getattr(config, "layer_types", None):
|
||||||
is_sliding = layer_types[layer_idx] == "sliding_attention"
|
# Fix for Eagle3 compatibility:
|
||||||
|
# for draft models, subtract target layer count
|
||||||
|
# to get draft-relative layer index starting from 0
|
||||||
|
if hasattr(config, 'target_layer_count'):
|
||||||
|
# This is a draft model,
|
||||||
|
# adjust layer_idx to be relative to draft layers
|
||||||
|
effective_layer_idx = layer_idx - config.target_layer_count
|
||||||
|
else:
|
||||||
|
# This is a target model, use layer_idx directly
|
||||||
|
effective_layer_idx = layer_idx
|
||||||
|
assert effective_layer_idx < len(layer_types), \
|
||||||
|
f"effective_layer_idx: {effective_layer_idx} \
|
||||||
|
is out of bounds for layer_types: {layer_types}"
|
||||||
|
|
||||||
|
is_sliding = layer_types[
|
||||||
|
effective_layer_idx] == "sliding_attention"
|
||||||
if is_sliding:
|
if is_sliding:
|
||||||
sliding_window = config.sliding_window
|
sliding_window = config.sliding_window
|
||||||
|
|
||||||
|
|||||||
@ -199,6 +199,10 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
|||||||
speculative_config.draft_model_config.hf_config
|
speculative_config.draft_model_config.hf_config
|
||||||
target_layer_num = vllm_config.model_config.get_num_layers(
|
target_layer_num = vllm_config.model_config.get_num_layers(
|
||||||
vllm_config.parallel_config)
|
vllm_config.parallel_config)
|
||||||
|
|
||||||
|
# Store target layer count in draft config for
|
||||||
|
# proper layer_types indexing in draft models
|
||||||
|
self.config.target_layer_count = target_layer_num
|
||||||
self.model = LlamaModel(vllm_config=vllm_config,
|
self.model = LlamaModel(vllm_config=vllm_config,
|
||||||
prefix="model",
|
prefix="model",
|
||||||
start_layer_id=target_layer_num)
|
start_layer_id=target_layer_num)
|
||||||
|
|||||||
@ -277,8 +277,7 @@ _SPECULATIVE_DECODING_MODELS = {
|
|||||||
"EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
|
"EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
|
||||||
"EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
|
"EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
|
||||||
"Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
|
"Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
|
||||||
# TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501
|
"LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
|
||||||
# "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
|
|
||||||
"EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
|
"EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
|
||||||
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
|
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
|
||||||
"ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
|
"ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
|
||||||
|
|||||||
@ -46,6 +46,7 @@ class EAGLEConfig(PretrainedConfig):
|
|||||||
# Eagle model name should follow naming convention of
|
# Eagle model name should follow naming convention of
|
||||||
# LlamaForCausalLM -> EagleLlamaForCausalLM
|
# LlamaForCausalLM -> EagleLlamaForCausalLM
|
||||||
# LlamaForCausalLM -> Eagle3LlamaForCausalLM
|
# LlamaForCausalLM -> Eagle3LlamaForCausalLM
|
||||||
|
# LlamaForCausalLMEagle3 -> LlamaForCausalLMEagle3
|
||||||
if method == "eagle":
|
if method == "eagle":
|
||||||
assert self.model is not None, \
|
assert self.model is not None, \
|
||||||
"model should not be None when method is eagle"
|
"model should not be None when method is eagle"
|
||||||
@ -53,6 +54,7 @@ class EAGLEConfig(PretrainedConfig):
|
|||||||
f"Eagle{arch}" if not arch.startswith("Eagle") \
|
f"Eagle{arch}" if not arch.startswith("Eagle") \
|
||||||
else arch for arch in self.model.architectures
|
else arch for arch in self.model.architectures
|
||||||
]
|
]
|
||||||
|
|
||||||
elif method == "eagle3":
|
elif method == "eagle3":
|
||||||
assert self.model is not None, \
|
assert self.model is not None, \
|
||||||
"model should not be None when method is eagle3"
|
"model should not be None when method is eagle3"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user