mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-02 19:39:46 +08:00
Optimize input preparation for FlashInfer [2/N] (#23174)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
5bd9f84158
commit
6578e87365
@ -6,6 +6,7 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
|
||||
BatchPrefillWithPagedKVCacheWrapper,
|
||||
@ -22,6 +23,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey, kFp8StaticTensorSym, kNvfp4Quant)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import cdiv, is_pin_memory_available
|
||||
from vllm.utils.flashinfer import (supports_trtllm_attention,
|
||||
use_trtllm_attention)
|
||||
@ -230,6 +232,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.paged_kv_indptr_np = self.paged_kv_indptr_cpu.numpy()
|
||||
self.paged_kv_indices_cpu = torch.zeros(max_num_pages,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
@ -238,10 +241,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
|
||||
self.block_table_arange = torch.arange(max_num_pages_per_req,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
self.paged_kv_last_page_len_np = (
|
||||
self.paged_kv_last_page_len_cpu.numpy())
|
||||
|
||||
def _get_workspace_buffer(self):
|
||||
if self._workspace_buffer is None:
|
||||
@ -317,9 +318,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
max_seq_len = common_attn_metadata.max_seq_len
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
||||
seq_lens_np = seq_lens_cpu.numpy()
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
|
||||
block_table_bounds_cpu = (seq_lens_cpu + page_size - 1) // page_size
|
||||
num_blocks_np = (seq_lens_np + (page_size - 1)) // page_size
|
||||
|
||||
use_cascade = common_prefix_len > 0
|
||||
if use_cascade:
|
||||
@ -342,37 +344,41 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
|
||||
# Remove the blocks of the shared prefix from all requests.
|
||||
block_table_tensor = block_table_tensor[:, num_common_kv_blocks:]
|
||||
block_table_bounds_cpu -= num_common_kv_blocks
|
||||
num_blocks_np -= num_common_kv_blocks
|
||||
else:
|
||||
shared_qo_indptr_cpu = None
|
||||
shared_kv_page_indptr_cpu = None
|
||||
shared_kv_page_indices_cpu = None
|
||||
shared_kv_last_page_len_cpu = None
|
||||
|
||||
max_num_blocks = block_table_bounds_cpu.max().item()
|
||||
block_table_bounds = block_table_bounds_cpu.to(self.device,
|
||||
non_blocking=True)
|
||||
mask = (self.block_table_arange[:max_num_blocks].unsqueeze(0)
|
||||
< block_table_bounds.unsqueeze(1))
|
||||
# write self.paged_kv_indices inplace
|
||||
num_actual_pages = torch.sum(mask)
|
||||
paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
|
||||
torch.masked_select(block_table_tensor[:, :max_num_blocks],
|
||||
mask,
|
||||
out=paged_kv_indices)
|
||||
|
||||
# write self.paged_kv_indptr_cpu inplace (0-index is always 0)
|
||||
torch.cumsum(block_table_bounds_cpu,
|
||||
dim=0,
|
||||
dtype=torch.int32,
|
||||
out=self.paged_kv_indptr_cpu[1:1 + num_reqs])
|
||||
np.cumsum(
|
||||
num_blocks_np,
|
||||
dtype=np.int32,
|
||||
out=self.paged_kv_indptr_np[1:num_reqs + 1],
|
||||
)
|
||||
paged_kv_indptr = self.paged_kv_indptr[:num_reqs + 1]
|
||||
paged_kv_indptr.copy_(self.paged_kv_indptr_cpu[:num_reqs + 1],
|
||||
non_blocking=True)
|
||||
|
||||
# write self.paged_kv_indices inplace
|
||||
num_actual_pages = num_blocks_np.sum().item()
|
||||
paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
|
||||
_copy_page_indices_kernel[(num_reqs, )](
|
||||
paged_kv_indices,
|
||||
block_table_tensor,
|
||||
block_table_tensor.stride(0),
|
||||
paged_kv_indptr,
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
|
||||
paged_kv_last_page_len_cpu = seq_lens_cpu % page_size
|
||||
# write self.paged_kv_last_page_len_cpu inplace
|
||||
torch.where(paged_kv_last_page_len_cpu == 0,
|
||||
torch.tensor(page_size),
|
||||
paged_kv_last_page_len_cpu,
|
||||
out=self.paged_kv_last_page_len_cpu[:num_reqs])
|
||||
paged_kv_last_page_len_np = seq_lens_np % page_size
|
||||
self.paged_kv_last_page_len_np[:num_reqs] = np.where(
|
||||
paged_kv_last_page_len_np == 0,
|
||||
page_size,
|
||||
paged_kv_last_page_len_np,
|
||||
)
|
||||
|
||||
# Check if any layer uses sinks (requires TRTLLM attention)
|
||||
has_sinks = self.global_hyperparameters.has_sinks
|
||||
@ -1002,3 +1008,25 @@ def fast_plan_decode(
|
||||
self._sm_scale = sm_scale
|
||||
self._rope_scale = rope_scale
|
||||
self._rope_theta = rope_theta
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _copy_page_indices_kernel(
|
||||
page_indices,
|
||||
block_table,
|
||||
block_table_stride,
|
||||
cu_num_blocks,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
row_ptr = block_table + req_idx * block_table_stride
|
||||
start_idx = tl.load(cu_num_blocks + req_idx)
|
||||
end_idx = tl.load(cu_num_blocks + req_idx + 1)
|
||||
num_blocks = end_idx - start_idx
|
||||
|
||||
offset = tl.arange(0, BLOCK_SIZE)
|
||||
for i in tl.range(0, num_blocks, BLOCK_SIZE):
|
||||
block_ids = tl.load(row_ptr + i + offset, mask=i + offset < num_blocks)
|
||||
tl.store(page_indices + start_idx + i + offset,
|
||||
block_ids,
|
||||
mask=i + offset < num_blocks)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user