mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 07:07:05 +08:00
reorder
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
caf963f2e9
commit
5c133fc860
@ -129,7 +129,7 @@ class BlockTables:
|
||||
new_block_ids_gpu = new_block_ids_cpu.to(self.device,
|
||||
non_blocking=True)
|
||||
|
||||
_append_block_ids_kernel[(num_reqs, self.num_kv_cache_groups)](
|
||||
_append_block_ids_kernel[(self.num_kv_cache_groups, num_reqs)](
|
||||
self.req_indices.copy_to_gpu(num_reqs),
|
||||
self.cu_num_new_blocks.copy_to_gpu(),
|
||||
self.cu_num_new_blocks.gpu.stride(0),
|
||||
@ -148,7 +148,7 @@ class BlockTables:
|
||||
idx_mapping: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
batch_size = idx_mapping.shape[0]
|
||||
_compute_block_tables_kernel[(batch_size, self.num_kv_cache_groups)](
|
||||
_compute_block_tables_kernel[(self.num_kv_cache_groups, batch_size)](
|
||||
idx_mapping,
|
||||
self.buffer_ptrs,
|
||||
self.block_table_ptrs,
|
||||
@ -167,7 +167,7 @@ class BlockTables:
|
||||
num_reqs = query_start_loc.shape[0] - 1
|
||||
num_tokens = positions.shape[0]
|
||||
num_groups = self.num_kv_cache_groups
|
||||
_compute_slot_mappings_kernel[(num_reqs + 1, num_groups)](
|
||||
_compute_slot_mappings_kernel[(num_groups, num_reqs + 1)](
|
||||
num_tokens,
|
||||
self.max_num_batched_tokens,
|
||||
query_start_loc,
|
||||
@ -200,8 +200,8 @@ def _append_block_ids_kernel(
|
||||
# Constants
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
group_id = tl.program_id(1)
|
||||
group_id = tl.program_id(0)
|
||||
batch_idx = tl.program_id(1)
|
||||
req_idx = tl.load(req_indices + batch_idx)
|
||||
do_overwrite = tl.load(overwrite + batch_idx)
|
||||
|
||||
@ -246,9 +246,9 @@ def _compute_block_tables_kernel(
|
||||
num_blocks_stride,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
# kv cache group id
|
||||
group_id = tl.program_id(1)
|
||||
group_id = tl.program_id(0)
|
||||
batch_idx = tl.program_id(1)
|
||||
req_idx = tl.load(batch_idx_to_req_idx + batch_idx)
|
||||
|
||||
group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride
|
||||
@ -280,12 +280,12 @@ def _compute_slot_mappings_kernel(
|
||||
PAD_ID: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
# kv cache group id
|
||||
group_id = tl.program_id(1)
|
||||
group_id = tl.program_id(0)
|
||||
req_idx = tl.program_id(1)
|
||||
slot_mapping_ptr = slot_mappings_ptr + group_id * slot_mappings_stride
|
||||
|
||||
if req_idx == tl.num_programs(0) - 1:
|
||||
if req_idx == tl.num_programs(1) - 1:
|
||||
# Pad remaining slots to -1. This is needed for CUDA graphs.
|
||||
for i in tl.range(num_tokens, max_num_tokens, BLOCK_SIZE):
|
||||
offset = i + tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user