mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-31 05:16:32 +08:00
[Attention] Make split_decodes_and_prefills(..., require_uniform=True) support padding (#29644)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Benjamin Chislett <chislett.ben@gmail.com>
This commit is contained in:
parent
e4605d225e
commit
aed846917f
@ -154,7 +154,10 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata):
|
||||
|
||||
|
||||
def apply_split_decodes_and_prefills(
|
||||
query_lens: list[int], decode_threshold: int, require_uniform: bool
|
||||
query_lens: list[int],
|
||||
decode_threshold: int,
|
||||
require_uniform: bool,
|
||||
padded_num_tokens: int | None = None,
|
||||
):
|
||||
"""Helper function to apply split_decodes_and_prefills and return
|
||||
the results."""
|
||||
@ -165,6 +168,10 @@ def apply_split_decodes_and_prefills(
|
||||
block_size=16,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if padded_num_tokens is not None:
|
||||
common_metadata.num_actual_tokens = padded_num_tokens
|
||||
|
||||
return split_decodes_and_prefills(
|
||||
common_metadata,
|
||||
decode_threshold=decode_threshold,
|
||||
@ -271,6 +278,22 @@ def test_split_decodes_and_prefills_uniform_mixed_batch_non_uniform_decodes():
|
||||
assert num_prefill_tokens == (sum(query_lens) - 2) # rest of the tokens
|
||||
|
||||
|
||||
def test_split_decodes_and_prefills_uniform_padded_batch_all_same():
|
||||
"""uniform batch where all query lengths are identical with 0 length padded reqs."""
|
||||
# All query lengths are 2, with decode_threshold=3 (so 2 <= 3)
|
||||
# This triggers the padded uniform path at line 891
|
||||
query_lens = [2, 2, 2, 0]
|
||||
padded_num_tokens = 8
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 3, True, padded_num_tokens)
|
||||
)
|
||||
# With uniform batch, all requests are treated as decodes
|
||||
assert num_decodes == 4
|
||||
assert num_prefills == 0
|
||||
assert num_decode_tokens == padded_num_tokens
|
||||
assert num_prefill_tokens == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"seq_lens,query_lens,split_point,expected_first_reqs,expected_second_reqs",
|
||||
[
|
||||
|
||||
@ -883,11 +883,15 @@ def split_decodes_and_prefills(
|
||||
return 0, num_reqs, 0, num_tokens
|
||||
|
||||
if require_uniform:
|
||||
# check if we are in a padded uniform batch; this is used for full-CGs, some
|
||||
# requests may have a query length of 0 but since they are padding its fine
|
||||
# to treat them as decodes (ensures num_decodes matches the captured size)
|
||||
if torch.all((query_lens == query_lens[0]) | (query_lens == 0)):
|
||||
assert num_reqs * query_lens[0] == num_tokens, "tokens not padded correctly"
|
||||
return num_reqs, 0, num_tokens, 0 # all decodes
|
||||
is_prefill = query_lens != query_lens[0]
|
||||
else:
|
||||
# 0-query len indicates a padded request; leave this at the back
|
||||
# of the batch with the prefills
|
||||
is_prefill = (query_lens > decode_threshold) | (query_lens == 0)
|
||||
is_prefill = query_lens > decode_threshold
|
||||
|
||||
if not torch.any(is_prefill):
|
||||
return num_reqs, 0, num_tokens, 0
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user