From da03cb8f0b912c04376da584cbca47959356cd0a Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 15 Aug 2025 14:00:26 -0700 Subject: [PATCH] fix Signed-off-by: Woosuk Kwon --- vllm/v1/attention/backends/flashinfer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index ca6976f74822e..a221ed2530445 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -250,7 +250,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): dtype=torch.int32, device="cpu") - self.sliding_window_size = getattr(kv_cache_spec, "sliding_window", None) + self.sliding_window = getattr(kv_cache_spec, "sliding_window", None) def _get_workspace_buffer(self): if self._workspace_buffer is None: @@ -489,10 +489,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): shared_kv_page_indices_cpu = None shared_kv_last_page_len_cpu = None - max_num_blocks = block_table_bounds_cpu.max() + max_num_blocks = block_table_bounds_cpu.max().item() arange = self.block_table_arange[:max_num_blocks].unsqueeze(0) mask = arange < block_table_bounds_cpu.unsqueeze(1) - if (self.sliding_window_size is not None and not use_cascade + if (self.sliding_window is not None and not use_cascade and num_decodes > 0): # NOTE(woosuk): Since FlashInfer's decode kernel doesn't skip the kv # outside the sliding window and only do masking, we manually @@ -500,7 +500,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): # NOTE: Don't apply this optimization to prefill requests. decode_seq_lens_cpu = seq_lens_cpu[:num_decodes] num_skipped_pages = ( - torch.relu(decode_seq_lens_cpu - self.sliding_window_size) // + torch.relu(decode_seq_lens_cpu - self.sliding_window) // page_size) block_table_bounds_cpu[:num_decodes] -= num_skipped_pages @@ -511,7 +511,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): num_actual_pages = torch.sum(mask) paged_kv_indices = self.paged_kv_indices[:num_actual_pages] torch.masked_select(block_table_tensor[:, :max_num_blocks], - mask, + mask.to(self.device, non_blocking=True), out=paged_kv_indices) # write self.paged_kv_indptr_cpu inplace (0-index is always 0)