mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-02 23:24:29 +08:00
[BugFix] Reordering extend logic fix (#27739)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
b8c48c5d72
commit
b5d70751d8
@ -53,7 +53,7 @@ REORDER_TEST_CASES = {
|
|||||||
expected_modified=True,
|
expected_modified=True,
|
||||||
),
|
),
|
||||||
"already_ordered": ReorderTestCase(
|
"already_ordered": ReorderTestCase(
|
||||||
requests=[(1, 10), (1, 20), (100, 100), (200, 200)],
|
requests=[(1, 10), (1, 20), (100, 100), (200, 0)],
|
||||||
expected_order=[0, 1, 2, 3],
|
expected_order=[0, 1, 2, 3],
|
||||||
expected_modified=False,
|
expected_modified=False,
|
||||||
),
|
),
|
||||||
@ -74,15 +74,30 @@ REORDER_TEST_CASES = {
|
|||||||
expected_modified=True,
|
expected_modified=True,
|
||||||
),
|
),
|
||||||
"decode_extend_prefill": ReorderTestCase(
|
"decode_extend_prefill": ReorderTestCase(
|
||||||
requests=[(100, 100), (10, 50), (1, 10)],
|
requests=[(100, 0), (10, 50), (1, 10)],
|
||||||
expected_order=[2, 1, 0],
|
expected_order=[2, 1, 0],
|
||||||
expected_modified=True,
|
expected_modified=True,
|
||||||
),
|
),
|
||||||
"extend_prefill_only": ReorderTestCase(
|
"extend_prefill_only": ReorderTestCase(
|
||||||
requests=[(100, 100), (10, 50), (200, 200), (20, 75)],
|
requests=[(100, 0), (10, 50), (200, 0), (20, 75)],
|
||||||
expected_order=[3, 1, 2, 0], # Only swap 0↔3, keep 1 and 2 in place
|
expected_order=[3, 1, 2, 0], # Only swap 0↔3, keep 1 and 2 in place
|
||||||
expected_modified=True,
|
expected_modified=True,
|
||||||
),
|
),
|
||||||
|
"complicated_mixed_interleaved": ReorderTestCase(
|
||||||
|
requests=[
|
||||||
|
(1, 20),
|
||||||
|
(1, 50),
|
||||||
|
(374, 0),
|
||||||
|
(300, 20),
|
||||||
|
(1, 20),
|
||||||
|
(256, 0),
|
||||||
|
(1, 5),
|
||||||
|
(27, 0),
|
||||||
|
(1, 4),
|
||||||
|
],
|
||||||
|
expected_order=[0, 1, 6, 8, 4, 3, 2, 7, 5],
|
||||||
|
expected_modified=True,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -811,8 +811,8 @@ def reorder_batch_to_split_decodes_and_prefills(
|
|||||||
num_computed_tokens_np = input_batch.num_computed_tokens_cpu[:num_reqs]
|
num_computed_tokens_np = input_batch.num_computed_tokens_cpu[:num_reqs]
|
||||||
|
|
||||||
is_decode = num_scheduled_tokens_np <= decode_threshold
|
is_decode = num_scheduled_tokens_np <= decode_threshold
|
||||||
is_extend = (~is_decode) & (num_computed_tokens_np > num_scheduled_tokens_np)
|
is_extend = (~is_decode) & (num_computed_tokens_np > 0)
|
||||||
is_prefill = (~is_decode) & (num_computed_tokens_np == num_scheduled_tokens_np)
|
is_prefill = (~is_decode) & (num_computed_tokens_np == 0)
|
||||||
|
|
||||||
# Desired order: decode → extend → prefill
|
# Desired order: decode → extend → prefill
|
||||||
req_regions = np.zeros(is_decode.shape, dtype=np.int32) # 0 = decode by default
|
req_regions = np.zeros(is_decode.shape, dtype=np.int32) # 0 = decode by default
|
||||||
@ -832,11 +832,11 @@ def reorder_batch_to_split_decodes_and_prefills(
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
# Extract indices that need swapping and sort by target region
|
# Extract indices that need swapping and sort by target region
|
||||||
swap_indices = np.where(needs_swap)[0]
|
orig_indices = np.where(needs_swap)[0]
|
||||||
sorted_order = np.argsort(req_regions[needs_swap], kind="stable")
|
sorted_order = np.argsort(req_regions[needs_swap], kind="stable")
|
||||||
dest_indices = swap_indices[sorted_order]
|
src_indices = orig_indices[sorted_order]
|
||||||
|
|
||||||
src_dest_map = {int(src): int(dst) for src, dst in zip(swap_indices, dest_indices)}
|
src_dest_map = {int(src): int(dst) for src, dst in zip(src_indices, orig_indices)}
|
||||||
|
|
||||||
for src in src_dest_map:
|
for src in src_dest_map:
|
||||||
dst = src_dest_map[src]
|
dst = src_dest_map[src]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user