mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:45:01 +08:00
[Core] Flashinfer - Remove advance step size restriction (#10282)
This commit is contained in:
parent
1b886aa104
commit
b6dde33019
@ -88,6 +88,7 @@ inline void verify_tensor(std::string const& name, torch::Tensor const& t,
|
||||
}
|
||||
}
|
||||
|
||||
/// each thread processes a block per query
|
||||
__global__ void advance_step_flashinfer_kernel(
|
||||
int num_threads, int num_seqs, int num_queries, int block_size,
|
||||
long* input_tokens_ptr, long const* sampled_token_ids_ptr,
|
||||
@ -134,8 +135,10 @@ __global__ void advance_step_flashinfer_indptr_kernel(
|
||||
int num_threads, int num_seqs, int num_queries, int* paged_kv_indptr_ptr,
|
||||
int* block_table_bound_ptr) {
|
||||
int idx = blockIdx.x * num_threads + threadIdx.x;
|
||||
|
||||
// Update paged_kv_indptr
|
||||
if (idx == 0) {
|
||||
paged_kv_indptr_ptr[idx] = 0;
|
||||
}
|
||||
if (idx < num_queries) {
|
||||
int sum = 0;
|
||||
for (int i = 0; i <= idx; ++i) {
|
||||
@ -146,20 +149,33 @@ __global__ void advance_step_flashinfer_indptr_kernel(
|
||||
}
|
||||
|
||||
__global__ void advance_step_flashinfer_indices_kernel(
|
||||
int num_threads, int num_seqs, int num_queries, int const* block_tables_ptr,
|
||||
int64_t const block_tables_stride, int* paged_kv_indices_ptr,
|
||||
int num_seqs, int num_queries, int const* block_tables_ptr,
|
||||
int64_t const max_num_blocks_per_seq, int* paged_kv_indices_ptr,
|
||||
int* paged_kv_indptr_ptr, int* block_table_bound_ptr) {
|
||||
int idx = blockIdx.x * num_threads + threadIdx.x;
|
||||
int row = idx / block_tables_stride;
|
||||
int col = idx % block_tables_stride;
|
||||
// note: max_num_blocks_per_seq = block_tables.stride(0)
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (row < num_queries && col < block_table_bound_ptr[row]) {
|
||||
paged_kv_indices_ptr[paged_kv_indptr_ptr[row] + col] =
|
||||
block_tables_ptr[row * block_tables_stride + col];
|
||||
// when cuda graphs are enabled, paged_kv_indptr tensor
|
||||
// has to be updated for the padded queries
|
||||
// tid represents a query# for paged_kv_indptr tensor
|
||||
if (num_queries < tid && tid <= num_seqs) {
|
||||
paged_kv_indptr_ptr[tid] = paged_kv_indptr_ptr[num_queries];
|
||||
}
|
||||
// if cudagraph, fill padded seqs with the last valid seq's indptr
|
||||
if (num_queries < row && row <= num_seqs) {
|
||||
paged_kv_indptr_ptr[row] = paged_kv_indptr_ptr[num_queries];
|
||||
|
||||
// each thread processes a block_ptr in block_tables
|
||||
// block_tables shape: [num_queries, max_num_blocks_per_seq]
|
||||
// paged_kv_indices is flattened block_tables.
|
||||
for (int idx = tid; idx < (num_seqs * max_num_blocks_per_seq);
|
||||
idx += (gridDim.x * blockDim.x)) {
|
||||
// block_tables-row = paged_kv_indptr[queryNum]
|
||||
int queryNum = idx / max_num_blocks_per_seq;
|
||||
int col = idx % max_num_blocks_per_seq;
|
||||
if (queryNum < num_queries && col < block_table_bound_ptr[queryNum]) {
|
||||
int indices_arr_idx = paged_kv_indptr_ptr[queryNum] + col;
|
||||
int block_tables_idx = queryNum * max_num_blocks_per_seq + col;
|
||||
paged_kv_indices_ptr[indices_arr_idx] =
|
||||
block_tables_ptr[block_tables_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -247,22 +263,16 @@ void advance_step_flashinfer(
|
||||
int threads;
|
||||
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
|
||||
cudaDeviceGetAttribute(&threads, cudaDevAttrMaxThreadsPerBlock, dev);
|
||||
|
||||
int block_tables_stride = block_tables.stride(0);
|
||||
TORCH_CHECK((blocks * threads > num_queries),
|
||||
"multi-step: not enough threads to map to num_queries = ",
|
||||
num_queries, " block_tables.stride(0) = ", block_tables.stride(0),
|
||||
" blocks = ", blocks, " max_threads = ", threads);
|
||||
if (logging) {
|
||||
printf("launching kernel with %d blocks\n", blocks);
|
||||
printf("launching kernels with %d blocks and %d threads\n", blocks,
|
||||
threads);
|
||||
}
|
||||
|
||||
// TODO(will): support arbitrary block_tables stride
|
||||
if ((blocks * threads) / block_tables.stride(0) < num_queries) {
|
||||
TORCH_CHECK(false,
|
||||
"multi-step: not enough threads to map block_table to"
|
||||
"FlashInfer's paged_kv_indices on GPU. Try reducing the number "
|
||||
"of seqs,",
|
||||
" increasing the block size or take smaller steps.",
|
||||
" num_queries = ", num_queries,
|
||||
" block_tables.stride(0) = ", block_tables.stride(0),
|
||||
" blocks = ", blocks, " max_threads = ", threads);
|
||||
}
|
||||
|
||||
advance_step_flashinfer_kernel<<<blocks, threads, 0, stream>>>(
|
||||
threads, num_seqs, num_queries, block_size,
|
||||
reinterpret_cast<long*>(input_tokens.data_ptr()),
|
||||
@ -281,7 +291,7 @@ void advance_step_flashinfer(
|
||||
reinterpret_cast<int*>(block_table_bound.data_ptr()));
|
||||
|
||||
advance_step_flashinfer_indices_kernel<<<blocks, threads, 0, stream>>>(
|
||||
threads, num_seqs, num_queries,
|
||||
num_seqs, num_queries,
|
||||
reinterpret_cast<int const*>(block_tables.data_ptr()),
|
||||
block_tables.stride(0),
|
||||
reinterpret_cast<int*>(paged_kv_indices.data_ptr()),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user