mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 11:37:12 +08:00
Implements dual-chunk-flash-attn backend for dual chunk attention with sparse attention support (#11844)
This commit is contained in:
parent
f6518b2b48
commit
60f7624334
@ -230,6 +230,7 @@ set(VLLM_EXT_SRC
|
||||
"csrc/attention/paged_attention_v1.cu"
|
||||
"csrc/attention/paged_attention_v2.cu"
|
||||
"csrc/attention/merge_attn_states.cu"
|
||||
"csrc/attention/vertical_slash_index.cu"
|
||||
"csrc/pos_encoding_kernels.cu"
|
||||
"csrc/activation_kernels.cu"
|
||||
"csrc/layernorm_kernels.cu"
|
||||
|
||||
401
csrc/attention/vertical_slash_index.cu
Normal file
401
csrc/attention/vertical_slash_index.cu
Normal file
@ -0,0 +1,401 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include <cuda.h>
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
__device__ int64_t save_blocks(int* block_offset, int64_t range_start,
|
||||
int64_t range_end, int64_t block_size,
|
||||
int64_t input_block_count, int64_t kv_seqlen) {
|
||||
if (range_start >= kv_seqlen) {
|
||||
return input_block_count;
|
||||
}
|
||||
if (range_end > kv_seqlen) {
|
||||
range_end = kv_seqlen;
|
||||
}
|
||||
int64_t current_block_count = input_block_count;
|
||||
for (int idx = range_start; idx < range_end; idx += block_size) {
|
||||
block_offset[current_block_count++] = idx;
|
||||
}
|
||||
return current_block_count;
|
||||
}
|
||||
|
||||
__global__ void convert_vertical_slash_indexes_kernel(
|
||||
const int* q_seqlens, // [BATCH, ]
|
||||
const int* kv_seqlens, // [BATCH, ]
|
||||
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
|
||||
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
|
||||
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
|
||||
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
|
||||
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
|
||||
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
|
||||
int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
|
||||
int64_t NNZ_V, int64_t NNZ_S,
|
||||
bool causal // True for intra, False for succ
|
||||
) {
|
||||
const int batch_idx = blockIdx.y;
|
||||
const int head_idx = blockIdx.x;
|
||||
const int group_idx = blockIdx.z;
|
||||
|
||||
int64_t q_seqlen = q_seqlens[batch_idx];
|
||||
int64_t kv_seqlen = kv_seqlens[batch_idx];
|
||||
int64_t block_idx_m = group_idx * blockDim.x + threadIdx.x;
|
||||
int64_t start_m = block_idx_m * BLOCK_SIZE_M;
|
||||
if (start_m >= q_seqlen) {
|
||||
return;
|
||||
}
|
||||
int64_t end_m = start_m + BLOCK_SIZE_M;
|
||||
vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
|
||||
slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;
|
||||
int64_t row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
|
||||
block_count += row_offset;
|
||||
block_offset += row_offset * NNZ_S;
|
||||
column_count += row_offset;
|
||||
column_index += row_offset * NNZ_V;
|
||||
|
||||
bool has_slash = true;
|
||||
int64_t tmp_col_cnt = 0, tmp_blk_cnt = 0;
|
||||
int64_t s = 0, v = 0;
|
||||
int64_t v_idx = vertical_indexes[v++];
|
||||
int64_t s_idx = slash_indexes[s++];
|
||||
if (causal) {
|
||||
while (s_idx >= end_m + (kv_seqlen - q_seqlen) && s < NNZ_S) {
|
||||
s_idx = slash_indexes[s++];
|
||||
}
|
||||
if (s_idx > end_m + (kv_seqlen - q_seqlen)) has_slash = false;
|
||||
s_idx = max((kv_seqlen - q_seqlen) + end_m - s_idx, BLOCK_SIZE_M);
|
||||
} else {
|
||||
while (s_idx >= end_m + kv_seqlen && s < NNZ_S) {
|
||||
s_idx = slash_indexes[s++];
|
||||
}
|
||||
if (s_idx > end_m + kv_seqlen) has_slash = false;
|
||||
s_idx = max(kv_seqlen + end_m - s_idx, BLOCK_SIZE_M);
|
||||
}
|
||||
|
||||
int64_t range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx;
|
||||
if (!has_slash) {
|
||||
if (causal) {
|
||||
range_start = (kv_seqlen - q_seqlen) + end_m;
|
||||
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
|
||||
} else {
|
||||
range_start = kv_seqlen;
|
||||
range_end = kv_seqlen + BLOCK_SIZE_N;
|
||||
}
|
||||
}
|
||||
|
||||
bool slash_finished = false;
|
||||
while (1) {
|
||||
if (v_idx < range_end) {
|
||||
if (v_idx < range_start) {
|
||||
column_index[tmp_col_cnt++] = v_idx;
|
||||
}
|
||||
if (v < NNZ_V) {
|
||||
v_idx = vertical_indexes[v++];
|
||||
} else {
|
||||
if (causal)
|
||||
v_idx = end_m + BLOCK_SIZE_N + (kv_seqlen - q_seqlen);
|
||||
else
|
||||
v_idx = end_m + BLOCK_SIZE_N + kv_seqlen;
|
||||
}
|
||||
} else {
|
||||
if ((s < NNZ_S && causal) ||
|
||||
(s < NNZ_S && !causal && slash_indexes[s] >= start_m)) {
|
||||
if (causal)
|
||||
s_idx = max((kv_seqlen - q_seqlen) + end_m - slash_indexes[s++],
|
||||
BLOCK_SIZE_M);
|
||||
else
|
||||
s_idx = max(kv_seqlen + end_m - slash_indexes[s++], BLOCK_SIZE_M);
|
||||
} else {
|
||||
if (v == NNZ_V || (v_idx > range_start && causal)) {
|
||||
// add the last vertical if no more slash
|
||||
if (v == NNZ_V && !causal && v_idx < kv_seqlen) {
|
||||
column_index[tmp_col_cnt++] = v_idx;
|
||||
}
|
||||
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
|
||||
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
|
||||
break;
|
||||
} else {
|
||||
if (causal) {
|
||||
range_start = (kv_seqlen - q_seqlen) + end_m;
|
||||
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
|
||||
} else {
|
||||
// if slash_finished but there are vertical left, save current
|
||||
// blocks
|
||||
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
|
||||
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
|
||||
range_start = kv_seqlen;
|
||||
range_end = kv_seqlen + BLOCK_SIZE_N;
|
||||
}
|
||||
slash_finished = true;
|
||||
}
|
||||
}
|
||||
if (!slash_finished) {
|
||||
if (s_idx > range_end + BLOCK_SIZE_M) {
|
||||
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
|
||||
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
|
||||
range_start = s_idx - BLOCK_SIZE_M;
|
||||
range_end = s_idx;
|
||||
} else if (s_idx > range_end) {
|
||||
range_end += BLOCK_SIZE_M;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
block_count[0] = tmp_blk_cnt;
|
||||
column_count[0] = tmp_col_cnt;
|
||||
}
|
||||
|
||||
void convert_vertical_slash_indexes_64x64(
|
||||
const int* q_seqlens, // [BATCH, ]
|
||||
const int* kv_seqlens, // [BATCH, ]
|
||||
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
|
||||
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
|
||||
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
|
||||
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
|
||||
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
|
||||
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
|
||||
int64_t BATCH_SIZE, int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M,
|
||||
int64_t BLOCK_SIZE_N, int64_t NNZ_V, int64_t NNZ_S, bool causal) {
|
||||
const int N_THREADS = 64;
|
||||
const dim3 dimBlock(N_THREADS);
|
||||
const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS);
|
||||
convert_vertical_slash_indexes_kernel<<<dimGrid, dimBlock>>>(
|
||||
q_seqlens, kv_seqlens, vertical_indexes, slash_indexes, block_count,
|
||||
block_offset, column_count, column_index, N_HEADS, N_ROWS, BLOCK_SIZE_M,
|
||||
BLOCK_SIZE_N, NNZ_V, NNZ_S, causal);
|
||||
}
|
||||
|
||||
/**
|
||||
* Implements the Algorithm 4 in paper https://arxiv.org/abs/2407.02490.
|
||||
*
|
||||
* This function builds the index of each row of blocks from vertical indices
|
||||
* and slash indices. The vertical indices are treated as points, while the
|
||||
* slash indices are converted as ranges. The output consists of the merged
|
||||
* ranges and separate column indices, where the ranges are represented by
|
||||
* block indices.
|
||||
*
|
||||
* The implementation is referenced from the original MInference repo:
|
||||
* https://github.com/microsoft/MInference/blob/main/csrc/vertical_slash_index.cu.
|
||||
*/
|
||||
void convert_vertical_slash_indexes(
|
||||
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
|
||||
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
|
||||
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
|
||||
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
|
||||
torch::Tensor q_seqlens, // [BATCH, ]
|
||||
torch::Tensor kv_seqlens, // [BATCH, ]
|
||||
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
|
||||
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
|
||||
int64_t context_size, int64_t block_size_M, int64_t block_size_N,
|
||||
bool causal) {
|
||||
cudaSetDevice(q_seqlens.get_device());
|
||||
|
||||
int batch_size = slash_indexes.size(0);
|
||||
int num_heads = slash_indexes.size(1);
|
||||
int nnz_slash = slash_indexes.size(2);
|
||||
int nnz_vertical = vertical_indexes.size(2);
|
||||
int num_rows = (context_size + block_size_M - 1) / block_size_M;
|
||||
|
||||
convert_vertical_slash_indexes_64x64(
|
||||
q_seqlens.data_ptr<int>(), kv_seqlens.data_ptr<int>(),
|
||||
vertical_indexes.data_ptr<int>(), slash_indexes.data_ptr<int>(),
|
||||
block_count.data_ptr<int>(), block_offset.data_ptr<int>(),
|
||||
column_count.data_ptr<int>(), column_index.data_ptr<int>(), batch_size,
|
||||
num_heads, num_rows, block_size_M, block_size_N, nnz_vertical, nnz_slash,
|
||||
causal);
|
||||
}
|
||||
|
||||
__global__ void convert_vertical_slash_indexes_kernel_mergehead(
|
||||
const int* q_seqlens, // [BATCH, ]
|
||||
const int* kv_seqlens, // [BATCH, ]
|
||||
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
|
||||
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
|
||||
const int* per_head_vertical_topkv, const int* per_head_slash_topkv,
|
||||
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
|
||||
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
|
||||
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
|
||||
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
|
||||
int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
|
||||
int64_t NNZ_V, int64_t NNZ_S,
|
||||
bool causal // True for intra, False for succ
|
||||
) {
|
||||
const int batch_idx = blockIdx.y;
|
||||
const int head_idx = blockIdx.x;
|
||||
const int group_idx = blockIdx.z;
|
||||
|
||||
int64_t q_seqlen = q_seqlens[batch_idx];
|
||||
int64_t kv_seqlen = kv_seqlens[batch_idx];
|
||||
int64_t block_idx_m = group_idx * blockDim.x + threadIdx.x;
|
||||
int64_t start_m = block_idx_m * BLOCK_SIZE_M;
|
||||
if (start_m >= q_seqlen) {
|
||||
return;
|
||||
}
|
||||
int64_t end_m = start_m + BLOCK_SIZE_M;
|
||||
vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
|
||||
slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;
|
||||
int64_t row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
|
||||
block_count += row_offset;
|
||||
block_offset += row_offset * NNZ_S;
|
||||
column_count += row_offset;
|
||||
column_index += row_offset * NNZ_V;
|
||||
|
||||
// MergeHead: each head has it's unique max topk NNZ_V,NNZ_S. (NNZ_V,NNZ_S
|
||||
// above is buffer size, use to compute offset)
|
||||
NNZ_S = per_head_slash_topkv[head_idx];
|
||||
NNZ_V = per_head_vertical_topkv[head_idx];
|
||||
|
||||
bool has_slash = true;
|
||||
int64_t tmp_col_cnt = 0, tmp_blk_cnt = 0;
|
||||
int64_t s = 0, v = 0;
|
||||
int64_t v_idx = vertical_indexes[v++];
|
||||
int64_t s_idx = slash_indexes[s++];
|
||||
if (causal) {
|
||||
while (s_idx >= end_m + (kv_seqlen - q_seqlen) && s < NNZ_S) {
|
||||
s_idx = slash_indexes[s++];
|
||||
}
|
||||
if (s_idx > end_m + (kv_seqlen - q_seqlen)) has_slash = false;
|
||||
s_idx = max((kv_seqlen - q_seqlen) + end_m - s_idx, BLOCK_SIZE_M);
|
||||
} else {
|
||||
while (s_idx >= end_m + kv_seqlen && s < NNZ_S) {
|
||||
s_idx = slash_indexes[s++];
|
||||
}
|
||||
if (s_idx > end_m + kv_seqlen) has_slash = false;
|
||||
s_idx = max(kv_seqlen + end_m - s_idx, BLOCK_SIZE_M);
|
||||
}
|
||||
|
||||
int64_t range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx;
|
||||
if (!has_slash) {
|
||||
if (causal) {
|
||||
range_start = (kv_seqlen - q_seqlen) + end_m;
|
||||
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
|
||||
} else {
|
||||
range_start = kv_seqlen;
|
||||
range_end = kv_seqlen + BLOCK_SIZE_N;
|
||||
}
|
||||
}
|
||||
|
||||
bool slash_finished = false;
|
||||
while (1) {
|
||||
if (v_idx < range_end) {
|
||||
if (v_idx < range_start) {
|
||||
column_index[tmp_col_cnt++] = v_idx;
|
||||
}
|
||||
if (v < NNZ_V) {
|
||||
v_idx = vertical_indexes[v++];
|
||||
} else {
|
||||
if (causal)
|
||||
v_idx = end_m + BLOCK_SIZE_N + (kv_seqlen - q_seqlen);
|
||||
else
|
||||
v_idx = end_m + BLOCK_SIZE_N + kv_seqlen;
|
||||
}
|
||||
} else {
|
||||
if ((s < NNZ_S && causal) ||
|
||||
(s < NNZ_S && !causal && slash_indexes[s] >= start_m)) {
|
||||
if (causal)
|
||||
s_idx = max((kv_seqlen - q_seqlen) + end_m - slash_indexes[s++],
|
||||
BLOCK_SIZE_M);
|
||||
else
|
||||
s_idx = max(kv_seqlen + end_m - slash_indexes[s++], BLOCK_SIZE_M);
|
||||
} else {
|
||||
if (v == NNZ_V || (v_idx > range_start && causal)) {
|
||||
// add the last vertical if no more slash
|
||||
if (v == NNZ_V && !causal && v_idx < kv_seqlen) {
|
||||
column_index[tmp_col_cnt++] = v_idx;
|
||||
}
|
||||
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
|
||||
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
|
||||
break;
|
||||
} else {
|
||||
if (causal) {
|
||||
range_start = (kv_seqlen - q_seqlen) + end_m;
|
||||
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
|
||||
} else {
|
||||
// if slash_finished but there are vertical left, save current
|
||||
// blocks
|
||||
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
|
||||
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
|
||||
range_start = kv_seqlen;
|
||||
range_end = kv_seqlen + BLOCK_SIZE_N;
|
||||
}
|
||||
slash_finished = true;
|
||||
}
|
||||
}
|
||||
if (!slash_finished) {
|
||||
if (s_idx > range_end + BLOCK_SIZE_M) {
|
||||
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
|
||||
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
|
||||
range_start = s_idx - BLOCK_SIZE_M;
|
||||
range_end = s_idx;
|
||||
} else if (s_idx > range_end) {
|
||||
range_end += BLOCK_SIZE_M;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
block_count[0] = tmp_blk_cnt;
|
||||
column_count[0] = tmp_col_cnt;
|
||||
}
|
||||
|
||||
void convert_vertical_slash_indexes_64x64_mergehead(
|
||||
const int* q_seqlens, // [BATCH, ]
|
||||
const int* kv_seqlens, // [BATCH, ]
|
||||
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
|
||||
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
|
||||
int* per_head_vertical_topkv, int* per_head_slash_topkv,
|
||||
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
|
||||
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
|
||||
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
|
||||
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
|
||||
int64_t BATCH_SIZE, int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M,
|
||||
int64_t BLOCK_SIZE_N, int64_t NNZ_V, int64_t NNZ_S, bool causal) {
|
||||
const int N_THREADS = 64;
|
||||
const dim3 dimBlock(N_THREADS);
|
||||
const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS);
|
||||
convert_vertical_slash_indexes_kernel_mergehead<<<dimGrid, dimBlock>>>(
|
||||
q_seqlens, kv_seqlens, vertical_indexes, slash_indexes,
|
||||
per_head_vertical_topkv, per_head_slash_topkv, block_count, block_offset,
|
||||
column_count, column_index, N_HEADS, N_ROWS, BLOCK_SIZE_M, BLOCK_SIZE_N,
|
||||
NNZ_V, NNZ_S, causal);
|
||||
}
|
||||
|
||||
/**
|
||||
* Implements the Algorithm 4 in paper https://arxiv.org/abs/2407.02490.
|
||||
*
|
||||
* Like the above convert_vertical_slash_indexes, but with
|
||||
* pre-computed vertical and slash counts.
|
||||
*/
|
||||
void convert_vertical_slash_indexes_mergehead(
|
||||
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
|
||||
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
|
||||
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
|
||||
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
|
||||
torch::Tensor q_seqlens, // [BATCH, ]
|
||||
torch::Tensor kv_seqlens, // [BATCH, ]
|
||||
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
|
||||
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
|
||||
torch::Tensor vertical_indices_count, // [N_HEADS, ]
|
||||
torch::Tensor slash_indices_count, // [N_HEADS, ]
|
||||
int64_t context_size, int64_t block_size_M, int64_t block_size_N,
|
||||
bool causal) {
|
||||
cudaSetDevice(q_seqlens.get_device());
|
||||
|
||||
int batch_size = slash_indexes.size(0);
|
||||
int num_heads = slash_indexes.size(1);
|
||||
int nnz_slash = slash_indexes.size(2);
|
||||
int nnz_vertical = vertical_indexes.size(2);
|
||||
int num_rows = (context_size + block_size_M - 1) / block_size_M;
|
||||
|
||||
convert_vertical_slash_indexes_64x64_mergehead(
|
||||
q_seqlens.data_ptr<int>(), kv_seqlens.data_ptr<int>(),
|
||||
vertical_indexes.data_ptr<int>(), slash_indexes.data_ptr<int>(),
|
||||
vertical_indices_count.data_ptr<int>(),
|
||||
slash_indices_count.data_ptr<int>(), block_count.data_ptr<int>(),
|
||||
block_offset.data_ptr<int>(), column_count.data_ptr<int>(),
|
||||
column_index.data_ptr<int>(), batch_size, num_heads, num_rows,
|
||||
block_size_M, block_size_N, nnz_vertical, nnz_slash, causal);
|
||||
}
|
||||
25
csrc/ops.h
25
csrc/ops.h
@ -59,6 +59,31 @@ void merge_attn_states(torch::Tensor& output,
|
||||
const torch::Tensor& prefix_lse,
|
||||
const torch::Tensor& suffix_output,
|
||||
const torch::Tensor& suffix_lse);
|
||||
|
||||
void convert_vertical_slash_indexes(
|
||||
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
|
||||
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
|
||||
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
|
||||
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
|
||||
torch::Tensor q_seqlens, // [BATCH, ]
|
||||
torch::Tensor kv_seqlens, // [BATCH, ]
|
||||
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
|
||||
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
|
||||
int64_t context_size, int64_t block_size_M, int64_t block_size_N,
|
||||
bool causal);
|
||||
|
||||
void convert_vertical_slash_indexes_mergehead(
|
||||
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
|
||||
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
|
||||
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
|
||||
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
|
||||
torch::Tensor q_seqlens, // [BATCH, ]
|
||||
torch::Tensor kv_seqlens, // [BATCH, ]
|
||||
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
|
||||
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
|
||||
torch::Tensor vertical_indices_count, // [N_HEADS, ]
|
||||
torch::Tensor slash_indices_count, int64_t context_size,
|
||||
int64_t block_size_M, int64_t block_size_N, bool causal);
|
||||
#endif
|
||||
|
||||
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
|
||||
|
||||
@ -77,6 +77,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor suffix_output,"
|
||||
" Tensor suffix_lse) -> ()");
|
||||
ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);
|
||||
|
||||
ops.def(
|
||||
"convert_vertical_slash_indexes("
|
||||
" Tensor! block_count, Tensor! block_offset, "
|
||||
" Tensor! column_count, Tensor! column_index, "
|
||||
" Tensor q_seqlens, Tensor q_seqlens, "
|
||||
" Tensor vertical_indexes, Tensor slash_indexes, "
|
||||
" int context_size, int block_size_M, int block_size_N, "
|
||||
" bool causal) -> ()");
|
||||
ops.impl("convert_vertical_slash_indexes", torch::kCUDA,
|
||||
&convert_vertical_slash_indexes);
|
||||
|
||||
ops.def(
|
||||
"convert_vertical_slash_indexes_mergehead("
|
||||
" Tensor! block_count, Tensor! block_offset, "
|
||||
" Tensor! column_count, Tensor! column_index, "
|
||||
" Tensor q_seqlens, Tensor q_seqlens, "
|
||||
" Tensor vertical_indexes, Tensor slash_indexes, "
|
||||
" Tensor vertical_indices_count, Tensor slash_indices_count, "
|
||||
" int context_size, int block_size_M, int block_size_N, "
|
||||
" bool causal) -> ()");
|
||||
ops.impl("convert_vertical_slash_indexes_mergehead", torch::kCUDA,
|
||||
&convert_vertical_slash_indexes_mergehead);
|
||||
#endif
|
||||
|
||||
// Activation ops
|
||||
|
||||
66
examples/offline_inference/qwen_1m.py
Normal file
66
examples/offline_inference/qwen_1m.py
Normal file
@ -0,0 +1,66 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import os
|
||||
from urllib.request import urlopen
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
os.environ["VLLM_ATTENTION_BACKEND"] = "DUAL_CHUNK_FLASH_ATTN"
|
||||
os.environ["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = "1"
|
||||
|
||||
|
||||
def load_prompt() -> str:
|
||||
# Test cases with various lengths can be found at:
|
||||
#
|
||||
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/64k.txt
|
||||
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/200k.txt
|
||||
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/600k.txt
|
||||
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/1m.txt
|
||||
|
||||
with urlopen(
|
||||
"https://qianwen-res.oss-cn-beijing.aliyuncs.com"
|
||||
"/Qwen2.5-1M/test-data/600k.txt",
|
||||
timeout=5) as response:
|
||||
prompt = response.read().decode('utf-8')
|
||||
return prompt
|
||||
|
||||
|
||||
# Processing the prompt.
|
||||
def process_requests(llm: LLM, prompts: list[str]) -> None:
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.7,
|
||||
top_p=0.8,
|
||||
top_k=20,
|
||||
repetition_penalty=1.05,
|
||||
detokenize=True,
|
||||
max_tokens=256,
|
||||
)
|
||||
# Generate texts from the prompts.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt_token_ids = output.prompt_token_ids
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt length: {len(prompt_token_ids)}, "
|
||||
f"Generated text: {generated_text!r}")
|
||||
|
||||
|
||||
# Create an LLM.
|
||||
def initialize_engine() -> LLM:
|
||||
llm = LLM(model="Qwen/Qwen2.5-7B-Instruct-1M",
|
||||
max_model_len=1048576,
|
||||
tensor_parallel_size=4,
|
||||
enforce_eager=True,
|
||||
enable_chunked_prefill=True,
|
||||
max_num_batched_tokens=131072)
|
||||
return llm
|
||||
|
||||
|
||||
def main():
|
||||
llm = initialize_engine()
|
||||
prompt = load_prompt()
|
||||
process_requests(llm, [prompt])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -150,6 +150,101 @@ def merge_attn_states(output: torch.Tensor,
|
||||
prefix_lse, suffix_output, suffix_lse)
|
||||
|
||||
|
||||
def convert_vertical_slash_indexes(
|
||||
q_seqlens: torch.Tensor, # [BATCH, ]
|
||||
kv_seqlens: torch.Tensor, # [BATCH, ]
|
||||
vertical_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_V]
|
||||
slash_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_S]
|
||||
context_size: int,
|
||||
block_size_M: int,
|
||||
block_size_N: int,
|
||||
causal: bool = True,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
batch_size = slash_indexes.size(0)
|
||||
num_heads = slash_indexes.size(1)
|
||||
nnz_slash = slash_indexes.size(2)
|
||||
nnz_vertical = vertical_indexes.size(2)
|
||||
num_rows = (context_size + block_size_M - 1) // block_size_M
|
||||
|
||||
block_count = torch.zeros(batch_size,
|
||||
num_heads,
|
||||
num_rows,
|
||||
dtype=q_seqlens.dtype,
|
||||
device=q_seqlens.device)
|
||||
block_offset = torch.zeros(batch_size,
|
||||
num_heads,
|
||||
num_rows,
|
||||
nnz_slash,
|
||||
dtype=q_seqlens.dtype,
|
||||
device=q_seqlens.device)
|
||||
column_count = torch.zeros(batch_size,
|
||||
num_heads,
|
||||
num_rows,
|
||||
dtype=q_seqlens.dtype,
|
||||
device=q_seqlens.device)
|
||||
column_index = torch.zeros(batch_size,
|
||||
num_heads,
|
||||
num_rows,
|
||||
nnz_vertical,
|
||||
dtype=q_seqlens.dtype,
|
||||
device=q_seqlens.device)
|
||||
|
||||
torch.ops._C.convert_vertical_slash_indexes(
|
||||
block_count, block_offset, column_count, column_index, q_seqlens,
|
||||
kv_seqlens, vertical_indexes, slash_indexes, context_size,
|
||||
block_size_M, block_size_N, causal)
|
||||
return block_count, block_offset, column_count, column_index
|
||||
|
||||
|
||||
def convert_vertical_slash_indexes_mergehead(
|
||||
q_seqlens: torch.Tensor, # [BATCH, ]
|
||||
kv_seqlens: torch.Tensor, # [BATCH, ]
|
||||
vertical_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_V]
|
||||
slash_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_S]
|
||||
# [N_HEADS] : different head use different number of indices
|
||||
vertical_indices_count: torch.Tensor,
|
||||
slash_indices_count: torch.Tensor,
|
||||
context_size: int,
|
||||
block_size_M: int,
|
||||
block_size_N: int,
|
||||
causal: bool = True,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
batch_size = slash_indexes.size(0)
|
||||
num_heads = slash_indexes.size(1)
|
||||
nnz_slash = slash_indexes.size(2)
|
||||
nnz_vertical = vertical_indexes.size(2)
|
||||
num_rows = (context_size + block_size_M - 1) // block_size_M
|
||||
|
||||
block_count = torch.empty(batch_size,
|
||||
num_heads,
|
||||
num_rows,
|
||||
dtype=q_seqlens.dtype,
|
||||
device=q_seqlens.device)
|
||||
block_offset = torch.empty(batch_size,
|
||||
num_heads,
|
||||
num_rows,
|
||||
nnz_slash,
|
||||
dtype=q_seqlens.dtype,
|
||||
device=q_seqlens.device)
|
||||
column_count = torch.empty(batch_size,
|
||||
num_heads,
|
||||
num_rows,
|
||||
dtype=q_seqlens.dtype,
|
||||
device=q_seqlens.device)
|
||||
column_index = torch.empty(batch_size,
|
||||
num_heads,
|
||||
num_rows,
|
||||
nnz_vertical,
|
||||
dtype=q_seqlens.dtype,
|
||||
device=q_seqlens.device)
|
||||
|
||||
torch.ops._C.convert_vertical_slash_indexes_mergehead(
|
||||
block_count, block_offset, column_count, column_index, q_seqlens,
|
||||
kv_seqlens, vertical_indexes, slash_indexes, vertical_indices_count,
|
||||
slash_indices_count, context_size, block_size_M, block_size_N, causal)
|
||||
return block_count, block_offset, column_count, column_index
|
||||
|
||||
|
||||
# pos encoding ops
|
||||
def rotary_embedding(
|
||||
positions: torch.Tensor,
|
||||
|
||||
1494
vllm/attention/backends/dual_chunk_flash_attn.py
Normal file
1494
vllm/attention/backends/dual_chunk_flash_attn.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -929,6 +929,23 @@ class ModelConfig:
|
||||
"Number of experts in the model must be greater than 0 "
|
||||
"when expert parallelism is enabled.")
|
||||
|
||||
def verify_dual_chunk_attention_config(
|
||||
self,
|
||||
load_config: "LoadConfig",
|
||||
) -> None:
|
||||
if hasattr(self.hf_config, "dual_chunk_attention_config"):
|
||||
# Try loading the sparse attention config
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
get_sparse_attention_config)
|
||||
sparse_attn_config = get_sparse_attention_config(self, load_config)
|
||||
if sparse_attn_config:
|
||||
self.hf_config.dual_chunk_attention_config[
|
||||
"sparse_attention_config"] = sparse_attn_config
|
||||
if "sparse_attention_enabled" not in \
|
||||
self.hf_config.dual_chunk_attention_config:
|
||||
self.hf_config.dual_chunk_attention_config[
|
||||
"sparse_attention_enabled"] = True
|
||||
|
||||
def verify_async_output_proc(self, parallel_config, speculative_config,
|
||||
device_config) -> None:
|
||||
if not self.use_async_output_proc:
|
||||
@ -4187,6 +4204,8 @@ class VllmConfig:
|
||||
self.speculative_config,
|
||||
self.device_config)
|
||||
self.model_config.verify_with_parallel_config(self.parallel_config)
|
||||
self.model_config.verify_dual_chunk_attention_config(
|
||||
self.load_config)
|
||||
|
||||
if self.cache_config is not None:
|
||||
self.cache_config.verify_with_parallel_config(self.parallel_config)
|
||||
|
||||
@ -37,8 +37,8 @@ from vllm.reasoning import ReasoningParserManager
|
||||
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
|
||||
from vllm.transformers_utils.utils import check_gguf_file
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import (FlexibleArgumentParser, GiB_bytes, is_in_doc_build,
|
||||
is_in_ray_actor)
|
||||
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
|
||||
GiB_bytes, is_in_doc_build, is_in_ray_actor)
|
||||
|
||||
# yapf: enable
|
||||
|
||||
@ -983,6 +983,17 @@ class EngineArgs:
|
||||
|
||||
assert self.enable_chunked_prefill is not None
|
||||
|
||||
if envs.VLLM_ATTENTION_BACKEND in [STR_DUAL_CHUNK_FLASH_ATTN_VAL]:
|
||||
assert self.enforce_eager, (
|
||||
"Cuda graph is not supported with DualChunkFlashAttention. "
|
||||
"To run the model in eager mode, set 'enforce_eager=True' "
|
||||
"or use '--enforce-eager' in the CLI.")
|
||||
assert current_platform.is_cuda(), (
|
||||
"DualChunkFlashAttention is only supported on CUDA platform.")
|
||||
assert not use_v1, (
|
||||
"DualChunkFlashAttention is not supported on V1 engine. "
|
||||
"To run the model in V0 engine, try set 'VLLM_USE_V1=0'")
|
||||
|
||||
cache_config = CacheConfig(
|
||||
block_size=self.block_size,
|
||||
gpu_memory_utilization=self.gpu_memory_utilization,
|
||||
|
||||
@ -1486,6 +1486,184 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
return updates
|
||||
|
||||
|
||||
@CustomOp.register("dual_chunk_rotary_embedding")
|
||||
class DualChunkRotaryEmbedding(CustomOp):
|
||||
"""Rotary positional embedding for Dual Chunk Attention."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
chunk_size: int,
|
||||
local_size: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.head_size = head_size
|
||||
self.rotary_dim = rotary_dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
self.is_neox_style = is_neox_style
|
||||
self.chunk_size = chunk_size
|
||||
self.local_size = local_size
|
||||
self.dtype = dtype
|
||||
self.device = torch.device(f"cuda:{torch.cuda.current_device()}")
|
||||
(q_cache, qc_cache, k_cache, qc_no_clamp_cache,
|
||||
q_inter_cache) = self._compute_cos_sin_cache()
|
||||
|
||||
self.register_buffer("cos_sin_q_cache", q_cache, persistent=False)
|
||||
self.register_buffer("cos_sin_qc_cache", qc_cache, persistent=False)
|
||||
self.register_buffer("cos_sin_k_cache", k_cache, persistent=False)
|
||||
self.register_buffer("cos_sin_qc_no_clamp_cache",
|
||||
qc_no_clamp_cache,
|
||||
persistent=False)
|
||||
self.register_buffer("cos_sin_q_inter_cache",
|
||||
q_inter_cache,
|
||||
persistent=False)
|
||||
|
||||
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
||||
"""Compute the inverse frequency."""
|
||||
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
|
||||
# However, we use `torch.arange(..., dtype=torch.float)` instead to
|
||||
# avoid numerical issues with large base values (e.g., 10000000).
|
||||
# This may cause a slight numerical difference between the HF
|
||||
# implementation and ours.
|
||||
# NOTE(woosuk): To exactly match the HF implementation, we need to
|
||||
# use CPU to compute the cache and then move it to GPU. However, we
|
||||
# create the cache on GPU for faster initialization. This may cause
|
||||
# a slight numerical difference between the HF implementation and ours.
|
||||
inv_freq = 1.0 / (base**(torch.arange(
|
||||
0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
|
||||
return inv_freq
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
"""Compute the cos and sin cache."""
|
||||
inv_freq = self._compute_inv_freq(self.base)
|
||||
chunk_len = self.chunk_size - self.local_size
|
||||
q_t = torch.arange(chunk_len, dtype=torch.float)
|
||||
qc_t = (torch.arange(chunk_len, dtype=torch.float) +
|
||||
chunk_len).clamp(max=self.chunk_size)
|
||||
k_t = torch.arange(self.max_position_embeddings,
|
||||
dtype=torch.float) % chunk_len
|
||||
|
||||
# count from chunk_len, no clamp(self.chunk_size) restriction
|
||||
qc_no_clamp_t = torch.arange(chunk_len, dtype=torch.float) + chunk_len
|
||||
# count from self.chunk_size for q_inter's rope
|
||||
q_inter_t = torch.arange(chunk_len,
|
||||
dtype=torch.float) + self.chunk_size
|
||||
|
||||
q_freqs = torch.outer(q_t, inv_freq)
|
||||
qc_freqs = torch.outer(qc_t, inv_freq)
|
||||
k_freqs = torch.outer(k_t, inv_freq)
|
||||
qc_no_clamp_freqs = torch.outer(qc_no_clamp_t, inv_freq)
|
||||
q_inter_freqs = torch.outer(q_inter_t, inv_freq)
|
||||
|
||||
q_cos = q_freqs.cos()
|
||||
q_sin = q_freqs.sin()
|
||||
qc_cos = qc_freqs.cos()
|
||||
qc_sin = qc_freqs.sin()
|
||||
k_cos = k_freqs.cos()
|
||||
k_sin = k_freqs.sin()
|
||||
|
||||
qc_no_clamp_cos = qc_no_clamp_freqs.cos()
|
||||
qc_no_clamp_sin = qc_no_clamp_freqs.sin()
|
||||
q_inter_cos = q_inter_freqs.cos()
|
||||
q_inter_sin = q_inter_freqs.sin()
|
||||
|
||||
q_cache = torch.cat((q_cos, q_sin), dim=-1).to(dtype=self.dtype,
|
||||
device=self.device)
|
||||
qc_cache = torch.cat((qc_cos, qc_sin), dim=-1).to(dtype=self.dtype,
|
||||
device=self.device)
|
||||
k_cache = torch.cat((k_cos, k_sin), dim=-1).to(dtype=self.dtype,
|
||||
device=self.device)
|
||||
qc_no_clamp_cache = torch.cat((qc_no_clamp_cos, qc_no_clamp_sin),
|
||||
dim=-1).to(dtype=self.dtype,
|
||||
device=self.device)
|
||||
q_inter_cache = torch.cat((q_inter_cos, q_inter_sin),
|
||||
dim=-1).to(dtype=self.dtype,
|
||||
device=self.device)
|
||||
return q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
query = query.view(*query.shape[:-1], -1, self.head_size)
|
||||
key = key.view(*key.shape[:-1], -1, self.head_size)
|
||||
query_rot = query[..., :self.rotary_dim]
|
||||
key_rot = key[..., :self.rotary_dim]
|
||||
if self.rotary_dim < self.head_size:
|
||||
query_pass = query[..., self.rotary_dim:]
|
||||
key_pass = key[..., self.rotary_dim:]
|
||||
else:
|
||||
query_pass = None
|
||||
key_pass = None
|
||||
|
||||
positions_with_offsets = (torch.add(positions, offsets)
|
||||
if offsets is not None else positions)
|
||||
key = self._apply_rotary_embedding(
|
||||
self.cos_sin_k_cache[positions_with_offsets], key_rot, key_pass)
|
||||
chunk_len = self.chunk_size - self.local_size
|
||||
query = self._apply_rotary_embedding(
|
||||
self.cos_sin_q_cache[positions_with_offsets % chunk_len],
|
||||
query_rot, query_pass)
|
||||
query_succ = self._apply_rotary_embedding(
|
||||
self.cos_sin_qc_cache[positions_with_offsets % chunk_len],
|
||||
query_rot, query_pass)
|
||||
query_inter = self._apply_rotary_embedding(
|
||||
self.cos_sin_qc_cache[chunk_len - 1].repeat(positions.shape[0], 1),
|
||||
query_rot, query_pass)
|
||||
query_succ_critical = self._apply_rotary_embedding(
|
||||
self.cos_sin_qc_no_clamp_cache[positions_with_offsets % chunk_len],
|
||||
query_rot, query_pass)
|
||||
query_inter_critical = self._apply_rotary_embedding(
|
||||
self.cos_sin_q_inter_cache[positions_with_offsets % chunk_len],
|
||||
query_rot, query_pass)
|
||||
|
||||
# merge query into one tensor to simplify the interfaces
|
||||
query = torch.cat((
|
||||
query,
|
||||
query_succ,
|
||||
query_inter,
|
||||
query_succ_critical,
|
||||
query_inter_critical,
|
||||
),
|
||||
dim=-1)
|
||||
return query, key
|
||||
|
||||
def _apply_rotary_embedding(self, cos_sin, hidden_rot, hidden_pass):
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
if self.is_neox_style:
|
||||
# NOTE(woosuk): Here we assume that the positions tensor has the
|
||||
# shape [batch_size, seq_len].
|
||||
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
|
||||
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
|
||||
else:
|
||||
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
||||
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
||||
rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
|
||||
hidden_rot = hidden_rot * cos + rotate_fn(hidden_rot) * sin
|
||||
|
||||
if self.rotary_dim < self.head_size:
|
||||
hidden = torch.cat((hidden_rot, hidden_pass), dim=-1)
|
||||
else:
|
||||
hidden = hidden_rot
|
||||
return hidden.flatten(-2).squeeze(0)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
|
||||
s += f", max_position_embeddings={self.max_position_embeddings}"
|
||||
s += f", base={self.base}, is_neox_style={self.is_neox_style}"
|
||||
s += f", chunk_size={self.chunk_size}, local_size={self.local_size}"
|
||||
return s
|
||||
|
||||
|
||||
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
|
||||
|
||||
|
||||
@ -1498,6 +1676,7 @@ def get_rope(
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
partial_rotary_factor: float = 1.0,
|
||||
dual_chunk_attention_config: Optional[Dict[str, Any]] = None,
|
||||
) -> RotaryEmbedding:
|
||||
if dtype is None:
|
||||
dtype = torch.get_default_dtype()
|
||||
@ -1510,14 +1689,35 @@ def get_rope(
|
||||
rope_scaling_args = tuple(rope_scaling_tuple.items())
|
||||
else:
|
||||
rope_scaling_args = None
|
||||
|
||||
if dual_chunk_attention_config is not None:
|
||||
dual_chunk_attention_tuple = {
|
||||
k: tuple(v) if isinstance(v, list) else v
|
||||
for k, v in dual_chunk_attention_config.items()
|
||||
if k != "sparse_attention_config"
|
||||
}
|
||||
dual_chunk_attention_args = tuple(dual_chunk_attention_tuple.items())
|
||||
else:
|
||||
dual_chunk_attention_args = None
|
||||
|
||||
if partial_rotary_factor < 1.0:
|
||||
rotary_dim = int(rotary_dim * partial_rotary_factor)
|
||||
key = (head_size, rotary_dim, max_position, base, is_neox_style,
|
||||
rope_scaling_args, dtype)
|
||||
rope_scaling_args, dual_chunk_attention_args, dtype)
|
||||
if key in _ROPE_DICT:
|
||||
return _ROPE_DICT[key]
|
||||
|
||||
if not rope_scaling:
|
||||
if dual_chunk_attention_config is not None:
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in dual_chunk_attention_config.items()
|
||||
if k in ("chunk_size", "local_size")
|
||||
}
|
||||
rotary_emb = DualChunkRotaryEmbedding(head_size, rotary_dim,
|
||||
max_position, base,
|
||||
is_neox_style, dtype,
|
||||
**extra_kwargs)
|
||||
elif not rope_scaling:
|
||||
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
|
||||
is_neox_style, dtype)
|
||||
else:
|
||||
|
||||
@ -217,6 +217,39 @@ def get_quant_config(model_config: ModelConfig,
|
||||
return quant_cls.from_config(config)
|
||||
|
||||
|
||||
def get_sparse_attention_config(
|
||||
model_config: ModelConfig,
|
||||
load_config: LoadConfig,
|
||||
sparse_attention_config_filename: str = "sparse_attention_config.json",
|
||||
) -> Dict[str, Any]:
|
||||
model_name_or_path = model_config.model
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
if not is_local:
|
||||
# Download the config files.
|
||||
with get_lock(model_name_or_path, load_config.download_dir):
|
||||
hf_folder = snapshot_download(
|
||||
model_name_or_path,
|
||||
revision=model_config.revision,
|
||||
allow_patterns="*.json",
|
||||
cache_dir=load_config.download_dir,
|
||||
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
||||
tqdm_class=DisabledTqdm,
|
||||
)
|
||||
else:
|
||||
hf_folder = model_name_or_path
|
||||
|
||||
config_file = os.path.join(hf_folder, sparse_attention_config_filename)
|
||||
if not os.path.exists(config_file):
|
||||
return {}
|
||||
|
||||
# Load the sparse attention config.
|
||||
with open(config_file) as f:
|
||||
config = json.load(f)
|
||||
logger.info("Loaded sparse attention config from %s", config_file)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def download_weights_from_hf(
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str],
|
||||
|
||||
@ -23,7 +23,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
|
||||
from typing import Iterable, Optional, Set, Tuple, Union
|
||||
from typing import Any, Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -53,7 +53,7 @@ from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
|
||||
is_pp_missing_parameter,
|
||||
extract_layer_index, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
@ -99,17 +99,20 @@ class Qwen2MLP(nn.Module):
|
||||
|
||||
class Qwen2Attention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
max_position: int = 4096 * 32,
|
||||
rope_theta: float = 10000,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
rope_scaling: Optional[Tuple] = None,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
max_position: int = 4096 * 32,
|
||||
rope_theta: float = 10000,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
rope_scaling: Optional[Tuple] = None,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
dual_chunk_attention_config: Optional[dict[str,
|
||||
Any]] = None) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
@ -131,6 +134,7 @@ class Qwen2Attention(nn.Module):
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.dual_chunk_attention_config = dual_chunk_attention_config
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
@ -155,15 +159,21 @@ class Qwen2Attention(nn.Module):
|
||||
max_position=max_position,
|
||||
base=self.rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
dual_chunk_attention_config=dual_chunk_attention_config,
|
||||
)
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
attn_type=attn_type)
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
attn_type=attn_type,
|
||||
prefix=f"{prefix}.attn",
|
||||
**{
|
||||
"layer_idx": extract_layer_index(prefix),
|
||||
"dual_chunk_attention_config": dual_chunk_attention_config,
|
||||
} if dual_chunk_attention_config else {})
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -192,6 +202,9 @@ class Qwen2DecoderLayer(nn.Module):
|
||||
# Requires transformers > 4.32.0
|
||||
rope_theta = getattr(config, "rope_theta", 1000000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
dual_chunk_attention_config = getattr(config,
|
||||
"dual_chunk_attention_config",
|
||||
None)
|
||||
|
||||
# By default, Qwen2 uses causal attention as it is a decoder-only model.
|
||||
# You can override the HF config with `is_causal=False` to enable
|
||||
@ -213,6 +226,7 @@ class Qwen2DecoderLayer(nn.Module):
|
||||
rope_scaling=rope_scaling,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
attn_type=attn_type,
|
||||
dual_chunk_attention_config=dual_chunk_attention_config,
|
||||
)
|
||||
self.mlp = Qwen2MLP(
|
||||
hidden_size=self.hidden_size,
|
||||
|
||||
@ -175,6 +175,7 @@ class Qwen2MoeAttention(nn.Module):
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
dual_chunk_attention_config: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -198,6 +199,7 @@ class Qwen2MoeAttention(nn.Module):
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.dual_chunk_attention_config = dual_chunk_attention_config
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
@ -221,14 +223,20 @@ class Qwen2MoeAttention(nn.Module):
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
dual_chunk_attention_config=dual_chunk_attention_config,
|
||||
)
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
**{
|
||||
"layer_idx": extract_layer_index(prefix),
|
||||
"dual_chunk_attention_config": dual_chunk_attention_config,
|
||||
} if dual_chunk_attention_config else {})
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -256,6 +264,9 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
dual_chunk_attention_config = getattr(config,
|
||||
"dual_chunk_attention_config",
|
||||
None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
self.self_attn = Qwen2MoeAttention(
|
||||
@ -268,6 +279,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
dual_chunk_attention_config=dual_chunk_attention_config,
|
||||
)
|
||||
|
||||
# Note: Qwen/Qwen2-57B-A14B-Instruct does not have
|
||||
|
||||
@ -222,6 +222,10 @@ class CudaPlatformBase(Platform):
|
||||
elif selected_backend == _Backend.XFORMERS:
|
||||
logger.info("Using XFormers backend.")
|
||||
return "vllm.attention.backends.xformers.XFormersBackend"
|
||||
elif selected_backend == _Backend.DUAL_CHUNK_FLASH_ATTN:
|
||||
logger.info("Using DualChunkFlashAttention backend.")
|
||||
return ("vllm.attention.backends.dual_chunk_flash_attn."
|
||||
"DualChunkFlashAttentionBackend")
|
||||
elif selected_backend == _Backend.FLASH_ATTN:
|
||||
pass
|
||||
elif selected_backend:
|
||||
|
||||
@ -51,6 +51,7 @@ class _Backend(enum.Enum):
|
||||
PALLAS_VLLM_V1 = enum.auto()
|
||||
IPEX = enum.auto()
|
||||
BLOCK_SPARSE_FLASH_ATTN = enum.auto()
|
||||
DUAL_CHUNK_FLASH_ATTN = enum.auto()
|
||||
NO_ATTENTION = enum.auto()
|
||||
|
||||
|
||||
|
||||
@ -153,6 +153,7 @@ STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA"
|
||||
STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH"
|
||||
STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
|
||||
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
|
||||
STR_DUAL_CHUNK_FLASH_ATTN_VAL: str = "DUAL_CHUNK_FLASH_ATTN"
|
||||
STR_INVALID_VAL: str = "INVALID"
|
||||
|
||||
GB_bytes = 1_000_000_000
|
||||
|
||||
@ -204,6 +204,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
self.mrope_input_positions = None # type: ignore
|
||||
self.seq_lens[0] = 0 # type: ignore
|
||||
self.orig_seq_lens[0] = 0 # type: ignore
|
||||
self.prompt_lens[0] = 0 # type: ignore
|
||||
self.query_lens[0] = 0 # type: ignore
|
||||
self.context_lens[0] = 0 # type: ignore
|
||||
self.curr_sliding_window_blocks[0] = 0 # type: ignore
|
||||
@ -236,6 +237,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
# The original sequence length (before applying sliding window).
|
||||
# This is used to compute slot mapping.
|
||||
orig_seq_lens: Optional[List[int]] = None,
|
||||
# This is used in the dual-chunk flash attention backend.
|
||||
prompt_lens: Optional[List[int]] = None,
|
||||
# The query length.
|
||||
query_lens: Optional[List[int]] = None,
|
||||
# The number of tokens that are already computed.
|
||||
@ -316,6 +319,12 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
for seq_id in range(len(self.seq_ids)):
|
||||
self.orig_seq_lens[seq_id] = 0
|
||||
|
||||
if prompt_lens:
|
||||
self.prompt_lens = prompt_lens
|
||||
else:
|
||||
for seq_id in range(len(self.seq_ids)):
|
||||
self.prompt_lens[seq_id] = 0
|
||||
|
||||
if query_lens:
|
||||
self.query_lens = query_lens
|
||||
else:
|
||||
@ -370,6 +379,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
self.mrope_input_positions = mrope_input_positions or None
|
||||
self.seq_lens = seq_lens or []
|
||||
self.orig_seq_lens = orig_seq_lens or []
|
||||
self.prompt_lens = prompt_lens or []
|
||||
self.query_lens = query_lens or []
|
||||
self.context_lens = context_lens or []
|
||||
self.curr_sliding_window_blocks = \
|
||||
@ -403,6 +413,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
self.mrope_input_positions = None
|
||||
self.seq_lens = [0] * self.n_seqs
|
||||
self.orig_seq_lens = [0] * self.n_seqs
|
||||
self.prompt_lens = [0] * self.n_seqs
|
||||
self.query_lens = [0] * self.n_seqs
|
||||
self.context_lens = [0] * self.n_seqs
|
||||
self.curr_sliding_window_blocks = [0] * self.n_seqs
|
||||
@ -552,6 +563,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
|
||||
inter_data.seq_lens[seq_idx] = seq_len
|
||||
inter_data.orig_seq_lens[seq_idx] = seq_len
|
||||
inter_data.prompt_lens[seq_idx] = seq_data.get_prompt_len()
|
||||
inter_data.context_lens[seq_idx] = context_len
|
||||
inter_data.input_tokens[seq_idx].extend(tokens)
|
||||
inter_data.inputs_embeds = prompt_embeds
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user