mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-05 09:42:14 +08:00
[MLA] Simplification to batch P/D reordering (#16673)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
e4755f7fac
commit
0377b8310b
@ -415,20 +415,18 @@ class MLACommonMetadataBuilder(Generic[M]):
|
|||||||
# the above loop
|
# the above loop
|
||||||
num_decodes = len(decodes)
|
num_decodes = len(decodes)
|
||||||
num_prefills = len(prefills)
|
num_prefills = len(prefills)
|
||||||
first_prefill = 0
|
|
||||||
modified_batch = False
|
modified_batch = False
|
||||||
|
|
||||||
for i in range(1, min(num_decodes, num_prefills) + 1):
|
for i in range(1, min(num_decodes, num_prefills) + 1):
|
||||||
# If the decode is at the "back" of the batch, i, we can swap it
|
# If the decode is at the "back" of the batch, i, we can swap it
|
||||||
# with the prefill closest to the front of the batch
|
# with the prefill closest to the front of the batch
|
||||||
if decodes[num_decodes - i] >= num_decodes:
|
decode_idx = decodes[num_decodes - i]
|
||||||
input_batch.swap_states(prefills[first_prefill],
|
if decode_idx < num_decodes:
|
||||||
decodes[num_decodes - i])
|
|
||||||
first_prefill += 1
|
|
||||||
modified_batch = True
|
|
||||||
else:
|
|
||||||
break
|
break
|
||||||
|
|
||||||
|
input_batch.swap_states(prefills[i - 1], decode_idx)
|
||||||
|
modified_batch = True
|
||||||
|
|
||||||
# Save for next `build` call
|
# Save for next `build` call
|
||||||
# TODO(lucas): this is a bit of a hack, we should probably have a
|
# TODO(lucas): this is a bit of a hack, we should probably have a
|
||||||
# better way of doing this
|
# better way of doing this
|
||||||
|
|||||||
@ -458,7 +458,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
if removed_req_indices:
|
if removed_req_indices:
|
||||||
self.input_batch.condense(removed_req_indices)
|
self.input_batch.condense(removed_req_indices)
|
||||||
|
|
||||||
if batch_changed:
|
# Some attention backends (namely MLA) may want to separate requests
|
||||||
|
# based on if the attention computation will be compute-bound or
|
||||||
|
# memory-bound. This gives them a hook to do that.
|
||||||
|
batch_reordered = self.attn_metadata_builder.reorder_batch(
|
||||||
|
self.input_batch, scheduler_output)
|
||||||
|
|
||||||
|
if batch_changed or batch_reordered:
|
||||||
self.input_batch.refresh_sampling_metadata()
|
self.input_batch.refresh_sampling_metadata()
|
||||||
|
|
||||||
def _prepare_inputs(
|
def _prepare_inputs(
|
||||||
@ -471,14 +477,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
num_reqs = self.input_batch.num_reqs
|
num_reqs = self.input_batch.num_reqs
|
||||||
assert num_reqs > 0
|
assert num_reqs > 0
|
||||||
|
|
||||||
# Some attention backends (namely MLA) may want to separate requests
|
|
||||||
# based on if the attention computation will be compute-bound or
|
|
||||||
# memory-bound. This gives them a hook to do that.
|
|
||||||
modified_batch = self.attn_metadata_builder.reorder_batch(
|
|
||||||
self.input_batch, scheduler_output)
|
|
||||||
if modified_batch:
|
|
||||||
self.input_batch.refresh_sampling_metadata()
|
|
||||||
|
|
||||||
# OPTIMIZATION: Start copying the block table first.
|
# OPTIMIZATION: Start copying the block table first.
|
||||||
# This way, we can overlap the copy with the following CPU operations.
|
# This way, we can overlap the copy with the following CPU operations.
|
||||||
self.input_batch.block_table.commit(num_reqs)
|
self.input_batch.block_table.commit(num_reqs)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user