From aed846917fee3b27df876d21ade4d3a30d4de402 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 9 Dec 2025 02:24:01 -0500 Subject: [PATCH] [Attention] Make `split_decodes_and_prefills(..., require_uniform=True)` support padding (#29644) Signed-off-by: Lucas Wilkinson Signed-off-by: Lucas Wilkinson Co-authored-by: Benjamin Chislett --- .../v1/attention/test_attention_splitting.py | 25 ++++++++++++++++++- vllm/v1/attention/backends/utils.py | 10 +++++--- 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/tests/v1/attention/test_attention_splitting.py b/tests/v1/attention/test_attention_splitting.py index f60861e3489d6..f08e2f480e30f 100644 --- a/tests/v1/attention/test_attention_splitting.py +++ b/tests/v1/attention/test_attention_splitting.py @@ -154,7 +154,10 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata): def apply_split_decodes_and_prefills( - query_lens: list[int], decode_threshold: int, require_uniform: bool + query_lens: list[int], + decode_threshold: int, + require_uniform: bool, + padded_num_tokens: int | None = None, ): """Helper function to apply split_decodes_and_prefills and return the results.""" @@ -165,6 +168,10 @@ def apply_split_decodes_and_prefills( block_size=16, device=device, ) + + if padded_num_tokens is not None: + common_metadata.num_actual_tokens = padded_num_tokens + return split_decodes_and_prefills( common_metadata, decode_threshold=decode_threshold, @@ -271,6 +278,22 @@ def test_split_decodes_and_prefills_uniform_mixed_batch_non_uniform_decodes(): assert num_prefill_tokens == (sum(query_lens) - 2) # rest of the tokens +def test_split_decodes_and_prefills_uniform_padded_batch_all_same(): + """uniform batch where all query lengths are identical with 0 length padded reqs.""" + # All query lengths are 2, with decode_threshold=3 (so 2 <= 3) + # This triggers the padded uniform path at line 891 + query_lens = [2, 2, 2, 0] + padded_num_tokens = 8 + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + apply_split_decodes_and_prefills(query_lens, 3, True, padded_num_tokens) + ) + # With uniform batch, all requests are treated as decodes + assert num_decodes == 4 + assert num_prefills == 0 + assert num_decode_tokens == padded_num_tokens + assert num_prefill_tokens == 0 + + @pytest.mark.parametrize( "seq_lens,query_lens,split_point,expected_first_reqs,expected_second_reqs", [ diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 8edfbb5140bc9..5200bc48b3156 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -883,11 +883,15 @@ def split_decodes_and_prefills( return 0, num_reqs, 0, num_tokens if require_uniform: + # check if we are in a padded uniform batch; this is used for full-CGs, some + # requests may have a query length of 0 but since they are padding its fine + # to treat them as decodes (ensures num_decodes matches the captured size) + if torch.all((query_lens == query_lens[0]) | (query_lens == 0)): + assert num_reqs * query_lens[0] == num_tokens, "tokens not padded correctly" + return num_reqs, 0, num_tokens, 0 # all decodes is_prefill = query_lens != query_lens[0] else: - # 0-query len indicates a padded request; leave this at the back - # of the batch with the prefills - is_prefill = (query_lens > decode_threshold) | (query_lens == 0) + is_prefill = query_lens > decode_threshold if not torch.any(is_prefill): return num_reqs, 0, num_tokens, 0