mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 06:47:02 +08:00
[Optimization] Truncate kv page indices for sliding window attention
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
df5afa82e5
commit
90d43db442
@ -248,7 +248,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
|
||||
self.block_table_arange = torch.arange(max_num_pages_per_req,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
device="cpu")
|
||||
|
||||
self.sliding_window_size = getattr(kv_cache_spec, "sliding_window", None)
|
||||
|
||||
def _get_workspace_buffer(self):
|
||||
if self._workspace_buffer is None:
|
||||
@ -488,10 +490,23 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
shared_kv_last_page_len_cpu = None
|
||||
|
||||
max_num_blocks = block_table_bounds_cpu.max()
|
||||
block_table_bounds = block_table_bounds_cpu.to(self.device,
|
||||
non_blocking=True)
|
||||
mask = (self.block_table_arange[:max_num_blocks].unsqueeze(0)
|
||||
< block_table_bounds.unsqueeze(1))
|
||||
arange = self.block_table_arange[:max_num_blocks].unsqueeze(0)
|
||||
mask = arange < block_table_bounds_cpu.unsqueeze(1)
|
||||
if (self.sliding_window_size is not None and not use_cascade
|
||||
and num_decodes > 0):
|
||||
# NOTE(woosuk): Since FlashInfer's decode kernel doesn't skip the kv
|
||||
# outside the sliding window and only do masking, we manually
|
||||
# manipulate the seq_lens and block table for skipping.
|
||||
# NOTE: Don't apply this optimization to prefill requests.
|
||||
decode_seq_lens_cpu = seq_lens_cpu[:num_decodes]
|
||||
num_skipped_pages = (
|
||||
torch.relu(decode_seq_lens_cpu - self.sliding_window_size) //
|
||||
page_size)
|
||||
|
||||
block_table_bounds_cpu[:num_decodes] -= num_skipped_pages
|
||||
mask[:num_decodes] &= (arange[:num_decodes]
|
||||
>= num_skipped_pages.unsqueeze(1))
|
||||
|
||||
# write self.paged_kv_indices inplace
|
||||
num_actual_pages = torch.sum(mask)
|
||||
paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user