mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 07:34:58 +08:00
[Bugfix] Padded Eagle Specdec with Chunked Prefill (#26263)
Signed-off-by: Rémi Delacourt <remi@mistral.ai> Signed-off-by: Rémi Delacourt <54138269+Flechman@users.noreply.github.com> Signed-off-by: remi <remi@mistral.ai> Co-authored-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
parent
18961c5ea6
commit
cec7c28833
@ -202,9 +202,9 @@ def test_speculators_model_integration(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
["model_setup", "mm_enabled"],
|
["model_setup", "mm_enabled", "chunked_prefill_enabled"],
|
||||||
[
|
[
|
||||||
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
|
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False, False),
|
||||||
pytest.param(
|
pytest.param(
|
||||||
(
|
(
|
||||||
"eagle3",
|
"eagle3",
|
||||||
@ -213,11 +213,12 @@ def test_speculators_model_integration(
|
|||||||
1,
|
1,
|
||||||
),
|
),
|
||||||
False,
|
False,
|
||||||
|
False,
|
||||||
marks=pytest.mark.skip(
|
marks=pytest.mark.skip(
|
||||||
reason="Skipping due to its head_dim not being a a multiple of 32"
|
reason="Skipping due to its head_dim not being a a multiple of 32"
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
(
|
pytest.param(
|
||||||
(
|
(
|
||||||
"eagle",
|
"eagle",
|
||||||
"meta-llama/Llama-3.1-8B-Instruct",
|
"meta-llama/Llama-3.1-8B-Instruct",
|
||||||
@ -225,7 +226,9 @@ def test_speculators_model_integration(
|
|||||||
1,
|
1,
|
||||||
),
|
),
|
||||||
False,
|
False,
|
||||||
),
|
True,
|
||||||
|
marks=large_gpu_mark(min_gb=40),
|
||||||
|
), # works on 4x H100
|
||||||
(
|
(
|
||||||
(
|
(
|
||||||
"eagle3",
|
"eagle3",
|
||||||
@ -234,6 +237,7 @@ def test_speculators_model_integration(
|
|||||||
1,
|
1,
|
||||||
),
|
),
|
||||||
False,
|
False,
|
||||||
|
False,
|
||||||
),
|
),
|
||||||
pytest.param(
|
pytest.param(
|
||||||
(
|
(
|
||||||
@ -243,6 +247,7 @@ def test_speculators_model_integration(
|
|||||||
4,
|
4,
|
||||||
),
|
),
|
||||||
False,
|
False,
|
||||||
|
False,
|
||||||
marks=large_gpu_mark(min_gb=80),
|
marks=large_gpu_mark(min_gb=80),
|
||||||
), # works on 4x H100
|
), # works on 4x H100
|
||||||
pytest.param(
|
pytest.param(
|
||||||
@ -253,6 +258,7 @@ def test_speculators_model_integration(
|
|||||||
4,
|
4,
|
||||||
),
|
),
|
||||||
True,
|
True,
|
||||||
|
True,
|
||||||
marks=large_gpu_mark(min_gb=80),
|
marks=large_gpu_mark(min_gb=80),
|
||||||
), # works on 4x H100
|
), # works on 4x H100
|
||||||
(
|
(
|
||||||
@ -263,6 +269,7 @@ def test_speculators_model_integration(
|
|||||||
1,
|
1,
|
||||||
),
|
),
|
||||||
False,
|
False,
|
||||||
|
False,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
ids=[
|
ids=[
|
||||||
@ -281,6 +288,7 @@ def test_eagle_correctness(
|
|||||||
sampling_config: SamplingParams,
|
sampling_config: SamplingParams,
|
||||||
model_setup: tuple[str, str, str, int],
|
model_setup: tuple[str, str, str, int],
|
||||||
mm_enabled: bool,
|
mm_enabled: bool,
|
||||||
|
chunked_prefill_enabled: bool,
|
||||||
attn_backend: str,
|
attn_backend: str,
|
||||||
):
|
):
|
||||||
if attn_backend == "TREE_ATTN":
|
if attn_backend == "TREE_ATTN":
|
||||||
@ -317,9 +325,13 @@ def test_eagle_correctness(
|
|||||||
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||||
|
|
||||||
method, model_name, spec_model_name, tp_size = model_setup
|
method, model_name, spec_model_name, tp_size = model_setup
|
||||||
|
max_model_len = 2048
|
||||||
|
max_num_batched_tokens = max_model_len
|
||||||
|
if chunked_prefill_enabled:
|
||||||
|
max_num_batched_tokens = 128
|
||||||
|
|
||||||
ref_llm = LLM(
|
ref_llm = LLM(
|
||||||
model=model_name, max_model_len=2048, tensor_parallel_size=tp_size
|
model=model_name, max_model_len=max_model_len, tensor_parallel_size=tp_size
|
||||||
)
|
)
|
||||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||||
del ref_llm
|
del ref_llm
|
||||||
@ -334,9 +346,11 @@ def test_eagle_correctness(
|
|||||||
"method": method,
|
"method": method,
|
||||||
"model": spec_model_name,
|
"model": spec_model_name,
|
||||||
"num_speculative_tokens": 3,
|
"num_speculative_tokens": 3,
|
||||||
"max_model_len": 2048,
|
"max_model_len": max_model_len,
|
||||||
},
|
},
|
||||||
max_model_len=2048,
|
max_model_len=max_model_len,
|
||||||
|
max_num_batched_tokens=max_num_batched_tokens,
|
||||||
|
enable_chunked_prefill=chunked_prefill_enabled,
|
||||||
)
|
)
|
||||||
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||||
matches = 0
|
matches = 0
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user