mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:34:27 +08:00
[multi-step] add flashinfer backend (#7928)
This commit is contained in:
parent
f2e263b801
commit
a6c0f3658d
19
csrc/ops.h
19
csrc/ops.h
@ -54,10 +54,21 @@ void gelu_fast(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
void gelu_quick(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
|
||||
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
|
||||
torch::Tensor& input_positions, torch::Tensor& seq_lens,
|
||||
torch::Tensor& slot_mapping, torch::Tensor& block_tables);
|
||||
void advance_step_flashattn(int64_t num_seqs, int64_t num_queries,
|
||||
int64_t block_size, torch::Tensor& input_tokens,
|
||||
torch::Tensor& sampled_token_ids,
|
||||
torch::Tensor& input_positions,
|
||||
torch::Tensor& seq_lens,
|
||||
torch::Tensor& slot_mapping,
|
||||
torch::Tensor& block_tables);
|
||||
|
||||
void advance_step_flashinfer(
|
||||
int64_t num_seqs, int64_t num_queries, int64_t block_size,
|
||||
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
|
||||
torch::Tensor& input_positions, torch::Tensor& seq_lens,
|
||||
torch::Tensor& slot_mapping, torch::Tensor& block_tables,
|
||||
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
|
||||
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
|
||||
|
||||
@ -12,13 +12,11 @@ namespace prepare_inputs {
|
||||
|
||||
//
|
||||
template <int const num_threads>
|
||||
__global__ void advance_step_kernel(int num_seqs, int num_queries,
|
||||
int block_size, long* input_tokens_ptr,
|
||||
long const* sampled_token_ids_ptr,
|
||||
long* input_positions_ptr,
|
||||
int* seq_lens_ptr, long* slot_mapping_ptr,
|
||||
int const* block_tables_ptr,
|
||||
int64_t const block_tables_stride) {
|
||||
__global__ void advance_step_flashattn_kernel(
|
||||
int num_seqs, int num_queries, int block_size, long* input_tokens_ptr,
|
||||
long const* sampled_token_ids_ptr, long* input_positions_ptr,
|
||||
int* seq_lens_ptr, long* slot_mapping_ptr, int const* block_tables_ptr,
|
||||
int64_t const block_tables_stride) {
|
||||
int num_query_blocks = div_ceil(num_queries, num_threads);
|
||||
|
||||
if (blockIdx.x >= num_query_blocks) {
|
||||
@ -79,16 +77,91 @@ inline void verify_tensor(std::string const& name, torch::Tensor& t,
|
||||
}
|
||||
}
|
||||
|
||||
void advance_step(int num_seqs, int num_queries, int block_size,
|
||||
torch::Tensor& input_tokens, // type: long
|
||||
torch::Tensor& sampled_token_ids, // type: long
|
||||
torch::Tensor& input_positions, // type: long
|
||||
torch::Tensor& seq_lens, // type: int
|
||||
torch::Tensor& slot_mapping, // type: long
|
||||
torch::Tensor& block_tables) { // type: int
|
||||
__global__ void advance_step_flashinfer_kernel(
|
||||
int num_threads, int num_seqs, int num_queries, int block_size,
|
||||
long* input_tokens_ptr, long const* sampled_token_ids_ptr,
|
||||
long* input_positions_ptr, int* seq_lens_ptr, long* slot_mapping_ptr,
|
||||
int const* block_tables_ptr, int64_t const block_tables_stride,
|
||||
int* paged_kv_last_page_len_ptr, int* block_table_bound_ptr) {
|
||||
int num_query_blocks = div_ceil(num_queries, num_threads);
|
||||
|
||||
if (blockIdx.x < num_query_blocks) {
|
||||
int cur_query_id = blockIdx.x * num_threads + threadIdx.x;
|
||||
|
||||
if (cur_query_id < num_queries) {
|
||||
// Update input_tokens
|
||||
input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id];
|
||||
|
||||
int seq_len = seq_lens_ptr[cur_query_id];
|
||||
int next_seq_len = seq_len + 1;
|
||||
int next_input_pos = next_seq_len - 1;
|
||||
|
||||
// Update seq_lens
|
||||
seq_lens_ptr[cur_query_id] = next_seq_len;
|
||||
// Update input_positions
|
||||
input_positions_ptr[cur_query_id] = next_input_pos;
|
||||
|
||||
int const* seq_block_tables_ptr =
|
||||
block_tables_ptr + block_tables_stride * cur_query_id;
|
||||
|
||||
int block_index = next_input_pos / block_size;
|
||||
int block_offset = next_input_pos % block_size;
|
||||
|
||||
// Update paged_kv_last_page_len
|
||||
paged_kv_last_page_len_ptr[cur_query_id] = block_offset + 1;
|
||||
|
||||
int slot_num =
|
||||
seq_block_tables_ptr[block_index] * block_size + block_offset;
|
||||
// Update slot_mapping
|
||||
slot_mapping_ptr[cur_query_id] = slot_num;
|
||||
block_table_bound_ptr[cur_query_id] = div_ceil(next_seq_len, block_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void advance_step_flashinfer_indptr_kernel(
|
||||
int num_threads, int num_seqs, int num_queries, int* paged_kv_indptr_ptr,
|
||||
int* block_table_bound_ptr) {
|
||||
int idx = blockIdx.x * num_threads + threadIdx.x;
|
||||
|
||||
// Update paged_kv_indptr
|
||||
if (idx < num_queries) {
|
||||
int sum = 0;
|
||||
for (int i = 0; i <= idx; ++i) {
|
||||
sum += block_table_bound_ptr[i];
|
||||
}
|
||||
paged_kv_indptr_ptr[idx + 1] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void advance_step_flashinfer_indices_kernel(
|
||||
int num_threads, int num_seqs, int num_queries, int const* block_tables_ptr,
|
||||
int64_t const block_tables_stride, int* paged_kv_indices_ptr,
|
||||
int* paged_kv_indptr_ptr, int* block_table_bound_ptr) {
|
||||
int idx = blockIdx.x * num_threads + threadIdx.x;
|
||||
int row = idx / block_tables_stride;
|
||||
int col = idx % block_tables_stride;
|
||||
|
||||
if (row < num_queries && col < block_table_bound_ptr[row]) {
|
||||
paged_kv_indices_ptr[paged_kv_indptr_ptr[row] + col] =
|
||||
block_tables_ptr[row * block_tables_stride + col];
|
||||
}
|
||||
// if cudagraph, fill padded seqs with the last valid seq's indptr
|
||||
if (num_queries < row && row <= num_seqs) {
|
||||
paged_kv_indptr_ptr[row] = paged_kv_indptr_ptr[num_queries];
|
||||
}
|
||||
}
|
||||
|
||||
void advance_step_flashattn(int num_seqs, int num_queries, int block_size,
|
||||
torch::Tensor& input_tokens, // type: long
|
||||
torch::Tensor& sampled_token_ids, // type: long
|
||||
torch::Tensor& input_positions, // type: long
|
||||
torch::Tensor& seq_lens, // type: int
|
||||
torch::Tensor& slot_mapping, // type: long
|
||||
torch::Tensor& block_tables) { // type: int
|
||||
|
||||
if (logging) {
|
||||
printf("advance_step:\n");
|
||||
printf("advance_step_flashattn:\n");
|
||||
printf(" num_seqs = %d\n", num_seqs);
|
||||
printf(" num_queries = %d\n", num_queries);
|
||||
printf(" block_size = %d\n", block_size);
|
||||
@ -108,24 +181,126 @@ void advance_step(int num_seqs, int num_queries, int block_size,
|
||||
int blocks;
|
||||
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
|
||||
|
||||
advance_step_kernel<max_threads><<<blocks, max_threads, 0, stream>>>(
|
||||
num_seqs, num_queries, block_size,
|
||||
advance_step_flashattn_kernel<max_threads>
|
||||
<<<blocks, max_threads, 0, stream>>>(
|
||||
num_seqs, num_queries, block_size,
|
||||
reinterpret_cast<long*>(input_tokens.data_ptr()),
|
||||
reinterpret_cast<long const*>(sampled_token_ids.data_ptr()),
|
||||
reinterpret_cast<long*>(input_positions.data_ptr()),
|
||||
reinterpret_cast<int*>(seq_lens.data_ptr()),
|
||||
reinterpret_cast<long*>(slot_mapping.data_ptr()),
|
||||
reinterpret_cast<int const*>(block_tables.data_ptr()),
|
||||
block_tables.stride(0));
|
||||
}
|
||||
|
||||
void advance_step_flashinfer(
|
||||
int num_seqs, int num_queries, int block_size,
|
||||
torch::Tensor& input_tokens, // type: long
|
||||
torch::Tensor& sampled_token_ids, // type: long
|
||||
torch::Tensor& input_positions, // type: long
|
||||
torch::Tensor& seq_lens, // type: int
|
||||
torch::Tensor& slot_mapping, // type: long
|
||||
torch::Tensor& block_tables, // type: int
|
||||
torch::Tensor& paged_kv_indices, // type: int
|
||||
torch::Tensor& paged_kv_indptr, // type: int
|
||||
torch::Tensor& paged_kv_last_page_len, // type: int
|
||||
torch::Tensor& block_table_bound) { // type: int
|
||||
|
||||
if (logging) {
|
||||
printf("advance_step_flashinfer:\n");
|
||||
printf(" num_seqs = %d\n", num_seqs);
|
||||
printf(" num_queries = %d\n", num_queries);
|
||||
printf(" block_size = %d\n", block_size);
|
||||
printf(" block_tables.stride(0) = %d\n", block_tables.stride(0));
|
||||
}
|
||||
// Verify all tensors
|
||||
verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong);
|
||||
// verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1,
|
||||
// at::kLong);
|
||||
verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong);
|
||||
verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt);
|
||||
verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong);
|
||||
verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt);
|
||||
|
||||
verify_tensor("paged_kv_indices", paged_kv_indices, -1, -1, at::kInt);
|
||||
verify_tensor("paged_kv_indptr", paged_kv_indptr, num_seqs + 1, -1, at::kInt);
|
||||
verify_tensor("paged_kv_last_page_len", paged_kv_last_page_len, num_seqs, -1,
|
||||
at::kInt);
|
||||
|
||||
verify_tensor("block_table_bound", block_table_bound, num_seqs, -1, at::kInt);
|
||||
|
||||
int dev = sampled_token_ids.get_device();
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
|
||||
|
||||
int blocks;
|
||||
int threads;
|
||||
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
|
||||
cudaDeviceGetAttribute(&threads, cudaDevAttrMaxThreadsPerBlock, dev);
|
||||
if (logging) {
|
||||
printf("launching kernel with %d blocks\n", blocks);
|
||||
}
|
||||
|
||||
// TODO(will): support arbitrary block_tables stride
|
||||
if ((blocks * threads) / block_tables.stride(0) < num_queries) {
|
||||
TORCH_CHECK(false,
|
||||
"multi-step: not enough threads to map block_table to"
|
||||
"FlashInfer's paged_kv_indices on GPU. Try reducing the number "
|
||||
"of seqs,",
|
||||
" increasing the block size or take smaller steps.",
|
||||
" num_queries = ", num_queries,
|
||||
" block_tables.stride(0) = ", block_tables.stride(0),
|
||||
" blocks = ", blocks, " max_threads = ", threads);
|
||||
}
|
||||
|
||||
advance_step_flashinfer_kernel<<<blocks, threads, 0, stream>>>(
|
||||
threads, num_seqs, num_queries, block_size,
|
||||
reinterpret_cast<long*>(input_tokens.data_ptr()),
|
||||
reinterpret_cast<long const*>(sampled_token_ids.data_ptr()),
|
||||
reinterpret_cast<long*>(input_positions.data_ptr()),
|
||||
reinterpret_cast<int*>(seq_lens.data_ptr()),
|
||||
reinterpret_cast<long*>(slot_mapping.data_ptr()),
|
||||
reinterpret_cast<int const*>(block_tables.data_ptr()),
|
||||
block_tables.stride(0));
|
||||
block_tables.stride(0),
|
||||
reinterpret_cast<int*>(paged_kv_last_page_len.data_ptr()),
|
||||
reinterpret_cast<int*>(block_table_bound.data_ptr()));
|
||||
|
||||
advance_step_flashinfer_indptr_kernel<<<blocks, threads, 0, stream>>>(
|
||||
threads, num_seqs, num_queries,
|
||||
reinterpret_cast<int*>(paged_kv_indptr.data_ptr()),
|
||||
reinterpret_cast<int*>(block_table_bound.data_ptr()));
|
||||
|
||||
advance_step_flashinfer_indices_kernel<<<blocks, threads, 0, stream>>>(
|
||||
threads, num_seqs, num_queries,
|
||||
reinterpret_cast<int const*>(block_tables.data_ptr()),
|
||||
block_tables.stride(0),
|
||||
reinterpret_cast<int*>(paged_kv_indices.data_ptr()),
|
||||
reinterpret_cast<int*>(paged_kv_indptr.data_ptr()),
|
||||
reinterpret_cast<int*>(block_table_bound.data_ptr()));
|
||||
}
|
||||
|
||||
} // namespace prepare_inputs
|
||||
|
||||
void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
|
||||
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
|
||||
torch::Tensor& input_positions, torch::Tensor& seq_lens,
|
||||
torch::Tensor& slot_mapping, torch::Tensor& block_tables) {
|
||||
prepare_inputs::advance_step(num_seqs, num_queries, block_size, input_tokens,
|
||||
sampled_token_ids, input_positions, seq_lens,
|
||||
slot_mapping, block_tables);
|
||||
void advance_step_flashattn(int64_t num_seqs, int64_t num_queries,
|
||||
int64_t block_size, torch::Tensor& input_tokens,
|
||||
torch::Tensor& sampled_token_ids,
|
||||
torch::Tensor& input_positions,
|
||||
torch::Tensor& seq_lens,
|
||||
torch::Tensor& slot_mapping,
|
||||
torch::Tensor& block_tables) {
|
||||
prepare_inputs::advance_step_flashattn(
|
||||
num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
|
||||
input_positions, seq_lens, slot_mapping, block_tables);
|
||||
}
|
||||
|
||||
void advance_step_flashinfer(
|
||||
int64_t num_seqs, int64_t num_queries, int64_t block_size,
|
||||
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
|
||||
torch::Tensor& input_positions, torch::Tensor& seq_lens,
|
||||
torch::Tensor& slot_mapping, torch::Tensor& block_tables,
|
||||
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
|
||||
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bound) {
|
||||
prepare_inputs::advance_step_flashinfer(
|
||||
num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
|
||||
input_positions, seq_lens, slot_mapping, block_tables, paged_kv_indices,
|
||||
paged_kv_indptr, paged_kv_last_page_len, block_table_bound);
|
||||
}
|
||||
@ -74,11 +74,22 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
|
||||
// prepare_inputs advance_step
|
||||
ops.def(
|
||||
"advance_step(int num_seqs, int num_queries, int block_size, "
|
||||
"advance_step_flashattn(int num_seqs, int num_queries, int block_size, "
|
||||
"Tensor! input_tokens, Tensor sampled_token_ids, "
|
||||
"Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping, "
|
||||
"Tensor block_tables) -> ()");
|
||||
ops.impl("advance_step", torch::kCUDA, &advance_step);
|
||||
ops.impl("advance_step_flashattn", torch::kCUDA, &advance_step_flashattn);
|
||||
|
||||
ops.def(
|
||||
"advance_step_flashinfer("
|
||||
" int num_seqs, int num_queries, int block_size,"
|
||||
" Tensor! input_tokens, Tensor sampled_token_ids,"
|
||||
" Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping,"
|
||||
" Tensor block_tables, Tensor! paged_kv_indices,"
|
||||
" Tensor! paged_kv_indptr, Tensor! paged_kv_last_page_len,"
|
||||
" Tensor! block_table_bounds"
|
||||
") -> ()");
|
||||
ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer);
|
||||
|
||||
// Layernorm
|
||||
// Apply Root Mean Square (RMS) Normalization to the input tensor.
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
# Test the AsyncLLMEngine with multi-step-decoding
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.kernels.utils import override_backend_env_variable
|
||||
|
||||
from ..models.utils import check_logprobs_close
|
||||
from ..utils import (completions_with_server_args, get_client_text_generations,
|
||||
get_client_text_logprob_generations)
|
||||
@ -33,8 +34,9 @@ DEFAULT_SERVER_ARGS: List[str] = [
|
||||
@pytest.mark.parametrize("eager_mode", [False, True])
|
||||
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
|
||||
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
|
||||
@pytest.mark.parametrize("num_logprobs", [None, 5])
|
||||
@pytest.mark.parametrize("is_async", [False, True])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@pytest.mark.parametrize("is_async", [True])
|
||||
@pytest.mark.parametrize("attention_backend", ["FLASHINFER", "FLASH_ATTN"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_step(
|
||||
example_prompts,
|
||||
@ -46,6 +48,8 @@ async def test_multi_step(
|
||||
num_prompts: int,
|
||||
is_async: bool,
|
||||
num_logprobs: Optional[int],
|
||||
attention_backend: str,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
"""Test vLLM engine with multi-step scheduling in an OpenAI-protocol
|
||||
client/server environment.
|
||||
@ -71,6 +75,8 @@ async def test_multi_step(
|
||||
completions endpoint; `None` -> no logprobs
|
||||
"""
|
||||
|
||||
override_backend_env_variable(monkeypatch, attention_backend)
|
||||
|
||||
prompts = example_prompts
|
||||
if len(prompts) < num_prompts:
|
||||
prompts = prompts * ((num_prompts // len(prompts)) + 1)
|
||||
|
||||
@ -161,16 +161,36 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
|
||||
torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
|
||||
|
||||
|
||||
def advance_step(num_seqs: int, num_queries: int, block_size: int,
|
||||
input_tokens: torch.Tensor, sampled_token_ids: torch.Tensor,
|
||||
input_positions: torch.Tensor, seq_lens: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
block_tables: torch.Tensor) -> None:
|
||||
def advance_step_flashattn(num_seqs: int, num_queries: int, block_size: int,
|
||||
input_tokens: torch.Tensor,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
input_positions: torch.Tensor,
|
||||
seq_lens: torch.Tensor, slot_mapping: torch.Tensor,
|
||||
block_tables: torch.Tensor) -> None:
|
||||
"""Advance a step on GPU for existing inputs for a multi-step runner"""
|
||||
return torch.ops._C.advance_step(num_seqs, num_queries, block_size,
|
||||
input_tokens, sampled_token_ids,
|
||||
input_positions, seq_lens, slot_mapping,
|
||||
block_tables)
|
||||
return torch.ops._C.advance_step_flashattn(num_seqs, num_queries,
|
||||
block_size, input_tokens,
|
||||
sampled_token_ids,
|
||||
input_positions, seq_lens,
|
||||
slot_mapping, block_tables)
|
||||
|
||||
|
||||
def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int,
|
||||
input_tokens: torch.Tensor,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
input_positions: torch.Tensor,
|
||||
seq_lens: torch.Tensor, slot_mapping: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
paged_kv_indices: torch.Tensor,
|
||||
paged_kv_indptr: torch.Tensor,
|
||||
paged_kv_last_page_len: torch.Tensor,
|
||||
block_table_bound: torch.Tensor) -> None:
|
||||
|
||||
return torch.ops._C.advance_step_flashinfer(
|
||||
num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
|
||||
input_positions, seq_lens, slot_mapping, block_tables,
|
||||
paged_kv_indices, paged_kv_indptr, paged_kv_last_page_len,
|
||||
block_table_bound)
|
||||
|
||||
|
||||
# quantization ops
|
||||
|
||||
@ -83,7 +83,9 @@ class AttentionBackend(ABC):
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def advance_step(self, num_seqs: int, num_queries: int):
|
||||
def advance_step(self, model_input: "ModelRunnerInputBase",
|
||||
sampled_token_ids: Optional[torch.Tensor],
|
||||
block_size: int, num_seqs: int, num_queries: int) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
||||
@ -380,15 +380,15 @@ class FlashAttentionMetadata(AttentionMetadata):
|
||||
self.seq_lens[i] += 1
|
||||
self.max_decode_seq_len = max(self.seq_lens)
|
||||
|
||||
ops.advance_step(num_seqs=num_seqs,
|
||||
num_queries=num_queries,
|
||||
block_size=block_size,
|
||||
input_tokens=model_input.input_tokens,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
input_positions=model_input.input_positions,
|
||||
seq_lens=self.seq_lens_tensor,
|
||||
slot_mapping=self.slot_mapping,
|
||||
block_tables=self.block_tables)
|
||||
ops.advance_step_flashattn(num_seqs=num_seqs,
|
||||
num_queries=num_queries,
|
||||
block_size=block_size,
|
||||
input_tokens=model_input.input_tokens,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
input_positions=model_input.input_positions,
|
||||
seq_lens=self.seq_lens_tensor,
|
||||
slot_mapping=self.slot_mapping,
|
||||
block_tables=self.block_tables)
|
||||
|
||||
|
||||
class FlashAttentionMetadataBuilder(
|
||||
|
||||
@ -30,7 +30,8 @@ from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
|
||||
make_tensor_with_pad)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUBuilder
|
||||
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
||||
ModelInputForGPUWithSamplingMetadata)
|
||||
|
||||
|
||||
class FlashInferBackend(AttentionBackend):
|
||||
@ -268,6 +269,10 @@ class FlashInferMetadata(AttentionMetadata):
|
||||
query_start_loc: Optional[torch.Tensor] = None
|
||||
block_tables: Optional[torch.Tensor] = None
|
||||
|
||||
# used for GPU in-place advance_step
|
||||
seq_lens_tensor: Optional[torch.Tensor] = None
|
||||
block_table_bound: Optional[torch.Tensor] = None
|
||||
|
||||
# An example for paged_kv_indices, paged_kv_indptr:
|
||||
# request 1, page indices [0, 5, 8]
|
||||
# request 2, page indices [1, 6, 7]
|
||||
@ -318,6 +323,8 @@ class FlashInferMetadata(AttentionMetadata):
|
||||
assert self.paged_kv_indices is not None
|
||||
assert self.paged_kv_indptr is not None
|
||||
assert self.paged_kv_last_page_len is not None
|
||||
assert self.block_table_bound is not None
|
||||
assert self.seq_lens_tensor is not None
|
||||
batch_size = self.query_start_loc.shape[0] - 1
|
||||
assert batch_size >= 0
|
||||
# We will use flash attention for profiling to
|
||||
@ -327,6 +334,8 @@ class FlashInferMetadata(AttentionMetadata):
|
||||
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
|
||||
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
|
||||
self.device)
|
||||
self.block_table_bound = self.block_table_bound.to(self.device)
|
||||
self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
|
||||
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
|
||||
self.prefill_wrapper.end_forward()
|
||||
self.prefill_wrapper.begin_forward(
|
||||
@ -335,14 +344,18 @@ class FlashInferMetadata(AttentionMetadata):
|
||||
self.num_qo_heads, self.num_kv_heads, self.head_dim,
|
||||
self.page_size)
|
||||
else:
|
||||
if not self.use_cuda_graph:
|
||||
assert self.paged_kv_indices is not None
|
||||
assert self.paged_kv_indptr is not None
|
||||
assert self.paged_kv_last_page_len is not None
|
||||
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
|
||||
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
|
||||
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
|
||||
self.device)
|
||||
assert self.paged_kv_indices is not None
|
||||
assert self.paged_kv_indptr is not None
|
||||
assert self.paged_kv_last_page_len is not None
|
||||
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
|
||||
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
|
||||
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
|
||||
self.device)
|
||||
# handle model warmup path
|
||||
if self.block_table_bound is not None:
|
||||
self.block_table_bound = self.block_table_bound.to(self.device)
|
||||
if self.seq_lens_tensor is not None:
|
||||
self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
|
||||
|
||||
assert self.decode_wrapper is not None
|
||||
self.decode_wrapper.end_forward()
|
||||
@ -391,6 +404,48 @@ class FlashInferMetadata(AttentionMetadata):
|
||||
|
||||
return self
|
||||
|
||||
def advance_step(
|
||||
self,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
sampled_token_ids: Optional[torch.Tensor],
|
||||
block_size: int,
|
||||
num_seqs: int,
|
||||
num_queries: int,
|
||||
):
|
||||
"""
|
||||
Update metadata in-place to advance one decode step.
|
||||
"""
|
||||
|
||||
assert num_seqs > 0
|
||||
assert num_queries > 0
|
||||
assert model_input.attn_metadata is not None
|
||||
assert sampled_token_ids is not None
|
||||
|
||||
# When using cudagraph, the num_seqs is padded to the next captured
|
||||
# batch sized, but num_queries tracks the actual number of requests in
|
||||
# the batch. For --enforce-eager mode, num_seqs == num_queries
|
||||
if num_seqs != num_queries:
|
||||
assert num_seqs > num_queries
|
||||
assert self.use_cuda_graph
|
||||
|
||||
model_input.input_tokens[:num_queries] = sampled_token_ids.flatten()
|
||||
|
||||
# Update GPU tensors
|
||||
ops.advance_step_flashinfer(
|
||||
num_seqs=num_seqs,
|
||||
num_queries=num_queries,
|
||||
block_size=block_size,
|
||||
input_tokens=model_input.input_tokens,
|
||||
sampled_token_ids=model_input.input_tokens,
|
||||
input_positions=model_input.input_positions,
|
||||
seq_lens=self.seq_lens_tensor,
|
||||
slot_mapping=self.slot_mapping,
|
||||
block_tables=self.block_tables,
|
||||
paged_kv_indices=self.paged_kv_indices,
|
||||
paged_kv_indptr=self.paged_kv_indptr,
|
||||
paged_kv_last_page_len=self.paged_kv_last_page_len,
|
||||
block_table_bound=self.block_table_bound)
|
||||
|
||||
|
||||
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
|
||||
@ -428,7 +483,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
self.paged_kv_indptr: List[int] = [0]
|
||||
# paged_kv_last_page_len is the length of the last page of each request
|
||||
self.paged_kv_last_page_len: List[int] = []
|
||||
|
||||
self.total_blocks = 0
|
||||
self.is_profile_run: bool = False
|
||||
|
||||
def _add_seq_group(
|
||||
@ -499,6 +554,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
# block_table_bound is 1 with 1 valid block.
|
||||
# If seq_len = 15, block_size = 16,
|
||||
# block_table_bound is 0 + 1 with 1 valid block.
|
||||
self.total_blocks += len(block_table)
|
||||
block_table_bound = seq_len // self.block_size + 1 \
|
||||
if seq_len % self.block_size != 0 \
|
||||
else seq_len // self.block_size
|
||||
@ -583,6 +639,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
out=query_start_loc[1:])
|
||||
|
||||
if len(self.paged_kv_indptr) > 0:
|
||||
# extend to the maximum number of blocks as returned by the
|
||||
# scheduler
|
||||
self.paged_kv_indices.extend(
|
||||
[0] * (self.total_blocks - len(self.paged_kv_indices)))
|
||||
paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices,
|
||||
device="cpu",
|
||||
dtype=torch.int)
|
||||
@ -591,10 +651,15 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
dtype=torch.int)
|
||||
paged_kv_last_page_len_tensor = torch.tensor(
|
||||
self.paged_kv_last_page_len, device="cpu", dtype=torch.int)
|
||||
block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) -
|
||||
1,
|
||||
device="cpu",
|
||||
dtype=torch.int)
|
||||
else:
|
||||
paged_kv_indices_tensor = None
|
||||
paged_kv_indptr_tensor = None
|
||||
paged_kv_last_page_len_tensor = None
|
||||
block_table_bound_tensor = None
|
||||
|
||||
if self.runner.kv_cache_dtype.startswith("fp8"):
|
||||
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
||||
@ -613,6 +678,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
paged_kv_indptr=paged_kv_indptr_tensor,
|
||||
paged_kv_indices=paged_kv_indices_tensor,
|
||||
paged_kv_last_page_len=paged_kv_last_page_len_tensor,
|
||||
block_table_bound=block_table_bound_tensor,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
num_qo_heads=self.runner.model_config.get_num_attention_heads(
|
||||
self.runner.parallel_config),
|
||||
num_kv_heads=self.runner.model_config.get_num_kv_heads(
|
||||
|
||||
@ -4,13 +4,6 @@ from dataclasses import dataclass, field
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
|
||||
Union)
|
||||
|
||||
try:
|
||||
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
except ModuleNotFoundError:
|
||||
# vllm_flash_attn is not installed, use the identical ROCm FA metadata
|
||||
from vllm.attention.backends.rocm_flash_attn import (
|
||||
ROCmFlashAttentionMetadata as FlashAttentionMetadata)
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed import get_pp_group
|
||||
@ -36,6 +29,8 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
MULTI_STEP_ATTENTION_BACKENDS = ["flash-attn", "flashinfer"]
|
||||
|
||||
|
||||
def seq_output_builder():
|
||||
return SequenceOutput(
|
||||
@ -489,27 +484,27 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
||||
|
||||
def _advance_step(self, model_input: StatefulModelInput,
|
||||
out: SamplerOutput) -> StatefulModelInput:
|
||||
frozen_model_input = model_input.frozen_model_input
|
||||
assert frozen_model_input is not None
|
||||
assert frozen_model_input.attn_metadata is not None
|
||||
if self.attn_backend.get_name() not in MULTI_STEP_ATTENTION_BACKENDS:
|
||||
raise ValueError(
|
||||
f"Multi-step not supported for attention backend: "
|
||||
f"{self.attn_backend.get_name()}. Set VLLM_ATTENTION_BACKEND "
|
||||
f"to a value from {MULTI_STEP_ATTENTION_BACKENDS}.")
|
||||
|
||||
sampled_token_ids = model_input.cached_outputs[-1].sampled_token_ids
|
||||
num_seqs = model_input.num_seqs
|
||||
num_queries = model_input.num_queries
|
||||
assert num_seqs > 0
|
||||
assert num_queries > 0
|
||||
assert num_seqs >= num_queries
|
||||
|
||||
frozen_model_input = model_input.frozen_model_input
|
||||
assert frozen_model_input is not None
|
||||
attn_metadata = frozen_model_input.attn_metadata
|
||||
assert isinstance(attn_metadata, FlashAttentionMetadata)
|
||||
assert attn_metadata is not None
|
||||
|
||||
attn_metadata.advance_step(
|
||||
frozen_model_input,
|
||||
model_input.cached_outputs[-1].sampled_token_ids, self.block_size,
|
||||
num_seqs, num_queries)
|
||||
|
||||
if frozen_model_input.seq_lens is not None:
|
||||
for i in range(num_queries):
|
||||
frozen_model_input.seq_lens[i] = attn_metadata.seq_lens[i]
|
||||
sampled_token_ids,
|
||||
self.block_size,
|
||||
num_seqs,
|
||||
num_queries,
|
||||
)
|
||||
|
||||
return model_input
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user