Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-08-15 14:00:26 -07:00
parent 90d43db442
commit da03cb8f0b

View File

@ -250,7 +250,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
dtype=torch.int32, dtype=torch.int32,
device="cpu") device="cpu")
self.sliding_window_size = getattr(kv_cache_spec, "sliding_window", None) self.sliding_window = getattr(kv_cache_spec, "sliding_window", None)
def _get_workspace_buffer(self): def _get_workspace_buffer(self):
if self._workspace_buffer is None: if self._workspace_buffer is None:
@ -489,10 +489,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
shared_kv_page_indices_cpu = None shared_kv_page_indices_cpu = None
shared_kv_last_page_len_cpu = None shared_kv_last_page_len_cpu = None
max_num_blocks = block_table_bounds_cpu.max() max_num_blocks = block_table_bounds_cpu.max().item()
arange = self.block_table_arange[:max_num_blocks].unsqueeze(0) arange = self.block_table_arange[:max_num_blocks].unsqueeze(0)
mask = arange < block_table_bounds_cpu.unsqueeze(1) mask = arange < block_table_bounds_cpu.unsqueeze(1)
if (self.sliding_window_size is not None and not use_cascade if (self.sliding_window is not None and not use_cascade
and num_decodes > 0): and num_decodes > 0):
# NOTE(woosuk): Since FlashInfer's decode kernel doesn't skip the kv # NOTE(woosuk): Since FlashInfer's decode kernel doesn't skip the kv
# outside the sliding window and only do masking, we manually # outside the sliding window and only do masking, we manually
@ -500,7 +500,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# NOTE: Don't apply this optimization to prefill requests. # NOTE: Don't apply this optimization to prefill requests.
decode_seq_lens_cpu = seq_lens_cpu[:num_decodes] decode_seq_lens_cpu = seq_lens_cpu[:num_decodes]
num_skipped_pages = ( num_skipped_pages = (
torch.relu(decode_seq_lens_cpu - self.sliding_window_size) // torch.relu(decode_seq_lens_cpu - self.sliding_window) //
page_size) page_size)
block_table_bounds_cpu[:num_decodes] -= num_skipped_pages block_table_bounds_cpu[:num_decodes] -= num_skipped_pages
@ -511,7 +511,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
num_actual_pages = torch.sum(mask) num_actual_pages = torch.sum(mask)
paged_kv_indices = self.paged_kv_indices[:num_actual_pages] paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
torch.masked_select(block_table_tensor[:, :max_num_blocks], torch.masked_select(block_table_tensor[:, :max_num_blocks],
mask, mask.to(self.device, non_blocking=True),
out=paged_kv_indices) out=paged_kv_indices)
# write self.paged_kv_indptr_cpu inplace (0-index is always 0) # write self.paged_kv_indptr_cpu inplace (0-index is always 0)