[multi-step] add flashinfer backend (#7928)

This commit is contained in:
William Lin 2024-09-12 11:16:22 -07:00 committed by GitHub
parent f2e263b801
commit a6c0f3658d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 371 additions and 84 deletions

View File

@ -54,10 +54,21 @@ void gelu_fast(torch::Tensor& out, torch::Tensor& input);
void gelu_quick(torch::Tensor& out, torch::Tensor& input);
void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
void advance_step_flashattn(int64_t num_seqs, int64_t num_queries,
int64_t block_size, torch::Tensor& input_tokens,
torch::Tensor& sampled_token_ids,
torch::Tensor& input_positions,
torch::Tensor& seq_lens,
torch::Tensor& slot_mapping,
torch::Tensor& block_tables);
void advance_step_flashinfer(
int64_t num_seqs, int64_t num_queries, int64_t block_size,
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
torch::Tensor& input_positions, torch::Tensor& seq_lens,
torch::Tensor& slot_mapping, torch::Tensor& block_tables);
torch::Tensor& slot_mapping, torch::Tensor& block_tables,
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds);
#ifndef USE_ROCM
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,

View File

@ -12,12 +12,10 @@ namespace prepare_inputs {
//
template <int const num_threads>
__global__ void advance_step_kernel(int num_seqs, int num_queries,
int block_size, long* input_tokens_ptr,
long const* sampled_token_ids_ptr,
long* input_positions_ptr,
int* seq_lens_ptr, long* slot_mapping_ptr,
int const* block_tables_ptr,
__global__ void advance_step_flashattn_kernel(
int num_seqs, int num_queries, int block_size, long* input_tokens_ptr,
long const* sampled_token_ids_ptr, long* input_positions_ptr,
int* seq_lens_ptr, long* slot_mapping_ptr, int const* block_tables_ptr,
int64_t const block_tables_stride) {
int num_query_blocks = div_ceil(num_queries, num_threads);
@ -79,7 +77,82 @@ inline void verify_tensor(std::string const& name, torch::Tensor& t,
}
}
void advance_step(int num_seqs, int num_queries, int block_size,
__global__ void advance_step_flashinfer_kernel(
int num_threads, int num_seqs, int num_queries, int block_size,
long* input_tokens_ptr, long const* sampled_token_ids_ptr,
long* input_positions_ptr, int* seq_lens_ptr, long* slot_mapping_ptr,
int const* block_tables_ptr, int64_t const block_tables_stride,
int* paged_kv_last_page_len_ptr, int* block_table_bound_ptr) {
int num_query_blocks = div_ceil(num_queries, num_threads);
if (blockIdx.x < num_query_blocks) {
int cur_query_id = blockIdx.x * num_threads + threadIdx.x;
if (cur_query_id < num_queries) {
// Update input_tokens
input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id];
int seq_len = seq_lens_ptr[cur_query_id];
int next_seq_len = seq_len + 1;
int next_input_pos = next_seq_len - 1;
// Update seq_lens
seq_lens_ptr[cur_query_id] = next_seq_len;
// Update input_positions
input_positions_ptr[cur_query_id] = next_input_pos;
int const* seq_block_tables_ptr =
block_tables_ptr + block_tables_stride * cur_query_id;
int block_index = next_input_pos / block_size;
int block_offset = next_input_pos % block_size;
// Update paged_kv_last_page_len
paged_kv_last_page_len_ptr[cur_query_id] = block_offset + 1;
int slot_num =
seq_block_tables_ptr[block_index] * block_size + block_offset;
// Update slot_mapping
slot_mapping_ptr[cur_query_id] = slot_num;
block_table_bound_ptr[cur_query_id] = div_ceil(next_seq_len, block_size);
}
}
}
__global__ void advance_step_flashinfer_indptr_kernel(
int num_threads, int num_seqs, int num_queries, int* paged_kv_indptr_ptr,
int* block_table_bound_ptr) {
int idx = blockIdx.x * num_threads + threadIdx.x;
// Update paged_kv_indptr
if (idx < num_queries) {
int sum = 0;
for (int i = 0; i <= idx; ++i) {
sum += block_table_bound_ptr[i];
}
paged_kv_indptr_ptr[idx + 1] = sum;
}
}
__global__ void advance_step_flashinfer_indices_kernel(
int num_threads, int num_seqs, int num_queries, int const* block_tables_ptr,
int64_t const block_tables_stride, int* paged_kv_indices_ptr,
int* paged_kv_indptr_ptr, int* block_table_bound_ptr) {
int idx = blockIdx.x * num_threads + threadIdx.x;
int row = idx / block_tables_stride;
int col = idx % block_tables_stride;
if (row < num_queries && col < block_table_bound_ptr[row]) {
paged_kv_indices_ptr[paged_kv_indptr_ptr[row] + col] =
block_tables_ptr[row * block_tables_stride + col];
}
// if cudagraph, fill padded seqs with the last valid seq's indptr
if (num_queries < row && row <= num_seqs) {
paged_kv_indptr_ptr[row] = paged_kv_indptr_ptr[num_queries];
}
}
void advance_step_flashattn(int num_seqs, int num_queries, int block_size,
torch::Tensor& input_tokens, // type: long
torch::Tensor& sampled_token_ids, // type: long
torch::Tensor& input_positions, // type: long
@ -88,7 +161,7 @@ void advance_step(int num_seqs, int num_queries, int block_size,
torch::Tensor& block_tables) { // type: int
if (logging) {
printf("advance_step:\n");
printf("advance_step_flashattn:\n");
printf(" num_seqs = %d\n", num_seqs);
printf(" num_queries = %d\n", num_queries);
printf(" block_size = %d\n", block_size);
@ -108,7 +181,8 @@ void advance_step(int num_seqs, int num_queries, int block_size,
int blocks;
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
advance_step_kernel<max_threads><<<blocks, max_threads, 0, stream>>>(
advance_step_flashattn_kernel<max_threads>
<<<blocks, max_threads, 0, stream>>>(
num_seqs, num_queries, block_size,
reinterpret_cast<long*>(input_tokens.data_ptr()),
reinterpret_cast<long const*>(sampled_token_ids.data_ptr()),
@ -119,13 +193,114 @@ void advance_step(int num_seqs, int num_queries, int block_size,
block_tables.stride(0));
}
void advance_step_flashinfer(
int num_seqs, int num_queries, int block_size,
torch::Tensor& input_tokens, // type: long
torch::Tensor& sampled_token_ids, // type: long
torch::Tensor& input_positions, // type: long
torch::Tensor& seq_lens, // type: int
torch::Tensor& slot_mapping, // type: long
torch::Tensor& block_tables, // type: int
torch::Tensor& paged_kv_indices, // type: int
torch::Tensor& paged_kv_indptr, // type: int
torch::Tensor& paged_kv_last_page_len, // type: int
torch::Tensor& block_table_bound) { // type: int
if (logging) {
printf("advance_step_flashinfer:\n");
printf(" num_seqs = %d\n", num_seqs);
printf(" num_queries = %d\n", num_queries);
printf(" block_size = %d\n", block_size);
printf(" block_tables.stride(0) = %d\n", block_tables.stride(0));
}
// Verify all tensors
verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong);
// verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1,
// at::kLong);
verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong);
verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt);
verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong);
verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt);
verify_tensor("paged_kv_indices", paged_kv_indices, -1, -1, at::kInt);
verify_tensor("paged_kv_indptr", paged_kv_indptr, num_seqs + 1, -1, at::kInt);
verify_tensor("paged_kv_last_page_len", paged_kv_last_page_len, num_seqs, -1,
at::kInt);
verify_tensor("block_table_bound", block_table_bound, num_seqs, -1, at::kInt);
int dev = sampled_token_ids.get_device();
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
int blocks;
int threads;
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
cudaDeviceGetAttribute(&threads, cudaDevAttrMaxThreadsPerBlock, dev);
if (logging) {
printf("launching kernel with %d blocks\n", blocks);
}
// TODO(will): support arbitrary block_tables stride
if ((blocks * threads) / block_tables.stride(0) < num_queries) {
TORCH_CHECK(false,
"multi-step: not enough threads to map block_table to"
"FlashInfer's paged_kv_indices on GPU. Try reducing the number "
"of seqs,",
" increasing the block size or take smaller steps.",
" num_queries = ", num_queries,
" block_tables.stride(0) = ", block_tables.stride(0),
" blocks = ", blocks, " max_threads = ", threads);
}
advance_step_flashinfer_kernel<<<blocks, threads, 0, stream>>>(
threads, num_seqs, num_queries, block_size,
reinterpret_cast<long*>(input_tokens.data_ptr()),
reinterpret_cast<long const*>(sampled_token_ids.data_ptr()),
reinterpret_cast<long*>(input_positions.data_ptr()),
reinterpret_cast<int*>(seq_lens.data_ptr()),
reinterpret_cast<long*>(slot_mapping.data_ptr()),
reinterpret_cast<int const*>(block_tables.data_ptr()),
block_tables.stride(0),
reinterpret_cast<int*>(paged_kv_last_page_len.data_ptr()),
reinterpret_cast<int*>(block_table_bound.data_ptr()));
advance_step_flashinfer_indptr_kernel<<<blocks, threads, 0, stream>>>(
threads, num_seqs, num_queries,
reinterpret_cast<int*>(paged_kv_indptr.data_ptr()),
reinterpret_cast<int*>(block_table_bound.data_ptr()));
advance_step_flashinfer_indices_kernel<<<blocks, threads, 0, stream>>>(
threads, num_seqs, num_queries,
reinterpret_cast<int const*>(block_tables.data_ptr()),
block_tables.stride(0),
reinterpret_cast<int*>(paged_kv_indices.data_ptr()),
reinterpret_cast<int*>(paged_kv_indptr.data_ptr()),
reinterpret_cast<int*>(block_table_bound.data_ptr()));
}
} // namespace prepare_inputs
void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
void advance_step_flashattn(int64_t num_seqs, int64_t num_queries,
int64_t block_size, torch::Tensor& input_tokens,
torch::Tensor& sampled_token_ids,
torch::Tensor& input_positions,
torch::Tensor& seq_lens,
torch::Tensor& slot_mapping,
torch::Tensor& block_tables) {
prepare_inputs::advance_step_flashattn(
num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
input_positions, seq_lens, slot_mapping, block_tables);
}
void advance_step_flashinfer(
int64_t num_seqs, int64_t num_queries, int64_t block_size,
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
torch::Tensor& input_positions, torch::Tensor& seq_lens,
torch::Tensor& slot_mapping, torch::Tensor& block_tables) {
prepare_inputs::advance_step(num_seqs, num_queries, block_size, input_tokens,
sampled_token_ids, input_positions, seq_lens,
slot_mapping, block_tables);
torch::Tensor& slot_mapping, torch::Tensor& block_tables,
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bound) {
prepare_inputs::advance_step_flashinfer(
num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
input_positions, seq_lens, slot_mapping, block_tables, paged_kv_indices,
paged_kv_indptr, paged_kv_last_page_len, block_table_bound);
}

View File

@ -74,11 +74,22 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// prepare_inputs advance_step
ops.def(
"advance_step(int num_seqs, int num_queries, int block_size, "
"advance_step_flashattn(int num_seqs, int num_queries, int block_size, "
"Tensor! input_tokens, Tensor sampled_token_ids, "
"Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping, "
"Tensor block_tables) -> ()");
ops.impl("advance_step", torch::kCUDA, &advance_step);
ops.impl("advance_step_flashattn", torch::kCUDA, &advance_step_flashattn);
ops.def(
"advance_step_flashinfer("
" int num_seqs, int num_queries, int block_size,"
" Tensor! input_tokens, Tensor sampled_token_ids,"
" Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping,"
" Tensor block_tables, Tensor! paged_kv_indices,"
" Tensor! paged_kv_indptr, Tensor! paged_kv_last_page_len,"
" Tensor! block_table_bounds"
") -> ()");
ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer);
// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.

View File

@ -1,9 +1,10 @@
# Test the AsyncLLMEngine with multi-step-decoding
from typing import List, Optional
import pytest
from tests.kernels.utils import override_backend_env_variable
from ..models.utils import check_logprobs_close
from ..utils import (completions_with_server_args, get_client_text_generations,
get_client_text_logprob_generations)
@ -33,8 +34,9 @@ DEFAULT_SERVER_ARGS: List[str] = [
@pytest.mark.parametrize("eager_mode", [False, True])
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
@pytest.mark.parametrize("num_logprobs", [None, 5])
@pytest.mark.parametrize("is_async", [False, True])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("is_async", [True])
@pytest.mark.parametrize("attention_backend", ["FLASHINFER", "FLASH_ATTN"])
@pytest.mark.asyncio
async def test_multi_step(
example_prompts,
@ -46,6 +48,8 @@ async def test_multi_step(
num_prompts: int,
is_async: bool,
num_logprobs: Optional[int],
attention_backend: str,
monkeypatch,
) -> None:
"""Test vLLM engine with multi-step scheduling in an OpenAI-protocol
client/server environment.
@ -71,6 +75,8 @@ async def test_multi_step(
completions endpoint; `None` -> no logprobs
"""
override_backend_env_variable(monkeypatch, attention_backend)
prompts = example_prompts
if len(prompts) < num_prompts:
prompts = prompts * ((num_prompts // len(prompts)) + 1)

View File

@ -161,16 +161,36 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
def advance_step(num_seqs: int, num_queries: int, block_size: int,
input_tokens: torch.Tensor, sampled_token_ids: torch.Tensor,
input_positions: torch.Tensor, seq_lens: torch.Tensor,
slot_mapping: torch.Tensor,
def advance_step_flashattn(num_seqs: int, num_queries: int, block_size: int,
input_tokens: torch.Tensor,
sampled_token_ids: torch.Tensor,
input_positions: torch.Tensor,
seq_lens: torch.Tensor, slot_mapping: torch.Tensor,
block_tables: torch.Tensor) -> None:
"""Advance a step on GPU for existing inputs for a multi-step runner"""
return torch.ops._C.advance_step(num_seqs, num_queries, block_size,
input_tokens, sampled_token_ids,
input_positions, seq_lens, slot_mapping,
block_tables)
return torch.ops._C.advance_step_flashattn(num_seqs, num_queries,
block_size, input_tokens,
sampled_token_ids,
input_positions, seq_lens,
slot_mapping, block_tables)
def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int,
input_tokens: torch.Tensor,
sampled_token_ids: torch.Tensor,
input_positions: torch.Tensor,
seq_lens: torch.Tensor, slot_mapping: torch.Tensor,
block_tables: torch.Tensor,
paged_kv_indices: torch.Tensor,
paged_kv_indptr: torch.Tensor,
paged_kv_last_page_len: torch.Tensor,
block_table_bound: torch.Tensor) -> None:
return torch.ops._C.advance_step_flashinfer(
num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
input_positions, seq_lens, slot_mapping, block_tables,
paged_kv_indices, paged_kv_indptr, paged_kv_last_page_len,
block_table_bound)
# quantization ops

View File

@ -83,7 +83,9 @@ class AttentionBackend(ABC):
) -> None:
raise NotImplementedError
def advance_step(self, num_seqs: int, num_queries: int):
def advance_step(self, model_input: "ModelRunnerInputBase",
sampled_token_ids: Optional[torch.Tensor],
block_size: int, num_seqs: int, num_queries: int) -> None:
raise NotImplementedError

View File

@ -380,7 +380,7 @@ class FlashAttentionMetadata(AttentionMetadata):
self.seq_lens[i] += 1
self.max_decode_seq_len = max(self.seq_lens)
ops.advance_step(num_seqs=num_seqs,
ops.advance_step_flashattn(num_seqs=num_seqs,
num_queries=num_queries,
block_size=block_size,
input_tokens=model_input.input_tokens,

View File

@ -30,7 +30,8 @@ from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
make_tensor_with_pad)
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUBuilder
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata)
class FlashInferBackend(AttentionBackend):
@ -268,6 +269,10 @@ class FlashInferMetadata(AttentionMetadata):
query_start_loc: Optional[torch.Tensor] = None
block_tables: Optional[torch.Tensor] = None
# used for GPU in-place advance_step
seq_lens_tensor: Optional[torch.Tensor] = None
block_table_bound: Optional[torch.Tensor] = None
# An example for paged_kv_indices, paged_kv_indptr:
# request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7]
@ -318,6 +323,8 @@ class FlashInferMetadata(AttentionMetadata):
assert self.paged_kv_indices is not None
assert self.paged_kv_indptr is not None
assert self.paged_kv_last_page_len is not None
assert self.block_table_bound is not None
assert self.seq_lens_tensor is not None
batch_size = self.query_start_loc.shape[0] - 1
assert batch_size >= 0
# We will use flash attention for profiling to
@ -327,6 +334,8 @@ class FlashInferMetadata(AttentionMetadata):
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
self.device)
self.block_table_bound = self.block_table_bound.to(self.device)
self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
self.prefill_wrapper.end_forward()
self.prefill_wrapper.begin_forward(
@ -335,7 +344,6 @@ class FlashInferMetadata(AttentionMetadata):
self.num_qo_heads, self.num_kv_heads, self.head_dim,
self.page_size)
else:
if not self.use_cuda_graph:
assert self.paged_kv_indices is not None
assert self.paged_kv_indptr is not None
assert self.paged_kv_last_page_len is not None
@ -343,6 +351,11 @@ class FlashInferMetadata(AttentionMetadata):
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
self.device)
# handle model warmup path
if self.block_table_bound is not None:
self.block_table_bound = self.block_table_bound.to(self.device)
if self.seq_lens_tensor is not None:
self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
assert self.decode_wrapper is not None
self.decode_wrapper.end_forward()
@ -391,6 +404,48 @@ class FlashInferMetadata(AttentionMetadata):
return self
def advance_step(
self,
model_input: "ModelInputForGPUWithSamplingMetadata",
sampled_token_ids: Optional[torch.Tensor],
block_size: int,
num_seqs: int,
num_queries: int,
):
"""
Update metadata in-place to advance one decode step.
"""
assert num_seqs > 0
assert num_queries > 0
assert model_input.attn_metadata is not None
assert sampled_token_ids is not None
# When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries
if num_seqs != num_queries:
assert num_seqs > num_queries
assert self.use_cuda_graph
model_input.input_tokens[:num_queries] = sampled_token_ids.flatten()
# Update GPU tensors
ops.advance_step_flashinfer(
num_seqs=num_seqs,
num_queries=num_queries,
block_size=block_size,
input_tokens=model_input.input_tokens,
sampled_token_ids=model_input.input_tokens,
input_positions=model_input.input_positions,
seq_lens=self.seq_lens_tensor,
slot_mapping=self.slot_mapping,
block_tables=self.block_tables,
paged_kv_indices=self.paged_kv_indices,
paged_kv_indptr=self.paged_kv_indptr,
paged_kv_last_page_len=self.paged_kv_last_page_len,
block_table_bound=self.block_table_bound)
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
@ -428,7 +483,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.paged_kv_indptr: List[int] = [0]
# paged_kv_last_page_len is the length of the last page of each request
self.paged_kv_last_page_len: List[int] = []
self.total_blocks = 0
self.is_profile_run: bool = False
def _add_seq_group(
@ -499,6 +554,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# block_table_bound is 1 with 1 valid block.
# If seq_len = 15, block_size = 16,
# block_table_bound is 0 + 1 with 1 valid block.
self.total_blocks += len(block_table)
block_table_bound = seq_len // self.block_size + 1 \
if seq_len % self.block_size != 0 \
else seq_len // self.block_size
@ -583,6 +639,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
out=query_start_loc[1:])
if len(self.paged_kv_indptr) > 0:
# extend to the maximum number of blocks as returned by the
# scheduler
self.paged_kv_indices.extend(
[0] * (self.total_blocks - len(self.paged_kv_indices)))
paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices,
device="cpu",
dtype=torch.int)
@ -591,10 +651,15 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
dtype=torch.int)
paged_kv_last_page_len_tensor = torch.tensor(
self.paged_kv_last_page_len, device="cpu", dtype=torch.int)
block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) -
1,
device="cpu",
dtype=torch.int)
else:
paged_kv_indices_tensor = None
paged_kv_indptr_tensor = None
paged_kv_last_page_len_tensor = None
block_table_bound_tensor = None
if self.runner.kv_cache_dtype.startswith("fp8"):
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
@ -613,6 +678,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
paged_kv_indptr=paged_kv_indptr_tensor,
paged_kv_indices=paged_kv_indices_tensor,
paged_kv_last_page_len=paged_kv_last_page_len_tensor,
block_table_bound=block_table_bound_tensor,
seq_lens_tensor=seq_lens_tensor,
num_qo_heads=self.runner.model_config.get_num_attention_heads(
self.runner.parallel_config),
num_kv_heads=self.runner.model_config.get_num_kv_heads(

View File

@ -4,13 +4,6 @@ from dataclasses import dataclass, field
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
Union)
try:
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
except ModuleNotFoundError:
# vllm_flash_attn is not installed, use the identical ROCm FA metadata
from vllm.attention.backends.rocm_flash_attn import (
ROCmFlashAttentionMetadata as FlashAttentionMetadata)
import torch
from vllm.distributed import get_pp_group
@ -36,6 +29,8 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
MULTI_STEP_ATTENTION_BACKENDS = ["flash-attn", "flashinfer"]
def seq_output_builder():
return SequenceOutput(
@ -489,27 +484,27 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
def _advance_step(self, model_input: StatefulModelInput,
out: SamplerOutput) -> StatefulModelInput:
frozen_model_input = model_input.frozen_model_input
assert frozen_model_input is not None
assert frozen_model_input.attn_metadata is not None
if self.attn_backend.get_name() not in MULTI_STEP_ATTENTION_BACKENDS:
raise ValueError(
f"Multi-step not supported for attention backend: "
f"{self.attn_backend.get_name()}. Set VLLM_ATTENTION_BACKEND "
f"to a value from {MULTI_STEP_ATTENTION_BACKENDS}.")
sampled_token_ids = model_input.cached_outputs[-1].sampled_token_ids
num_seqs = model_input.num_seqs
num_queries = model_input.num_queries
assert num_seqs > 0
assert num_queries > 0
assert num_seqs >= num_queries
frozen_model_input = model_input.frozen_model_input
assert frozen_model_input is not None
attn_metadata = frozen_model_input.attn_metadata
assert isinstance(attn_metadata, FlashAttentionMetadata)
assert attn_metadata is not None
attn_metadata.advance_step(
frozen_model_input,
model_input.cached_outputs[-1].sampled_token_ids, self.block_size,
num_seqs, num_queries)
if frozen_model_input.seq_lens is not None:
for i in range(num_queries):
frozen_model_input.seq_lens[i] = attn_metadata.seq_lens[i]
sampled_token_ids,
self.block_size,
num_seqs,
num_queries,
)
return model_input