mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-17 13:27:04 +08:00
simplify
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
b1d52734f7
commit
a851aaa0fc
@ -37,10 +37,6 @@ class BlockTables:
|
||||
self.block_tables: list[torch.Tensor] = []
|
||||
# [num_kv_cache_groups, max_num_cached_reqs, max_num_blocks]
|
||||
self.block_table_buffers: list[torch.Tensor] = []
|
||||
# [num_kv_cache_groups, max_num_reqs]
|
||||
self.num_blocks: list[torch.Tensor] = []
|
||||
# [num_kv_cache_groups, max_num_tokens]
|
||||
self.slot_mappings: list[torch.Tensor] = []
|
||||
for i in range(self.num_kv_cache_groups):
|
||||
block_size = self.block_sizes[i]
|
||||
max_num_blocks = cdiv(self.max_model_len, block_size)
|
||||
@ -61,44 +57,40 @@ class BlockTables:
|
||||
)
|
||||
self.block_table_buffers.append(block_table_buffer)
|
||||
|
||||
num_blocks = torch.zeros(self.max_num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
self.num_blocks.append(num_blocks)
|
||||
|
||||
slot_mapping = torch.zeros(self.max_num_batched_tokens,
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
self.slot_mappings.append(slot_mapping)
|
||||
|
||||
self.block_table_ptrs = self._make_ptr_tensor(self.block_tables)
|
||||
self.buffer_ptrs = self._make_ptr_tensor(self.block_table_buffers)
|
||||
self.block_table_strides = torch.tensor(
|
||||
[b.stride(0) for b in self.block_tables],
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
self.num_blocks_ptrs = self._make_ptr_tensor(self.num_blocks)
|
||||
self.block_sizes_tensor = torch.tensor(self.block_sizes,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
self.slot_mapping_ptrs = self._make_ptr_tensor(self.slot_mappings)
|
||||
self.num_blocks = torch.zeros(self.num_kv_cache_groups,
|
||||
self.max_num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
self.slot_mappings = torch.zeros(self.num_kv_cache_groups,
|
||||
self.max_num_batched_tokens,
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
|
||||
# Misc buffers.
|
||||
self.req_indices = self._make_buffer(self.max_num_reqs, torch.int32)
|
||||
self.overwrite = self._make_buffer(self.max_num_reqs, torch.bool)
|
||||
self.cu_num_new_blocks: list[CpuGpuBuffer] = []
|
||||
self.new_block_ids: list[CpuGpuBuffer] = []
|
||||
for i in range(self.num_kv_cache_groups):
|
||||
self.cu_num_new_blocks.append(
|
||||
self._make_buffer(self.max_num_reqs + 1, torch.int32))
|
||||
# NOTE(woosuk): Here, we assume that total number of new blocks
|
||||
# is ALWAYS less than max_num_batched_tokens.
|
||||
# TODO(woosuk): Rigorously verify that this assumption is correct.
|
||||
self.new_block_ids.append(
|
||||
self._make_buffer(self.max_num_batched_tokens, torch.int32))
|
||||
self.req_indices = self._make_buffer(self.max_num_reqs,
|
||||
dtype=torch.int32)
|
||||
self.overwrite = self._make_buffer(self.max_num_reqs, dtype=torch.bool)
|
||||
self.cu_num_new_blocks = self._make_buffer(self.num_kv_cache_groups,
|
||||
self.max_num_reqs + 1,
|
||||
dtype=torch.int32)
|
||||
# NOTE(woosuk): Here, we assume that total number of new blocks
|
||||
# is ALWAYS less than max_num_batched_tokens.
|
||||
# TODO(woosuk): Rigorously verify that this assumption is correct.
|
||||
self.new_block_ids = self._make_buffer(self.num_kv_cache_groups,
|
||||
self.max_num_batched_tokens,
|
||||
dtype=torch.int32)
|
||||
|
||||
def _make_buffer(self, n: int, dtype: torch.dtype) -> CpuGpuBuffer:
|
||||
return CpuGpuBuffer(n,
|
||||
def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
|
||||
return CpuGpuBuffer(*args,
|
||||
dtype=dtype,
|
||||
pin_memory=self.pin_memory,
|
||||
device=self.device)
|
||||
@ -126,24 +118,21 @@ class BlockTables:
|
||||
self.req_indices.np[:num_reqs] = req_indices
|
||||
self.overwrite.np[:num_reqs] = overwrite
|
||||
for i in range(self.num_kv_cache_groups):
|
||||
self.cu_num_new_blocks[i].np[:num_reqs + 1] = cu_num_new_blocks[i]
|
||||
self.cu_num_new_blocks.np[i, :num_reqs + 1] = cu_num_new_blocks[i]
|
||||
n = len(new_block_ids[i])
|
||||
self.new_block_ids[i].np[:n] = new_block_ids[i]
|
||||
self.new_block_ids.np[i, :n] = new_block_ids[i]
|
||||
|
||||
cu_num_new_blocks_ptrs = self._make_ptr_tensor(
|
||||
[x.copy_to_gpu(num_reqs + 1) for x in self.cu_num_new_blocks])
|
||||
new_block_ids_ptrs = self._make_ptr_tensor([
|
||||
x.copy_to_gpu(len(new_block_ids[i]))
|
||||
for i, x in enumerate(self.new_block_ids)
|
||||
])
|
||||
_append_block_ids_kernel[(num_reqs, self.num_kv_cache_groups)](
|
||||
self.req_indices.copy_to_gpu(num_reqs),
|
||||
cu_num_new_blocks_ptrs,
|
||||
new_block_ids_ptrs,
|
||||
self.cu_num_new_blocks.copy_to_gpu(),
|
||||
self.cu_num_new_blocks.gpu.stride(0),
|
||||
self.new_block_ids.copy_to_gpu(),
|
||||
self.new_block_ids.gpu.stride(0),
|
||||
self.overwrite.copy_to_gpu(num_reqs),
|
||||
self.block_table_strides,
|
||||
self.buffer_ptrs,
|
||||
self.num_blocks_ptrs,
|
||||
self.num_blocks,
|
||||
self.num_blocks.stride(0),
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
|
||||
@ -157,7 +146,8 @@ class BlockTables:
|
||||
self.buffer_ptrs,
|
||||
self.block_table_ptrs,
|
||||
self.block_table_strides,
|
||||
self.num_blocks_ptrs,
|
||||
self.num_blocks,
|
||||
self.num_blocks.stride(0),
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
return tuple(b[:batch_size] for b in self.block_tables)
|
||||
@ -178,7 +168,8 @@ class BlockTables:
|
||||
self.block_table_ptrs,
|
||||
self.block_table_strides,
|
||||
self.block_sizes_tensor,
|
||||
self.slot_mapping_ptrs,
|
||||
self.slot_mappings,
|
||||
self.slot_mappings.stride(0),
|
||||
PAD_ID=PAD_SLOT_ID,
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
@ -189,13 +180,16 @@ class BlockTables:
|
||||
def _append_block_ids_kernel(
|
||||
# Inputs
|
||||
req_indices, # [num_reqs]
|
||||
cu_num_new_block_ptrs, # [num_kv_cache_groups, num_reqs + 1]
|
||||
new_block_id_ptrs, # [num_kv_cache_groups, num_new_blocks]
|
||||
cu_num_new_blocks_ptr, # [num_kv_cache_groups, num_reqs + 1]
|
||||
cu_num_new_blocks_stride,
|
||||
new_block_ids_ptr, # [num_kv_cache_groups, num_new_blocks]
|
||||
new_block_ids_stride,
|
||||
overwrite, # [num_reqs]
|
||||
block_table_strides, # [num_kv_cache_groups]
|
||||
# Outputs
|
||||
block_table_buffer_ptrs, # [num_kv_cache_groups]
|
||||
num_block_ptrs, # [num_kv_cache_groups]
|
||||
num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs]
|
||||
num_blocks_stride,
|
||||
# Constants
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
@ -204,19 +198,19 @@ def _append_block_ids_kernel(
|
||||
req_idx = tl.load(req_indices + batch_idx)
|
||||
do_overwrite = tl.load(overwrite + batch_idx)
|
||||
|
||||
cu_num_new_blocks_ptr = _load_ptr(cu_num_new_block_ptrs + group_id,
|
||||
tl.int32)
|
||||
start_idx = tl.load(cu_num_new_blocks_ptr + batch_idx)
|
||||
end_idx = tl.load(cu_num_new_blocks_ptr + batch_idx + 1)
|
||||
group_new_blocks_ptr = (cu_num_new_blocks_ptr +
|
||||
group_id * cu_num_new_blocks_stride)
|
||||
start_idx = tl.load(group_new_blocks_ptr + batch_idx)
|
||||
end_idx = tl.load(group_new_blocks_ptr + batch_idx + 1)
|
||||
num_new_blocks = end_idx - start_idx
|
||||
|
||||
num_blocks_ptr = _load_ptr(num_block_ptrs + group_id, tl.int32)
|
||||
group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride
|
||||
if do_overwrite:
|
||||
dst_start_idx = 0
|
||||
else:
|
||||
dst_start_idx = tl.load(num_blocks_ptr + req_idx)
|
||||
dst_start_idx = tl.load(group_num_blocks_ptr + req_idx)
|
||||
dst_end_idx = dst_start_idx + num_new_blocks
|
||||
tl.store(num_blocks_ptr + req_idx, dst_end_idx)
|
||||
tl.store(group_num_blocks_ptr + req_idx, dst_end_idx)
|
||||
|
||||
# Destination
|
||||
block_table_buffer_ptr = _load_ptr(block_table_buffer_ptrs + group_id,
|
||||
@ -224,10 +218,11 @@ def _append_block_ids_kernel(
|
||||
block_table_stride = tl.load(block_table_strides + group_id)
|
||||
buffer_row_ptr = block_table_buffer_ptr + req_idx * block_table_stride
|
||||
|
||||
new_block_ids_ptr = _load_ptr(new_block_id_ptrs + group_id, tl.int32)
|
||||
group_new_block_ids_ptr = (new_block_ids_ptr +
|
||||
group_id * new_block_ids_stride)
|
||||
for i in tl.range(0, num_new_blocks, BLOCK_SIZE):
|
||||
offset = i + tl.arange(0, BLOCK_SIZE)
|
||||
block_ids = tl.load(new_block_ids_ptr + start_idx + offset,
|
||||
block_ids = tl.load(group_new_block_ids_ptr + start_idx + offset,
|
||||
mask=offset < num_new_blocks)
|
||||
tl.store(buffer_row_ptr + dst_start_idx + offset,
|
||||
block_ids,
|
||||
@ -240,7 +235,8 @@ def _compute_block_tables_kernel(
|
||||
src_block_table_ptrs, # [num_kv_cache_groups]
|
||||
dst_block_table_ptrs, # [num_kv_cache_groups]
|
||||
block_table_strides, # [num_kv_cache_groups]
|
||||
num_blocks_ptrs, # [num_kv_cache_groups]
|
||||
num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs]
|
||||
num_blocks_stride,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
@ -248,8 +244,8 @@ def _compute_block_tables_kernel(
|
||||
group_id = tl.program_id(1)
|
||||
req_idx = tl.load(batch_idx_to_req_idx + batch_idx)
|
||||
|
||||
num_blocks_ptr = _load_ptr(num_blocks_ptrs + group_id, tl.int32)
|
||||
num_blocks = tl.load(num_blocks_ptr + req_idx)
|
||||
group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride
|
||||
num_blocks = tl.load(group_num_blocks_ptr + req_idx)
|
||||
|
||||
stride = tl.load(block_table_strides + group_id)
|
||||
src_block_table_ptr = _load_ptr(src_block_table_ptrs + group_id, tl.int32)
|
||||
@ -272,14 +268,15 @@ def _compute_slot_mappings_kernel(
|
||||
block_table_ptrs, # [num_kv_cache_groups]
|
||||
block_table_strides, # [num_kv_cache_groups]
|
||||
page_sizes, # [num_kv_cache_groups]
|
||||
slot_mapping_ptrs, # [num_kv_cache_groups]
|
||||
slot_mappings_ptr, # [num_kv_cache_groups, max_num_tokens]
|
||||
slot_mappings_stride,
|
||||
PAD_ID: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
# kv cache group id
|
||||
group_id = tl.program_id(1)
|
||||
slot_mapping_ptr = _load_ptr(slot_mapping_ptrs + group_id, tl.int64)
|
||||
slot_mapping_ptr = slot_mappings_ptr + group_id * slot_mappings_stride
|
||||
|
||||
if req_idx == tl.num_programs(0) - 1:
|
||||
# Pad remaining slots to -1. This is needed for CUDA graphs.
|
||||
|
||||
@ -659,7 +659,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
)
|
||||
|
||||
slot_mappings = self.block_tables.compute_slot_mappings(
|
||||
query_start_loc, self.positions[:total_num_scheduled_tokens])
|
||||
query_start_loc, self.positions.gpu[:total_num_scheduled_tokens])
|
||||
|
||||
# Used in the below loop.
|
||||
query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1]
|
||||
|
||||
@ -109,8 +109,6 @@ class RequestState:
|
||||
self.req_id_to_index: dict[str, int] = {}
|
||||
self.index_to_req_id: dict[int, str] = {}
|
||||
self.free_indices = list(range(max_num_cached_reqs))
|
||||
# Used to construct the input batch.
|
||||
self._add_scalar_attr("idx_mapping", torch.int32)
|
||||
|
||||
# Request states.
|
||||
self.req_data: dict[int, RequestData] = {}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user