From cec7c288333339028f6fe8e0ac3222e3924da90b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Delacourt?= <54138269+Flechman@users.noreply.github.com> Date: Mon, 3 Nov 2025 08:22:46 +0100 Subject: [PATCH] [Bugfix] Padded Eagle Specdec with Chunked Prefill (#26263) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Rémi Delacourt Signed-off-by: Rémi Delacourt <54138269+Flechman@users.noreply.github.com> Signed-off-by: remi Co-authored-by: Benjamin Chislett --- tests/v1/e2e/test_spec_decode.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 45b48e585893..ea7fcdf3174e 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -202,9 +202,9 @@ def test_speculators_model_integration( @pytest.mark.parametrize( - ["model_setup", "mm_enabled"], + ["model_setup", "mm_enabled", "chunked_prefill_enabled"], [ - (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False), + (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False, False), pytest.param( ( "eagle3", @@ -213,11 +213,12 @@ def test_speculators_model_integration( 1, ), False, + False, marks=pytest.mark.skip( reason="Skipping due to its head_dim not being a a multiple of 32" ), ), - ( + pytest.param( ( "eagle", "meta-llama/Llama-3.1-8B-Instruct", @@ -225,7 +226,9 @@ def test_speculators_model_integration( 1, ), False, - ), + True, + marks=large_gpu_mark(min_gb=40), + ), # works on 4x H100 ( ( "eagle3", @@ -234,6 +237,7 @@ def test_speculators_model_integration( 1, ), False, + False, ), pytest.param( ( @@ -243,6 +247,7 @@ def test_speculators_model_integration( 4, ), False, + False, marks=large_gpu_mark(min_gb=80), ), # works on 4x H100 pytest.param( @@ -253,6 +258,7 @@ def test_speculators_model_integration( 4, ), True, + True, marks=large_gpu_mark(min_gb=80), ), # works on 4x H100 ( @@ -263,6 +269,7 @@ def test_speculators_model_integration( 1, ), False, + False, ), ], ids=[ @@ -281,6 +288,7 @@ def test_eagle_correctness( sampling_config: SamplingParams, model_setup: tuple[str, str, str, int], mm_enabled: bool, + chunked_prefill_enabled: bool, attn_backend: str, ): if attn_backend == "TREE_ATTN": @@ -317,9 +325,13 @@ def test_eagle_correctness( m.setenv("VLLM_ROCM_USE_AITER", "1") method, model_name, spec_model_name, tp_size = model_setup + max_model_len = 2048 + max_num_batched_tokens = max_model_len + if chunked_prefill_enabled: + max_num_batched_tokens = 128 ref_llm = LLM( - model=model_name, max_model_len=2048, tensor_parallel_size=tp_size + model=model_name, max_model_len=max_model_len, tensor_parallel_size=tp_size ) ref_outputs = ref_llm.chat(test_prompts, sampling_config) del ref_llm @@ -334,9 +346,11 @@ def test_eagle_correctness( "method": method, "model": spec_model_name, "num_speculative_tokens": 3, - "max_model_len": 2048, + "max_model_len": max_model_len, }, - max_model_len=2048, + max_model_len=max_model_len, + max_num_batched_tokens=max_num_batched_tokens, + enable_chunked_prefill=chunked_prefill_enabled, ) spec_outputs = spec_llm.chat(test_prompts, sampling_config) matches = 0