mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-25 16:31:18 +08:00
fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
90d43db442
commit
da03cb8f0b
@ -250,7 +250,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device="cpu")
|
device="cpu")
|
||||||
|
|
||||||
self.sliding_window_size = getattr(kv_cache_spec, "sliding_window", None)
|
self.sliding_window = getattr(kv_cache_spec, "sliding_window", None)
|
||||||
|
|
||||||
def _get_workspace_buffer(self):
|
def _get_workspace_buffer(self):
|
||||||
if self._workspace_buffer is None:
|
if self._workspace_buffer is None:
|
||||||
@ -489,10 +489,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
shared_kv_page_indices_cpu = None
|
shared_kv_page_indices_cpu = None
|
||||||
shared_kv_last_page_len_cpu = None
|
shared_kv_last_page_len_cpu = None
|
||||||
|
|
||||||
max_num_blocks = block_table_bounds_cpu.max()
|
max_num_blocks = block_table_bounds_cpu.max().item()
|
||||||
arange = self.block_table_arange[:max_num_blocks].unsqueeze(0)
|
arange = self.block_table_arange[:max_num_blocks].unsqueeze(0)
|
||||||
mask = arange < block_table_bounds_cpu.unsqueeze(1)
|
mask = arange < block_table_bounds_cpu.unsqueeze(1)
|
||||||
if (self.sliding_window_size is not None and not use_cascade
|
if (self.sliding_window is not None and not use_cascade
|
||||||
and num_decodes > 0):
|
and num_decodes > 0):
|
||||||
# NOTE(woosuk): Since FlashInfer's decode kernel doesn't skip the kv
|
# NOTE(woosuk): Since FlashInfer's decode kernel doesn't skip the kv
|
||||||
# outside the sliding window and only do masking, we manually
|
# outside the sliding window and only do masking, we manually
|
||||||
@ -500,7 +500,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
# NOTE: Don't apply this optimization to prefill requests.
|
# NOTE: Don't apply this optimization to prefill requests.
|
||||||
decode_seq_lens_cpu = seq_lens_cpu[:num_decodes]
|
decode_seq_lens_cpu = seq_lens_cpu[:num_decodes]
|
||||||
num_skipped_pages = (
|
num_skipped_pages = (
|
||||||
torch.relu(decode_seq_lens_cpu - self.sliding_window_size) //
|
torch.relu(decode_seq_lens_cpu - self.sliding_window) //
|
||||||
page_size)
|
page_size)
|
||||||
|
|
||||||
block_table_bounds_cpu[:num_decodes] -= num_skipped_pages
|
block_table_bounds_cpu[:num_decodes] -= num_skipped_pages
|
||||||
@ -511,7 +511,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
num_actual_pages = torch.sum(mask)
|
num_actual_pages = torch.sum(mask)
|
||||||
paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
|
paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
|
||||||
torch.masked_select(block_table_tensor[:, :max_num_blocks],
|
torch.masked_select(block_table_tensor[:, :max_num_blocks],
|
||||||
mask,
|
mask.to(self.device, non_blocking=True),
|
||||||
out=paged_kv_indices)
|
out=paged_kv_indices)
|
||||||
|
|
||||||
# write self.paged_kv_indptr_cpu inplace (0-index is always 0)
|
# write self.paged_kv_indptr_cpu inplace (0-index is always 0)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user