[Kernel][Amd] Add fp8 kv cache support for rocm custom paged attention (#8577)

This commit is contained in:
Charlie Fu 2024-09-19 12:37:57 -05:00 committed by GitHub
parent 76515f303b
commit 9cc373f390
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 245 additions and 282 deletions

View File

@ -18,8 +18,11 @@
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <hip/hip_bf16.h> #include <hip/hip_bf16.h>
#include "cuda_compat.h"
#include <algorithm> #include <algorithm>
#include "../attention/dtype_fp8.cuh"
#include "../quantization/fp8/amd/quant_utils.cuh"
#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx940__) || \ #if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx940__) || \
defined(__gfx941__) || defined(__gfx942__)) defined(__gfx941__) || defined(__gfx942__))
@ -38,7 +41,6 @@
#define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
#define WARP_SIZE 64
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support #if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
@ -60,6 +62,8 @@ typedef struct _B16x8 {
_B16x4 xy[2]; _B16x4 xy[2];
} _B16x8; } _B16x8;
using _B8x8 = uint2;
////// Non temporal load stores /////// ////// Non temporal load stores ///////
template <typename T> template <typename T>
@ -168,18 +172,40 @@ __device__ __forceinline__ _B16x4 addx4(const _B16x4& inp1,
} }
} }
template <typename T, vllm::Fp8KVCacheDataType KV_DTYPE>
__device__ __forceinline__ _B16x8 scaled_convert_b8x8(const _B8x8 input,
const float scale) {
union alignas(16) {
uint4 u4;
_B16x8 u16x8;
vllm::bf16_8_t b16x8;
} tmp;
if constexpr (std::is_same<T, _Float16>::value) {
tmp.u4 = vllm::fp8::scaled_convert<uint4, _B8x8, KV_DTYPE>(input, scale);
return tmp.u16x8;
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
tmp.b16x8 = vllm::fp8::scaled_convert<vllm::bf16_8_t, _B8x8, KV_DTYPE>(
input, scale);
return tmp.u16x8;
} else {
static_assert(false, "unsupported 16b dtype");
}
}
/////////////////////////////////////// ///////////////////////////////////////
// grid (num_seqs, num_partitions,num_heads/gqa_ratio) // grid (num_seqs, num_partitions,num_heads/gqa_ratio)
// block (partition size) // block (partition size)
template <typename scalar_t, int BLOCK_SIZE, int HEAD_SIZE, int NUM_THREADS, template <typename scalar_t, typename cache_t,
vllm::Fp8KVCacheDataType KV_DTYPE, int BLOCK_SIZE, int HEAD_SIZE,
int NUM_THREADS,
int GQA_RATIO> int GQA_RATIO>
__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x] // head_size/x, block_size, x]
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size] // head_size, block_size]
const int num_kv_heads, const float scale, const int num_kv_heads, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs] const int* __restrict__ context_lens, // [num_seqs]
@ -192,10 +218,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
// head_size] // head_size]
scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size]
#if 0 int max_ctx_blocks, float k_scale, float v_scale) {
scalar_t* __restrict__ qk_out, // [num_heads, num_seqs, max_ctx_blocks,block_size]
#endif
int max_ctx_blocks) {
constexpr int NWARPS = NUM_THREADS / WARP_SIZE; constexpr int NWARPS = NUM_THREADS / WARP_SIZE;
const int warpid = threadIdx.x / WARP_SIZE; const int warpid = threadIdx.x / WARP_SIZE;
const int laneid = threadIdx.x % WARP_SIZE; const int laneid = threadIdx.x % WARP_SIZE;
@ -222,12 +245,14 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
constexpr int x = 16 / sizeof(scalar_t); constexpr int x = 16 / sizeof(scalar_t);
constexpr int KHELOOP = HEAD_SIZE / x; constexpr int KHELOOP = HEAD_SIZE / x;
_B16x8 Klocal[KHELOOP]; _B16x8 Klocal[KHELOOP];
_B8x8 Klocalb8[KHELOOP];
constexpr int VHELOOP = constexpr int VHELOOP =
HEAD_SIZE / HEAD_SIZE /
WARP_SIZE; // v head_size dimension is distributed across lanes WARP_SIZE; // v head_size dimension is distributed across lanes
constexpr int VTLOOP = 8; // 16 separate 4xtokens across warp -> 16/2 constexpr int VTLOOP = 8; // 16 separate 4xtokens across warp -> 16/2
// 8xtokens // 8xtokens
_B16x8 Vlocal[VHELOOP][VTLOOP]; _B16x8 Vlocal[VHELOOP][VTLOOP];
_B8x8 Vlocalb8[VHELOOP][VTLOOP];
floatx4 dout[QHLOOP]; floatx4 dout[QHLOOP];
float qk_max[QHLOOP]; float qk_max[QHLOOP];
#pragma unroll #pragma unroll
@ -279,6 +304,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
(vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block; (vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block;
vphysical_blocks[b] = block_table[vblock_idx_ctx]; vphysical_blocks[b] = block_table[vblock_idx_ctx];
} }
// each 4 lanes fetch 8 helems, so warp fetches 8*16 = 128 helems // each 4 lanes fetch 8 helems, so warp fetches 8*16 = 128 helems
const scalar_t* q_ptr = const scalar_t* q_ptr =
q + seq_idx * q_stride + wg_start_head_idx * HEAD_SIZE; q + seq_idx * q_stride + wg_start_head_idx * HEAD_SIZE;
@ -298,17 +324,29 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
Qlocal[QHLOOP - 1].xy[1] = {0}; Qlocal[QHLOOP - 1].xy[1] = {0};
} }
const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride + const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride +
wg_start_kv_head_idx * kv_head_stride; wg_start_kv_head_idx * kv_head_stride;
const int physical_block_offset = const int physical_block_offset =
local_token_idx % BLOCK_SIZE; // since x=half8, physical_block_offset local_token_idx % BLOCK_SIZE; // since x=half8, physical_block_offset
// is already cast as _H8 // is already cast as _H8
if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) {
const _B16x8* k_ptrh8 = reinterpret_cast<const _B16x8*>(k_ptr); const _B16x8* k_ptrh8 = reinterpret_cast<const _B16x8*>(k_ptr);
#pragma unroll #pragma unroll
for (int d = 0; d < KHELOOP; d++) { for (int d = 0; d < KHELOOP; d++) {
Klocal[d] = k_ptrh8[d * BLOCK_SIZE + physical_block_offset]; Klocal[d] = k_ptrh8[d * BLOCK_SIZE + physical_block_offset];
}
} else {
constexpr int X = 16 / sizeof(cache_t);
const cache_t* k_ptr2 = k_ptr + physical_block_offset * X;
#pragma unroll
for (int d = 0; d < KHELOOP; d++) {
const int head_elem = d * 8;
const int offset1 = head_elem / X;
const int offset2 = head_elem % X;
const cache_t* k_ptr3 = k_ptr2 + offset1 * BLOCK_SIZE * X + offset2;
Klocalb8[d] = *reinterpret_cast<const _B8x8*>(k_ptr3);
}
} }
float alibi_slope[QHLOOP]; float alibi_slope[QHLOOP];
@ -322,28 +360,64 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
} }
} }
const scalar_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride; const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride;
const _B16x8* v_ptrh8 = reinterpret_cast<const _B16x8*>(v_ptr); if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) {
// iterate over each v block const _B16x8* v_ptrh8 = reinterpret_cast<const _B16x8*>(v_ptr);
// iterate over each v block
#pragma unroll #pragma unroll
for (int b = 0; b < VBLOCKS; b++) { for (int b = 0; b < VBLOCKS; b++) {
// int32 physical_block_number leads to overflow when multiplied with // int32 physical_block_number leads to overflow when multiplied with
// kv_block_stride // kv_block_stride
const int64_t vphysical_block_number = const int64_t vphysical_block_number =
static_cast<int64_t>(vphysical_blocks[b]); static_cast<int64_t>(vphysical_blocks[b]);
const _B16x8* v_ptrh8b = const _B16x8* v_ptrh8b =
v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8; v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8;
// iterate over each head elem (within head_size) // iterate over each head elem (within head_size)
#pragma unroll #pragma unroll
for (int h = 0; h < VHELOOP; h++) { for (int h = 0; h < VHELOOP; h++) {
const int head_size_elem = h * WARP_SIZE + laneid; const int head_size_elem = h * WARP_SIZE + laneid;
const _B16x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; const _B16x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8;
// iterate over all velems within block // iterate over all velems within block
#pragma unroll #pragma unroll
for (int d = 0; d < BLOCK_SIZE / 8; d++) { for (int d = 0; d < BLOCK_SIZE / 8; d++) {
Vlocal[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; Vlocal[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d];
}
} }
} }
} else {
const _B8x8* v_ptrh8 = reinterpret_cast<const _B8x8*>(v_ptr);
// iterate over each v block
#pragma unroll
for (int b = 0; b < VBLOCKS; b++) {
// int32 physical_block_number leads to overflow when multiplied with
// kv_block_stride
const int64_t vphysical_block_number =
static_cast<int64_t>(vphysical_blocks[b]);
const _B8x8* v_ptrh8b =
v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8;
// iterate over each head elem (within head_size)
#pragma unroll
for (int h = 0; h < VHELOOP; h++) {
const int head_size_elem = h * WARP_SIZE + laneid;
const _B8x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8;
// iterate over all velems within block
#pragma unroll
for (int d = 0; d < BLOCK_SIZE / 8; d++) {
// Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d];
const _B8x8 Vlocalb8 = v_ptrh8be[d];
Vlocal[h][b * BLOCK_SIZE / 8 + d] =
scaled_convert_b8x8<scalar_t, KV_DTYPE>(Vlocalb8, v_scale);
}
}
}
}
if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) {
#pragma unroll
for (int d = 0; d < KHELOOP; d++) {
Klocal[d] =
scaled_convert_b8x8<scalar_t, KV_DTYPE>(Klocalb8[d], k_scale);
}
} }
#pragma unroll #pragma unroll
@ -794,14 +868,16 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support #else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
template <typename scalar_t, int BLOCK_SIZE, int HEAD_SIZE, int NUM_THREADS, template <typename scalar_t, typename cache_t,
vllm::Fp8KVCacheDataType KV_DTYPE, int BLOCK_SIZE, int HEAD_SIZE,
int NUM_THREADS,
int GQA_RATIO> int GQA_RATIO>
__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x] // head_size/x, block_size, x]
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size] // head_size, block_size]
const int num_kv_heads, const float scale, const int num_kv_heads, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs] const int* __restrict__ context_lens, // [num_seqs]
@ -814,10 +890,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
// head_size] // head_size]
scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size]
#if 0 int max_ctx_blocks, float k_scale, float v_scale) {
scalar_t* __restrict__ qk_out, // [num_heads, num_seqs, max_ctx_blocks,block_size]
#endif
int max_ctx_blocks) {
UNREACHABLE_CODE UNREACHABLE_CODE
} }
@ -839,26 +912,24 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support #endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
#define LAUNCH_CUSTOM_ATTENTION(GQA_RATIO) \ #define LAUNCH_CUSTOM_ATTENTION(GQA_RATIO) \
paged_attention_ll4mi_QKV_kernel<T, BLOCK_SIZE, HEAD_SIZE, NTHR, GQA_RATIO> \ paged_attention_ll4mi_QKV_kernel<T, KVT, KV_DTYPE, BLOCK_SIZE, HEAD_SIZE, \
NTHR, GQA_RATIO> \
<<<grid, block, 0, stream>>>( \ <<<grid, block, 0, stream>>>( \
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks); exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \
k_scale, v_scale);
template <typename T, int BLOCK_SIZE, int HEAD_SIZE, int PARTITION_SIZE = 256> template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE,
int BLOCK_SIZE, int HEAD_SIZE, int PARTITION_SIZE = 512>
void paged_attention_custom_launcher( void paged_attention_custom_launcher(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, const int num_kv_heads, float scale, torch::Tensor& value_cache, const int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& context_lens, torch::Tensor& block_tables, torch::Tensor& context_lens,
int max_context_len, int max_context_len, const c10::optional<torch::Tensor>& alibi_slopes,
#if 0 float k_scale, float v_scale) {
torch::Tensor& qk_out,
torch::Tensor& softmax_out,
#endif
const c10::optional<torch::Tensor>& alibi_slopes) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
int num_heads = query.size(1); int num_heads = query.size(1);
int head_size = query.size(2); int head_size = query.size(2);
@ -878,14 +949,10 @@ void paged_attention_custom_launcher(
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr()); float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr()); T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr()); T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr()); KVT* key_cache_ptr = reinterpret_cast<KVT*>(key_cache.data_ptr());
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr()); KVT* value_cache_ptr = reinterpret_cast<KVT*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>(); int* block_tables_ptr = block_tables.data_ptr<int>();
int* context_lens_ptr = context_lens.data_ptr<int>(); int* context_lens_ptr = context_lens.data_ptr<int>();
#if 0
T* qk_out_ptr = reinterpret_cast<T*>(qk_out.data_ptr());
T* softmax_out_ptr = reinterpret_cast<T*>(softmax_out.data_ptr());
#endif
const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE); const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE);
const int max_num_partitions = const int max_num_partitions =
@ -972,32 +1039,32 @@ void paged_attention_custom_launcher(
} }
} }
#define CALL_CUSTOM_LAUNCHER(T, BLK_SIZE, HEAD_SIZE) \ #define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \
paged_attention_custom_launcher<T, BLK_SIZE, HEAD_SIZE>( \ paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, context_lens, max_context_len, \ num_kv_heads, scale, block_tables, context_lens, max_context_len, \
alibi_slopes); alibi_slopes, k_scale, v_scale);
#define CALL_CUSTOM_LAUNCHER_BLK(T, HEAD_SIZE) \ #define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \
switch (block_size) { \ switch (block_size) { \
case 16: \ case 16: \
CALL_CUSTOM_LAUNCHER(T, 16, HEAD_SIZE); \ CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, 16, HEAD_SIZE); \
break; \ break; \
case 32: \ case 32: \
CALL_CUSTOM_LAUNCHER(T, 32, HEAD_SIZE); \ CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, 32, HEAD_SIZE); \
break; \ break; \
default: \ default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \ break; \
} }
#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T) \ #define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T, KVT, KV_DTYPE) \
switch (head_size) { \ switch (head_size) { \
case 64: \ case 64: \
CALL_CUSTOM_LAUNCHER_BLK(T, 64); \ CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 64); \
break; \ break; \
case 128: \ case 128: \
CALL_CUSTOM_LAUNCHER_BLK(T, 128); \ CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 128); \
break; \ break; \
default: \ default: \
TORCH_CHECK(false, "Unsupported head size: ", head_size); \ TORCH_CHECK(false, "Unsupported head size: ", head_size); \
@ -1020,15 +1087,30 @@ void paged_attention(
torch::Tensor& context_lens, // [num_seqs] torch::Tensor& context_lens, // [num_seqs]
int64_t block_size, int64_t max_context_len, int64_t block_size, int64_t max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype) { const std::string& kv_cache_dtype, double k_scale, double v_scale) {
assert(kv_cache_dtype == "auto");
const int head_size = query.size(2); const int head_size = query.size(2);
if (query.dtype() == at::ScalarType::Half) { if (kv_cache_dtype == "auto") {
CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16); if (query.dtype() == at::ScalarType::Half) {
} else if (query.dtype() == at::ScalarType::BFloat16) { CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, _Float16,
CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16); vllm::Fp8KVCacheDataType::kAuto);
} else if (query.dtype() == at::ScalarType::BFloat16) {
CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, __hip_bfloat16,
vllm::Fp8KVCacheDataType::kAuto);
} else {
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
}
} else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") {
if (query.dtype() == at::ScalarType::Half) {
CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, uint8_t,
vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (query.dtype() == at::ScalarType::BFloat16) {
CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, uint8_t,
vllm::Fp8KVCacheDataType::kFp8E4M3);
} else {
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
}
} else { } else {
TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); TORCH_CHECK(false, "Unsupported KV cache dtype: ", kv_cache_dtype);
} }
} }

View File

@ -10,4 +10,5 @@ void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums,
torch::Tensor& context_lens, int64_t block_size, torch::Tensor& context_lens, int64_t block_size,
int64_t max_context_len, int64_t max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype); const std::string& kv_cache_dtype, double k_scale,
double v_scale);

View File

@ -26,7 +26,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
" Tensor context_lens, int block_size," " Tensor context_lens, int block_size,"
" int max_context_len," " int max_context_len,"
" Tensor? alibi_slopes," " Tensor? alibi_slopes,"
" str kv_cache_dtype) -> ()"); " str kv_cache_dtype,"
" float k_scale, float v_scale) -> ()");
rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention); rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention);
} }

View File

@ -31,8 +31,7 @@ NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
# FlashAttention forward only supports head dimension at most 128 # FlashAttention forward only supports head dimension at most 128
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62 # https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256 HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256]
] if not is_hip() else [64, 80, 96, 112, 128]
BLOCK_SIZES = [16, 32] BLOCK_SIZES = [16, 32]
USE_ALIBI = [False, True] USE_ALIBI = [False, True]
@ -114,7 +113,8 @@ def ref_single_query_cached_kv_attention(
output[i].copy_(out, non_blocking=True) output[i].copy_(out, non_blocking=True)
@pytest.mark.parametrize("version", ["v1", "v2"]) @pytest.mark.parametrize(
"version", ["v1", "v2"] if not is_hip() else ["v1", "v2", "rocm"])
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("head_size", HEAD_SIZES)
@ -137,7 +137,8 @@ def test_paged_attention(
seed: int, seed: int,
device: str, device: str,
) -> None: ) -> None:
if kv_cache_dtype == "fp8" and head_size % 16: if ((kv_cache_dtype == "fp8" and head_size % 16)
or (version == "rocm" and head_size not in (64, 128))):
pytest.skip() pytest.skip()
seed_everything(seed) seed_everything(seed)
@ -206,7 +207,7 @@ def test_paged_attention(
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
cond=(head_size == HEAD_SIZES[0])) cond=(head_size == HEAD_SIZES[0]))
elif version == "v2": elif version in ("v2", "rocm"):
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
assert PARTITION_SIZE % block_size == 0 assert PARTITION_SIZE % block_size == 0
num_seqs, num_heads, head_size = output.shape num_seqs, num_heads, head_size = output.shape
@ -219,32 +220,61 @@ def test_paged_attention(
dtype=torch.float32, dtype=torch.float32,
) )
max_logits = torch.empty_like(exp_sums) max_logits = torch.empty_like(exp_sums)
ops.paged_attention_v2( if version == "v2":
output, ops.paged_attention_v2(
exp_sums, output,
max_logits, exp_sums,
tmp_output, max_logits,
query, tmp_output,
key_cache, query,
value_cache, key_cache,
num_kv_heads, value_cache,
scale, num_kv_heads,
block_tables, scale,
seq_lens, block_tables,
block_size, seq_lens,
max_seq_len, block_size,
alibi_slopes, max_seq_len,
kv_cache_dtype, alibi_slopes,
k_scale, kv_cache_dtype,
v_scale, k_scale,
) v_scale,
)
opcheck(torch.ops._C.paged_attention_v2, opcheck(torch.ops._C.paged_attention_v2,
(output, exp_sums, max_logits, tmp_output, query, key_cache, (output, exp_sums, max_logits, tmp_output, query,
value_cache, num_kv_heads, scale, block_tables, seq_lens, key_cache, value_cache, num_kv_heads, scale, block_tables,
block_size, max_seq_len, alibi_slopes, kv_cache_dtype, seq_lens, block_size, max_seq_len, alibi_slopes,
k_scale, v_scale, 0, 0, 0, 64, 0), kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
cond=(head_size == HEAD_SIZES[0])) cond=(head_size == HEAD_SIZES[0]))
else:
ops.paged_attention_rocm(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
)
opcheck(torch.ops._rocm_C.paged_attention,
(output, exp_sums, max_logits, tmp_output, query,
key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale),
cond=(head_size == HEAD_SIZES[0]))
else: else:
raise AssertionError(f"Unknown version: {version}") raise AssertionError(f"Unknown version: {version}")
@ -328,162 +358,6 @@ def ref_multi_query_kv_attention(
return torch.cat(ref_outputs, dim=0) return torch.cat(ref_outputs, dim=0)
@pytest.mark.parametrize("version", ["rocm"])
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", [64, 128]) # only test 64 128
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("kv_cache_dtype", ["auto"])
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.skipif(not is_hip(), reason="only for rocm")
def test_paged_attention_rocm(
kv_cache_factory,
version: str,
num_seqs: int,
num_heads: Tuple[int, int],
head_size: int,
use_alibi: bool,
block_size: int,
dtype: torch.dtype,
kv_cache_dtype: str,
seed: int,
device: str,
) -> None:
seed_everything(seed)
torch.set_default_device(device)
scale = float(1.0 / (head_size**0.5))
num_query_heads, num_kv_heads = num_heads
query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
query.uniform_(-scale, scale)
assert num_query_heads % num_kv_heads == 0
num_queries_per_kv = num_query_heads // num_kv_heads
alibi_slopes = None
if use_alibi:
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
context_lens[-1] = MAX_SEQ_LEN
#context_lens = [8192 for _ in range(num_seqs)]
max_context_len = max(context_lens)
context_lens = torch.tensor(context_lens, dtype=torch.int)
#print('>>> ctx lens', context_lens)
# Create the block tables.
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
block_tables = []
for _ in range(num_seqs):
block_table = [
random.randint(0, NUM_BLOCKS - 1)
for _ in range(max_num_blocks_per_seq)
]
block_tables.append(block_table)
block_tables = torch.tensor(block_tables, dtype=torch.int)
# Create the KV caches.
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
num_kv_heads, head_size,
kv_cache_dtype, dtype, seed,
device)
key_cache, value_cache = key_caches[0], value_caches[0]
# TODO(charlifu) enable fp8 kv cache
# Using default kv_scale
# kv_scale = 1.0
# Call the paged attention kernel.
output = torch.empty_like(query)
PARTITION_SIZE_ROCM = 256
num_partitions = ((max_context_len + PARTITION_SIZE_ROCM - 1) //
PARTITION_SIZE_ROCM)
assert PARTITION_SIZE_ROCM % block_size == 0
num_seqs, num_heads, head_size = output.shape
tmp_output = torch.empty(
size=(num_seqs, num_heads, num_partitions, head_size),
dtype=output.dtype,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, num_partitions),
dtype=torch.float32,
)
max_logits = torch.empty_like(exp_sums)
if version == "rocm":
ops.paged_attention_rocm(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
alibi_slopes,
kv_cache_dtype,
)
else:
raise AssertionError(f"Unknown version: {version}")
# Run the reference implementation.
if kv_cache_dtype == "fp8":
# Convert cache data back to dtype.
x = 16 // torch.tensor([], dtype=dtype).element_size()
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
block_size, x)
dequantized_key_cache = torch.empty(size=key_cache_shape,
dtype=dtype,
device=device)
ops.convert_fp8(key_cache, dequantized_key_cache)
key_cache = dequantized_key_cache
value_cache_shape = value_cache.shape
dequantized_value_cache = torch.empty(size=value_cache_shape,
dtype=dtype,
device=device)
ops.convert_fp8(value_cache, dequantized_value_cache)
value_cache = dequantized_value_cache
ref_output = torch.empty_like(query)
ref_single_query_cached_kv_attention(
ref_output,
query,
num_queries_per_kv,
key_cache,
value_cache,
block_tables,
context_lens,
scale,
alibi_slopes,
)
# NOTE(woosuk): Due to the kernel-level differences in the two
# implementations, there is a small numerical difference in the two
# outputs. Thus, we use a relaxed tolerance for the test.
atol = get_default_atol(output) if is_hip() else 1e-3
rtol = get_default_rtol(output) if is_hip() else 1e-5
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
# so we use a relaxed tolerance for the test.
atol, rtol = 1e-4, 1e-5
if dtype == torch.bfloat16:
atol, rtol = 2e-4, 1e-5
if use_alibi:
if dtype == torch.half:
atol, rtol = 5e-4, 1e-5
if dtype == torch.bfloat16:
atol, rtol = 1e-3, 1e-5
if kv_cache_dtype == "fp8":
atol, rtol = 1e-2, 1e-5
assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)
# TODO(woosuk): Add tests for USE_ALIBI=True. # TODO(woosuk): Add tests for USE_ALIBI=True.
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) @pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_heads", NUM_HEADS)
@ -491,7 +365,8 @@ def test_paged_attention_rocm(
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.skipif(is_hip(), reason="skip for rocm") @pytest.mark.skipif(is_hip(),
reason="Xformers backend is not supported on ROCm.")
@torch.inference_mode() @torch.inference_mode()
def test_multi_query_kv_attention( def test_multi_query_kv_attention(
num_seqs: int, num_seqs: int,

View File

@ -146,12 +146,14 @@ def paged_attention_rocm(
max_seq_len: int, max_seq_len: int,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str, kv_cache_dtype: str,
k_scale: float,
v_scale: float,
) -> None: ) -> None:
torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query, torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
key_cache, value_cache, num_kv_heads, key_cache, value_cache, num_kv_heads,
scale, block_tables, seq_lens, scale, block_tables, seq_lens,
block_size, max_seq_len, alibi_slopes, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype) kv_cache_dtype, k_scale, v_scale)
# pos encoding ops # pos encoding ops

View File

@ -17,8 +17,8 @@ from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
_PARTITION_SIZE = 256 _PARTITION_SIZE_ROCM = 512
ON_NAVI = "gfx1" in torch.cuda.get_device_properties("cuda").gcnArchName _ON_NAVI = "gfx1" in torch.cuda.get_device_properties("cuda").gcnArchName
class ROCmFlashAttentionBackend(AttentionBackend): class ROCmFlashAttentionBackend(AttentionBackend):
@ -489,14 +489,15 @@ class ROCmFlashAttentionImpl(AttentionImpl):
num_seqs, num_heads, head_size = decode_query.shape num_seqs, num_heads, head_size = decode_query.shape
block_size = value_cache.shape[3] block_size = value_cache.shape[3]
gqa_ratio = num_heads // self.num_kv_heads gqa_ratio = num_heads // self.num_kv_heads
use_custom = use_rocm_custom_paged_attention( use_custom = _use_rocm_custom_paged_attention(
decode_query.dtype, head_size, block_size, self.kv_cache_dtype, decode_query.dtype, head_size, block_size, gqa_ratio,
gqa_ratio, decode_meta.max_decode_seq_len) decode_meta.max_decode_seq_len)
if use_custom: if use_custom:
max_seq_len = decode_meta.max_decode_seq_len max_seq_len = decode_meta.max_decode_seq_len
max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) // max_num_partitions = (
_PARTITION_SIZE) (max_seq_len + _PARTITION_SIZE_ROCM - 1) //
assert _PARTITION_SIZE % block_size == 0 _PARTITION_SIZE_ROCM)
assert _PARTITION_SIZE_ROCM % block_size == 0
tmp_output = torch.empty( tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size), size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=output.dtype, dtype=output.dtype,
@ -524,6 +525,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
max_seq_len, max_seq_len,
self.alibi_slopes, self.alibi_slopes,
self.kv_cache_dtype, self.kv_cache_dtype,
k_scale,
v_scale,
) )
else: else:
output[num_prefill_tokens:] = PagedAttention.forward_decode( output[num_prefill_tokens:] = PagedAttention.forward_decode(
@ -580,12 +583,11 @@ def _sdpa_attention(
return output return output
def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
block_size: int, kv_cache_dtype: str, block_size: int, gqa_ratio: int,
gqa_ratio: int, max_seq_len: int) -> bool: max_seq_len: int) -> bool:
# rocm custom page attention not support on navi (gfx1*) # rocm custom page attention not support on navi (gfx1*)
return (not ON_NAVI and (qtype == torch.half or qtype == torch.bfloat16) return (not _ON_NAVI and (qtype == torch.half or qtype == torch.bfloat16)
and (head_size == 64 or head_size == 128) and (head_size == 64 or head_size == 128)
and (block_size == 16 or block_size == 32) and (block_size == 16 or block_size == 32)
and kv_cache_dtype == "auto"
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768) and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768)