From 90d43db442ee0251ac259209b557c10ac0529f8d Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 15 Aug 2025 19:10:01 +0000 Subject: [PATCH] [Optimization] Truncate kv page indices for sliding window attention Signed-off-by: Woosuk Kwon --- vllm/v1/attention/backends/flashinfer.py | 25 +++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 02decb171fc05..ca6976f74822e 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -248,7 +248,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self.block_table_arange = torch.arange(max_num_pages_per_req, dtype=torch.int32, - device=self.device) + device="cpu") + + self.sliding_window_size = getattr(kv_cache_spec, "sliding_window", None) def _get_workspace_buffer(self): if self._workspace_buffer is None: @@ -488,10 +490,23 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): shared_kv_last_page_len_cpu = None max_num_blocks = block_table_bounds_cpu.max() - block_table_bounds = block_table_bounds_cpu.to(self.device, - non_blocking=True) - mask = (self.block_table_arange[:max_num_blocks].unsqueeze(0) - < block_table_bounds.unsqueeze(1)) + arange = self.block_table_arange[:max_num_blocks].unsqueeze(0) + mask = arange < block_table_bounds_cpu.unsqueeze(1) + if (self.sliding_window_size is not None and not use_cascade + and num_decodes > 0): + # NOTE(woosuk): Since FlashInfer's decode kernel doesn't skip the kv + # outside the sliding window and only do masking, we manually + # manipulate the seq_lens and block table for skipping. + # NOTE: Don't apply this optimization to prefill requests. + decode_seq_lens_cpu = seq_lens_cpu[:num_decodes] + num_skipped_pages = ( + torch.relu(decode_seq_lens_cpu - self.sliding_window_size) // + page_size) + + block_table_bounds_cpu[:num_decodes] -= num_skipped_pages + mask[:num_decodes] &= (arange[:num_decodes] + >= num_skipped_pages.unsqueeze(1)) + # write self.paged_kv_indices inplace num_actual_pages = torch.sum(mask) paged_kv_indices = self.paged_kv_indices[:num_actual_pages]