[Attention] Make split_decodes_and_prefills(..., require_uniform=True) support padding (#29644)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: Benjamin Chislett <chislett.ben@gmail.com>
This commit is contained in:
Lucas Wilkinson 2025-12-09 02:24:01 -05:00 committed by GitHub
parent e4605d225e
commit aed846917f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 31 additions and 4 deletions

View File

@ -154,7 +154,10 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata):
def apply_split_decodes_and_prefills(
query_lens: list[int], decode_threshold: int, require_uniform: bool
query_lens: list[int],
decode_threshold: int,
require_uniform: bool,
padded_num_tokens: int | None = None,
):
"""Helper function to apply split_decodes_and_prefills and return
the results."""
@ -165,6 +168,10 @@ def apply_split_decodes_and_prefills(
block_size=16,
device=device,
)
if padded_num_tokens is not None:
common_metadata.num_actual_tokens = padded_num_tokens
return split_decodes_and_prefills(
common_metadata,
decode_threshold=decode_threshold,
@ -271,6 +278,22 @@ def test_split_decodes_and_prefills_uniform_mixed_batch_non_uniform_decodes():
assert num_prefill_tokens == (sum(query_lens) - 2) # rest of the tokens
def test_split_decodes_and_prefills_uniform_padded_batch_all_same():
"""uniform batch where all query lengths are identical with 0 length padded reqs."""
# All query lengths are 2, with decode_threshold=3 (so 2 <= 3)
# This triggers the padded uniform path at line 891
query_lens = [2, 2, 2, 0]
padded_num_tokens = 8
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 3, True, padded_num_tokens)
)
# With uniform batch, all requests are treated as decodes
assert num_decodes == 4
assert num_prefills == 0
assert num_decode_tokens == padded_num_tokens
assert num_prefill_tokens == 0
@pytest.mark.parametrize(
"seq_lens,query_lens,split_point,expected_first_reqs,expected_second_reqs",
[

View File

@ -883,11 +883,15 @@ def split_decodes_and_prefills(
return 0, num_reqs, 0, num_tokens
if require_uniform:
# check if we are in a padded uniform batch; this is used for full-CGs, some
# requests may have a query length of 0 but since they are padding its fine
# to treat them as decodes (ensures num_decodes matches the captured size)
if torch.all((query_lens == query_lens[0]) | (query_lens == 0)):
assert num_reqs * query_lens[0] == num_tokens, "tokens not padded correctly"
return num_reqs, 0, num_tokens, 0 # all decodes
is_prefill = query_lens != query_lens[0]
else:
# 0-query len indicates a padded request; leave this at the back
# of the batch with the prefills
is_prefill = (query_lens > decode_threshold) | (query_lens == 0)
is_prefill = query_lens > decode_threshold
if not torch.any(is_prefill):
return num_reqs, 0, num_tokens, 0