diff --git a/tests/v1/attention/test_batch_reordering.py b/tests/v1/attention/test_batch_reordering.py new file mode 100644 index 0000000000000..b271409b92955 --- /dev/null +++ b/tests/v1/attention/test_batch_reordering.py @@ -0,0 +1,111 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass + +import numpy as np +import pytest + +from vllm.v1.attention.backends.utils import reorder_batch_to_split_decodes_and_prefills + + +class MockInputBatch: + def __init__(self, req_ids, num_computed_tokens_cpu): + self.req_ids = req_ids + self.num_computed_tokens_cpu = num_computed_tokens_cpu + + def swap_states(self, i, j): + self.req_ids[i], self.req_ids[j] = self.req_ids[j], self.req_ids[i] + self.num_computed_tokens_cpu[i], self.num_computed_tokens_cpu[j] = ( + self.num_computed_tokens_cpu[j], + self.num_computed_tokens_cpu[i], + ) + + +class MockSchedulerOutput: + def __init__(self, num_scheduled_tokens): + self.num_scheduled_tokens = num_scheduled_tokens + + +@dataclass +class ReorderTestCase: + requests: list[tuple[int, int]] # (num_scheduled_tokens, num_computed_tokens) + expected_order: list[int] + expected_modified: bool + decode_threshold: int = 1 + + +# Test cases for batch reordering +REORDER_TEST_CASES = { + "all_decodes": ReorderTestCase( + requests=[(1, 10), (1, 20), (1, 30)], + expected_order=[0, 1, 2], + expected_modified=False, + ), + "all_prefills": ReorderTestCase( + requests=[(100, 100), (200, 200), (300, 300)], + expected_order=[0, 1, 2], + expected_modified=False, + ), + "mixed_interleaved": ReorderTestCase( + requests=[(100, 100), (1, 10), (200, 200), (1, 20)], + expected_order=[3, 1, 2, 0], # Only swap 0↔3, keep 1 and 2 in place + expected_modified=True, + ), + "already_ordered": ReorderTestCase( + requests=[(1, 10), (1, 20), (100, 100), (200, 200)], + expected_order=[0, 1, 2, 3], + expected_modified=False, + ), + "single_request": ReorderTestCase( + requests=[(1, 10)], + expected_order=[0], + expected_modified=False, + ), + "higher_threshold": ReorderTestCase( + requests=[(2, 10), (3, 20), (5, 30), (6, 40)], + expected_order=[0, 1, 2, 3], + expected_modified=False, + decode_threshold=4, + ), + "decodes_at_end": ReorderTestCase( + requests=[(100, 100), (200, 200), (1, 10), (1, 20)], + expected_order=[2, 3, 0, 1], + expected_modified=True, + ), + "decode_extend_prefill": ReorderTestCase( + requests=[(100, 100), (10, 50), (1, 10)], + expected_order=[2, 1, 0], + expected_modified=True, + ), + "extend_prefill_only": ReorderTestCase( + requests=[(100, 100), (10, 50), (200, 200), (20, 75)], + expected_order=[3, 1, 2, 0], # Only swap 0↔3, keep 1 and 2 in place + expected_modified=True, + ), +} + + +@pytest.mark.parametrize( + "test_case", REORDER_TEST_CASES.values(), ids=REORDER_TEST_CASES.keys() +) +def test_reorder_batch_to_split_decodes_and_prefills(test_case: ReorderTestCase): + req_ids = [f"r{i}" for i in range(len(test_case.requests))] + num_computed_tokens = np.array([r[1] for r in test_case.requests], dtype=np.int32) + num_scheduled_tokens = {f"r{i}": r[0] for i, r in enumerate(test_case.requests)} + + input_batch = MockInputBatch(req_ids, num_computed_tokens) + scheduler_output = MockSchedulerOutput(num_scheduled_tokens) + + modified = reorder_batch_to_split_decodes_and_prefills( + input_batch, scheduler_output, decode_threshold=test_case.decode_threshold + ) + + expected_req_ids = [f"r{i}" for i in test_case.expected_order] + + assert modified == test_case.expected_modified, ( + f"Expected modified={test_case.expected_modified}, got {modified}" + ) + assert input_batch.req_ids == expected_req_ids, ( + f"Expected order {expected_req_ids}, got {input_batch.req_ids}" + ) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index a0d354df06ca3..389baf1488be0 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -795,51 +795,59 @@ def reorder_batch_to_split_decodes_and_prefills( Returns: True if the batch was modified, False otherwise. """ - # We now want to reorder the batch so that the "decode" requests are at - # the front and the "prefill" requests are at the back using the least - # amount of swaps possible. (NOTE for now we loosely use "decode" to mean - # requests where attention is likely memory-bound and "prefill" to mean - # requests where attention is likely compute-bound, TODO(lucas): figure out - # a better naming here) - decodes = [] - prefills = [] - num_decode_tokens = 0 - num_prefill_tokens = 0 + # We now want to reorder the batch into decode → extend → prefill order + # where: + # decode: request with num_scheduled_tokens <= decode_threshold + # extend: non-decode request with existing context + # prefill: non-decode request with no existing context + # NOTE for now we loosely use "decode" to mean requests where attention is + # likely memory-bound and "prefill" to mean requests where attention is + # likely compute-bound, + num_reqs = len(input_batch.req_ids) + num_scheduled_tokens = [ + scheduler_output.num_scheduled_tokens[id] for id in input_batch.req_ids + ] + num_scheduled_tokens_np = np.array(num_scheduled_tokens) + num_computed_tokens_np = input_batch.num_computed_tokens_cpu[:num_reqs] - for i, req_id in enumerate(input_batch.req_ids): - num_tokens = scheduler_output.num_scheduled_tokens[req_id] - if num_tokens <= decode_threshold: - decodes.append(i) - num_decode_tokens += num_tokens - else: - prefills.append(i) - num_prefill_tokens += num_tokens + is_decode = num_scheduled_tokens_np <= decode_threshold + is_extend = (~is_decode) & (num_computed_tokens_np > num_scheduled_tokens_np) + is_prefill = (~is_decode) & (num_computed_tokens_np == num_scheduled_tokens_np) - # We hope that this is fairly minimal since decodes - # should be around for a number of iterations so hopefully they are - # relatively stationary (and new request are generally appended to the - # persistent batch so already should be at the back) - # To achieve this we loop over the decodes in descending order and - # the prefills in ascending order. We swap decodes from the "back" - # i.e. past where the last decode should be in the reodorered with - # prefills from the front of the batch. - # `decodes` and `prefills` are already in ascending order just based on - # the above loop - num_decodes = len(decodes) - num_prefills = len(prefills) - modified_batch = False + # Desired order: decode → extend → prefill + req_regions = np.zeros(is_decode.shape, dtype=np.int32) # 0 = decode by default + req_regions[is_extend] = 1 + req_regions[is_prefill] = 2 - for i in range(1, min(num_decodes, num_prefills) + 1): - # If the decode is at the "back" of the batch, i, we can swap it - # with the prefill closest to the front of the batch - decode_idx = decodes[num_decodes - i] - if decode_idx < num_decodes: - break + num_decodes = int(is_decode.sum()) + num_extends = int(is_extend.sum()) - input_batch.swap_states(prefills[i - 1], decode_idx) - modified_batch = True + target_regions = np.zeros(num_reqs, dtype=np.int32) + target_regions[num_decodes : num_decodes + num_extends] = 1 + target_regions[num_decodes + num_extends :] = 2 - return modified_batch + needs_swap = req_regions != target_regions + + if not needs_swap.any(): + return False + + # Extract indices that need swapping and sort by target region + swap_indices = np.where(needs_swap)[0] + sorted_order = np.argsort(req_regions[needs_swap], kind="stable") + dest_indices = swap_indices[sorted_order] + + src_dest_map = {int(src): int(dst) for src, dst in zip(swap_indices, dest_indices)} + + for src in src_dest_map: + dst = src_dest_map[src] + while src != dst: + input_batch.swap_states(src, dst) + # Mark dst as done by updating its destination to itself + next_dst = src_dest_map.get(dst, dst) + src_dest_map[dst] = dst + dst = next_dst + + return True def reshape_query_for_spec_decode(query: torch.Tensor, batch_size: int) -> torch.Tensor: