mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-27 10:25:15 +08:00
[V0 Deprecation] Remove advance_step (#22969)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
74f441f4b5
commit
1c859a1387
@ -249,7 +249,6 @@ set(VLLM_EXT_SRC
|
||||
"csrc/quantization/gguf/gguf_kernel.cu"
|
||||
"csrc/quantization/activation_kernels.cu"
|
||||
"csrc/cuda_utils_kernels.cu"
|
||||
"csrc/prepare_inputs/advance_step.cu"
|
||||
"csrc/custom_all_reduce.cu"
|
||||
"csrc/torch_bindings.cpp")
|
||||
|
||||
|
||||
16
csrc/ops.h
16
csrc/ops.h
@ -145,22 +145,6 @@ void gelu_fast(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
void gelu_quick(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
void advance_step_flashattn(int64_t num_seqs, int64_t num_queries,
|
||||
int64_t block_size, torch::Tensor& input_tokens,
|
||||
torch::Tensor& sampled_token_ids,
|
||||
torch::Tensor& input_positions,
|
||||
torch::Tensor& seq_lens,
|
||||
torch::Tensor& slot_mapping,
|
||||
torch::Tensor& block_tables);
|
||||
|
||||
void advance_step_flashinfer(
|
||||
int64_t num_seqs, int64_t num_queries, int64_t block_size,
|
||||
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
|
||||
torch::Tensor& input_positions, torch::Tensor& seq_lens,
|
||||
torch::Tensor& slot_mapping, torch::Tensor& block_tables,
|
||||
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
|
||||
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds);
|
||||
|
||||
void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope,
|
||||
torch::Tensor const& q_pe,
|
||||
torch::Tensor const& kv_c_and_k_pe_cache,
|
||||
|
||||
@ -1,336 +0,0 @@
|
||||
/*
|
||||
* The goal of this GPU kernel is to advance input tensors on the GPU directly
|
||||
* PR: https://github.com/vllm-project/vllm/pull/6338
|
||||
* Current restrictions:
|
||||
* 1. Specialized for DraftModelRunner
|
||||
* 2. Supports flash_attn only
|
||||
*/
|
||||
|
||||
#include "advance_step.cuh"
|
||||
|
||||
namespace prepare_inputs {
|
||||
|
||||
//
|
||||
template <int const num_threads>
|
||||
__global__ void advance_step_flashattn_kernel(
|
||||
int num_seqs, int num_queries, int block_size, long* input_tokens_ptr,
|
||||
long const* sampled_token_ids_ptr, long* input_positions_ptr,
|
||||
int* seq_lens_ptr, long* slot_mapping_ptr, int const* block_tables_ptr,
|
||||
int64_t const block_tables_stride) {
|
||||
int const n_pad = num_seqs - num_queries;
|
||||
if (n_pad && blockIdx.x == 0) {
|
||||
// Handle cuda graph padding
|
||||
int const offset = num_queries;
|
||||
for (int i = threadIdx.x; i < n_pad; i += blockDim.x) {
|
||||
input_tokens_ptr[offset + i] = 0;
|
||||
input_positions_ptr[offset + i] = 0;
|
||||
slot_mapping_ptr[offset + i] = -1;
|
||||
}
|
||||
}
|
||||
|
||||
int num_query_blocks = div_ceil(num_queries, num_threads);
|
||||
|
||||
if (blockIdx.x >= num_query_blocks) {
|
||||
return;
|
||||
}
|
||||
|
||||
int cur_query_id = blockIdx.x * num_threads + threadIdx.x;
|
||||
|
||||
if (cur_query_id >= num_queries) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Update input_tokens
|
||||
input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id];
|
||||
|
||||
int seq_len = seq_lens_ptr[cur_query_id];
|
||||
int next_seq_len = seq_len + 1;
|
||||
int next_input_pos = next_seq_len - 1;
|
||||
|
||||
// Update seq_lens
|
||||
seq_lens_ptr[cur_query_id] = next_seq_len;
|
||||
// Update input_positions
|
||||
input_positions_ptr[cur_query_id] = next_input_pos;
|
||||
|
||||
int const* seq_block_tables_ptr =
|
||||
block_tables_ptr + block_tables_stride * cur_query_id;
|
||||
|
||||
int block_index = next_input_pos / block_size;
|
||||
int block_offset = next_input_pos % block_size;
|
||||
|
||||
int slot_num = seq_block_tables_ptr[block_index] * block_size + block_offset;
|
||||
// Update slot_mapping
|
||||
slot_mapping_ptr[cur_query_id] = slot_num;
|
||||
}
|
||||
|
||||
inline void verify_tensor(std::string const& name, torch::Tensor const& t,
|
||||
int64_t const size_0, int64_t const size_1,
|
||||
c10::ScalarType const type) {
|
||||
bool size_0_cond = true;
|
||||
if (size_0 != -1) {
|
||||
size_0_cond = t.size(0) == size_0;
|
||||
}
|
||||
|
||||
bool size_1_cond = true;
|
||||
if (size_1 != -1) {
|
||||
size_1_cond = t.size(1) == size_1;
|
||||
}
|
||||
|
||||
bool is_contiguous = t.is_contiguous();
|
||||
bool same_type = t.dtype() == type;
|
||||
|
||||
bool pass = size_0_cond && size_1_cond && is_contiguous && same_type;
|
||||
if (!pass) {
|
||||
TORCH_CHECK(false, "tensor: name = ", name, ", shape = ", t.sizes(),
|
||||
" is_cont = ", t.is_contiguous(), ", type = ", t.dtype(),
|
||||
" is not as expected: shape = [", size_0, ", ", size_1,
|
||||
"], type = ", type);
|
||||
}
|
||||
}
|
||||
|
||||
/// each thread processes a block per query
|
||||
__global__ void advance_step_flashinfer_kernel(
|
||||
int num_threads, int num_seqs, int num_queries, int block_size,
|
||||
long* input_tokens_ptr, long const* sampled_token_ids_ptr,
|
||||
long* input_positions_ptr, int* seq_lens_ptr, long* slot_mapping_ptr,
|
||||
int const* block_tables_ptr, int64_t const block_tables_stride,
|
||||
int* paged_kv_last_page_len_ptr, int* block_table_bound_ptr) {
|
||||
int const n_pad = num_seqs - num_queries;
|
||||
if (n_pad && blockIdx.x == 0) {
|
||||
// Handle cuda graph padding
|
||||
int const offset = num_queries;
|
||||
for (int i = threadIdx.x; i < n_pad; i += blockDim.x) {
|
||||
input_tokens_ptr[offset + i] = 0;
|
||||
input_positions_ptr[offset + i] = 0;
|
||||
slot_mapping_ptr[offset + i] = -1;
|
||||
}
|
||||
}
|
||||
int num_query_blocks = div_ceil(num_queries, num_threads);
|
||||
|
||||
if (blockIdx.x < num_query_blocks) {
|
||||
int cur_query_id = blockIdx.x * num_threads + threadIdx.x;
|
||||
|
||||
if (cur_query_id < num_queries) {
|
||||
// Update input_tokens
|
||||
input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id];
|
||||
|
||||
int seq_len = seq_lens_ptr[cur_query_id];
|
||||
int next_seq_len = seq_len + 1;
|
||||
int next_input_pos = next_seq_len - 1;
|
||||
|
||||
// Update seq_lens
|
||||
seq_lens_ptr[cur_query_id] = next_seq_len;
|
||||
// Update input_positions
|
||||
input_positions_ptr[cur_query_id] = next_input_pos;
|
||||
|
||||
int const* seq_block_tables_ptr =
|
||||
block_tables_ptr + block_tables_stride * cur_query_id;
|
||||
|
||||
int block_index = next_input_pos / block_size;
|
||||
int block_offset = next_input_pos % block_size;
|
||||
|
||||
// Update paged_kv_last_page_len
|
||||
paged_kv_last_page_len_ptr[cur_query_id] = block_offset + 1;
|
||||
|
||||
int slot_num =
|
||||
seq_block_tables_ptr[block_index] * block_size + block_offset;
|
||||
// Update slot_mapping
|
||||
slot_mapping_ptr[cur_query_id] = slot_num;
|
||||
block_table_bound_ptr[cur_query_id] = div_ceil(next_seq_len, block_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void advance_step_flashinfer_indptr_kernel(
|
||||
int num_threads, int num_seqs, int num_queries, int* paged_kv_indptr_ptr,
|
||||
int* block_table_bound_ptr) {
|
||||
int idx = blockIdx.x * num_threads + threadIdx.x;
|
||||
// Update paged_kv_indptr
|
||||
if (idx == 0) {
|
||||
paged_kv_indptr_ptr[idx] = 0;
|
||||
}
|
||||
if (idx < num_queries) {
|
||||
int sum = 0;
|
||||
for (int i = 0; i <= idx; ++i) {
|
||||
sum += block_table_bound_ptr[i];
|
||||
}
|
||||
paged_kv_indptr_ptr[idx + 1] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void advance_step_flashinfer_indices_kernel(
|
||||
int num_seqs, int num_queries, int const* block_tables_ptr,
|
||||
int64_t const max_num_blocks_per_seq, int* paged_kv_indices_ptr,
|
||||
int* paged_kv_indptr_ptr, int* block_table_bound_ptr) {
|
||||
// note: max_num_blocks_per_seq = block_tables.stride(0)
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
// when cuda graphs are enabled, paged_kv_indptr tensor
|
||||
// has to be updated for the padded queries
|
||||
// tid represents a query# for paged_kv_indptr tensor
|
||||
if (num_queries < tid && tid <= num_seqs) {
|
||||
paged_kv_indptr_ptr[tid] = paged_kv_indptr_ptr[num_queries];
|
||||
}
|
||||
|
||||
// each thread processes a block_ptr in block_tables
|
||||
// block_tables shape: [num_queries, max_num_blocks_per_seq]
|
||||
// paged_kv_indices is flattened block_tables.
|
||||
for (int idx = tid; idx < (num_seqs * max_num_blocks_per_seq);
|
||||
idx += (gridDim.x * blockDim.x)) {
|
||||
// block_tables-row = paged_kv_indptr[queryNum]
|
||||
int queryNum = idx / max_num_blocks_per_seq;
|
||||
int col = idx % max_num_blocks_per_seq;
|
||||
if (queryNum < num_queries && col < block_table_bound_ptr[queryNum]) {
|
||||
int indices_arr_idx = paged_kv_indptr_ptr[queryNum] + col;
|
||||
int block_tables_idx = queryNum * max_num_blocks_per_seq + col;
|
||||
paged_kv_indices_ptr[indices_arr_idx] =
|
||||
block_tables_ptr[block_tables_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void advance_step_flashattn(int num_seqs, int num_queries, int block_size,
|
||||
torch::Tensor& input_tokens, // type: long
|
||||
torch::Tensor& sampled_token_ids, // type: long
|
||||
torch::Tensor& input_positions, // type: long
|
||||
torch::Tensor& seq_lens, // type: int
|
||||
torch::Tensor& slot_mapping, // type: long
|
||||
torch::Tensor& block_tables) { // type: int
|
||||
|
||||
if (logging) {
|
||||
printf("advance_step_flashattn:\n");
|
||||
printf(" num_seqs = %d\n", num_seqs);
|
||||
printf(" num_queries = %d\n", num_queries);
|
||||
printf(" block_size = %d\n", block_size);
|
||||
}
|
||||
// Verify all tensors
|
||||
verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong);
|
||||
verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1,
|
||||
at::kLong);
|
||||
verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong);
|
||||
verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt);
|
||||
verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong);
|
||||
verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt);
|
||||
|
||||
int dev = sampled_token_ids.get_device();
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
|
||||
|
||||
int blocks;
|
||||
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
|
||||
|
||||
advance_step_flashattn_kernel<max_threads>
|
||||
<<<blocks, max_threads, 0, stream>>>(
|
||||
num_seqs, num_queries, block_size,
|
||||
reinterpret_cast<long*>(input_tokens.data_ptr()),
|
||||
reinterpret_cast<long const*>(sampled_token_ids.data_ptr()),
|
||||
reinterpret_cast<long*>(input_positions.data_ptr()),
|
||||
reinterpret_cast<int*>(seq_lens.data_ptr()),
|
||||
reinterpret_cast<long*>(slot_mapping.data_ptr()),
|
||||
reinterpret_cast<int const*>(block_tables.data_ptr()),
|
||||
block_tables.stride(0));
|
||||
}
|
||||
|
||||
void advance_step_flashinfer(
|
||||
int num_seqs, int num_queries, int block_size,
|
||||
torch::Tensor& input_tokens, // type: long
|
||||
torch::Tensor& sampled_token_ids, // type: long
|
||||
torch::Tensor& input_positions, // type: long
|
||||
torch::Tensor& seq_lens, // type: int
|
||||
torch::Tensor& slot_mapping, // type: long
|
||||
torch::Tensor& block_tables, // type: int
|
||||
torch::Tensor& paged_kv_indices, // type: int
|
||||
torch::Tensor& paged_kv_indptr, // type: int
|
||||
torch::Tensor& paged_kv_last_page_len, // type: int
|
||||
torch::Tensor& block_table_bound) { // type: int
|
||||
|
||||
if (logging) {
|
||||
printf("advance_step_flashinfer:\n");
|
||||
printf(" num_seqs = %d\n", num_seqs);
|
||||
printf(" num_queries = %d\n", num_queries);
|
||||
printf(" block_size = %d\n", block_size);
|
||||
printf(" block_tables.stride(0) = %zu\n", block_tables.stride(0));
|
||||
}
|
||||
// Verify all tensors
|
||||
verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong);
|
||||
// verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1,
|
||||
// at::kLong);
|
||||
verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong);
|
||||
verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt);
|
||||
verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong);
|
||||
verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt);
|
||||
|
||||
verify_tensor("paged_kv_indices", paged_kv_indices, -1, -1, at::kInt);
|
||||
verify_tensor("paged_kv_indptr", paged_kv_indptr, num_seqs + 1, -1, at::kInt);
|
||||
verify_tensor("paged_kv_last_page_len", paged_kv_last_page_len, num_seqs, -1,
|
||||
at::kInt);
|
||||
|
||||
verify_tensor("block_table_bound", block_table_bound, num_seqs, -1, at::kInt);
|
||||
|
||||
int dev = sampled_token_ids.get_device();
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
|
||||
|
||||
int blocks;
|
||||
int threads;
|
||||
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
|
||||
cudaDeviceGetAttribute(&threads, cudaDevAttrMaxThreadsPerBlock, dev);
|
||||
|
||||
TORCH_CHECK((blocks * threads > num_queries),
|
||||
"multi-step: not enough threads to map to num_queries = ",
|
||||
num_queries, " block_tables.stride(0) = ", block_tables.stride(0),
|
||||
" blocks = ", blocks, " max_threads = ", threads);
|
||||
if (logging) {
|
||||
printf("launching kernels with %d blocks and %d threads\n", blocks,
|
||||
threads);
|
||||
}
|
||||
advance_step_flashinfer_kernel<<<blocks, threads, 0, stream>>>(
|
||||
threads, num_seqs, num_queries, block_size,
|
||||
reinterpret_cast<long*>(input_tokens.data_ptr()),
|
||||
reinterpret_cast<long const*>(sampled_token_ids.data_ptr()),
|
||||
reinterpret_cast<long*>(input_positions.data_ptr()),
|
||||
reinterpret_cast<int*>(seq_lens.data_ptr()),
|
||||
reinterpret_cast<long*>(slot_mapping.data_ptr()),
|
||||
reinterpret_cast<int const*>(block_tables.data_ptr()),
|
||||
block_tables.stride(0),
|
||||
reinterpret_cast<int*>(paged_kv_last_page_len.data_ptr()),
|
||||
reinterpret_cast<int*>(block_table_bound.data_ptr()));
|
||||
|
||||
advance_step_flashinfer_indptr_kernel<<<blocks, threads, 0, stream>>>(
|
||||
threads, num_seqs, num_queries,
|
||||
reinterpret_cast<int*>(paged_kv_indptr.data_ptr()),
|
||||
reinterpret_cast<int*>(block_table_bound.data_ptr()));
|
||||
|
||||
advance_step_flashinfer_indices_kernel<<<blocks, threads, 0, stream>>>(
|
||||
num_seqs, num_queries,
|
||||
reinterpret_cast<int const*>(block_tables.data_ptr()),
|
||||
block_tables.stride(0),
|
||||
reinterpret_cast<int*>(paged_kv_indices.data_ptr()),
|
||||
reinterpret_cast<int*>(paged_kv_indptr.data_ptr()),
|
||||
reinterpret_cast<int*>(block_table_bound.data_ptr()));
|
||||
}
|
||||
|
||||
} // namespace prepare_inputs
|
||||
|
||||
void advance_step_flashattn(int64_t num_seqs, int64_t num_queries,
|
||||
int64_t block_size, torch::Tensor& input_tokens,
|
||||
torch::Tensor& sampled_token_ids,
|
||||
torch::Tensor& input_positions,
|
||||
torch::Tensor& seq_lens,
|
||||
torch::Tensor& slot_mapping,
|
||||
torch::Tensor& block_tables) {
|
||||
prepare_inputs::advance_step_flashattn(
|
||||
num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
|
||||
input_positions, seq_lens, slot_mapping, block_tables);
|
||||
}
|
||||
|
||||
void advance_step_flashinfer(
|
||||
int64_t num_seqs, int64_t num_queries, int64_t block_size,
|
||||
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
|
||||
torch::Tensor& input_positions, torch::Tensor& seq_lens,
|
||||
torch::Tensor& slot_mapping, torch::Tensor& block_tables,
|
||||
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
|
||||
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bound) {
|
||||
prepare_inputs::advance_step_flashinfer(
|
||||
num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
|
||||
input_positions, seq_lens, slot_mapping, block_tables, paged_kv_indices,
|
||||
paged_kv_indptr, paged_kv_last_page_len, block_table_bound);
|
||||
}
|
||||
@ -1,19 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <iostream>
|
||||
|
||||
namespace prepare_inputs {
|
||||
|
||||
static constexpr int max_threads = 256;
|
||||
static constexpr bool logging = false;
|
||||
|
||||
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
|
||||
|
||||
} // namespace prepare_inputs
|
||||
@ -142,25 +142,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
|
||||
ops.impl("gelu_quick", torch::kCUDA, &gelu_quick);
|
||||
|
||||
// prepare_inputs advance_step
|
||||
ops.def(
|
||||
"advance_step_flashattn(int num_seqs, int num_queries, int block_size, "
|
||||
"Tensor! input_tokens, Tensor sampled_token_ids, "
|
||||
"Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping, "
|
||||
"Tensor block_tables) -> ()");
|
||||
ops.impl("advance_step_flashattn", torch::kCUDA, &advance_step_flashattn);
|
||||
|
||||
ops.def(
|
||||
"advance_step_flashinfer("
|
||||
" int num_seqs, int num_queries, int block_size,"
|
||||
" Tensor! input_tokens, Tensor sampled_token_ids,"
|
||||
" Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping,"
|
||||
" Tensor block_tables, Tensor! paged_kv_indices,"
|
||||
" Tensor! paged_kv_indptr, Tensor! paged_kv_last_page_len,"
|
||||
" Tensor! block_table_bounds"
|
||||
") -> ()");
|
||||
ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer);
|
||||
|
||||
// Layernorm
|
||||
// Apply Root Mean Square (RMS) Normalization to the input tensor.
|
||||
ops.def(
|
||||
|
||||
@ -319,38 +319,6 @@ def apply_repetition_penalties(logits: torch.Tensor, prompt_mask: torch.Tensor,
|
||||
repetition_penalties)
|
||||
|
||||
|
||||
def advance_step_flashattn(num_seqs: int, num_queries: int, block_size: int,
|
||||
input_tokens: torch.Tensor,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
input_positions: torch.Tensor,
|
||||
seq_lens: torch.Tensor, slot_mapping: torch.Tensor,
|
||||
block_tables: torch.Tensor) -> None:
|
||||
"""Advance a step on GPU for existing inputs for a multi-step runner"""
|
||||
return torch.ops._C.advance_step_flashattn(num_seqs, num_queries,
|
||||
block_size, input_tokens,
|
||||
sampled_token_ids,
|
||||
input_positions, seq_lens,
|
||||
slot_mapping, block_tables)
|
||||
|
||||
|
||||
def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int,
|
||||
input_tokens: torch.Tensor,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
input_positions: torch.Tensor,
|
||||
seq_lens: torch.Tensor, slot_mapping: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
paged_kv_indices: torch.Tensor,
|
||||
paged_kv_indptr: torch.Tensor,
|
||||
paged_kv_last_page_len: torch.Tensor,
|
||||
block_table_bound: torch.Tensor) -> None:
|
||||
|
||||
return torch.ops._C.advance_step_flashinfer(
|
||||
num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
|
||||
input_positions, seq_lens, slot_mapping, block_tables,
|
||||
paged_kv_indices, paged_kv_indptr, paged_kv_last_page_len,
|
||||
block_table_bound)
|
||||
|
||||
|
||||
# fused quant layer norm ops
|
||||
def rms_norm_dynamic_per_token_quant(
|
||||
input: torch.Tensor,
|
||||
|
||||
@ -101,11 +101,6 @@ class AttentionBackend(ABC):
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def advance_step(self, model_input: "ModelRunnerInputBase",
|
||||
sampled_token_ids: Optional[torch.Tensor],
|
||||
block_size: int, num_seqs: int, num_queries: int) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def full_cls_name(cls) -> tuple[str, str]:
|
||||
return (cls.__module__, cls.__qualname__)
|
||||
|
||||
@ -35,8 +35,7 @@ from vllm.vllm_flash_attn import (flash_attn_varlen_func,
|
||||
flash_attn_with_kvcache)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
||||
ModelInputForGPUWithSamplingMetadata)
|
||||
from vllm.worker.model_runner import ModelInputForGPUBuilder
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -326,79 +325,6 @@ class DifferentialFlashAttentionMetadata(AttentionMetadata):
|
||||
cross_block_tables=self.cross_block_tables)
|
||||
return self._cached_decode_metadata
|
||||
|
||||
def advance_step(self,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
sampled_token_ids: Optional[torch.Tensor],
|
||||
block_size: int,
|
||||
num_seqs: int,
|
||||
num_queries: int,
|
||||
turn_prefills_into_decodes: bool = False):
|
||||
"""
|
||||
Update metadata in-place to advance one decode step.
|
||||
"""
|
||||
# When using cudagraph, the num_seqs is padded to the next captured
|
||||
# batch sized, but num_queries tracks the actual number of requests in
|
||||
# the batch. For --enforce-eager mode, num_seqs == num_queries
|
||||
if num_seqs != num_queries:
|
||||
assert num_seqs > num_queries
|
||||
assert self.use_cuda_graph
|
||||
|
||||
if turn_prefills_into_decodes:
|
||||
# When Multi-Step is enabled with Chunked-Prefill, prefills and
|
||||
# decodes are scheduled together. In the first step, all the
|
||||
# prefills turn into decodes. This update reflects that
|
||||
# conversion.
|
||||
assert self.num_decode_tokens + self.num_prefills == num_seqs
|
||||
self.num_decode_tokens += self.num_prefills
|
||||
self.num_prefills = 0
|
||||
self.num_prefill_tokens = 0
|
||||
self.max_prefill_seq_len = 0
|
||||
self.max_query_len = 1
|
||||
|
||||
self.slot_mapping = self.slot_mapping[:num_seqs]
|
||||
else:
|
||||
assert self.seq_lens is not None
|
||||
assert self.max_decode_seq_len == max(self.seq_lens)
|
||||
|
||||
assert self.num_prefills == 0
|
||||
assert self.num_prefill_tokens == 0
|
||||
assert self.num_decode_tokens == num_seqs
|
||||
assert self.slot_mapping.shape == (num_seqs, )
|
||||
|
||||
assert self.seq_lens is not None
|
||||
assert len(self.seq_lens) == num_seqs
|
||||
assert self.seq_lens_tensor is not None
|
||||
assert self.seq_lens_tensor.shape == (num_seqs, )
|
||||
assert self.max_query_len == 1
|
||||
assert self.max_prefill_seq_len == 0
|
||||
|
||||
assert self.query_start_loc is not None
|
||||
assert self.query_start_loc.shape == (num_queries + 1, )
|
||||
assert self.seq_start_loc is not None
|
||||
assert self.seq_start_loc.shape == (num_seqs + 1, )
|
||||
|
||||
assert self.context_lens_tensor is not None
|
||||
assert self.context_lens_tensor.shape == (num_queries, )
|
||||
|
||||
assert self.block_tables is not None
|
||||
assert self.block_tables.shape[0] == num_seqs
|
||||
|
||||
# Update query lengths. Note that we update only queries and not seqs,
|
||||
# since tensors may be padded due to captured cuda graph batch size
|
||||
for i in range(num_queries):
|
||||
self.seq_lens[i] += 1
|
||||
self.max_decode_seq_len = max(self.seq_lens)
|
||||
|
||||
ops.advance_step_flashattn(num_seqs=num_seqs,
|
||||
num_queries=num_queries,
|
||||
block_size=block_size,
|
||||
input_tokens=model_input.input_tokens,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
input_positions=model_input.input_positions,
|
||||
seq_lens=self.seq_lens_tensor,
|
||||
slot_mapping=self.slot_mapping,
|
||||
block_tables=self.block_tables)
|
||||
|
||||
|
||||
class DifferentialFlashAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[DifferentialFlashAttentionMetadata]):
|
||||
|
||||
@ -32,8 +32,7 @@ from vllm.vllm_flash_attn import (flash_attn_varlen_func,
|
||||
flash_attn_with_kvcache)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
||||
ModelInputForGPUWithSamplingMetadata)
|
||||
from vllm.worker.model_runner import ModelInputForGPUBuilder
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -309,79 +308,6 @@ class FlashAttentionMetadata(AttentionMetadata):
|
||||
cross_block_tables=self.cross_block_tables)
|
||||
return self._cached_decode_metadata
|
||||
|
||||
def advance_step(self,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
sampled_token_ids: Optional[torch.Tensor],
|
||||
block_size: int,
|
||||
num_seqs: int,
|
||||
num_queries: int,
|
||||
turn_prefills_into_decodes: bool = False):
|
||||
"""
|
||||
Update metadata in-place to advance one decode step.
|
||||
"""
|
||||
# When using cudagraph, the num_seqs is padded to the next captured
|
||||
# batch sized, but num_queries tracks the actual number of requests in
|
||||
# the batch. For --enforce-eager mode, num_seqs == num_queries
|
||||
if num_seqs != num_queries:
|
||||
assert num_seqs > num_queries
|
||||
assert self.use_cuda_graph
|
||||
|
||||
if turn_prefills_into_decodes:
|
||||
# When Multi-Step is enabled with Chunked-Prefill, prefills and
|
||||
# decodes are scheduled together. In the first step, all the
|
||||
# prefills turn into decodes. This update reflects that
|
||||
# conversion.
|
||||
assert self.num_decode_tokens + self.num_prefills == num_seqs
|
||||
self.num_decode_tokens += self.num_prefills
|
||||
self.num_prefills = 0
|
||||
self.num_prefill_tokens = 0
|
||||
self.max_prefill_seq_len = 0
|
||||
self.max_query_len = 1
|
||||
|
||||
self.slot_mapping = self.slot_mapping[:num_seqs]
|
||||
else:
|
||||
assert self.seq_lens is not None
|
||||
assert self.max_decode_seq_len == max(self.seq_lens)
|
||||
|
||||
assert self.num_prefills == 0
|
||||
assert self.num_prefill_tokens == 0
|
||||
assert self.num_decode_tokens == num_seqs
|
||||
assert self.slot_mapping.shape == (num_seqs, )
|
||||
|
||||
assert self.seq_lens is not None
|
||||
assert len(self.seq_lens) == num_seqs
|
||||
assert self.seq_lens_tensor is not None
|
||||
assert self.seq_lens_tensor.shape == (num_seqs, )
|
||||
assert self.max_query_len == 1
|
||||
assert self.max_prefill_seq_len == 0
|
||||
|
||||
assert self.query_start_loc is not None
|
||||
assert self.query_start_loc.shape == (num_queries + 1, )
|
||||
assert self.seq_start_loc is not None
|
||||
assert self.seq_start_loc.shape == (num_seqs + 1, )
|
||||
|
||||
assert self.context_lens_tensor is not None
|
||||
assert self.context_lens_tensor.shape == (num_queries, )
|
||||
|
||||
assert self.block_tables is not None
|
||||
assert self.block_tables.shape[0] == num_seqs
|
||||
|
||||
# Update query lengths. Note that we update only queries and not seqs,
|
||||
# since tensors may be padded due to captured cuda graph batch size
|
||||
for i in range(num_queries):
|
||||
self.seq_lens[i] += 1
|
||||
self.max_decode_seq_len = max(self.seq_lens)
|
||||
|
||||
ops.advance_step_flashattn(num_seqs=num_seqs,
|
||||
num_queries=num_queries,
|
||||
block_size=block_size,
|
||||
input_tokens=model_input.input_tokens,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
input_positions=model_input.input_positions,
|
||||
seq_lens=self.seq_lens_tensor,
|
||||
slot_mapping=self.slot_mapping,
|
||||
block_tables=self.block_tables)
|
||||
|
||||
|
||||
class FlashAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[FlashAttentionMetadata]):
|
||||
|
||||
@ -51,8 +51,7 @@ from vllm.utils.flashinfer import use_trtllm_attention
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
||||
ModelInputForGPUWithSamplingMetadata)
|
||||
from vllm.worker.model_runner import ModelInputForGPUBuilder
|
||||
|
||||
|
||||
class FlashInferBackend(AttentionBackend):
|
||||
@ -428,7 +427,7 @@ class FlashInferMetadata(AttentionMetadata):
|
||||
query_start_loc: Optional[torch.Tensor] = None
|
||||
block_tables: Optional[torch.Tensor] = None
|
||||
|
||||
# used for GPU in-place advance_step
|
||||
# used for GPU operations
|
||||
seq_lens_tensor: Optional[torch.Tensor] = None
|
||||
block_table_bound: Optional[torch.Tensor] = None
|
||||
|
||||
@ -587,66 +586,6 @@ class FlashInferMetadata(AttentionMetadata):
|
||||
return None
|
||||
return self
|
||||
|
||||
def advance_step(self,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
sampled_token_ids: Optional[torch.Tensor],
|
||||
block_size: int,
|
||||
num_seqs: int,
|
||||
num_queries: int,
|
||||
turn_prefills_into_decodes: bool = False):
|
||||
"""
|
||||
Update metadata in-place to advance one decode step.
|
||||
"""
|
||||
|
||||
if turn_prefills_into_decodes:
|
||||
# When Multi-Step is enabled with Chunked-Prefill, prefills and
|
||||
# decodes are scheduled together. In the first step, all the
|
||||
# prefills turn into decodes. This update reflects that
|
||||
# conversion.
|
||||
assert self.num_decode_tokens + self.num_prefills == num_seqs
|
||||
# Flashinfer doesn't support speculative decoding + chunked-prefill
|
||||
# + multi-step scheduling yet.
|
||||
assert self.decode_query_len == 1
|
||||
self.num_decode_tokens += self.num_prefills
|
||||
self.num_prefills = 0
|
||||
self.num_prefill_tokens = 0
|
||||
self.max_prefill_seq_len = 0
|
||||
self.max_query_len = 1
|
||||
|
||||
self.slot_mapping = self.slot_mapping[:num_seqs]
|
||||
else:
|
||||
assert self.seq_lens_tensor is not None
|
||||
|
||||
assert num_seqs > 0
|
||||
assert num_queries > 0
|
||||
assert model_input.attn_metadata is not None
|
||||
assert sampled_token_ids is not None
|
||||
|
||||
# When using cudagraph, the num_seqs is padded to the next captured
|
||||
# batch sized, but num_queries tracks the actual number of requests in
|
||||
# the batch. For --enforce-eager mode, num_seqs == num_queries
|
||||
if num_seqs != num_queries:
|
||||
assert num_seqs > num_queries
|
||||
assert self.use_cuda_graph
|
||||
|
||||
model_input.input_tokens[:num_queries] = sampled_token_ids.flatten()
|
||||
|
||||
# Update GPU tensors
|
||||
ops.advance_step_flashinfer(
|
||||
num_seqs=num_seqs,
|
||||
num_queries=num_queries,
|
||||
block_size=block_size,
|
||||
input_tokens=model_input.input_tokens,
|
||||
sampled_token_ids=model_input.input_tokens,
|
||||
input_positions=model_input.input_positions,
|
||||
seq_lens=self.seq_lens_tensor,
|
||||
slot_mapping=self.slot_mapping,
|
||||
block_tables=self.block_tables,
|
||||
paged_kv_indices=self.paged_kv_indices,
|
||||
paged_kv_indptr=self.paged_kv_indptr,
|
||||
paged_kv_last_page_len=self.paged_kv_last_page_len,
|
||||
block_table_bound=self.block_table_bound)
|
||||
|
||||
|
||||
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Type
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
@ -18,9 +18,6 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
|
||||
get_mla_metadata,
|
||||
is_flashmla_supported)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||
|
||||
|
||||
class FlashMLABackend(MLACommonBackend):
|
||||
|
||||
@ -62,16 +59,6 @@ class FlashMLAMetadata(MLACommonMetadata):
|
||||
self.decode_num_splits
|
||||
return decode_metadata
|
||||
|
||||
def advance_step(self,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
sampled_token_ids: Optional[torch.Tensor],
|
||||
block_size: int,
|
||||
num_seqs: int,
|
||||
num_queries: int,
|
||||
turn_prefills_into_decodes: bool = False):
|
||||
raise NotImplementedError(
|
||||
"advance_step is not implemented for FlashMLA")
|
||||
|
||||
|
||||
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
|
||||
|
||||
@ -234,8 +234,7 @@ except ImportError:
|
||||
flash_attn_varlen_func = None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
||||
ModelInputForGPUWithSamplingMetadata)
|
||||
from vllm.worker.model_runner import ModelInputForGPUBuilder
|
||||
|
||||
is_hip = current_platform.is_rocm()
|
||||
|
||||
@ -631,90 +630,6 @@ class MLACommonMetadata(AttentionMetadata):
|
||||
is_profile_run=self.is_profile_run)
|
||||
return self._cached_decode_metadata
|
||||
|
||||
def advance_step(self,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
sampled_token_ids: Optional[torch.Tensor],
|
||||
block_size: int,
|
||||
num_seqs: int,
|
||||
num_queries: int,
|
||||
turn_prefills_into_decodes: bool = False):
|
||||
"""
|
||||
Update metadata in-place to advance one decode step.
|
||||
"""
|
||||
# When using cudagraph, the num_seqs is padded to the next captured
|
||||
# batch sized, but num_queries tracks the actual number of requests in
|
||||
# the batch. For --enforce-eager mode, num_seqs == num_queries
|
||||
if num_seqs != num_queries:
|
||||
assert num_seqs > num_queries
|
||||
|
||||
if turn_prefills_into_decodes:
|
||||
# When Multi-Step is enabled with Chunked-Prefill, prefills and
|
||||
# decodes are scheduled together. In the first step, all the
|
||||
# prefills turn into decodes. This update reflects that
|
||||
# conversion.
|
||||
assert self.num_decode_tokens + self.num_prefills == num_seqs
|
||||
self.num_decode_tokens += self.num_prefills
|
||||
self.num_prefills = 0
|
||||
self.num_prefill_tokens = 0
|
||||
self.max_prefill_seq_len = 0
|
||||
self.max_query_len = 1
|
||||
|
||||
self.slot_mapping = self.slot_mapping[:num_seqs]
|
||||
else:
|
||||
assert self.seq_lens is not None
|
||||
assert self.max_decode_seq_len == max(self.seq_lens)
|
||||
|
||||
assert self.num_prefills == 0
|
||||
assert self.num_prefill_tokens == 0
|
||||
assert self.num_decode_tokens == num_seqs
|
||||
assert self.slot_mapping.shape == (num_seqs, )
|
||||
|
||||
assert self.seq_lens is not None
|
||||
assert len(self.seq_lens) == num_seqs
|
||||
assert self.seq_lens_tensor is not None
|
||||
assert self.seq_lens_tensor.shape == (num_seqs, )
|
||||
assert self.max_query_len == 1
|
||||
assert self.max_prefill_seq_len == 0
|
||||
|
||||
assert self.query_start_loc is not None
|
||||
assert self.query_start_loc.shape == (num_queries + 1, )
|
||||
assert self.seq_start_loc is not None
|
||||
assert self.seq_start_loc.shape == (num_seqs + 1, )
|
||||
|
||||
assert self.context_lens_tensor is not None
|
||||
assert self.context_lens_tensor.shape == (num_queries, )
|
||||
|
||||
assert self.block_tables is not None
|
||||
assert self.block_tables.shape[0] == num_seqs
|
||||
|
||||
# Update query lengths. Note that we update only queries and not seqs,
|
||||
# since tensors may be padded due to captured cuda graph batch size
|
||||
for i in range(num_queries):
|
||||
self.seq_lens[i] += 1
|
||||
self.max_decode_seq_len = max(self.seq_lens)
|
||||
|
||||
self._ops_advance_step(num_seqs=num_seqs,
|
||||
num_queries=num_queries,
|
||||
block_size=block_size,
|
||||
input_tokens=model_input.input_tokens,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
input_positions=model_input.input_positions)
|
||||
|
||||
def _ops_advance_step(self, num_seqs: int, num_queries: int,
|
||||
block_size: int, input_tokens: torch.Tensor,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
input_positions: torch.Tensor) -> None:
|
||||
# here we use advance_step_flashinfo to update the paged_kv_* tensors
|
||||
ops.advance_step_flashattn(num_seqs=num_seqs,
|
||||
num_queries=num_queries,
|
||||
block_size=block_size,
|
||||
input_tokens=input_tokens,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
input_positions=input_positions,
|
||||
seq_lens=self.seq_lens_tensor,
|
||||
slot_mapping=self.slot_mapping,
|
||||
block_tables=self.block_tables)
|
||||
|
||||
|
||||
class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
|
||||
"""
|
||||
|
||||
@ -15,8 +15,7 @@ from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.multimodal import MultiModalPlaceholderMap
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
||||
ModelInputForGPUWithSamplingMetadata)
|
||||
from vllm.worker.model_runner import (ModelInputForGPUBuilder)
|
||||
from vllm.utils import async_tensor_h2d
|
||||
|
||||
# Placeholder attention backend for models like Mamba and pooling models that
|
||||
@ -201,65 +200,6 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
|
||||
)
|
||||
return self._cached_decode_metadata
|
||||
|
||||
def advance_step(self,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
sampled_token_ids: Optional[torch.Tensor],
|
||||
block_size: int,
|
||||
num_seqs: int,
|
||||
num_queries: int,
|
||||
turn_prefills_into_decodes: bool = False):
|
||||
"""
|
||||
Update metadata in-place to advance one decode step.
|
||||
"""
|
||||
# When using cudagraph, the num_seqs is padded to the next captured
|
||||
# batch sized, but num_queries tracks the actual number of requests in
|
||||
# the batch. For --enforce-eager mode, num_seqs == num_queries
|
||||
if num_seqs != num_queries:
|
||||
assert num_seqs > num_queries
|
||||
assert self.use_cuda_graph
|
||||
|
||||
assert not turn_prefills_into_decodes, \
|
||||
("Multi-Step + Chunked-Prefill is not supported for attention-free"
|
||||
"models. turn_prefills_into_decodes is a "
|
||||
"Multi-Step + Chunked-Prefill specific parameter.")
|
||||
|
||||
assert self.seq_lens is not None
|
||||
assert self.max_decode_seq_len == max(self.seq_lens)
|
||||
|
||||
assert self.num_prefills == 0
|
||||
assert self.num_prefill_tokens == 0
|
||||
assert self.num_decode_tokens == num_seqs
|
||||
|
||||
assert self.seq_lens is not None
|
||||
assert len(self.seq_lens) == num_seqs
|
||||
assert self.seq_lens_tensor is not None
|
||||
assert self.seq_lens_tensor.shape == (num_seqs, )
|
||||
assert self.max_query_len == 1
|
||||
assert self.max_prefill_seq_len == 0
|
||||
|
||||
assert self.query_start_loc is not None
|
||||
assert self.query_start_loc.shape == (num_queries + 1, )
|
||||
assert self.seq_start_loc is not None
|
||||
assert self.seq_start_loc.shape == (num_seqs + 1, )
|
||||
|
||||
assert self.context_lens_tensor is not None
|
||||
assert self.context_lens_tensor.shape == (num_queries, )
|
||||
|
||||
# Update query lengths. Note that we update only queries and not seqs,
|
||||
# since tensors may be padded due to captured cuda graph batch size
|
||||
for i in range(num_queries):
|
||||
self.seq_lens[i] += 1
|
||||
self.max_decode_seq_len = max(self.seq_lens)
|
||||
|
||||
# Update sequences, masking off entries greater than num_queries
|
||||
device = self.seq_lens_tensor.device
|
||||
mask = torch.arange(self.seq_lens_tensor.size(0),
|
||||
device=device) < num_queries
|
||||
self.seq_lens_tensor += mask.to(self.seq_lens_tensor.dtype)
|
||||
if sampled_token_ids is not None:
|
||||
model_input.input_tokens.masked_scatter_(
|
||||
mask, sampled_token_ids[:num_queries])
|
||||
|
||||
|
||||
class PlaceholderAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[PlaceholderAttentionMetadata]):
|
||||
|
||||
@ -7,7 +7,6 @@ from typing import TYPE_CHECKING, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
@ -107,26 +106,6 @@ class AiterMLAMetadata(MLACommonMetadata):
|
||||
|
||||
return self._cached_decode_metadata
|
||||
|
||||
def _ops_advance_step(self, num_seqs: int, num_queries: int,
|
||||
block_size: int, input_tokens: torch.Tensor,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
input_positions: torch.Tensor) -> None:
|
||||
|
||||
ops.advance_step_flashinfer(
|
||||
num_seqs=num_seqs,
|
||||
num_queries=num_queries,
|
||||
block_size=block_size,
|
||||
input_tokens=input_tokens,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
input_positions=input_positions,
|
||||
seq_lens=self.seq_lens_tensor,
|
||||
slot_mapping=self.slot_mapping,
|
||||
block_tables=self.block_tables,
|
||||
paged_kv_indices=self.paged_kv_indices,
|
||||
paged_kv_indptr=self.paged_kv_indptr,
|
||||
paged_kv_last_page_lens=self.paged_kv_last_page_lens,
|
||||
block_table_bound=self.block_table_bound)
|
||||
|
||||
|
||||
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
BLOCK_TABLE_EXTENDER: list[list[int]] = [[]]
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
import itertools
|
||||
from dataclasses import dataclass
|
||||
from functools import cache
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Type
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
@ -23,9 +23,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
_PARTITION_SIZE_ROCM = 256
|
||||
|
||||
@ -261,69 +258,6 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
self._cached_decode_metadata.query_start_loc = qs - qs[0]
|
||||
return self._cached_decode_metadata
|
||||
|
||||
def advance_step(self,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
sampled_token_ids: Optional[torch.Tensor],
|
||||
block_size: int,
|
||||
num_seqs: int,
|
||||
num_queries: int,
|
||||
turn_prefills_into_decodes: bool = False):
|
||||
"""
|
||||
Update metadata in-place to advance one decode step.
|
||||
"""
|
||||
|
||||
assert not turn_prefills_into_decodes, \
|
||||
("Chunked prefill is not supported with rocm_flash_attn yet."
|
||||
"turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill "
|
||||
"specific parameter.")
|
||||
|
||||
# When using cudagraph, the num_seqs is padded to the next captured
|
||||
# batch sized, but num_queries tracks the actual number of requests in
|
||||
# the batch. For --enforce-eager mode, num_seqs == num_queries
|
||||
if num_seqs != num_queries:
|
||||
assert num_seqs > num_queries
|
||||
assert self.use_cuda_graph
|
||||
|
||||
assert self.num_prefills == 0
|
||||
assert self.num_prefill_tokens == 0
|
||||
assert self.num_decode_tokens == num_seqs
|
||||
assert self.slot_mapping.shape == (num_seqs, )
|
||||
|
||||
assert self.seq_lens is not None
|
||||
assert len(self.seq_lens) == num_seqs
|
||||
assert self.seq_lens_tensor is not None
|
||||
assert self.seq_lens_tensor.shape == (num_seqs, )
|
||||
assert self.max_query_len == 1
|
||||
assert self.max_prefill_seq_len == 0
|
||||
assert self.max_decode_seq_len == max(self.seq_lens)
|
||||
|
||||
assert self.query_start_loc is not None
|
||||
assert self.query_start_loc.shape == (num_queries + 1, )
|
||||
assert self.seq_start_loc is not None
|
||||
assert self.seq_start_loc.shape == (num_seqs + 1, )
|
||||
|
||||
assert self.context_lens_tensor is not None
|
||||
assert self.context_lens_tensor.shape == (num_queries, )
|
||||
|
||||
assert self.block_tables is not None
|
||||
assert self.block_tables.shape[0] == num_seqs
|
||||
|
||||
# Update query lengths. Note that we update only queries and not seqs,
|
||||
# since tensors may be padded due to captured cuda graph batch size
|
||||
for i in range(num_queries):
|
||||
self.seq_lens[i] += 1
|
||||
self.max_decode_seq_len = max(self.seq_lens)
|
||||
|
||||
ops.advance_step_flashattn(num_seqs=num_seqs,
|
||||
num_queries=num_queries,
|
||||
block_size=block_size,
|
||||
input_tokens=model_input.input_tokens,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
input_positions=model_input.input_positions,
|
||||
seq_lens=self.seq_lens_tensor,
|
||||
slot_mapping=self.slot_mapping,
|
||||
block_tables=self.block_tables)
|
||||
|
||||
|
||||
class ROCmFlashAttentionMetadataBuilder(
|
||||
CommonMetadataBuilder[ROCmFlashAttentionMetadata]):
|
||||
|
||||
@ -762,8 +762,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
has Prefills (if any). The rest of the steps are guaranteed to be all
|
||||
decodes. In this case, we set up the padding as if all the sequences
|
||||
are decodes so we may run all steps except the first step in CUDA graph
|
||||
mode. The padding is accounted for in the multi-step `advance_step`
|
||||
family of functions.
|
||||
mode.
|
||||
|
||||
Args:
|
||||
num_seqs (int): Number of sequences scheduled to run.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user