mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:15:01 +08:00
[multi-step] add flashinfer backend (#7928)
This commit is contained in:
parent
f2e263b801
commit
a6c0f3658d
19
csrc/ops.h
19
csrc/ops.h
@ -54,10 +54,21 @@ void gelu_fast(torch::Tensor& out, torch::Tensor& input);
|
|||||||
|
|
||||||
void gelu_quick(torch::Tensor& out, torch::Tensor& input);
|
void gelu_quick(torch::Tensor& out, torch::Tensor& input);
|
||||||
|
|
||||||
void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
|
void advance_step_flashattn(int64_t num_seqs, int64_t num_queries,
|
||||||
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
|
int64_t block_size, torch::Tensor& input_tokens,
|
||||||
torch::Tensor& input_positions, torch::Tensor& seq_lens,
|
torch::Tensor& sampled_token_ids,
|
||||||
torch::Tensor& slot_mapping, torch::Tensor& block_tables);
|
torch::Tensor& input_positions,
|
||||||
|
torch::Tensor& seq_lens,
|
||||||
|
torch::Tensor& slot_mapping,
|
||||||
|
torch::Tensor& block_tables);
|
||||||
|
|
||||||
|
void advance_step_flashinfer(
|
||||||
|
int64_t num_seqs, int64_t num_queries, int64_t block_size,
|
||||||
|
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
|
||||||
|
torch::Tensor& input_positions, torch::Tensor& seq_lens,
|
||||||
|
torch::Tensor& slot_mapping, torch::Tensor& block_tables,
|
||||||
|
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
|
||||||
|
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds);
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
|
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
|
||||||
|
|||||||
@ -12,13 +12,11 @@ namespace prepare_inputs {
|
|||||||
|
|
||||||
//
|
//
|
||||||
template <int const num_threads>
|
template <int const num_threads>
|
||||||
__global__ void advance_step_kernel(int num_seqs, int num_queries,
|
__global__ void advance_step_flashattn_kernel(
|
||||||
int block_size, long* input_tokens_ptr,
|
int num_seqs, int num_queries, int block_size, long* input_tokens_ptr,
|
||||||
long const* sampled_token_ids_ptr,
|
long const* sampled_token_ids_ptr, long* input_positions_ptr,
|
||||||
long* input_positions_ptr,
|
int* seq_lens_ptr, long* slot_mapping_ptr, int const* block_tables_ptr,
|
||||||
int* seq_lens_ptr, long* slot_mapping_ptr,
|
int64_t const block_tables_stride) {
|
||||||
int const* block_tables_ptr,
|
|
||||||
int64_t const block_tables_stride) {
|
|
||||||
int num_query_blocks = div_ceil(num_queries, num_threads);
|
int num_query_blocks = div_ceil(num_queries, num_threads);
|
||||||
|
|
||||||
if (blockIdx.x >= num_query_blocks) {
|
if (blockIdx.x >= num_query_blocks) {
|
||||||
@ -79,16 +77,91 @@ inline void verify_tensor(std::string const& name, torch::Tensor& t,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void advance_step(int num_seqs, int num_queries, int block_size,
|
__global__ void advance_step_flashinfer_kernel(
|
||||||
torch::Tensor& input_tokens, // type: long
|
int num_threads, int num_seqs, int num_queries, int block_size,
|
||||||
torch::Tensor& sampled_token_ids, // type: long
|
long* input_tokens_ptr, long const* sampled_token_ids_ptr,
|
||||||
torch::Tensor& input_positions, // type: long
|
long* input_positions_ptr, int* seq_lens_ptr, long* slot_mapping_ptr,
|
||||||
torch::Tensor& seq_lens, // type: int
|
int const* block_tables_ptr, int64_t const block_tables_stride,
|
||||||
torch::Tensor& slot_mapping, // type: long
|
int* paged_kv_last_page_len_ptr, int* block_table_bound_ptr) {
|
||||||
torch::Tensor& block_tables) { // type: int
|
int num_query_blocks = div_ceil(num_queries, num_threads);
|
||||||
|
|
||||||
|
if (blockIdx.x < num_query_blocks) {
|
||||||
|
int cur_query_id = blockIdx.x * num_threads + threadIdx.x;
|
||||||
|
|
||||||
|
if (cur_query_id < num_queries) {
|
||||||
|
// Update input_tokens
|
||||||
|
input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id];
|
||||||
|
|
||||||
|
int seq_len = seq_lens_ptr[cur_query_id];
|
||||||
|
int next_seq_len = seq_len + 1;
|
||||||
|
int next_input_pos = next_seq_len - 1;
|
||||||
|
|
||||||
|
// Update seq_lens
|
||||||
|
seq_lens_ptr[cur_query_id] = next_seq_len;
|
||||||
|
// Update input_positions
|
||||||
|
input_positions_ptr[cur_query_id] = next_input_pos;
|
||||||
|
|
||||||
|
int const* seq_block_tables_ptr =
|
||||||
|
block_tables_ptr + block_tables_stride * cur_query_id;
|
||||||
|
|
||||||
|
int block_index = next_input_pos / block_size;
|
||||||
|
int block_offset = next_input_pos % block_size;
|
||||||
|
|
||||||
|
// Update paged_kv_last_page_len
|
||||||
|
paged_kv_last_page_len_ptr[cur_query_id] = block_offset + 1;
|
||||||
|
|
||||||
|
int slot_num =
|
||||||
|
seq_block_tables_ptr[block_index] * block_size + block_offset;
|
||||||
|
// Update slot_mapping
|
||||||
|
slot_mapping_ptr[cur_query_id] = slot_num;
|
||||||
|
block_table_bound_ptr[cur_query_id] = div_ceil(next_seq_len, block_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void advance_step_flashinfer_indptr_kernel(
|
||||||
|
int num_threads, int num_seqs, int num_queries, int* paged_kv_indptr_ptr,
|
||||||
|
int* block_table_bound_ptr) {
|
||||||
|
int idx = blockIdx.x * num_threads + threadIdx.x;
|
||||||
|
|
||||||
|
// Update paged_kv_indptr
|
||||||
|
if (idx < num_queries) {
|
||||||
|
int sum = 0;
|
||||||
|
for (int i = 0; i <= idx; ++i) {
|
||||||
|
sum += block_table_bound_ptr[i];
|
||||||
|
}
|
||||||
|
paged_kv_indptr_ptr[idx + 1] = sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void advance_step_flashinfer_indices_kernel(
|
||||||
|
int num_threads, int num_seqs, int num_queries, int const* block_tables_ptr,
|
||||||
|
int64_t const block_tables_stride, int* paged_kv_indices_ptr,
|
||||||
|
int* paged_kv_indptr_ptr, int* block_table_bound_ptr) {
|
||||||
|
int idx = blockIdx.x * num_threads + threadIdx.x;
|
||||||
|
int row = idx / block_tables_stride;
|
||||||
|
int col = idx % block_tables_stride;
|
||||||
|
|
||||||
|
if (row < num_queries && col < block_table_bound_ptr[row]) {
|
||||||
|
paged_kv_indices_ptr[paged_kv_indptr_ptr[row] + col] =
|
||||||
|
block_tables_ptr[row * block_tables_stride + col];
|
||||||
|
}
|
||||||
|
// if cudagraph, fill padded seqs with the last valid seq's indptr
|
||||||
|
if (num_queries < row && row <= num_seqs) {
|
||||||
|
paged_kv_indptr_ptr[row] = paged_kv_indptr_ptr[num_queries];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void advance_step_flashattn(int num_seqs, int num_queries, int block_size,
|
||||||
|
torch::Tensor& input_tokens, // type: long
|
||||||
|
torch::Tensor& sampled_token_ids, // type: long
|
||||||
|
torch::Tensor& input_positions, // type: long
|
||||||
|
torch::Tensor& seq_lens, // type: int
|
||||||
|
torch::Tensor& slot_mapping, // type: long
|
||||||
|
torch::Tensor& block_tables) { // type: int
|
||||||
|
|
||||||
if (logging) {
|
if (logging) {
|
||||||
printf("advance_step:\n");
|
printf("advance_step_flashattn:\n");
|
||||||
printf(" num_seqs = %d\n", num_seqs);
|
printf(" num_seqs = %d\n", num_seqs);
|
||||||
printf(" num_queries = %d\n", num_queries);
|
printf(" num_queries = %d\n", num_queries);
|
||||||
printf(" block_size = %d\n", block_size);
|
printf(" block_size = %d\n", block_size);
|
||||||
@ -108,24 +181,126 @@ void advance_step(int num_seqs, int num_queries, int block_size,
|
|||||||
int blocks;
|
int blocks;
|
||||||
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
|
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
|
||||||
|
|
||||||
advance_step_kernel<max_threads><<<blocks, max_threads, 0, stream>>>(
|
advance_step_flashattn_kernel<max_threads>
|
||||||
num_seqs, num_queries, block_size,
|
<<<blocks, max_threads, 0, stream>>>(
|
||||||
|
num_seqs, num_queries, block_size,
|
||||||
|
reinterpret_cast<long*>(input_tokens.data_ptr()),
|
||||||
|
reinterpret_cast<long const*>(sampled_token_ids.data_ptr()),
|
||||||
|
reinterpret_cast<long*>(input_positions.data_ptr()),
|
||||||
|
reinterpret_cast<int*>(seq_lens.data_ptr()),
|
||||||
|
reinterpret_cast<long*>(slot_mapping.data_ptr()),
|
||||||
|
reinterpret_cast<int const*>(block_tables.data_ptr()),
|
||||||
|
block_tables.stride(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
void advance_step_flashinfer(
|
||||||
|
int num_seqs, int num_queries, int block_size,
|
||||||
|
torch::Tensor& input_tokens, // type: long
|
||||||
|
torch::Tensor& sampled_token_ids, // type: long
|
||||||
|
torch::Tensor& input_positions, // type: long
|
||||||
|
torch::Tensor& seq_lens, // type: int
|
||||||
|
torch::Tensor& slot_mapping, // type: long
|
||||||
|
torch::Tensor& block_tables, // type: int
|
||||||
|
torch::Tensor& paged_kv_indices, // type: int
|
||||||
|
torch::Tensor& paged_kv_indptr, // type: int
|
||||||
|
torch::Tensor& paged_kv_last_page_len, // type: int
|
||||||
|
torch::Tensor& block_table_bound) { // type: int
|
||||||
|
|
||||||
|
if (logging) {
|
||||||
|
printf("advance_step_flashinfer:\n");
|
||||||
|
printf(" num_seqs = %d\n", num_seqs);
|
||||||
|
printf(" num_queries = %d\n", num_queries);
|
||||||
|
printf(" block_size = %d\n", block_size);
|
||||||
|
printf(" block_tables.stride(0) = %d\n", block_tables.stride(0));
|
||||||
|
}
|
||||||
|
// Verify all tensors
|
||||||
|
verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong);
|
||||||
|
// verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1,
|
||||||
|
// at::kLong);
|
||||||
|
verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong);
|
||||||
|
verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt);
|
||||||
|
verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong);
|
||||||
|
verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt);
|
||||||
|
|
||||||
|
verify_tensor("paged_kv_indices", paged_kv_indices, -1, -1, at::kInt);
|
||||||
|
verify_tensor("paged_kv_indptr", paged_kv_indptr, num_seqs + 1, -1, at::kInt);
|
||||||
|
verify_tensor("paged_kv_last_page_len", paged_kv_last_page_len, num_seqs, -1,
|
||||||
|
at::kInt);
|
||||||
|
|
||||||
|
verify_tensor("block_table_bound", block_table_bound, num_seqs, -1, at::kInt);
|
||||||
|
|
||||||
|
int dev = sampled_token_ids.get_device();
|
||||||
|
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
|
||||||
|
|
||||||
|
int blocks;
|
||||||
|
int threads;
|
||||||
|
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
|
||||||
|
cudaDeviceGetAttribute(&threads, cudaDevAttrMaxThreadsPerBlock, dev);
|
||||||
|
if (logging) {
|
||||||
|
printf("launching kernel with %d blocks\n", blocks);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(will): support arbitrary block_tables stride
|
||||||
|
if ((blocks * threads) / block_tables.stride(0) < num_queries) {
|
||||||
|
TORCH_CHECK(false,
|
||||||
|
"multi-step: not enough threads to map block_table to"
|
||||||
|
"FlashInfer's paged_kv_indices on GPU. Try reducing the number "
|
||||||
|
"of seqs,",
|
||||||
|
" increasing the block size or take smaller steps.",
|
||||||
|
" num_queries = ", num_queries,
|
||||||
|
" block_tables.stride(0) = ", block_tables.stride(0),
|
||||||
|
" blocks = ", blocks, " max_threads = ", threads);
|
||||||
|
}
|
||||||
|
|
||||||
|
advance_step_flashinfer_kernel<<<blocks, threads, 0, stream>>>(
|
||||||
|
threads, num_seqs, num_queries, block_size,
|
||||||
reinterpret_cast<long*>(input_tokens.data_ptr()),
|
reinterpret_cast<long*>(input_tokens.data_ptr()),
|
||||||
reinterpret_cast<long const*>(sampled_token_ids.data_ptr()),
|
reinterpret_cast<long const*>(sampled_token_ids.data_ptr()),
|
||||||
reinterpret_cast<long*>(input_positions.data_ptr()),
|
reinterpret_cast<long*>(input_positions.data_ptr()),
|
||||||
reinterpret_cast<int*>(seq_lens.data_ptr()),
|
reinterpret_cast<int*>(seq_lens.data_ptr()),
|
||||||
reinterpret_cast<long*>(slot_mapping.data_ptr()),
|
reinterpret_cast<long*>(slot_mapping.data_ptr()),
|
||||||
reinterpret_cast<int const*>(block_tables.data_ptr()),
|
reinterpret_cast<int const*>(block_tables.data_ptr()),
|
||||||
block_tables.stride(0));
|
block_tables.stride(0),
|
||||||
|
reinterpret_cast<int*>(paged_kv_last_page_len.data_ptr()),
|
||||||
|
reinterpret_cast<int*>(block_table_bound.data_ptr()));
|
||||||
|
|
||||||
|
advance_step_flashinfer_indptr_kernel<<<blocks, threads, 0, stream>>>(
|
||||||
|
threads, num_seqs, num_queries,
|
||||||
|
reinterpret_cast<int*>(paged_kv_indptr.data_ptr()),
|
||||||
|
reinterpret_cast<int*>(block_table_bound.data_ptr()));
|
||||||
|
|
||||||
|
advance_step_flashinfer_indices_kernel<<<blocks, threads, 0, stream>>>(
|
||||||
|
threads, num_seqs, num_queries,
|
||||||
|
reinterpret_cast<int const*>(block_tables.data_ptr()),
|
||||||
|
block_tables.stride(0),
|
||||||
|
reinterpret_cast<int*>(paged_kv_indices.data_ptr()),
|
||||||
|
reinterpret_cast<int*>(paged_kv_indptr.data_ptr()),
|
||||||
|
reinterpret_cast<int*>(block_table_bound.data_ptr()));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace prepare_inputs
|
} // namespace prepare_inputs
|
||||||
|
|
||||||
void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
|
void advance_step_flashattn(int64_t num_seqs, int64_t num_queries,
|
||||||
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
|
int64_t block_size, torch::Tensor& input_tokens,
|
||||||
torch::Tensor& input_positions, torch::Tensor& seq_lens,
|
torch::Tensor& sampled_token_ids,
|
||||||
torch::Tensor& slot_mapping, torch::Tensor& block_tables) {
|
torch::Tensor& input_positions,
|
||||||
prepare_inputs::advance_step(num_seqs, num_queries, block_size, input_tokens,
|
torch::Tensor& seq_lens,
|
||||||
sampled_token_ids, input_positions, seq_lens,
|
torch::Tensor& slot_mapping,
|
||||||
slot_mapping, block_tables);
|
torch::Tensor& block_tables) {
|
||||||
|
prepare_inputs::advance_step_flashattn(
|
||||||
|
num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
|
||||||
|
input_positions, seq_lens, slot_mapping, block_tables);
|
||||||
|
}
|
||||||
|
|
||||||
|
void advance_step_flashinfer(
|
||||||
|
int64_t num_seqs, int64_t num_queries, int64_t block_size,
|
||||||
|
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
|
||||||
|
torch::Tensor& input_positions, torch::Tensor& seq_lens,
|
||||||
|
torch::Tensor& slot_mapping, torch::Tensor& block_tables,
|
||||||
|
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
|
||||||
|
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bound) {
|
||||||
|
prepare_inputs::advance_step_flashinfer(
|
||||||
|
num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
|
||||||
|
input_positions, seq_lens, slot_mapping, block_tables, paged_kv_indices,
|
||||||
|
paged_kv_indptr, paged_kv_last_page_len, block_table_bound);
|
||||||
}
|
}
|
||||||
@ -74,11 +74,22 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
|
|
||||||
// prepare_inputs advance_step
|
// prepare_inputs advance_step
|
||||||
ops.def(
|
ops.def(
|
||||||
"advance_step(int num_seqs, int num_queries, int block_size, "
|
"advance_step_flashattn(int num_seqs, int num_queries, int block_size, "
|
||||||
"Tensor! input_tokens, Tensor sampled_token_ids, "
|
"Tensor! input_tokens, Tensor sampled_token_ids, "
|
||||||
"Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping, "
|
"Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping, "
|
||||||
"Tensor block_tables) -> ()");
|
"Tensor block_tables) -> ()");
|
||||||
ops.impl("advance_step", torch::kCUDA, &advance_step);
|
ops.impl("advance_step_flashattn", torch::kCUDA, &advance_step_flashattn);
|
||||||
|
|
||||||
|
ops.def(
|
||||||
|
"advance_step_flashinfer("
|
||||||
|
" int num_seqs, int num_queries, int block_size,"
|
||||||
|
" Tensor! input_tokens, Tensor sampled_token_ids,"
|
||||||
|
" Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping,"
|
||||||
|
" Tensor block_tables, Tensor! paged_kv_indices,"
|
||||||
|
" Tensor! paged_kv_indptr, Tensor! paged_kv_last_page_len,"
|
||||||
|
" Tensor! block_table_bounds"
|
||||||
|
") -> ()");
|
||||||
|
ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer);
|
||||||
|
|
||||||
// Layernorm
|
// Layernorm
|
||||||
// Apply Root Mean Square (RMS) Normalization to the input tensor.
|
// Apply Root Mean Square (RMS) Normalization to the input tensor.
|
||||||
|
|||||||
@ -1,9 +1,10 @@
|
|||||||
# Test the AsyncLLMEngine with multi-step-decoding
|
# Test the AsyncLLMEngine with multi-step-decoding
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from tests.kernels.utils import override_backend_env_variable
|
||||||
|
|
||||||
from ..models.utils import check_logprobs_close
|
from ..models.utils import check_logprobs_close
|
||||||
from ..utils import (completions_with_server_args, get_client_text_generations,
|
from ..utils import (completions_with_server_args, get_client_text_generations,
|
||||||
get_client_text_logprob_generations)
|
get_client_text_logprob_generations)
|
||||||
@ -33,8 +34,9 @@ DEFAULT_SERVER_ARGS: List[str] = [
|
|||||||
@pytest.mark.parametrize("eager_mode", [False, True])
|
@pytest.mark.parametrize("eager_mode", [False, True])
|
||||||
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
|
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
|
||||||
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
|
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
|
||||||
@pytest.mark.parametrize("num_logprobs", [None, 5])
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
@pytest.mark.parametrize("is_async", [False, True])
|
@pytest.mark.parametrize("is_async", [True])
|
||||||
|
@pytest.mark.parametrize("attention_backend", ["FLASHINFER", "FLASH_ATTN"])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_multi_step(
|
async def test_multi_step(
|
||||||
example_prompts,
|
example_prompts,
|
||||||
@ -46,6 +48,8 @@ async def test_multi_step(
|
|||||||
num_prompts: int,
|
num_prompts: int,
|
||||||
is_async: bool,
|
is_async: bool,
|
||||||
num_logprobs: Optional[int],
|
num_logprobs: Optional[int],
|
||||||
|
attention_backend: str,
|
||||||
|
monkeypatch,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test vLLM engine with multi-step scheduling in an OpenAI-protocol
|
"""Test vLLM engine with multi-step scheduling in an OpenAI-protocol
|
||||||
client/server environment.
|
client/server environment.
|
||||||
@ -71,6 +75,8 @@ async def test_multi_step(
|
|||||||
completions endpoint; `None` -> no logprobs
|
completions endpoint; `None` -> no logprobs
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
override_backend_env_variable(monkeypatch, attention_backend)
|
||||||
|
|
||||||
prompts = example_prompts
|
prompts = example_prompts
|
||||||
if len(prompts) < num_prompts:
|
if len(prompts) < num_prompts:
|
||||||
prompts = prompts * ((num_prompts // len(prompts)) + 1)
|
prompts = prompts * ((num_prompts // len(prompts)) + 1)
|
||||||
|
|||||||
@ -161,16 +161,36 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
|
|||||||
torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
|
torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
|
||||||
|
|
||||||
|
|
||||||
def advance_step(num_seqs: int, num_queries: int, block_size: int,
|
def advance_step_flashattn(num_seqs: int, num_queries: int, block_size: int,
|
||||||
input_tokens: torch.Tensor, sampled_token_ids: torch.Tensor,
|
input_tokens: torch.Tensor,
|
||||||
input_positions: torch.Tensor, seq_lens: torch.Tensor,
|
sampled_token_ids: torch.Tensor,
|
||||||
slot_mapping: torch.Tensor,
|
input_positions: torch.Tensor,
|
||||||
block_tables: torch.Tensor) -> None:
|
seq_lens: torch.Tensor, slot_mapping: torch.Tensor,
|
||||||
|
block_tables: torch.Tensor) -> None:
|
||||||
"""Advance a step on GPU for existing inputs for a multi-step runner"""
|
"""Advance a step on GPU for existing inputs for a multi-step runner"""
|
||||||
return torch.ops._C.advance_step(num_seqs, num_queries, block_size,
|
return torch.ops._C.advance_step_flashattn(num_seqs, num_queries,
|
||||||
input_tokens, sampled_token_ids,
|
block_size, input_tokens,
|
||||||
input_positions, seq_lens, slot_mapping,
|
sampled_token_ids,
|
||||||
block_tables)
|
input_positions, seq_lens,
|
||||||
|
slot_mapping, block_tables)
|
||||||
|
|
||||||
|
|
||||||
|
def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int,
|
||||||
|
input_tokens: torch.Tensor,
|
||||||
|
sampled_token_ids: torch.Tensor,
|
||||||
|
input_positions: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor, slot_mapping: torch.Tensor,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
paged_kv_indices: torch.Tensor,
|
||||||
|
paged_kv_indptr: torch.Tensor,
|
||||||
|
paged_kv_last_page_len: torch.Tensor,
|
||||||
|
block_table_bound: torch.Tensor) -> None:
|
||||||
|
|
||||||
|
return torch.ops._C.advance_step_flashinfer(
|
||||||
|
num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
|
||||||
|
input_positions, seq_lens, slot_mapping, block_tables,
|
||||||
|
paged_kv_indices, paged_kv_indptr, paged_kv_last_page_len,
|
||||||
|
block_table_bound)
|
||||||
|
|
||||||
|
|
||||||
# quantization ops
|
# quantization ops
|
||||||
|
|||||||
@ -83,7 +83,9 @@ class AttentionBackend(ABC):
|
|||||||
) -> None:
|
) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def advance_step(self, num_seqs: int, num_queries: int):
|
def advance_step(self, model_input: "ModelRunnerInputBase",
|
||||||
|
sampled_token_ids: Optional[torch.Tensor],
|
||||||
|
block_size: int, num_seqs: int, num_queries: int) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -380,15 +380,15 @@ class FlashAttentionMetadata(AttentionMetadata):
|
|||||||
self.seq_lens[i] += 1
|
self.seq_lens[i] += 1
|
||||||
self.max_decode_seq_len = max(self.seq_lens)
|
self.max_decode_seq_len = max(self.seq_lens)
|
||||||
|
|
||||||
ops.advance_step(num_seqs=num_seqs,
|
ops.advance_step_flashattn(num_seqs=num_seqs,
|
||||||
num_queries=num_queries,
|
num_queries=num_queries,
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
input_tokens=model_input.input_tokens,
|
input_tokens=model_input.input_tokens,
|
||||||
sampled_token_ids=sampled_token_ids,
|
sampled_token_ids=sampled_token_ids,
|
||||||
input_positions=model_input.input_positions,
|
input_positions=model_input.input_positions,
|
||||||
seq_lens=self.seq_lens_tensor,
|
seq_lens=self.seq_lens_tensor,
|
||||||
slot_mapping=self.slot_mapping,
|
slot_mapping=self.slot_mapping,
|
||||||
block_tables=self.block_tables)
|
block_tables=self.block_tables)
|
||||||
|
|
||||||
|
|
||||||
class FlashAttentionMetadataBuilder(
|
class FlashAttentionMetadataBuilder(
|
||||||
|
|||||||
@ -30,7 +30,8 @@ from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
|
|||||||
make_tensor_with_pad)
|
make_tensor_with_pad)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.worker.model_runner import ModelInputForGPUBuilder
|
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
||||||
|
ModelInputForGPUWithSamplingMetadata)
|
||||||
|
|
||||||
|
|
||||||
class FlashInferBackend(AttentionBackend):
|
class FlashInferBackend(AttentionBackend):
|
||||||
@ -268,6 +269,10 @@ class FlashInferMetadata(AttentionMetadata):
|
|||||||
query_start_loc: Optional[torch.Tensor] = None
|
query_start_loc: Optional[torch.Tensor] = None
|
||||||
block_tables: Optional[torch.Tensor] = None
|
block_tables: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
# used for GPU in-place advance_step
|
||||||
|
seq_lens_tensor: Optional[torch.Tensor] = None
|
||||||
|
block_table_bound: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
# An example for paged_kv_indices, paged_kv_indptr:
|
# An example for paged_kv_indices, paged_kv_indptr:
|
||||||
# request 1, page indices [0, 5, 8]
|
# request 1, page indices [0, 5, 8]
|
||||||
# request 2, page indices [1, 6, 7]
|
# request 2, page indices [1, 6, 7]
|
||||||
@ -318,6 +323,8 @@ class FlashInferMetadata(AttentionMetadata):
|
|||||||
assert self.paged_kv_indices is not None
|
assert self.paged_kv_indices is not None
|
||||||
assert self.paged_kv_indptr is not None
|
assert self.paged_kv_indptr is not None
|
||||||
assert self.paged_kv_last_page_len is not None
|
assert self.paged_kv_last_page_len is not None
|
||||||
|
assert self.block_table_bound is not None
|
||||||
|
assert self.seq_lens_tensor is not None
|
||||||
batch_size = self.query_start_loc.shape[0] - 1
|
batch_size = self.query_start_loc.shape[0] - 1
|
||||||
assert batch_size >= 0
|
assert batch_size >= 0
|
||||||
# We will use flash attention for profiling to
|
# We will use flash attention for profiling to
|
||||||
@ -327,6 +334,8 @@ class FlashInferMetadata(AttentionMetadata):
|
|||||||
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
|
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
|
||||||
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
|
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
|
||||||
self.device)
|
self.device)
|
||||||
|
self.block_table_bound = self.block_table_bound.to(self.device)
|
||||||
|
self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
|
||||||
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
|
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
|
||||||
self.prefill_wrapper.end_forward()
|
self.prefill_wrapper.end_forward()
|
||||||
self.prefill_wrapper.begin_forward(
|
self.prefill_wrapper.begin_forward(
|
||||||
@ -335,14 +344,18 @@ class FlashInferMetadata(AttentionMetadata):
|
|||||||
self.num_qo_heads, self.num_kv_heads, self.head_dim,
|
self.num_qo_heads, self.num_kv_heads, self.head_dim,
|
||||||
self.page_size)
|
self.page_size)
|
||||||
else:
|
else:
|
||||||
if not self.use_cuda_graph:
|
assert self.paged_kv_indices is not None
|
||||||
assert self.paged_kv_indices is not None
|
assert self.paged_kv_indptr is not None
|
||||||
assert self.paged_kv_indptr is not None
|
assert self.paged_kv_last_page_len is not None
|
||||||
assert self.paged_kv_last_page_len is not None
|
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
|
||||||
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
|
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
|
||||||
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
|
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
|
||||||
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
|
self.device)
|
||||||
self.device)
|
# handle model warmup path
|
||||||
|
if self.block_table_bound is not None:
|
||||||
|
self.block_table_bound = self.block_table_bound.to(self.device)
|
||||||
|
if self.seq_lens_tensor is not None:
|
||||||
|
self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
|
||||||
|
|
||||||
assert self.decode_wrapper is not None
|
assert self.decode_wrapper is not None
|
||||||
self.decode_wrapper.end_forward()
|
self.decode_wrapper.end_forward()
|
||||||
@ -391,6 +404,48 @@ class FlashInferMetadata(AttentionMetadata):
|
|||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def advance_step(
|
||||||
|
self,
|
||||||
|
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||||
|
sampled_token_ids: Optional[torch.Tensor],
|
||||||
|
block_size: int,
|
||||||
|
num_seqs: int,
|
||||||
|
num_queries: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update metadata in-place to advance one decode step.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert num_seqs > 0
|
||||||
|
assert num_queries > 0
|
||||||
|
assert model_input.attn_metadata is not None
|
||||||
|
assert sampled_token_ids is not None
|
||||||
|
|
||||||
|
# When using cudagraph, the num_seqs is padded to the next captured
|
||||||
|
# batch sized, but num_queries tracks the actual number of requests in
|
||||||
|
# the batch. For --enforce-eager mode, num_seqs == num_queries
|
||||||
|
if num_seqs != num_queries:
|
||||||
|
assert num_seqs > num_queries
|
||||||
|
assert self.use_cuda_graph
|
||||||
|
|
||||||
|
model_input.input_tokens[:num_queries] = sampled_token_ids.flatten()
|
||||||
|
|
||||||
|
# Update GPU tensors
|
||||||
|
ops.advance_step_flashinfer(
|
||||||
|
num_seqs=num_seqs,
|
||||||
|
num_queries=num_queries,
|
||||||
|
block_size=block_size,
|
||||||
|
input_tokens=model_input.input_tokens,
|
||||||
|
sampled_token_ids=model_input.input_tokens,
|
||||||
|
input_positions=model_input.input_positions,
|
||||||
|
seq_lens=self.seq_lens_tensor,
|
||||||
|
slot_mapping=self.slot_mapping,
|
||||||
|
block_tables=self.block_tables,
|
||||||
|
paged_kv_indices=self.paged_kv_indices,
|
||||||
|
paged_kv_indptr=self.paged_kv_indptr,
|
||||||
|
paged_kv_last_page_len=self.paged_kv_last_page_len,
|
||||||
|
block_table_bound=self.block_table_bound)
|
||||||
|
|
||||||
|
|
||||||
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||||
|
|
||||||
@ -428,7 +483,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
self.paged_kv_indptr: List[int] = [0]
|
self.paged_kv_indptr: List[int] = [0]
|
||||||
# paged_kv_last_page_len is the length of the last page of each request
|
# paged_kv_last_page_len is the length of the last page of each request
|
||||||
self.paged_kv_last_page_len: List[int] = []
|
self.paged_kv_last_page_len: List[int] = []
|
||||||
|
self.total_blocks = 0
|
||||||
self.is_profile_run: bool = False
|
self.is_profile_run: bool = False
|
||||||
|
|
||||||
def _add_seq_group(
|
def _add_seq_group(
|
||||||
@ -499,6 +554,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
# block_table_bound is 1 with 1 valid block.
|
# block_table_bound is 1 with 1 valid block.
|
||||||
# If seq_len = 15, block_size = 16,
|
# If seq_len = 15, block_size = 16,
|
||||||
# block_table_bound is 0 + 1 with 1 valid block.
|
# block_table_bound is 0 + 1 with 1 valid block.
|
||||||
|
self.total_blocks += len(block_table)
|
||||||
block_table_bound = seq_len // self.block_size + 1 \
|
block_table_bound = seq_len // self.block_size + 1 \
|
||||||
if seq_len % self.block_size != 0 \
|
if seq_len % self.block_size != 0 \
|
||||||
else seq_len // self.block_size
|
else seq_len // self.block_size
|
||||||
@ -583,6 +639,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
out=query_start_loc[1:])
|
out=query_start_loc[1:])
|
||||||
|
|
||||||
if len(self.paged_kv_indptr) > 0:
|
if len(self.paged_kv_indptr) > 0:
|
||||||
|
# extend to the maximum number of blocks as returned by the
|
||||||
|
# scheduler
|
||||||
|
self.paged_kv_indices.extend(
|
||||||
|
[0] * (self.total_blocks - len(self.paged_kv_indices)))
|
||||||
paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices,
|
paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
dtype=torch.int)
|
dtype=torch.int)
|
||||||
@ -591,10 +651,15 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
dtype=torch.int)
|
dtype=torch.int)
|
||||||
paged_kv_last_page_len_tensor = torch.tensor(
|
paged_kv_last_page_len_tensor = torch.tensor(
|
||||||
self.paged_kv_last_page_len, device="cpu", dtype=torch.int)
|
self.paged_kv_last_page_len, device="cpu", dtype=torch.int)
|
||||||
|
block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) -
|
||||||
|
1,
|
||||||
|
device="cpu",
|
||||||
|
dtype=torch.int)
|
||||||
else:
|
else:
|
||||||
paged_kv_indices_tensor = None
|
paged_kv_indices_tensor = None
|
||||||
paged_kv_indptr_tensor = None
|
paged_kv_indptr_tensor = None
|
||||||
paged_kv_last_page_len_tensor = None
|
paged_kv_last_page_len_tensor = None
|
||||||
|
block_table_bound_tensor = None
|
||||||
|
|
||||||
if self.runner.kv_cache_dtype.startswith("fp8"):
|
if self.runner.kv_cache_dtype.startswith("fp8"):
|
||||||
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
||||||
@ -613,6 +678,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
paged_kv_indptr=paged_kv_indptr_tensor,
|
paged_kv_indptr=paged_kv_indptr_tensor,
|
||||||
paged_kv_indices=paged_kv_indices_tensor,
|
paged_kv_indices=paged_kv_indices_tensor,
|
||||||
paged_kv_last_page_len=paged_kv_last_page_len_tensor,
|
paged_kv_last_page_len=paged_kv_last_page_len_tensor,
|
||||||
|
block_table_bound=block_table_bound_tensor,
|
||||||
|
seq_lens_tensor=seq_lens_tensor,
|
||||||
num_qo_heads=self.runner.model_config.get_num_attention_heads(
|
num_qo_heads=self.runner.model_config.get_num_attention_heads(
|
||||||
self.runner.parallel_config),
|
self.runner.parallel_config),
|
||||||
num_kv_heads=self.runner.model_config.get_num_kv_heads(
|
num_kv_heads=self.runner.model_config.get_num_kv_heads(
|
||||||
|
|||||||
@ -4,13 +4,6 @@ from dataclasses import dataclass, field
|
|||||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
|
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
|
||||||
Union)
|
Union)
|
||||||
|
|
||||||
try:
|
|
||||||
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
# vllm_flash_attn is not installed, use the identical ROCm FA metadata
|
|
||||||
from vllm.attention.backends.rocm_flash_attn import (
|
|
||||||
ROCmFlashAttentionMetadata as FlashAttentionMetadata)
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.distributed import get_pp_group
|
from vllm.distributed import get_pp_group
|
||||||
@ -36,6 +29,8 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
MULTI_STEP_ATTENTION_BACKENDS = ["flash-attn", "flashinfer"]
|
||||||
|
|
||||||
|
|
||||||
def seq_output_builder():
|
def seq_output_builder():
|
||||||
return SequenceOutput(
|
return SequenceOutput(
|
||||||
@ -489,27 +484,27 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
|||||||
|
|
||||||
def _advance_step(self, model_input: StatefulModelInput,
|
def _advance_step(self, model_input: StatefulModelInput,
|
||||||
out: SamplerOutput) -> StatefulModelInput:
|
out: SamplerOutput) -> StatefulModelInput:
|
||||||
frozen_model_input = model_input.frozen_model_input
|
if self.attn_backend.get_name() not in MULTI_STEP_ATTENTION_BACKENDS:
|
||||||
assert frozen_model_input is not None
|
raise ValueError(
|
||||||
assert frozen_model_input.attn_metadata is not None
|
f"Multi-step not supported for attention backend: "
|
||||||
|
f"{self.attn_backend.get_name()}. Set VLLM_ATTENTION_BACKEND "
|
||||||
|
f"to a value from {MULTI_STEP_ATTENTION_BACKENDS}.")
|
||||||
|
|
||||||
|
sampled_token_ids = model_input.cached_outputs[-1].sampled_token_ids
|
||||||
num_seqs = model_input.num_seqs
|
num_seqs = model_input.num_seqs
|
||||||
num_queries = model_input.num_queries
|
num_queries = model_input.num_queries
|
||||||
assert num_seqs > 0
|
frozen_model_input = model_input.frozen_model_input
|
||||||
assert num_queries > 0
|
assert frozen_model_input is not None
|
||||||
assert num_seqs >= num_queries
|
|
||||||
|
|
||||||
attn_metadata = frozen_model_input.attn_metadata
|
attn_metadata = frozen_model_input.attn_metadata
|
||||||
assert isinstance(attn_metadata, FlashAttentionMetadata)
|
assert attn_metadata is not None
|
||||||
|
|
||||||
attn_metadata.advance_step(
|
attn_metadata.advance_step(
|
||||||
frozen_model_input,
|
frozen_model_input,
|
||||||
model_input.cached_outputs[-1].sampled_token_ids, self.block_size,
|
sampled_token_ids,
|
||||||
num_seqs, num_queries)
|
self.block_size,
|
||||||
|
num_seqs,
|
||||||
if frozen_model_input.seq_lens is not None:
|
num_queries,
|
||||||
for i in range(num_queries):
|
)
|
||||||
frozen_model_input.seq_lens[i] = attn_metadata.seq_lens[i]
|
|
||||||
|
|
||||||
return model_input
|
return model_input
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user