mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:26:12 +08:00
[Config] add "qwen" as a native eagle3 target supported model (#22333)
Signed-off-by: lechen <lecself@163.com> Signed-off-by: LeChen <lecself@163.com>
This commit is contained in:
parent
0c5254b82a
commit
3d7363e61c
@ -525,6 +525,10 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
|||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
|
speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
|
||||||
tokenizer="meta-llama/Llama-3.1-8B-Instruct"),
|
tokenizer="meta-llama/Llama-3.1-8B-Instruct"),
|
||||||
|
"LlamaForCausalLMEagle3": _HfExamplesInfo("AngelSlim/Qwen3-8B_eagle3", # noqa: E501
|
||||||
|
trust_remote_code=True,
|
||||||
|
speculative_model="AngelSlim/Qwen3-8B_eagle3",
|
||||||
|
tokenizer="Qwen/Qwen3-8B"),
|
||||||
"EagleLlama4ForCausalLM": _HfExamplesInfo(
|
"EagleLlama4ForCausalLM": _HfExamplesInfo(
|
||||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
|
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
|
|||||||
@ -125,8 +125,8 @@ def test_ngram_correctness(
|
|||||||
cleanup_dist_env_and_memory()
|
cleanup_dist_env_and_memory()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(["model_setup", "mm_enabled"], [
|
||||||
["model_setup", "mm_enabled"], [
|
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
|
||||||
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
|
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
|
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
|
||||||
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
|
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
@ -141,8 +141,11 @@ def test_ngram_correctness(
|
|||||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
|
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
|
||||||
True,
|
True,
|
||||||
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
|
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
|
||||||
],
|
],
|
||||||
ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle", "llama4_eagle_mm"])
|
ids=[
|
||||||
|
"qwen3_eagle3", "llama3_eagle", "llama3_eagle3",
|
||||||
|
"llama4_eagle", "llama4_eagle_mm"
|
||||||
|
])
|
||||||
@pytest.mark.parametrize("attn_backend",
|
@pytest.mark.parametrize("attn_backend",
|
||||||
get_attn_backend_list_based_on_platform())
|
get_attn_backend_list_based_on_platform())
|
||||||
def test_eagle_correctness(
|
def test_eagle_correctness(
|
||||||
|
|||||||
@ -2852,13 +2852,7 @@ class SpeculativeConfig:
|
|||||||
"speculative decoding is > 1, but got "
|
"speculative decoding is > 1, but got "
|
||||||
f"{self.disable_by_batch_size=}")
|
f"{self.disable_by_batch_size=}")
|
||||||
|
|
||||||
from vllm.transformers_utils.configs import SpeculatorsConfig
|
eagle3_target_supported = ["llama", "qwen"]
|
||||||
|
|
||||||
eagle3_target_supported = ["llama"]
|
|
||||||
if self.draft_model_config and isinstance(
|
|
||||||
self.draft_model_config.hf_config, SpeculatorsConfig):
|
|
||||||
eagle3_target_supported.append("qwen")
|
|
||||||
|
|
||||||
if self.method == "eagle3" and self.target_model_config and not any(
|
if self.method == "eagle3" and self.target_model_config and not any(
|
||||||
supported_model in
|
supported_model in
|
||||||
self.target_model_config.hf_text_config.model_type
|
self.target_model_config.hf_text_config.model_type
|
||||||
|
|||||||
@ -259,6 +259,7 @@ _SPECULATIVE_DECODING_MODELS = {
|
|||||||
"EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
|
"EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
|
||||||
"EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
|
"EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
|
||||||
"Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
|
"Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
|
||||||
|
"LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
|
||||||
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
|
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
|
||||||
"Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
|
"Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
|
||||||
"MedusaModel": ("medusa", "Medusa"),
|
"MedusaModel": ("medusa", "Medusa"),
|
||||||
|
|||||||
@ -45,6 +45,7 @@ class EAGLEConfig(PretrainedConfig):
|
|||||||
|
|
||||||
# Eagle model name should follow naming convention of
|
# Eagle model name should follow naming convention of
|
||||||
# LlamaForCausalLM -> EagleLlamaForCausalLM
|
# LlamaForCausalLM -> EagleLlamaForCausalLM
|
||||||
|
# LlamaForCausalLM -> Eagle3LlamaForCausalLM / LlamaForCausalLMEagle3
|
||||||
if method == "eagle":
|
if method == "eagle":
|
||||||
assert self.model is not None, \
|
assert self.model is not None, \
|
||||||
"model should not be None when method is eagle"
|
"model should not be None when method is eagle"
|
||||||
@ -56,8 +57,8 @@ class EAGLEConfig(PretrainedConfig):
|
|||||||
assert self.model is not None, \
|
assert self.model is not None, \
|
||||||
"model should not be None when method is eagle3"
|
"model should not be None when method is eagle3"
|
||||||
kwargs["architectures"] = [
|
kwargs["architectures"] = [
|
||||||
f"Eagle3{arch}" if not arch.startswith("Eagle3") \
|
arch if arch.startswith("Eagle3") or arch.endswith("Eagle3")
|
||||||
else arch for arch in self.model.architectures
|
else f"Eagle3{arch}" for arch in self.model.architectures
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid method {method}. \
|
raise ValueError(f"Invalid method {method}. \
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user