mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:35:00 +08:00
[Kernel][Amd] Add fp8 kv cache support for rocm custom paged attention (#8577)
This commit is contained in:
parent
76515f303b
commit
9cc373f390
@ -18,8 +18,11 @@
|
|||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
#include <hip/hip_bf16.h>
|
#include <hip/hip_bf16.h>
|
||||||
|
#include "cuda_compat.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include "../attention/dtype_fp8.cuh"
|
||||||
|
#include "../quantization/fp8/amd/quant_utils.cuh"
|
||||||
|
|
||||||
#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx940__) || \
|
#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx940__) || \
|
||||||
defined(__gfx941__) || defined(__gfx942__))
|
defined(__gfx941__) || defined(__gfx942__))
|
||||||
@ -38,7 +41,6 @@
|
|||||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||||
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
|
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
|
||||||
#define WARP_SIZE 64
|
|
||||||
|
|
||||||
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
|
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
|
||||||
|
|
||||||
@ -60,6 +62,8 @@ typedef struct _B16x8 {
|
|||||||
_B16x4 xy[2];
|
_B16x4 xy[2];
|
||||||
} _B16x8;
|
} _B16x8;
|
||||||
|
|
||||||
|
using _B8x8 = uint2;
|
||||||
|
|
||||||
////// Non temporal load stores ///////
|
////// Non temporal load stores ///////
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -168,17 +172,39 @@ __device__ __forceinline__ _B16x4 addx4(const _B16x4& inp1,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T, vllm::Fp8KVCacheDataType KV_DTYPE>
|
||||||
|
__device__ __forceinline__ _B16x8 scaled_convert_b8x8(const _B8x8 input,
|
||||||
|
const float scale) {
|
||||||
|
union alignas(16) {
|
||||||
|
uint4 u4;
|
||||||
|
_B16x8 u16x8;
|
||||||
|
vllm::bf16_8_t b16x8;
|
||||||
|
} tmp;
|
||||||
|
if constexpr (std::is_same<T, _Float16>::value) {
|
||||||
|
tmp.u4 = vllm::fp8::scaled_convert<uint4, _B8x8, KV_DTYPE>(input, scale);
|
||||||
|
return tmp.u16x8;
|
||||||
|
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
|
||||||
|
tmp.b16x8 = vllm::fp8::scaled_convert<vllm::bf16_8_t, _B8x8, KV_DTYPE>(
|
||||||
|
input, scale);
|
||||||
|
return tmp.u16x8;
|
||||||
|
} else {
|
||||||
|
static_assert(false, "unsupported 16b dtype");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
///////////////////////////////////////
|
///////////////////////////////////////
|
||||||
|
|
||||||
// grid (num_seqs, num_partitions,num_heads/gqa_ratio)
|
// grid (num_seqs, num_partitions,num_heads/gqa_ratio)
|
||||||
// block (partition size)
|
// block (partition size)
|
||||||
template <typename scalar_t, int BLOCK_SIZE, int HEAD_SIZE, int NUM_THREADS,
|
template <typename scalar_t, typename cache_t,
|
||||||
|
vllm::Fp8KVCacheDataType KV_DTYPE, int BLOCK_SIZE, int HEAD_SIZE,
|
||||||
|
int NUM_THREADS,
|
||||||
int GQA_RATIO>
|
int GQA_RATIO>
|
||||||
__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
|
__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
|
||||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
|
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
|
||||||
// head_size/x, block_size, x]
|
// head_size/x, block_size, x]
|
||||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
||||||
// head_size, block_size]
|
// head_size, block_size]
|
||||||
const int num_kv_heads, const float scale,
|
const int num_kv_heads, const float scale,
|
||||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||||
@ -192,10 +218,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
|
|||||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
|
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
|
||||||
// head_size]
|
// head_size]
|
||||||
scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size]
|
scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size]
|
||||||
#if 0
|
int max_ctx_blocks, float k_scale, float v_scale) {
|
||||||
scalar_t* __restrict__ qk_out, // [num_heads, num_seqs, max_ctx_blocks,block_size]
|
|
||||||
#endif
|
|
||||||
int max_ctx_blocks) {
|
|
||||||
constexpr int NWARPS = NUM_THREADS / WARP_SIZE;
|
constexpr int NWARPS = NUM_THREADS / WARP_SIZE;
|
||||||
const int warpid = threadIdx.x / WARP_SIZE;
|
const int warpid = threadIdx.x / WARP_SIZE;
|
||||||
const int laneid = threadIdx.x % WARP_SIZE;
|
const int laneid = threadIdx.x % WARP_SIZE;
|
||||||
@ -222,12 +245,14 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
|
|||||||
constexpr int x = 16 / sizeof(scalar_t);
|
constexpr int x = 16 / sizeof(scalar_t);
|
||||||
constexpr int KHELOOP = HEAD_SIZE / x;
|
constexpr int KHELOOP = HEAD_SIZE / x;
|
||||||
_B16x8 Klocal[KHELOOP];
|
_B16x8 Klocal[KHELOOP];
|
||||||
|
_B8x8 Klocalb8[KHELOOP];
|
||||||
constexpr int VHELOOP =
|
constexpr int VHELOOP =
|
||||||
HEAD_SIZE /
|
HEAD_SIZE /
|
||||||
WARP_SIZE; // v head_size dimension is distributed across lanes
|
WARP_SIZE; // v head_size dimension is distributed across lanes
|
||||||
constexpr int VTLOOP = 8; // 16 separate 4xtokens across warp -> 16/2
|
constexpr int VTLOOP = 8; // 16 separate 4xtokens across warp -> 16/2
|
||||||
// 8xtokens
|
// 8xtokens
|
||||||
_B16x8 Vlocal[VHELOOP][VTLOOP];
|
_B16x8 Vlocal[VHELOOP][VTLOOP];
|
||||||
|
_B8x8 Vlocalb8[VHELOOP][VTLOOP];
|
||||||
floatx4 dout[QHLOOP];
|
floatx4 dout[QHLOOP];
|
||||||
float qk_max[QHLOOP];
|
float qk_max[QHLOOP];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
@ -279,6 +304,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
|
|||||||
(vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block;
|
(vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block;
|
||||||
vphysical_blocks[b] = block_table[vblock_idx_ctx];
|
vphysical_blocks[b] = block_table[vblock_idx_ctx];
|
||||||
}
|
}
|
||||||
|
|
||||||
// each 4 lanes fetch 8 helems, so warp fetches 8*16 = 128 helems
|
// each 4 lanes fetch 8 helems, so warp fetches 8*16 = 128 helems
|
||||||
const scalar_t* q_ptr =
|
const scalar_t* q_ptr =
|
||||||
q + seq_idx * q_stride + wg_start_head_idx * HEAD_SIZE;
|
q + seq_idx * q_stride + wg_start_head_idx * HEAD_SIZE;
|
||||||
@ -298,18 +324,30 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
|
|||||||
Qlocal[QHLOOP - 1].xy[1] = {0};
|
Qlocal[QHLOOP - 1].xy[1] = {0};
|
||||||
}
|
}
|
||||||
|
|
||||||
const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride +
|
const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride +
|
||||||
wg_start_kv_head_idx * kv_head_stride;
|
wg_start_kv_head_idx * kv_head_stride;
|
||||||
|
|
||||||
const int physical_block_offset =
|
const int physical_block_offset =
|
||||||
local_token_idx % BLOCK_SIZE; // since x=half8, physical_block_offset
|
local_token_idx % BLOCK_SIZE; // since x=half8, physical_block_offset
|
||||||
// is already cast as _H8
|
// is already cast as _H8
|
||||||
|
if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) {
|
||||||
const _B16x8* k_ptrh8 = reinterpret_cast<const _B16x8*>(k_ptr);
|
const _B16x8* k_ptrh8 = reinterpret_cast<const _B16x8*>(k_ptr);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int d = 0; d < KHELOOP; d++) {
|
for (int d = 0; d < KHELOOP; d++) {
|
||||||
Klocal[d] = k_ptrh8[d * BLOCK_SIZE + physical_block_offset];
|
Klocal[d] = k_ptrh8[d * BLOCK_SIZE + physical_block_offset];
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
constexpr int X = 16 / sizeof(cache_t);
|
||||||
|
const cache_t* k_ptr2 = k_ptr + physical_block_offset * X;
|
||||||
|
#pragma unroll
|
||||||
|
for (int d = 0; d < KHELOOP; d++) {
|
||||||
|
const int head_elem = d * 8;
|
||||||
|
const int offset1 = head_elem / X;
|
||||||
|
const int offset2 = head_elem % X;
|
||||||
|
const cache_t* k_ptr3 = k_ptr2 + offset1 * BLOCK_SIZE * X + offset2;
|
||||||
|
Klocalb8[d] = *reinterpret_cast<const _B8x8*>(k_ptr3);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
float alibi_slope[QHLOOP];
|
float alibi_slope[QHLOOP];
|
||||||
if (alibi_slopes != nullptr) {
|
if (alibi_slopes != nullptr) {
|
||||||
@ -322,7 +360,8 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const scalar_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride;
|
const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride;
|
||||||
|
if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) {
|
||||||
const _B16x8* v_ptrh8 = reinterpret_cast<const _B16x8*>(v_ptr);
|
const _B16x8* v_ptrh8 = reinterpret_cast<const _B16x8*>(v_ptr);
|
||||||
// iterate over each v block
|
// iterate over each v block
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
@ -345,6 +384,41 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
const _B8x8* v_ptrh8 = reinterpret_cast<const _B8x8*>(v_ptr);
|
||||||
|
// iterate over each v block
|
||||||
|
#pragma unroll
|
||||||
|
for (int b = 0; b < VBLOCKS; b++) {
|
||||||
|
// int32 physical_block_number leads to overflow when multiplied with
|
||||||
|
// kv_block_stride
|
||||||
|
const int64_t vphysical_block_number =
|
||||||
|
static_cast<int64_t>(vphysical_blocks[b]);
|
||||||
|
const _B8x8* v_ptrh8b =
|
||||||
|
v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8;
|
||||||
|
// iterate over each head elem (within head_size)
|
||||||
|
#pragma unroll
|
||||||
|
for (int h = 0; h < VHELOOP; h++) {
|
||||||
|
const int head_size_elem = h * WARP_SIZE + laneid;
|
||||||
|
const _B8x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8;
|
||||||
|
// iterate over all velems within block
|
||||||
|
#pragma unroll
|
||||||
|
for (int d = 0; d < BLOCK_SIZE / 8; d++) {
|
||||||
|
// Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d];
|
||||||
|
const _B8x8 Vlocalb8 = v_ptrh8be[d];
|
||||||
|
Vlocal[h][b * BLOCK_SIZE / 8 + d] =
|
||||||
|
scaled_convert_b8x8<scalar_t, KV_DTYPE>(Vlocalb8, v_scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int d = 0; d < KHELOOP; d++) {
|
||||||
|
Klocal[d] =
|
||||||
|
scaled_convert_b8x8<scalar_t, KV_DTYPE>(Klocalb8[d], k_scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int h = 0; h < QHLOOP; h++) {
|
for (int h = 0; h < QHLOOP; h++) {
|
||||||
@ -794,13 +868,15 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
|||||||
|
|
||||||
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
|
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
|
||||||
|
|
||||||
template <typename scalar_t, int BLOCK_SIZE, int HEAD_SIZE, int NUM_THREADS,
|
template <typename scalar_t, typename cache_t,
|
||||||
|
vllm::Fp8KVCacheDataType KV_DTYPE, int BLOCK_SIZE, int HEAD_SIZE,
|
||||||
|
int NUM_THREADS,
|
||||||
int GQA_RATIO>
|
int GQA_RATIO>
|
||||||
__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
|
__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
|
||||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
|
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
|
||||||
// head_size/x, block_size, x]
|
// head_size/x, block_size, x]
|
||||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
||||||
// head_size, block_size]
|
// head_size, block_size]
|
||||||
const int num_kv_heads, const float scale,
|
const int num_kv_heads, const float scale,
|
||||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||||
@ -814,10 +890,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
|
|||||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
|
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
|
||||||
// head_size]
|
// head_size]
|
||||||
scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size]
|
scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size]
|
||||||
#if 0
|
int max_ctx_blocks, float k_scale, float v_scale) {
|
||||||
scalar_t* __restrict__ qk_out, // [num_heads, num_seqs, max_ctx_blocks,block_size]
|
|
||||||
#endif
|
|
||||||
int max_ctx_blocks) {
|
|
||||||
UNREACHABLE_CODE
|
UNREACHABLE_CODE
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -839,26 +912,24 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
|||||||
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
|
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
|
||||||
|
|
||||||
#define LAUNCH_CUSTOM_ATTENTION(GQA_RATIO) \
|
#define LAUNCH_CUSTOM_ATTENTION(GQA_RATIO) \
|
||||||
paged_attention_ll4mi_QKV_kernel<T, BLOCK_SIZE, HEAD_SIZE, NTHR, GQA_RATIO> \
|
paged_attention_ll4mi_QKV_kernel<T, KVT, KV_DTYPE, BLOCK_SIZE, HEAD_SIZE, \
|
||||||
|
NTHR, GQA_RATIO> \
|
||||||
<<<grid, block, 0, stream>>>( \
|
<<<grid, block, 0, stream>>>( \
|
||||||
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
|
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
|
||||||
block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \
|
block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \
|
||||||
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
|
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
|
||||||
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks);
|
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \
|
||||||
|
k_scale, v_scale);
|
||||||
|
|
||||||
template <typename T, int BLOCK_SIZE, int HEAD_SIZE, int PARTITION_SIZE = 256>
|
template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE,
|
||||||
|
int BLOCK_SIZE, int HEAD_SIZE, int PARTITION_SIZE = 512>
|
||||||
void paged_attention_custom_launcher(
|
void paged_attention_custom_launcher(
|
||||||
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
||||||
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||||
torch::Tensor& value_cache, const int num_kv_heads, float scale,
|
torch::Tensor& value_cache, const int num_kv_heads, float scale,
|
||||||
torch::Tensor& block_tables, torch::Tensor& context_lens,
|
torch::Tensor& block_tables, torch::Tensor& context_lens,
|
||||||
int max_context_len,
|
int max_context_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
#if 0
|
float k_scale, float v_scale) {
|
||||||
torch::Tensor& qk_out,
|
|
||||||
torch::Tensor& softmax_out,
|
|
||||||
#endif
|
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes) {
|
|
||||||
|
|
||||||
int num_seqs = query.size(0);
|
int num_seqs = query.size(0);
|
||||||
int num_heads = query.size(1);
|
int num_heads = query.size(1);
|
||||||
int head_size = query.size(2);
|
int head_size = query.size(2);
|
||||||
@ -878,14 +949,10 @@ void paged_attention_custom_launcher(
|
|||||||
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
|
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
|
||||||
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
|
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
|
||||||
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
||||||
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
|
KVT* key_cache_ptr = reinterpret_cast<KVT*>(key_cache.data_ptr());
|
||||||
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
|
KVT* value_cache_ptr = reinterpret_cast<KVT*>(value_cache.data_ptr());
|
||||||
int* block_tables_ptr = block_tables.data_ptr<int>();
|
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||||
int* context_lens_ptr = context_lens.data_ptr<int>();
|
int* context_lens_ptr = context_lens.data_ptr<int>();
|
||||||
#if 0
|
|
||||||
T* qk_out_ptr = reinterpret_cast<T*>(qk_out.data_ptr());
|
|
||||||
T* softmax_out_ptr = reinterpret_cast<T*>(softmax_out.data_ptr());
|
|
||||||
#endif
|
|
||||||
|
|
||||||
const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE);
|
const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE);
|
||||||
const int max_num_partitions =
|
const int max_num_partitions =
|
||||||
@ -972,32 +1039,32 @@ void paged_attention_custom_launcher(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define CALL_CUSTOM_LAUNCHER(T, BLK_SIZE, HEAD_SIZE) \
|
#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \
|
||||||
paged_attention_custom_launcher<T, BLK_SIZE, HEAD_SIZE>( \
|
paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE>( \
|
||||||
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
||||||
num_kv_heads, scale, block_tables, context_lens, max_context_len, \
|
num_kv_heads, scale, block_tables, context_lens, max_context_len, \
|
||||||
alibi_slopes);
|
alibi_slopes, k_scale, v_scale);
|
||||||
|
|
||||||
#define CALL_CUSTOM_LAUNCHER_BLK(T, HEAD_SIZE) \
|
#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \
|
||||||
switch (block_size) { \
|
switch (block_size) { \
|
||||||
case 16: \
|
case 16: \
|
||||||
CALL_CUSTOM_LAUNCHER(T, 16, HEAD_SIZE); \
|
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, 16, HEAD_SIZE); \
|
||||||
break; \
|
break; \
|
||||||
case 32: \
|
case 32: \
|
||||||
CALL_CUSTOM_LAUNCHER(T, 32, HEAD_SIZE); \
|
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, 32, HEAD_SIZE); \
|
||||||
break; \
|
break; \
|
||||||
default: \
|
default: \
|
||||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||||
break; \
|
break; \
|
||||||
}
|
}
|
||||||
|
|
||||||
#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T) \
|
#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T, KVT, KV_DTYPE) \
|
||||||
switch (head_size) { \
|
switch (head_size) { \
|
||||||
case 64: \
|
case 64: \
|
||||||
CALL_CUSTOM_LAUNCHER_BLK(T, 64); \
|
CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 64); \
|
||||||
break; \
|
break; \
|
||||||
case 128: \
|
case 128: \
|
||||||
CALL_CUSTOM_LAUNCHER_BLK(T, 128); \
|
CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 128); \
|
||||||
break; \
|
break; \
|
||||||
default: \
|
default: \
|
||||||
TORCH_CHECK(false, "Unsupported head size: ", head_size); \
|
TORCH_CHECK(false, "Unsupported head size: ", head_size); \
|
||||||
@ -1020,16 +1087,31 @@ void paged_attention(
|
|||||||
torch::Tensor& context_lens, // [num_seqs]
|
torch::Tensor& context_lens, // [num_seqs]
|
||||||
int64_t block_size, int64_t max_context_len,
|
int64_t block_size, int64_t max_context_len,
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
const std::string& kv_cache_dtype) {
|
const std::string& kv_cache_dtype, double k_scale, double v_scale) {
|
||||||
assert(kv_cache_dtype == "auto");
|
|
||||||
const int head_size = query.size(2);
|
const int head_size = query.size(2);
|
||||||
|
if (kv_cache_dtype == "auto") {
|
||||||
if (query.dtype() == at::ScalarType::Half) {
|
if (query.dtype() == at::ScalarType::Half) {
|
||||||
CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16);
|
CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, _Float16,
|
||||||
|
vllm::Fp8KVCacheDataType::kAuto);
|
||||||
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||||
CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16);
|
CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, __hip_bfloat16,
|
||||||
|
vllm::Fp8KVCacheDataType::kAuto);
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||||
}
|
}
|
||||||
|
} else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") {
|
||||||
|
if (query.dtype() == at::ScalarType::Half) {
|
||||||
|
CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, uint8_t,
|
||||||
|
vllm::Fp8KVCacheDataType::kFp8E4M3);
|
||||||
|
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||||
|
CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, uint8_t,
|
||||||
|
vllm::Fp8KVCacheDataType::kFp8E4M3);
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(false, "Unsupported KV cache dtype: ", kv_cache_dtype);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#undef WARP_SIZE
|
#undef WARP_SIZE
|
||||||
|
|||||||
@ -10,4 +10,5 @@ void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums,
|
|||||||
torch::Tensor& context_lens, int64_t block_size,
|
torch::Tensor& context_lens, int64_t block_size,
|
||||||
int64_t max_context_len,
|
int64_t max_context_len,
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
const std::string& kv_cache_dtype);
|
const std::string& kv_cache_dtype, double k_scale,
|
||||||
|
double v_scale);
|
||||||
|
|||||||
@ -26,7 +26,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
|
|||||||
" Tensor context_lens, int block_size,"
|
" Tensor context_lens, int block_size,"
|
||||||
" int max_context_len,"
|
" int max_context_len,"
|
||||||
" Tensor? alibi_slopes,"
|
" Tensor? alibi_slopes,"
|
||||||
" str kv_cache_dtype) -> ()");
|
" str kv_cache_dtype,"
|
||||||
|
" float k_scale, float v_scale) -> ()");
|
||||||
rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention);
|
rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -31,8 +31,7 @@ NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
|
|||||||
|
|
||||||
# FlashAttention forward only supports head dimension at most 128
|
# FlashAttention forward only supports head dimension at most 128
|
||||||
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
|
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
|
||||||
HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256
|
HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256]
|
||||||
] if not is_hip() else [64, 80, 96, 112, 128]
|
|
||||||
|
|
||||||
BLOCK_SIZES = [16, 32]
|
BLOCK_SIZES = [16, 32]
|
||||||
USE_ALIBI = [False, True]
|
USE_ALIBI = [False, True]
|
||||||
@ -114,7 +113,8 @@ def ref_single_query_cached_kv_attention(
|
|||||||
output[i].copy_(out, non_blocking=True)
|
output[i].copy_(out, non_blocking=True)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("version", ["v1", "v2"])
|
@pytest.mark.parametrize(
|
||||||
|
"version", ["v1", "v2"] if not is_hip() else ["v1", "v2", "rocm"])
|
||||||
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
|
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
|
||||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
@ -137,7 +137,8 @@ def test_paged_attention(
|
|||||||
seed: int,
|
seed: int,
|
||||||
device: str,
|
device: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
if kv_cache_dtype == "fp8" and head_size % 16:
|
if ((kv_cache_dtype == "fp8" and head_size % 16)
|
||||||
|
or (version == "rocm" and head_size not in (64, 128))):
|
||||||
pytest.skip()
|
pytest.skip()
|
||||||
|
|
||||||
seed_everything(seed)
|
seed_everything(seed)
|
||||||
@ -206,7 +207,7 @@ def test_paged_attention(
|
|||||||
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
|
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
|
||||||
cond=(head_size == HEAD_SIZES[0]))
|
cond=(head_size == HEAD_SIZES[0]))
|
||||||
|
|
||||||
elif version == "v2":
|
elif version in ("v2", "rocm"):
|
||||||
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
|
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
|
||||||
assert PARTITION_SIZE % block_size == 0
|
assert PARTITION_SIZE % block_size == 0
|
||||||
num_seqs, num_heads, head_size = output.shape
|
num_seqs, num_heads, head_size = output.shape
|
||||||
@ -219,6 +220,7 @@ def test_paged_attention(
|
|||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
)
|
)
|
||||||
max_logits = torch.empty_like(exp_sums)
|
max_logits = torch.empty_like(exp_sums)
|
||||||
|
if version == "v2":
|
||||||
ops.paged_attention_v2(
|
ops.paged_attention_v2(
|
||||||
output,
|
output,
|
||||||
exp_sums,
|
exp_sums,
|
||||||
@ -240,10 +242,38 @@ def test_paged_attention(
|
|||||||
)
|
)
|
||||||
|
|
||||||
opcheck(torch.ops._C.paged_attention_v2,
|
opcheck(torch.ops._C.paged_attention_v2,
|
||||||
(output, exp_sums, max_logits, tmp_output, query, key_cache,
|
(output, exp_sums, max_logits, tmp_output, query,
|
||||||
value_cache, num_kv_heads, scale, block_tables, seq_lens,
|
key_cache, value_cache, num_kv_heads, scale, block_tables,
|
||||||
block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
|
seq_lens, block_size, max_seq_len, alibi_slopes,
|
||||||
k_scale, v_scale, 0, 0, 0, 64, 0),
|
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
|
||||||
|
cond=(head_size == HEAD_SIZES[0]))
|
||||||
|
|
||||||
|
else:
|
||||||
|
ops.paged_attention_rocm(
|
||||||
|
output,
|
||||||
|
exp_sums,
|
||||||
|
max_logits,
|
||||||
|
tmp_output,
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
num_kv_heads,
|
||||||
|
scale,
|
||||||
|
block_tables,
|
||||||
|
seq_lens,
|
||||||
|
block_size,
|
||||||
|
max_seq_len,
|
||||||
|
alibi_slopes,
|
||||||
|
kv_cache_dtype,
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
opcheck(torch.ops._rocm_C.paged_attention,
|
||||||
|
(output, exp_sums, max_logits, tmp_output, query,
|
||||||
|
key_cache, value_cache, num_kv_heads, scale, block_tables,
|
||||||
|
seq_lens, block_size, max_seq_len, alibi_slopes,
|
||||||
|
kv_cache_dtype, k_scale, v_scale),
|
||||||
cond=(head_size == HEAD_SIZES[0]))
|
cond=(head_size == HEAD_SIZES[0]))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -328,162 +358,6 @@ def ref_multi_query_kv_attention(
|
|||||||
return torch.cat(ref_outputs, dim=0)
|
return torch.cat(ref_outputs, dim=0)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("version", ["rocm"])
|
|
||||||
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
|
|
||||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
|
||||||
@pytest.mark.parametrize("head_size", [64, 128]) # only test 64 128
|
|
||||||
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
|
|
||||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
|
||||||
@pytest.mark.parametrize("kv_cache_dtype", ["auto"])
|
|
||||||
@pytest.mark.parametrize("seed", SEEDS)
|
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
|
||||||
@pytest.mark.skipif(not is_hip(), reason="only for rocm")
|
|
||||||
def test_paged_attention_rocm(
|
|
||||||
kv_cache_factory,
|
|
||||||
version: str,
|
|
||||||
num_seqs: int,
|
|
||||||
num_heads: Tuple[int, int],
|
|
||||||
head_size: int,
|
|
||||||
use_alibi: bool,
|
|
||||||
block_size: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
kv_cache_dtype: str,
|
|
||||||
seed: int,
|
|
||||||
device: str,
|
|
||||||
) -> None:
|
|
||||||
seed_everything(seed)
|
|
||||||
torch.set_default_device(device)
|
|
||||||
scale = float(1.0 / (head_size**0.5))
|
|
||||||
num_query_heads, num_kv_heads = num_heads
|
|
||||||
query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
|
|
||||||
query.uniform_(-scale, scale)
|
|
||||||
|
|
||||||
assert num_query_heads % num_kv_heads == 0
|
|
||||||
num_queries_per_kv = num_query_heads // num_kv_heads
|
|
||||||
alibi_slopes = None
|
|
||||||
if use_alibi:
|
|
||||||
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
|
|
||||||
|
|
||||||
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
|
|
||||||
context_lens[-1] = MAX_SEQ_LEN
|
|
||||||
#context_lens = [8192 for _ in range(num_seqs)]
|
|
||||||
max_context_len = max(context_lens)
|
|
||||||
context_lens = torch.tensor(context_lens, dtype=torch.int)
|
|
||||||
#print('>>> ctx lens', context_lens)
|
|
||||||
|
|
||||||
# Create the block tables.
|
|
||||||
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
|
|
||||||
block_tables = []
|
|
||||||
for _ in range(num_seqs):
|
|
||||||
block_table = [
|
|
||||||
random.randint(0, NUM_BLOCKS - 1)
|
|
||||||
for _ in range(max_num_blocks_per_seq)
|
|
||||||
]
|
|
||||||
block_tables.append(block_table)
|
|
||||||
block_tables = torch.tensor(block_tables, dtype=torch.int)
|
|
||||||
|
|
||||||
# Create the KV caches.
|
|
||||||
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
|
|
||||||
num_kv_heads, head_size,
|
|
||||||
kv_cache_dtype, dtype, seed,
|
|
||||||
device)
|
|
||||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
|
||||||
|
|
||||||
# TODO(charlifu) enable fp8 kv cache
|
|
||||||
# Using default kv_scale
|
|
||||||
# kv_scale = 1.0
|
|
||||||
|
|
||||||
# Call the paged attention kernel.
|
|
||||||
output = torch.empty_like(query)
|
|
||||||
PARTITION_SIZE_ROCM = 256
|
|
||||||
num_partitions = ((max_context_len + PARTITION_SIZE_ROCM - 1) //
|
|
||||||
PARTITION_SIZE_ROCM)
|
|
||||||
assert PARTITION_SIZE_ROCM % block_size == 0
|
|
||||||
num_seqs, num_heads, head_size = output.shape
|
|
||||||
tmp_output = torch.empty(
|
|
||||||
size=(num_seqs, num_heads, num_partitions, head_size),
|
|
||||||
dtype=output.dtype,
|
|
||||||
)
|
|
||||||
exp_sums = torch.empty(
|
|
||||||
size=(num_seqs, num_heads, num_partitions),
|
|
||||||
dtype=torch.float32,
|
|
||||||
)
|
|
||||||
max_logits = torch.empty_like(exp_sums)
|
|
||||||
if version == "rocm":
|
|
||||||
ops.paged_attention_rocm(
|
|
||||||
output,
|
|
||||||
exp_sums,
|
|
||||||
max_logits,
|
|
||||||
tmp_output,
|
|
||||||
query,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
num_kv_heads,
|
|
||||||
scale,
|
|
||||||
block_tables,
|
|
||||||
context_lens,
|
|
||||||
block_size,
|
|
||||||
max_context_len,
|
|
||||||
alibi_slopes,
|
|
||||||
kv_cache_dtype,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise AssertionError(f"Unknown version: {version}")
|
|
||||||
|
|
||||||
# Run the reference implementation.
|
|
||||||
if kv_cache_dtype == "fp8":
|
|
||||||
# Convert cache data back to dtype.
|
|
||||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
|
||||||
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
|
|
||||||
block_size, x)
|
|
||||||
dequantized_key_cache = torch.empty(size=key_cache_shape,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device)
|
|
||||||
ops.convert_fp8(key_cache, dequantized_key_cache)
|
|
||||||
key_cache = dequantized_key_cache
|
|
||||||
|
|
||||||
value_cache_shape = value_cache.shape
|
|
||||||
dequantized_value_cache = torch.empty(size=value_cache_shape,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device)
|
|
||||||
ops.convert_fp8(value_cache, dequantized_value_cache)
|
|
||||||
value_cache = dequantized_value_cache
|
|
||||||
|
|
||||||
ref_output = torch.empty_like(query)
|
|
||||||
ref_single_query_cached_kv_attention(
|
|
||||||
ref_output,
|
|
||||||
query,
|
|
||||||
num_queries_per_kv,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
block_tables,
|
|
||||||
context_lens,
|
|
||||||
scale,
|
|
||||||
alibi_slopes,
|
|
||||||
)
|
|
||||||
|
|
||||||
# NOTE(woosuk): Due to the kernel-level differences in the two
|
|
||||||
# implementations, there is a small numerical difference in the two
|
|
||||||
# outputs. Thus, we use a relaxed tolerance for the test.
|
|
||||||
atol = get_default_atol(output) if is_hip() else 1e-3
|
|
||||||
rtol = get_default_rtol(output) if is_hip() else 1e-5
|
|
||||||
|
|
||||||
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
|
|
||||||
# so we use a relaxed tolerance for the test.
|
|
||||||
atol, rtol = 1e-4, 1e-5
|
|
||||||
if dtype == torch.bfloat16:
|
|
||||||
atol, rtol = 2e-4, 1e-5
|
|
||||||
if use_alibi:
|
|
||||||
if dtype == torch.half:
|
|
||||||
atol, rtol = 5e-4, 1e-5
|
|
||||||
if dtype == torch.bfloat16:
|
|
||||||
atol, rtol = 1e-3, 1e-5
|
|
||||||
if kv_cache_dtype == "fp8":
|
|
||||||
atol, rtol = 1e-2, 1e-5
|
|
||||||
assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)
|
|
||||||
|
|
||||||
|
|
||||||
# TODO(woosuk): Add tests for USE_ALIBI=True.
|
# TODO(woosuk): Add tests for USE_ALIBI=True.
|
||||||
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
|
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
|
||||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
@ -491,7 +365,8 @@ def test_paged_attention_rocm(
|
|||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
@pytest.mark.parametrize("seed", SEEDS)
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
@pytest.mark.skipif(is_hip(), reason="skip for rocm")
|
@pytest.mark.skipif(is_hip(),
|
||||||
|
reason="Xformers backend is not supported on ROCm.")
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_multi_query_kv_attention(
|
def test_multi_query_kv_attention(
|
||||||
num_seqs: int,
|
num_seqs: int,
|
||||||
|
|||||||
@ -146,12 +146,14 @@ def paged_attention_rocm(
|
|||||||
max_seq_len: int,
|
max_seq_len: int,
|
||||||
alibi_slopes: Optional[torch.Tensor],
|
alibi_slopes: Optional[torch.Tensor],
|
||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
|
k_scale: float,
|
||||||
|
v_scale: float,
|
||||||
) -> None:
|
) -> None:
|
||||||
torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
|
torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
|
||||||
key_cache, value_cache, num_kv_heads,
|
key_cache, value_cache, num_kv_heads,
|
||||||
scale, block_tables, seq_lens,
|
scale, block_tables, seq_lens,
|
||||||
block_size, max_seq_len, alibi_slopes,
|
block_size, max_seq_len, alibi_slopes,
|
||||||
kv_cache_dtype)
|
kv_cache_dtype, k_scale, v_scale)
|
||||||
|
|
||||||
|
|
||||||
# pos encoding ops
|
# pos encoding ops
|
||||||
|
|||||||
@ -17,8 +17,8 @@ from vllm.platforms import current_platform
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
_PARTITION_SIZE = 256
|
_PARTITION_SIZE_ROCM = 512
|
||||||
ON_NAVI = "gfx1" in torch.cuda.get_device_properties("cuda").gcnArchName
|
_ON_NAVI = "gfx1" in torch.cuda.get_device_properties("cuda").gcnArchName
|
||||||
|
|
||||||
|
|
||||||
class ROCmFlashAttentionBackend(AttentionBackend):
|
class ROCmFlashAttentionBackend(AttentionBackend):
|
||||||
@ -489,14 +489,15 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
num_seqs, num_heads, head_size = decode_query.shape
|
num_seqs, num_heads, head_size = decode_query.shape
|
||||||
block_size = value_cache.shape[3]
|
block_size = value_cache.shape[3]
|
||||||
gqa_ratio = num_heads // self.num_kv_heads
|
gqa_ratio = num_heads // self.num_kv_heads
|
||||||
use_custom = use_rocm_custom_paged_attention(
|
use_custom = _use_rocm_custom_paged_attention(
|
||||||
decode_query.dtype, head_size, block_size, self.kv_cache_dtype,
|
decode_query.dtype, head_size, block_size, gqa_ratio,
|
||||||
gqa_ratio, decode_meta.max_decode_seq_len)
|
decode_meta.max_decode_seq_len)
|
||||||
if use_custom:
|
if use_custom:
|
||||||
max_seq_len = decode_meta.max_decode_seq_len
|
max_seq_len = decode_meta.max_decode_seq_len
|
||||||
max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
|
max_num_partitions = (
|
||||||
_PARTITION_SIZE)
|
(max_seq_len + _PARTITION_SIZE_ROCM - 1) //
|
||||||
assert _PARTITION_SIZE % block_size == 0
|
_PARTITION_SIZE_ROCM)
|
||||||
|
assert _PARTITION_SIZE_ROCM % block_size == 0
|
||||||
tmp_output = torch.empty(
|
tmp_output = torch.empty(
|
||||||
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
||||||
dtype=output.dtype,
|
dtype=output.dtype,
|
||||||
@ -524,6 +525,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
max_seq_len,
|
max_seq_len,
|
||||||
self.alibi_slopes,
|
self.alibi_slopes,
|
||||||
self.kv_cache_dtype,
|
self.kv_cache_dtype,
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
output[num_prefill_tokens:] = PagedAttention.forward_decode(
|
output[num_prefill_tokens:] = PagedAttention.forward_decode(
|
||||||
@ -580,12 +583,11 @@ def _sdpa_attention(
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
|
def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
|
||||||
block_size: int, kv_cache_dtype: str,
|
block_size: int, gqa_ratio: int,
|
||||||
gqa_ratio: int, max_seq_len: int) -> bool:
|
max_seq_len: int) -> bool:
|
||||||
# rocm custom page attention not support on navi (gfx1*)
|
# rocm custom page attention not support on navi (gfx1*)
|
||||||
return (not ON_NAVI and (qtype == torch.half or qtype == torch.bfloat16)
|
return (not _ON_NAVI and (qtype == torch.half or qtype == torch.bfloat16)
|
||||||
and (head_size == 64 or head_size == 128)
|
and (head_size == 64 or head_size == 128)
|
||||||
and (block_size == 16 or block_size == 32)
|
and (block_size == 16 or block_size == 32)
|
||||||
and kv_cache_dtype == "auto"
|
|
||||||
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768)
|
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user