Implements dual-chunk-flash-attn backend for dual chunk attention with sparse attention support (#11844)

This commit is contained in:
Tao He 2025-05-13 10:52:47 +08:00 committed by GitHub
parent f6518b2b48
commit 60f7624334
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 2444 additions and 32 deletions

View File

@ -230,6 +230,7 @@ set(VLLM_EXT_SRC
"csrc/attention/paged_attention_v1.cu" "csrc/attention/paged_attention_v1.cu"
"csrc/attention/paged_attention_v2.cu" "csrc/attention/paged_attention_v2.cu"
"csrc/attention/merge_attn_states.cu" "csrc/attention/merge_attn_states.cu"
"csrc/attention/vertical_slash_index.cu"
"csrc/pos_encoding_kernels.cu" "csrc/pos_encoding_kernels.cu"
"csrc/activation_kernels.cu" "csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu" "csrc/layernorm_kernels.cu"

View File

@ -0,0 +1,401 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#include <assert.h>
#include <cuda.h>
#include <torch/all.h>
__device__ int64_t save_blocks(int* block_offset, int64_t range_start,
int64_t range_end, int64_t block_size,
int64_t input_block_count, int64_t kv_seqlen) {
if (range_start >= kv_seqlen) {
return input_block_count;
}
if (range_end > kv_seqlen) {
range_end = kv_seqlen;
}
int64_t current_block_count = input_block_count;
for (int idx = range_start; idx < range_end; idx += block_size) {
block_offset[current_block_count++] = idx;
}
return current_block_count;
}
__global__ void convert_vertical_slash_indexes_kernel(
const int* q_seqlens, // [BATCH, ]
const int* kv_seqlens, // [BATCH, ]
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
int64_t NNZ_V, int64_t NNZ_S,
bool causal // True for intra, False for succ
) {
const int batch_idx = blockIdx.y;
const int head_idx = blockIdx.x;
const int group_idx = blockIdx.z;
int64_t q_seqlen = q_seqlens[batch_idx];
int64_t kv_seqlen = kv_seqlens[batch_idx];
int64_t block_idx_m = group_idx * blockDim.x + threadIdx.x;
int64_t start_m = block_idx_m * BLOCK_SIZE_M;
if (start_m >= q_seqlen) {
return;
}
int64_t end_m = start_m + BLOCK_SIZE_M;
vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;
int64_t row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
block_count += row_offset;
block_offset += row_offset * NNZ_S;
column_count += row_offset;
column_index += row_offset * NNZ_V;
bool has_slash = true;
int64_t tmp_col_cnt = 0, tmp_blk_cnt = 0;
int64_t s = 0, v = 0;
int64_t v_idx = vertical_indexes[v++];
int64_t s_idx = slash_indexes[s++];
if (causal) {
while (s_idx >= end_m + (kv_seqlen - q_seqlen) && s < NNZ_S) {
s_idx = slash_indexes[s++];
}
if (s_idx > end_m + (kv_seqlen - q_seqlen)) has_slash = false;
s_idx = max((kv_seqlen - q_seqlen) + end_m - s_idx, BLOCK_SIZE_M);
} else {
while (s_idx >= end_m + kv_seqlen && s < NNZ_S) {
s_idx = slash_indexes[s++];
}
if (s_idx > end_m + kv_seqlen) has_slash = false;
s_idx = max(kv_seqlen + end_m - s_idx, BLOCK_SIZE_M);
}
int64_t range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx;
if (!has_slash) {
if (causal) {
range_start = (kv_seqlen - q_seqlen) + end_m;
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
} else {
range_start = kv_seqlen;
range_end = kv_seqlen + BLOCK_SIZE_N;
}
}
bool slash_finished = false;
while (1) {
if (v_idx < range_end) {
if (v_idx < range_start) {
column_index[tmp_col_cnt++] = v_idx;
}
if (v < NNZ_V) {
v_idx = vertical_indexes[v++];
} else {
if (causal)
v_idx = end_m + BLOCK_SIZE_N + (kv_seqlen - q_seqlen);
else
v_idx = end_m + BLOCK_SIZE_N + kv_seqlen;
}
} else {
if ((s < NNZ_S && causal) ||
(s < NNZ_S && !causal && slash_indexes[s] >= start_m)) {
if (causal)
s_idx = max((kv_seqlen - q_seqlen) + end_m - slash_indexes[s++],
BLOCK_SIZE_M);
else
s_idx = max(kv_seqlen + end_m - slash_indexes[s++], BLOCK_SIZE_M);
} else {
if (v == NNZ_V || (v_idx > range_start && causal)) {
// add the last vertical if no more slash
if (v == NNZ_V && !causal && v_idx < kv_seqlen) {
column_index[tmp_col_cnt++] = v_idx;
}
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
break;
} else {
if (causal) {
range_start = (kv_seqlen - q_seqlen) + end_m;
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
} else {
// if slash_finished but there are vertical left, save current
// blocks
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
range_start = kv_seqlen;
range_end = kv_seqlen + BLOCK_SIZE_N;
}
slash_finished = true;
}
}
if (!slash_finished) {
if (s_idx > range_end + BLOCK_SIZE_M) {
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
range_start = s_idx - BLOCK_SIZE_M;
range_end = s_idx;
} else if (s_idx > range_end) {
range_end += BLOCK_SIZE_M;
}
}
}
}
block_count[0] = tmp_blk_cnt;
column_count[0] = tmp_col_cnt;
}
void convert_vertical_slash_indexes_64x64(
const int* q_seqlens, // [BATCH, ]
const int* kv_seqlens, // [BATCH, ]
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
int64_t BATCH_SIZE, int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M,
int64_t BLOCK_SIZE_N, int64_t NNZ_V, int64_t NNZ_S, bool causal) {
const int N_THREADS = 64;
const dim3 dimBlock(N_THREADS);
const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS);
convert_vertical_slash_indexes_kernel<<<dimGrid, dimBlock>>>(
q_seqlens, kv_seqlens, vertical_indexes, slash_indexes, block_count,
block_offset, column_count, column_index, N_HEADS, N_ROWS, BLOCK_SIZE_M,
BLOCK_SIZE_N, NNZ_V, NNZ_S, causal);
}
/**
* Implements the Algorithm 4 in paper https://arxiv.org/abs/2407.02490.
*
* This function builds the index of each row of blocks from vertical indices
* and slash indices. The vertical indices are treated as points, while the
* slash indices are converted as ranges. The output consists of the merged
* ranges and separate column indices, where the ranges are represented by
* block indices.
*
* The implementation is referenced from the original MInference repo:
* https://github.com/microsoft/MInference/blob/main/csrc/vertical_slash_index.cu.
*/
void convert_vertical_slash_indexes(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
torch::Tensor q_seqlens, // [BATCH, ]
torch::Tensor kv_seqlens, // [BATCH, ]
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int64_t context_size, int64_t block_size_M, int64_t block_size_N,
bool causal) {
cudaSetDevice(q_seqlens.get_device());
int batch_size = slash_indexes.size(0);
int num_heads = slash_indexes.size(1);
int nnz_slash = slash_indexes.size(2);
int nnz_vertical = vertical_indexes.size(2);
int num_rows = (context_size + block_size_M - 1) / block_size_M;
convert_vertical_slash_indexes_64x64(
q_seqlens.data_ptr<int>(), kv_seqlens.data_ptr<int>(),
vertical_indexes.data_ptr<int>(), slash_indexes.data_ptr<int>(),
block_count.data_ptr<int>(), block_offset.data_ptr<int>(),
column_count.data_ptr<int>(), column_index.data_ptr<int>(), batch_size,
num_heads, num_rows, block_size_M, block_size_N, nnz_vertical, nnz_slash,
causal);
}
__global__ void convert_vertical_slash_indexes_kernel_mergehead(
const int* q_seqlens, // [BATCH, ]
const int* kv_seqlens, // [BATCH, ]
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
const int* per_head_vertical_topkv, const int* per_head_slash_topkv,
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
int64_t NNZ_V, int64_t NNZ_S,
bool causal // True for intra, False for succ
) {
const int batch_idx = blockIdx.y;
const int head_idx = blockIdx.x;
const int group_idx = blockIdx.z;
int64_t q_seqlen = q_seqlens[batch_idx];
int64_t kv_seqlen = kv_seqlens[batch_idx];
int64_t block_idx_m = group_idx * blockDim.x + threadIdx.x;
int64_t start_m = block_idx_m * BLOCK_SIZE_M;
if (start_m >= q_seqlen) {
return;
}
int64_t end_m = start_m + BLOCK_SIZE_M;
vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;
int64_t row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
block_count += row_offset;
block_offset += row_offset * NNZ_S;
column_count += row_offset;
column_index += row_offset * NNZ_V;
// MergeHead: each head has it's unique max topk NNZ_VNNZ_S. (NNZ_VNNZ_S
// above is buffer size, use to compute offset)
NNZ_S = per_head_slash_topkv[head_idx];
NNZ_V = per_head_vertical_topkv[head_idx];
bool has_slash = true;
int64_t tmp_col_cnt = 0, tmp_blk_cnt = 0;
int64_t s = 0, v = 0;
int64_t v_idx = vertical_indexes[v++];
int64_t s_idx = slash_indexes[s++];
if (causal) {
while (s_idx >= end_m + (kv_seqlen - q_seqlen) && s < NNZ_S) {
s_idx = slash_indexes[s++];
}
if (s_idx > end_m + (kv_seqlen - q_seqlen)) has_slash = false;
s_idx = max((kv_seqlen - q_seqlen) + end_m - s_idx, BLOCK_SIZE_M);
} else {
while (s_idx >= end_m + kv_seqlen && s < NNZ_S) {
s_idx = slash_indexes[s++];
}
if (s_idx > end_m + kv_seqlen) has_slash = false;
s_idx = max(kv_seqlen + end_m - s_idx, BLOCK_SIZE_M);
}
int64_t range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx;
if (!has_slash) {
if (causal) {
range_start = (kv_seqlen - q_seqlen) + end_m;
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
} else {
range_start = kv_seqlen;
range_end = kv_seqlen + BLOCK_SIZE_N;
}
}
bool slash_finished = false;
while (1) {
if (v_idx < range_end) {
if (v_idx < range_start) {
column_index[tmp_col_cnt++] = v_idx;
}
if (v < NNZ_V) {
v_idx = vertical_indexes[v++];
} else {
if (causal)
v_idx = end_m + BLOCK_SIZE_N + (kv_seqlen - q_seqlen);
else
v_idx = end_m + BLOCK_SIZE_N + kv_seqlen;
}
} else {
if ((s < NNZ_S && causal) ||
(s < NNZ_S && !causal && slash_indexes[s] >= start_m)) {
if (causal)
s_idx = max((kv_seqlen - q_seqlen) + end_m - slash_indexes[s++],
BLOCK_SIZE_M);
else
s_idx = max(kv_seqlen + end_m - slash_indexes[s++], BLOCK_SIZE_M);
} else {
if (v == NNZ_V || (v_idx > range_start && causal)) {
// add the last vertical if no more slash
if (v == NNZ_V && !causal && v_idx < kv_seqlen) {
column_index[tmp_col_cnt++] = v_idx;
}
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
break;
} else {
if (causal) {
range_start = (kv_seqlen - q_seqlen) + end_m;
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
} else {
// if slash_finished but there are vertical left, save current
// blocks
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
range_start = kv_seqlen;
range_end = kv_seqlen + BLOCK_SIZE_N;
}
slash_finished = true;
}
}
if (!slash_finished) {
if (s_idx > range_end + BLOCK_SIZE_M) {
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
range_start = s_idx - BLOCK_SIZE_M;
range_end = s_idx;
} else if (s_idx > range_end) {
range_end += BLOCK_SIZE_M;
}
}
}
}
block_count[0] = tmp_blk_cnt;
column_count[0] = tmp_col_cnt;
}
void convert_vertical_slash_indexes_64x64_mergehead(
const int* q_seqlens, // [BATCH, ]
const int* kv_seqlens, // [BATCH, ]
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int* per_head_vertical_topkv, int* per_head_slash_topkv,
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
int64_t BATCH_SIZE, int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M,
int64_t BLOCK_SIZE_N, int64_t NNZ_V, int64_t NNZ_S, bool causal) {
const int N_THREADS = 64;
const dim3 dimBlock(N_THREADS);
const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS);
convert_vertical_slash_indexes_kernel_mergehead<<<dimGrid, dimBlock>>>(
q_seqlens, kv_seqlens, vertical_indexes, slash_indexes,
per_head_vertical_topkv, per_head_slash_topkv, block_count, block_offset,
column_count, column_index, N_HEADS, N_ROWS, BLOCK_SIZE_M, BLOCK_SIZE_N,
NNZ_V, NNZ_S, causal);
}
/**
* Implements the Algorithm 4 in paper https://arxiv.org/abs/2407.02490.
*
* Like the above convert_vertical_slash_indexes, but with
* pre-computed vertical and slash counts.
*/
void convert_vertical_slash_indexes_mergehead(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
torch::Tensor q_seqlens, // [BATCH, ]
torch::Tensor kv_seqlens, // [BATCH, ]
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
torch::Tensor vertical_indices_count, // [N_HEADS, ]
torch::Tensor slash_indices_count, // [N_HEADS, ]
int64_t context_size, int64_t block_size_M, int64_t block_size_N,
bool causal) {
cudaSetDevice(q_seqlens.get_device());
int batch_size = slash_indexes.size(0);
int num_heads = slash_indexes.size(1);
int nnz_slash = slash_indexes.size(2);
int nnz_vertical = vertical_indexes.size(2);
int num_rows = (context_size + block_size_M - 1) / block_size_M;
convert_vertical_slash_indexes_64x64_mergehead(
q_seqlens.data_ptr<int>(), kv_seqlens.data_ptr<int>(),
vertical_indexes.data_ptr<int>(), slash_indexes.data_ptr<int>(),
vertical_indices_count.data_ptr<int>(),
slash_indices_count.data_ptr<int>(), block_count.data_ptr<int>(),
block_offset.data_ptr<int>(), column_count.data_ptr<int>(),
column_index.data_ptr<int>(), batch_size, num_heads, num_rows,
block_size_M, block_size_N, nnz_vertical, nnz_slash, causal);
}

View File

@ -59,6 +59,31 @@ void merge_attn_states(torch::Tensor& output,
const torch::Tensor& prefix_lse, const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output, const torch::Tensor& suffix_output,
const torch::Tensor& suffix_lse); const torch::Tensor& suffix_lse);
void convert_vertical_slash_indexes(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
torch::Tensor q_seqlens, // [BATCH, ]
torch::Tensor kv_seqlens, // [BATCH, ]
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int64_t context_size, int64_t block_size_M, int64_t block_size_N,
bool causal);
void convert_vertical_slash_indexes_mergehead(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
torch::Tensor q_seqlens, // [BATCH, ]
torch::Tensor kv_seqlens, // [BATCH, ]
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
torch::Tensor vertical_indices_count, // [N_HEADS, ]
torch::Tensor slash_indices_count, int64_t context_size,
int64_t block_size_M, int64_t block_size_N, bool causal);
#endif #endif
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,

View File

@ -77,6 +77,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor suffix_output," " Tensor suffix_output,"
" Tensor suffix_lse) -> ()"); " Tensor suffix_lse) -> ()");
ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states); ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);
ops.def(
"convert_vertical_slash_indexes("
" Tensor! block_count, Tensor! block_offset, "
" Tensor! column_count, Tensor! column_index, "
" Tensor q_seqlens, Tensor q_seqlens, "
" Tensor vertical_indexes, Tensor slash_indexes, "
" int context_size, int block_size_M, int block_size_N, "
" bool causal) -> ()");
ops.impl("convert_vertical_slash_indexes", torch::kCUDA,
&convert_vertical_slash_indexes);
ops.def(
"convert_vertical_slash_indexes_mergehead("
" Tensor! block_count, Tensor! block_offset, "
" Tensor! column_count, Tensor! column_index, "
" Tensor q_seqlens, Tensor q_seqlens, "
" Tensor vertical_indexes, Tensor slash_indexes, "
" Tensor vertical_indices_count, Tensor slash_indices_count, "
" int context_size, int block_size_M, int block_size_N, "
" bool causal) -> ()");
ops.impl("convert_vertical_slash_indexes_mergehead", torch::kCUDA,
&convert_vertical_slash_indexes_mergehead);
#endif #endif
// Activation ops // Activation ops

View File

@ -0,0 +1,66 @@
# SPDX-License-Identifier: Apache-2.0
import os
from urllib.request import urlopen
from vllm import LLM, SamplingParams
os.environ["VLLM_ATTENTION_BACKEND"] = "DUAL_CHUNK_FLASH_ATTN"
os.environ["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = "1"
def load_prompt() -> str:
# Test cases with various lengths can be found at:
#
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/64k.txt
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/200k.txt
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/600k.txt
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/1m.txt
with urlopen(
"https://qianwen-res.oss-cn-beijing.aliyuncs.com"
"/Qwen2.5-1M/test-data/600k.txt",
timeout=5) as response:
prompt = response.read().decode('utf-8')
return prompt
# Processing the prompt.
def process_requests(llm: LLM, prompts: list[str]) -> None:
# Create a sampling params object.
sampling_params = SamplingParams(
temperature=0.7,
top_p=0.8,
top_k=20,
repetition_penalty=1.05,
detokenize=True,
max_tokens=256,
)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt_token_ids = output.prompt_token_ids
generated_text = output.outputs[0].text
print(f"Prompt length: {len(prompt_token_ids)}, "
f"Generated text: {generated_text!r}")
# Create an LLM.
def initialize_engine() -> LLM:
llm = LLM(model="Qwen/Qwen2.5-7B-Instruct-1M",
max_model_len=1048576,
tensor_parallel_size=4,
enforce_eager=True,
enable_chunked_prefill=True,
max_num_batched_tokens=131072)
return llm
def main():
llm = initialize_engine()
prompt = load_prompt()
process_requests(llm, [prompt])
if __name__ == '__main__':
main()

View File

@ -150,6 +150,101 @@ def merge_attn_states(output: torch.Tensor,
prefix_lse, suffix_output, suffix_lse) prefix_lse, suffix_output, suffix_lse)
def convert_vertical_slash_indexes(
q_seqlens: torch.Tensor, # [BATCH, ]
kv_seqlens: torch.Tensor, # [BATCH, ]
vertical_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_V]
slash_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_S]
context_size: int,
block_size_M: int,
block_size_N: int,
causal: bool = True,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size = slash_indexes.size(0)
num_heads = slash_indexes.size(1)
nnz_slash = slash_indexes.size(2)
nnz_vertical = vertical_indexes.size(2)
num_rows = (context_size + block_size_M - 1) // block_size_M
block_count = torch.zeros(batch_size,
num_heads,
num_rows,
dtype=q_seqlens.dtype,
device=q_seqlens.device)
block_offset = torch.zeros(batch_size,
num_heads,
num_rows,
nnz_slash,
dtype=q_seqlens.dtype,
device=q_seqlens.device)
column_count = torch.zeros(batch_size,
num_heads,
num_rows,
dtype=q_seqlens.dtype,
device=q_seqlens.device)
column_index = torch.zeros(batch_size,
num_heads,
num_rows,
nnz_vertical,
dtype=q_seqlens.dtype,
device=q_seqlens.device)
torch.ops._C.convert_vertical_slash_indexes(
block_count, block_offset, column_count, column_index, q_seqlens,
kv_seqlens, vertical_indexes, slash_indexes, context_size,
block_size_M, block_size_N, causal)
return block_count, block_offset, column_count, column_index
def convert_vertical_slash_indexes_mergehead(
q_seqlens: torch.Tensor, # [BATCH, ]
kv_seqlens: torch.Tensor, # [BATCH, ]
vertical_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_V]
slash_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_S]
# [N_HEADS] : different head use different number of indices
vertical_indices_count: torch.Tensor,
slash_indices_count: torch.Tensor,
context_size: int,
block_size_M: int,
block_size_N: int,
causal: bool = True,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size = slash_indexes.size(0)
num_heads = slash_indexes.size(1)
nnz_slash = slash_indexes.size(2)
nnz_vertical = vertical_indexes.size(2)
num_rows = (context_size + block_size_M - 1) // block_size_M
block_count = torch.empty(batch_size,
num_heads,
num_rows,
dtype=q_seqlens.dtype,
device=q_seqlens.device)
block_offset = torch.empty(batch_size,
num_heads,
num_rows,
nnz_slash,
dtype=q_seqlens.dtype,
device=q_seqlens.device)
column_count = torch.empty(batch_size,
num_heads,
num_rows,
dtype=q_seqlens.dtype,
device=q_seqlens.device)
column_index = torch.empty(batch_size,
num_heads,
num_rows,
nnz_vertical,
dtype=q_seqlens.dtype,
device=q_seqlens.device)
torch.ops._C.convert_vertical_slash_indexes_mergehead(
block_count, block_offset, column_count, column_index, q_seqlens,
kv_seqlens, vertical_indexes, slash_indexes, vertical_indices_count,
slash_indices_count, context_size, block_size_M, block_size_N, causal)
return block_count, block_offset, column_count, column_index
# pos encoding ops # pos encoding ops
def rotary_embedding( def rotary_embedding(
positions: torch.Tensor, positions: torch.Tensor,

File diff suppressed because it is too large Load Diff

View File

@ -929,6 +929,23 @@ class ModelConfig:
"Number of experts in the model must be greater than 0 " "Number of experts in the model must be greater than 0 "
"when expert parallelism is enabled.") "when expert parallelism is enabled.")
def verify_dual_chunk_attention_config(
self,
load_config: "LoadConfig",
) -> None:
if hasattr(self.hf_config, "dual_chunk_attention_config"):
# Try loading the sparse attention config
from vllm.model_executor.model_loader.weight_utils import (
get_sparse_attention_config)
sparse_attn_config = get_sparse_attention_config(self, load_config)
if sparse_attn_config:
self.hf_config.dual_chunk_attention_config[
"sparse_attention_config"] = sparse_attn_config
if "sparse_attention_enabled" not in \
self.hf_config.dual_chunk_attention_config:
self.hf_config.dual_chunk_attention_config[
"sparse_attention_enabled"] = True
def verify_async_output_proc(self, parallel_config, speculative_config, def verify_async_output_proc(self, parallel_config, speculative_config,
device_config) -> None: device_config) -> None:
if not self.use_async_output_proc: if not self.use_async_output_proc:
@ -4187,6 +4204,8 @@ class VllmConfig:
self.speculative_config, self.speculative_config,
self.device_config) self.device_config)
self.model_config.verify_with_parallel_config(self.parallel_config) self.model_config.verify_with_parallel_config(self.parallel_config)
self.model_config.verify_dual_chunk_attention_config(
self.load_config)
if self.cache_config is not None: if self.cache_config is not None:
self.cache_config.verify_with_parallel_config(self.parallel_config) self.cache_config.verify_with_parallel_config(self.parallel_config)

View File

@ -37,8 +37,8 @@ from vllm.reasoning import ReasoningParserManager
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3 from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
from vllm.transformers_utils.utils import check_gguf_file from vllm.transformers_utils.utils import check_gguf_file
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import (FlexibleArgumentParser, GiB_bytes, is_in_doc_build, from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
is_in_ray_actor) GiB_bytes, is_in_doc_build, is_in_ray_actor)
# yapf: enable # yapf: enable
@ -983,6 +983,17 @@ class EngineArgs:
assert self.enable_chunked_prefill is not None assert self.enable_chunked_prefill is not None
if envs.VLLM_ATTENTION_BACKEND in [STR_DUAL_CHUNK_FLASH_ATTN_VAL]:
assert self.enforce_eager, (
"Cuda graph is not supported with DualChunkFlashAttention. "
"To run the model in eager mode, set 'enforce_eager=True' "
"or use '--enforce-eager' in the CLI.")
assert current_platform.is_cuda(), (
"DualChunkFlashAttention is only supported on CUDA platform.")
assert not use_v1, (
"DualChunkFlashAttention is not supported on V1 engine. "
"To run the model in V0 engine, try set 'VLLM_USE_V1=0'")
cache_config = CacheConfig( cache_config = CacheConfig(
block_size=self.block_size, block_size=self.block_size,
gpu_memory_utilization=self.gpu_memory_utilization, gpu_memory_utilization=self.gpu_memory_utilization,

View File

@ -1486,6 +1486,184 @@ class MRotaryEmbedding(RotaryEmbedding):
return updates return updates
@CustomOp.register("dual_chunk_rotary_embedding")
class DualChunkRotaryEmbedding(CustomOp):
"""Rotary positional embedding for Dual Chunk Attention."""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
chunk_size: int,
local_size: int,
) -> None:
super().__init__()
self.head_size = head_size
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.is_neox_style = is_neox_style
self.chunk_size = chunk_size
self.local_size = local_size
self.dtype = dtype
self.device = torch.device(f"cuda:{torch.cuda.current_device()}")
(q_cache, qc_cache, k_cache, qc_no_clamp_cache,
q_inter_cache) = self._compute_cos_sin_cache()
self.register_buffer("cos_sin_q_cache", q_cache, persistent=False)
self.register_buffer("cos_sin_qc_cache", qc_cache, persistent=False)
self.register_buffer("cos_sin_k_cache", k_cache, persistent=False)
self.register_buffer("cos_sin_qc_no_clamp_cache",
qc_no_clamp_cache,
persistent=False)
self.register_buffer("cos_sin_q_inter_cache",
q_inter_cache,
persistent=False)
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
"""Compute the inverse frequency."""
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
# However, we use `torch.arange(..., dtype=torch.float)` instead to
# avoid numerical issues with large base values (e.g., 10000000).
# This may cause a slight numerical difference between the HF
# implementation and ours.
# NOTE(woosuk): To exactly match the HF implementation, we need to
# use CPU to compute the cache and then move it to GPU. However, we
# create the cache on GPU for faster initialization. This may cause
# a slight numerical difference between the HF implementation and ours.
inv_freq = 1.0 / (base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
"""Compute the cos and sin cache."""
inv_freq = self._compute_inv_freq(self.base)
chunk_len = self.chunk_size - self.local_size
q_t = torch.arange(chunk_len, dtype=torch.float)
qc_t = (torch.arange(chunk_len, dtype=torch.float) +
chunk_len).clamp(max=self.chunk_size)
k_t = torch.arange(self.max_position_embeddings,
dtype=torch.float) % chunk_len
# count from chunk_len, no clamp(self.chunk_size) restriction
qc_no_clamp_t = torch.arange(chunk_len, dtype=torch.float) + chunk_len
# count from self.chunk_size for q_inter's rope
q_inter_t = torch.arange(chunk_len,
dtype=torch.float) + self.chunk_size
q_freqs = torch.outer(q_t, inv_freq)
qc_freqs = torch.outer(qc_t, inv_freq)
k_freqs = torch.outer(k_t, inv_freq)
qc_no_clamp_freqs = torch.outer(qc_no_clamp_t, inv_freq)
q_inter_freqs = torch.outer(q_inter_t, inv_freq)
q_cos = q_freqs.cos()
q_sin = q_freqs.sin()
qc_cos = qc_freqs.cos()
qc_sin = qc_freqs.sin()
k_cos = k_freqs.cos()
k_sin = k_freqs.sin()
qc_no_clamp_cos = qc_no_clamp_freqs.cos()
qc_no_clamp_sin = qc_no_clamp_freqs.sin()
q_inter_cos = q_inter_freqs.cos()
q_inter_sin = q_inter_freqs.sin()
q_cache = torch.cat((q_cos, q_sin), dim=-1).to(dtype=self.dtype,
device=self.device)
qc_cache = torch.cat((qc_cos, qc_sin), dim=-1).to(dtype=self.dtype,
device=self.device)
k_cache = torch.cat((k_cos, k_sin), dim=-1).to(dtype=self.dtype,
device=self.device)
qc_no_clamp_cache = torch.cat((qc_no_clamp_cos, qc_no_clamp_sin),
dim=-1).to(dtype=self.dtype,
device=self.device)
q_inter_cache = torch.cat((q_inter_cos, q_inter_sin),
dim=-1).to(dtype=self.dtype,
device=self.device)
return q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache
def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
query = query.view(*query.shape[:-1], -1, self.head_size)
key = key.view(*key.shape[:-1], -1, self.head_size)
query_rot = query[..., :self.rotary_dim]
key_rot = key[..., :self.rotary_dim]
if self.rotary_dim < self.head_size:
query_pass = query[..., self.rotary_dim:]
key_pass = key[..., self.rotary_dim:]
else:
query_pass = None
key_pass = None
positions_with_offsets = (torch.add(positions, offsets)
if offsets is not None else positions)
key = self._apply_rotary_embedding(
self.cos_sin_k_cache[positions_with_offsets], key_rot, key_pass)
chunk_len = self.chunk_size - self.local_size
query = self._apply_rotary_embedding(
self.cos_sin_q_cache[positions_with_offsets % chunk_len],
query_rot, query_pass)
query_succ = self._apply_rotary_embedding(
self.cos_sin_qc_cache[positions_with_offsets % chunk_len],
query_rot, query_pass)
query_inter = self._apply_rotary_embedding(
self.cos_sin_qc_cache[chunk_len - 1].repeat(positions.shape[0], 1),
query_rot, query_pass)
query_succ_critical = self._apply_rotary_embedding(
self.cos_sin_qc_no_clamp_cache[positions_with_offsets % chunk_len],
query_rot, query_pass)
query_inter_critical = self._apply_rotary_embedding(
self.cos_sin_q_inter_cache[positions_with_offsets % chunk_len],
query_rot, query_pass)
# merge query into one tensor to simplify the interfaces
query = torch.cat((
query,
query_succ,
query_inter,
query_succ_critical,
query_inter_critical,
),
dim=-1)
return query, key
def _apply_rotary_embedding(self, cos_sin, hidden_rot, hidden_pass):
cos, sin = cos_sin.chunk(2, dim=-1)
if self.is_neox_style:
# NOTE(woosuk): Here we assume that the positions tensor has the
# shape [batch_size, seq_len].
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
else:
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
hidden_rot = hidden_rot * cos + rotate_fn(hidden_rot) * sin
if self.rotary_dim < self.head_size:
hidden = torch.cat((hidden_rot, hidden_pass), dim=-1)
else:
hidden = hidden_rot
return hidden.flatten(-2).squeeze(0)
def extra_repr(self) -> str:
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
s += f", max_position_embeddings={self.max_position_embeddings}"
s += f", base={self.base}, is_neox_style={self.is_neox_style}"
s += f", chunk_size={self.chunk_size}, local_size={self.local_size}"
return s
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
@ -1498,6 +1676,7 @@ def get_rope(
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
partial_rotary_factor: float = 1.0, partial_rotary_factor: float = 1.0,
dual_chunk_attention_config: Optional[Dict[str, Any]] = None,
) -> RotaryEmbedding: ) -> RotaryEmbedding:
if dtype is None: if dtype is None:
dtype = torch.get_default_dtype() dtype = torch.get_default_dtype()
@ -1510,14 +1689,35 @@ def get_rope(
rope_scaling_args = tuple(rope_scaling_tuple.items()) rope_scaling_args = tuple(rope_scaling_tuple.items())
else: else:
rope_scaling_args = None rope_scaling_args = None
if dual_chunk_attention_config is not None:
dual_chunk_attention_tuple = {
k: tuple(v) if isinstance(v, list) else v
for k, v in dual_chunk_attention_config.items()
if k != "sparse_attention_config"
}
dual_chunk_attention_args = tuple(dual_chunk_attention_tuple.items())
else:
dual_chunk_attention_args = None
if partial_rotary_factor < 1.0: if partial_rotary_factor < 1.0:
rotary_dim = int(rotary_dim * partial_rotary_factor) rotary_dim = int(rotary_dim * partial_rotary_factor)
key = (head_size, rotary_dim, max_position, base, is_neox_style, key = (head_size, rotary_dim, max_position, base, is_neox_style,
rope_scaling_args, dtype) rope_scaling_args, dual_chunk_attention_args, dtype)
if key in _ROPE_DICT: if key in _ROPE_DICT:
return _ROPE_DICT[key] return _ROPE_DICT[key]
if not rope_scaling: if dual_chunk_attention_config is not None:
extra_kwargs = {
k: v
for k, v in dual_chunk_attention_config.items()
if k in ("chunk_size", "local_size")
}
rotary_emb = DualChunkRotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style, dtype,
**extra_kwargs)
elif not rope_scaling:
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style, dtype) is_neox_style, dtype)
else: else:

View File

@ -217,6 +217,39 @@ def get_quant_config(model_config: ModelConfig,
return quant_cls.from_config(config) return quant_cls.from_config(config)
def get_sparse_attention_config(
model_config: ModelConfig,
load_config: LoadConfig,
sparse_attention_config_filename: str = "sparse_attention_config.json",
) -> Dict[str, Any]:
model_name_or_path = model_config.model
is_local = os.path.isdir(model_name_or_path)
if not is_local:
# Download the config files.
with get_lock(model_name_or_path, load_config.download_dir):
hf_folder = snapshot_download(
model_name_or_path,
revision=model_config.revision,
allow_patterns="*.json",
cache_dir=load_config.download_dir,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
tqdm_class=DisabledTqdm,
)
else:
hf_folder = model_name_or_path
config_file = os.path.join(hf_folder, sparse_attention_config_filename)
if not os.path.exists(config_file):
return {}
# Load the sparse attention config.
with open(config_file) as f:
config = json.load(f)
logger.info("Loaded sparse attention config from %s", config_file)
return config
def download_weights_from_hf( def download_weights_from_hf(
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str], cache_dir: Optional[str],

View File

@ -23,7 +23,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Qwen2 model compatible with HuggingFace weights.""" """Inference-only Qwen2 model compatible with HuggingFace weights."""
from typing import Iterable, Optional, Set, Tuple, Union from typing import Any, Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
@ -53,7 +53,7 @@ from vllm.sequence import IntermediateTensors, PoolerOutput
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
is_pp_missing_parameter, extract_layer_index, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix)
@ -99,17 +99,20 @@ class Qwen2MLP(nn.Module):
class Qwen2Attention(nn.Module): class Qwen2Attention(nn.Module):
def __init__(self, def __init__(
hidden_size: int, self,
num_heads: int, hidden_size: int,
num_kv_heads: int, num_heads: int,
max_position: int = 4096 * 32, num_kv_heads: int,
rope_theta: float = 10000, max_position: int = 4096 * 32,
cache_config: Optional[CacheConfig] = None, rope_theta: float = 10000,
quant_config: Optional[QuantizationConfig] = None, cache_config: Optional[CacheConfig] = None,
rope_scaling: Optional[Tuple] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", rope_scaling: Optional[Tuple] = None,
attn_type: str = AttentionType.DECODER) -> None: prefix: str = "",
attn_type: str = AttentionType.DECODER,
dual_chunk_attention_config: Optional[dict[str,
Any]] = None) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
@ -131,6 +134,7 @@ class Qwen2Attention(nn.Module):
self.kv_size = self.num_kv_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.dual_chunk_attention_config = dual_chunk_attention_config
self.qkv_proj = QKVParallelLinear( self.qkv_proj = QKVParallelLinear(
hidden_size, hidden_size,
@ -155,15 +159,21 @@ class Qwen2Attention(nn.Module):
max_position=max_position, max_position=max_position,
base=self.rope_theta, base=self.rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
dual_chunk_attention_config=dual_chunk_attention_config,
) )
self.attn = Attention(self.num_heads, self.attn = Attention(
self.head_dim, self.num_heads,
self.scaling, self.head_dim,
num_kv_heads=self.num_kv_heads, self.scaling,
cache_config=cache_config, num_kv_heads=self.num_kv_heads,
quant_config=quant_config, cache_config=cache_config,
prefix=f"{prefix}.attn", quant_config=quant_config,
attn_type=attn_type) attn_type=attn_type,
prefix=f"{prefix}.attn",
**{
"layer_idx": extract_layer_index(prefix),
"dual_chunk_attention_config": dual_chunk_attention_config,
} if dual_chunk_attention_config else {})
def forward( def forward(
self, self,
@ -192,6 +202,9 @@ class Qwen2DecoderLayer(nn.Module):
# Requires transformers > 4.32.0 # Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 1000000) rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None) rope_scaling = getattr(config, "rope_scaling", None)
dual_chunk_attention_config = getattr(config,
"dual_chunk_attention_config",
None)
# By default, Qwen2 uses causal attention as it is a decoder-only model. # By default, Qwen2 uses causal attention as it is a decoder-only model.
# You can override the HF config with `is_causal=False` to enable # You can override the HF config with `is_causal=False` to enable
@ -213,6 +226,7 @@ class Qwen2DecoderLayer(nn.Module):
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
attn_type=attn_type, attn_type=attn_type,
dual_chunk_attention_config=dual_chunk_attention_config,
) )
self.mlp = Qwen2MLP( self.mlp = Qwen2MLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,

View File

@ -175,6 +175,7 @@ class Qwen2MoeAttention(nn.Module):
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
dual_chunk_attention_config: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
@ -198,6 +199,7 @@ class Qwen2MoeAttention(nn.Module):
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.dual_chunk_attention_config = dual_chunk_attention_config
self.qkv_proj = QKVParallelLinear( self.qkv_proj = QKVParallelLinear(
hidden_size, hidden_size,
@ -221,14 +223,20 @@ class Qwen2MoeAttention(nn.Module):
max_position=max_position_embeddings, max_position=max_position_embeddings,
base=rope_theta, base=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
dual_chunk_attention_config=dual_chunk_attention_config,
) )
self.attn = Attention(self.num_heads, self.attn = Attention(
self.head_dim, self.num_heads,
self.scaling, self.head_dim,
num_kv_heads=self.num_kv_heads, self.scaling,
cache_config=cache_config, num_kv_heads=self.num_kv_heads,
quant_config=quant_config, cache_config=cache_config,
prefix=f"{prefix}.attn") quant_config=quant_config,
prefix=f"{prefix}.attn",
**{
"layer_idx": extract_layer_index(prefix),
"dual_chunk_attention_config": dual_chunk_attention_config,
} if dual_chunk_attention_config else {})
def forward( def forward(
self, self,
@ -256,6 +264,9 @@ class Qwen2MoeDecoderLayer(nn.Module):
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None) rope_scaling = getattr(config, "rope_scaling", None)
dual_chunk_attention_config = getattr(config,
"dual_chunk_attention_config",
None)
max_position_embeddings = getattr(config, "max_position_embeddings", max_position_embeddings = getattr(config, "max_position_embeddings",
8192) 8192)
self.self_attn = Qwen2MoeAttention( self.self_attn = Qwen2MoeAttention(
@ -268,6 +279,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
dual_chunk_attention_config=dual_chunk_attention_config,
) )
# Note: Qwen/Qwen2-57B-A14B-Instruct does not have # Note: Qwen/Qwen2-57B-A14B-Instruct does not have

View File

@ -222,6 +222,10 @@ class CudaPlatformBase(Platform):
elif selected_backend == _Backend.XFORMERS: elif selected_backend == _Backend.XFORMERS:
logger.info("Using XFormers backend.") logger.info("Using XFormers backend.")
return "vllm.attention.backends.xformers.XFormersBackend" return "vllm.attention.backends.xformers.XFormersBackend"
elif selected_backend == _Backend.DUAL_CHUNK_FLASH_ATTN:
logger.info("Using DualChunkFlashAttention backend.")
return ("vllm.attention.backends.dual_chunk_flash_attn."
"DualChunkFlashAttentionBackend")
elif selected_backend == _Backend.FLASH_ATTN: elif selected_backend == _Backend.FLASH_ATTN:
pass pass
elif selected_backend: elif selected_backend:

View File

@ -51,6 +51,7 @@ class _Backend(enum.Enum):
PALLAS_VLLM_V1 = enum.auto() PALLAS_VLLM_V1 = enum.auto()
IPEX = enum.auto() IPEX = enum.auto()
BLOCK_SPARSE_FLASH_ATTN = enum.auto() BLOCK_SPARSE_FLASH_ATTN = enum.auto()
DUAL_CHUNK_FLASH_ATTN = enum.auto()
NO_ATTENTION = enum.auto() NO_ATTENTION = enum.auto()

View File

@ -153,6 +153,7 @@ STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA"
STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH" STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH"
STR_XFORMERS_ATTN_VAL: str = "XFORMERS" STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN" STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_DUAL_CHUNK_FLASH_ATTN_VAL: str = "DUAL_CHUNK_FLASH_ATTN"
STR_INVALID_VAL: str = "INVALID" STR_INVALID_VAL: str = "INVALID"
GB_bytes = 1_000_000_000 GB_bytes = 1_000_000_000

View File

@ -204,6 +204,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.mrope_input_positions = None # type: ignore self.mrope_input_positions = None # type: ignore
self.seq_lens[0] = 0 # type: ignore self.seq_lens[0] = 0 # type: ignore
self.orig_seq_lens[0] = 0 # type: ignore self.orig_seq_lens[0] = 0 # type: ignore
self.prompt_lens[0] = 0 # type: ignore
self.query_lens[0] = 0 # type: ignore self.query_lens[0] = 0 # type: ignore
self.context_lens[0] = 0 # type: ignore self.context_lens[0] = 0 # type: ignore
self.curr_sliding_window_blocks[0] = 0 # type: ignore self.curr_sliding_window_blocks[0] = 0 # type: ignore
@ -236,6 +237,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# The original sequence length (before applying sliding window). # The original sequence length (before applying sliding window).
# This is used to compute slot mapping. # This is used to compute slot mapping.
orig_seq_lens: Optional[List[int]] = None, orig_seq_lens: Optional[List[int]] = None,
# This is used in the dual-chunk flash attention backend.
prompt_lens: Optional[List[int]] = None,
# The query length. # The query length.
query_lens: Optional[List[int]] = None, query_lens: Optional[List[int]] = None,
# The number of tokens that are already computed. # The number of tokens that are already computed.
@ -316,6 +319,12 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
for seq_id in range(len(self.seq_ids)): for seq_id in range(len(self.seq_ids)):
self.orig_seq_lens[seq_id] = 0 self.orig_seq_lens[seq_id] = 0
if prompt_lens:
self.prompt_lens = prompt_lens
else:
for seq_id in range(len(self.seq_ids)):
self.prompt_lens[seq_id] = 0
if query_lens: if query_lens:
self.query_lens = query_lens self.query_lens = query_lens
else: else:
@ -370,6 +379,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.mrope_input_positions = mrope_input_positions or None self.mrope_input_positions = mrope_input_positions or None
self.seq_lens = seq_lens or [] self.seq_lens = seq_lens or []
self.orig_seq_lens = orig_seq_lens or [] self.orig_seq_lens = orig_seq_lens or []
self.prompt_lens = prompt_lens or []
self.query_lens = query_lens or [] self.query_lens = query_lens or []
self.context_lens = context_lens or [] self.context_lens = context_lens or []
self.curr_sliding_window_blocks = \ self.curr_sliding_window_blocks = \
@ -403,6 +413,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.mrope_input_positions = None self.mrope_input_positions = None
self.seq_lens = [0] * self.n_seqs self.seq_lens = [0] * self.n_seqs
self.orig_seq_lens = [0] * self.n_seqs self.orig_seq_lens = [0] * self.n_seqs
self.prompt_lens = [0] * self.n_seqs
self.query_lens = [0] * self.n_seqs self.query_lens = [0] * self.n_seqs
self.context_lens = [0] * self.n_seqs self.context_lens = [0] * self.n_seqs
self.curr_sliding_window_blocks = [0] * self.n_seqs self.curr_sliding_window_blocks = [0] * self.n_seqs
@ -552,6 +563,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
inter_data.seq_lens[seq_idx] = seq_len inter_data.seq_lens[seq_idx] = seq_len
inter_data.orig_seq_lens[seq_idx] = seq_len inter_data.orig_seq_lens[seq_idx] = seq_len
inter_data.prompt_lens[seq_idx] = seq_data.get_prompt_len()
inter_data.context_lens[seq_idx] = context_len inter_data.context_lens[seq_idx] = context_len
inter_data.input_tokens[seq_idx].extend(tokens) inter_data.input_tokens[seq_idx].extend(tokens)
inter_data.inputs_embeds = prompt_embeds inter_data.inputs_embeds = prompt_embeds