mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-27 10:14:26 +08:00
[Attention] Make split_decodes_and_prefills(..., require_uniform=True) support padding (#29644)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Benjamin Chislett <chislett.ben@gmail.com>
This commit is contained in:
parent
e4605d225e
commit
aed846917f
@ -154,7 +154,10 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata):
|
|||||||
|
|
||||||
|
|
||||||
def apply_split_decodes_and_prefills(
|
def apply_split_decodes_and_prefills(
|
||||||
query_lens: list[int], decode_threshold: int, require_uniform: bool
|
query_lens: list[int],
|
||||||
|
decode_threshold: int,
|
||||||
|
require_uniform: bool,
|
||||||
|
padded_num_tokens: int | None = None,
|
||||||
):
|
):
|
||||||
"""Helper function to apply split_decodes_and_prefills and return
|
"""Helper function to apply split_decodes_and_prefills and return
|
||||||
the results."""
|
the results."""
|
||||||
@ -165,6 +168,10 @@ def apply_split_decodes_and_prefills(
|
|||||||
block_size=16,
|
block_size=16,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if padded_num_tokens is not None:
|
||||||
|
common_metadata.num_actual_tokens = padded_num_tokens
|
||||||
|
|
||||||
return split_decodes_and_prefills(
|
return split_decodes_and_prefills(
|
||||||
common_metadata,
|
common_metadata,
|
||||||
decode_threshold=decode_threshold,
|
decode_threshold=decode_threshold,
|
||||||
@ -271,6 +278,22 @@ def test_split_decodes_and_prefills_uniform_mixed_batch_non_uniform_decodes():
|
|||||||
assert num_prefill_tokens == (sum(query_lens) - 2) # rest of the tokens
|
assert num_prefill_tokens == (sum(query_lens) - 2) # rest of the tokens
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_decodes_and_prefills_uniform_padded_batch_all_same():
|
||||||
|
"""uniform batch where all query lengths are identical with 0 length padded reqs."""
|
||||||
|
# All query lengths are 2, with decode_threshold=3 (so 2 <= 3)
|
||||||
|
# This triggers the padded uniform path at line 891
|
||||||
|
query_lens = [2, 2, 2, 0]
|
||||||
|
padded_num_tokens = 8
|
||||||
|
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||||
|
apply_split_decodes_and_prefills(query_lens, 3, True, padded_num_tokens)
|
||||||
|
)
|
||||||
|
# With uniform batch, all requests are treated as decodes
|
||||||
|
assert num_decodes == 4
|
||||||
|
assert num_prefills == 0
|
||||||
|
assert num_decode_tokens == padded_num_tokens
|
||||||
|
assert num_prefill_tokens == 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"seq_lens,query_lens,split_point,expected_first_reqs,expected_second_reqs",
|
"seq_lens,query_lens,split_point,expected_first_reqs,expected_second_reqs",
|
||||||
[
|
[
|
||||||
|
|||||||
@ -883,11 +883,15 @@ def split_decodes_and_prefills(
|
|||||||
return 0, num_reqs, 0, num_tokens
|
return 0, num_reqs, 0, num_tokens
|
||||||
|
|
||||||
if require_uniform:
|
if require_uniform:
|
||||||
|
# check if we are in a padded uniform batch; this is used for full-CGs, some
|
||||||
|
# requests may have a query length of 0 but since they are padding its fine
|
||||||
|
# to treat them as decodes (ensures num_decodes matches the captured size)
|
||||||
|
if torch.all((query_lens == query_lens[0]) | (query_lens == 0)):
|
||||||
|
assert num_reqs * query_lens[0] == num_tokens, "tokens not padded correctly"
|
||||||
|
return num_reqs, 0, num_tokens, 0 # all decodes
|
||||||
is_prefill = query_lens != query_lens[0]
|
is_prefill = query_lens != query_lens[0]
|
||||||
else:
|
else:
|
||||||
# 0-query len indicates a padded request; leave this at the back
|
is_prefill = query_lens > decode_threshold
|
||||||
# of the batch with the prefills
|
|
||||||
is_prefill = (query_lens > decode_threshold) | (query_lens == 0)
|
|
||||||
|
|
||||||
if not torch.any(is_prefill):
|
if not torch.any(is_prefill):
|
||||||
return num_reqs, 0, num_tokens, 0
|
return num_reqs, 0, num_tokens, 0
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user