mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:54:56 +08:00
[Kernel][Amd] Add fp8 kv cache support for rocm custom paged attention (#8577)
This commit is contained in:
parent
76515f303b
commit
9cc373f390
@ -18,8 +18,11 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <hip/hip_bf16.h>
|
||||
#include "cuda_compat.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include "../attention/dtype_fp8.cuh"
|
||||
#include "../quantization/fp8/amd/quant_utils.cuh"
|
||||
|
||||
#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx940__) || \
|
||||
defined(__gfx941__) || defined(__gfx942__))
|
||||
@ -38,7 +41,6 @@
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
|
||||
#define WARP_SIZE 64
|
||||
|
||||
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
|
||||
|
||||
@ -60,6 +62,8 @@ typedef struct _B16x8 {
|
||||
_B16x4 xy[2];
|
||||
} _B16x8;
|
||||
|
||||
using _B8x8 = uint2;
|
||||
|
||||
////// Non temporal load stores ///////
|
||||
|
||||
template <typename T>
|
||||
@ -168,18 +172,40 @@ __device__ __forceinline__ _B16x4 addx4(const _B16x4& inp1,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, vllm::Fp8KVCacheDataType KV_DTYPE>
|
||||
__device__ __forceinline__ _B16x8 scaled_convert_b8x8(const _B8x8 input,
|
||||
const float scale) {
|
||||
union alignas(16) {
|
||||
uint4 u4;
|
||||
_B16x8 u16x8;
|
||||
vllm::bf16_8_t b16x8;
|
||||
} tmp;
|
||||
if constexpr (std::is_same<T, _Float16>::value) {
|
||||
tmp.u4 = vllm::fp8::scaled_convert<uint4, _B8x8, KV_DTYPE>(input, scale);
|
||||
return tmp.u16x8;
|
||||
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
|
||||
tmp.b16x8 = vllm::fp8::scaled_convert<vllm::bf16_8_t, _B8x8, KV_DTYPE>(
|
||||
input, scale);
|
||||
return tmp.u16x8;
|
||||
} else {
|
||||
static_assert(false, "unsupported 16b dtype");
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////
|
||||
|
||||
// grid (num_seqs, num_partitions,num_heads/gqa_ratio)
|
||||
// block (partition size)
|
||||
template <typename scalar_t, int BLOCK_SIZE, int HEAD_SIZE, int NUM_THREADS,
|
||||
template <typename scalar_t, typename cache_t,
|
||||
vllm::Fp8KVCacheDataType KV_DTYPE, int BLOCK_SIZE, int HEAD_SIZE,
|
||||
int NUM_THREADS,
|
||||
int GQA_RATIO>
|
||||
__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
|
||||
// head_size/x, block_size, x]
|
||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
||||
// head_size, block_size]
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
|
||||
// head_size/x, block_size, x]
|
||||
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
||||
// head_size, block_size]
|
||||
const int num_kv_heads, const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
@ -192,10 +218,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
|
||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
|
||||
// head_size]
|
||||
scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size]
|
||||
#if 0
|
||||
scalar_t* __restrict__ qk_out, // [num_heads, num_seqs, max_ctx_blocks,block_size]
|
||||
#endif
|
||||
int max_ctx_blocks) {
|
||||
int max_ctx_blocks, float k_scale, float v_scale) {
|
||||
constexpr int NWARPS = NUM_THREADS / WARP_SIZE;
|
||||
const int warpid = threadIdx.x / WARP_SIZE;
|
||||
const int laneid = threadIdx.x % WARP_SIZE;
|
||||
@ -222,12 +245,14 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
|
||||
constexpr int x = 16 / sizeof(scalar_t);
|
||||
constexpr int KHELOOP = HEAD_SIZE / x;
|
||||
_B16x8 Klocal[KHELOOP];
|
||||
_B8x8 Klocalb8[KHELOOP];
|
||||
constexpr int VHELOOP =
|
||||
HEAD_SIZE /
|
||||
WARP_SIZE; // v head_size dimension is distributed across lanes
|
||||
constexpr int VTLOOP = 8; // 16 separate 4xtokens across warp -> 16/2
|
||||
// 8xtokens
|
||||
_B16x8 Vlocal[VHELOOP][VTLOOP];
|
||||
_B8x8 Vlocalb8[VHELOOP][VTLOOP];
|
||||
floatx4 dout[QHLOOP];
|
||||
float qk_max[QHLOOP];
|
||||
#pragma unroll
|
||||
@ -279,6 +304,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
|
||||
(vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block;
|
||||
vphysical_blocks[b] = block_table[vblock_idx_ctx];
|
||||
}
|
||||
|
||||
// each 4 lanes fetch 8 helems, so warp fetches 8*16 = 128 helems
|
||||
const scalar_t* q_ptr =
|
||||
q + seq_idx * q_stride + wg_start_head_idx * HEAD_SIZE;
|
||||
@ -298,17 +324,29 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
|
||||
Qlocal[QHLOOP - 1].xy[1] = {0};
|
||||
}
|
||||
|
||||
const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride +
|
||||
wg_start_kv_head_idx * kv_head_stride;
|
||||
const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride +
|
||||
wg_start_kv_head_idx * kv_head_stride;
|
||||
|
||||
const int physical_block_offset =
|
||||
local_token_idx % BLOCK_SIZE; // since x=half8, physical_block_offset
|
||||
// is already cast as _H8
|
||||
|
||||
const _B16x8* k_ptrh8 = reinterpret_cast<const _B16x8*>(k_ptr);
|
||||
if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) {
|
||||
const _B16x8* k_ptrh8 = reinterpret_cast<const _B16x8*>(k_ptr);
|
||||
#pragma unroll
|
||||
for (int d = 0; d < KHELOOP; d++) {
|
||||
Klocal[d] = k_ptrh8[d * BLOCK_SIZE + physical_block_offset];
|
||||
for (int d = 0; d < KHELOOP; d++) {
|
||||
Klocal[d] = k_ptrh8[d * BLOCK_SIZE + physical_block_offset];
|
||||
}
|
||||
} else {
|
||||
constexpr int X = 16 / sizeof(cache_t);
|
||||
const cache_t* k_ptr2 = k_ptr + physical_block_offset * X;
|
||||
#pragma unroll
|
||||
for (int d = 0; d < KHELOOP; d++) {
|
||||
const int head_elem = d * 8;
|
||||
const int offset1 = head_elem / X;
|
||||
const int offset2 = head_elem % X;
|
||||
const cache_t* k_ptr3 = k_ptr2 + offset1 * BLOCK_SIZE * X + offset2;
|
||||
Klocalb8[d] = *reinterpret_cast<const _B8x8*>(k_ptr3);
|
||||
}
|
||||
}
|
||||
|
||||
float alibi_slope[QHLOOP];
|
||||
@ -322,28 +360,64 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
const scalar_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride;
|
||||
const _B16x8* v_ptrh8 = reinterpret_cast<const _B16x8*>(v_ptr);
|
||||
// iterate over each v block
|
||||
const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride;
|
||||
if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) {
|
||||
const _B16x8* v_ptrh8 = reinterpret_cast<const _B16x8*>(v_ptr);
|
||||
// iterate over each v block
|
||||
#pragma unroll
|
||||
for (int b = 0; b < VBLOCKS; b++) {
|
||||
// int32 physical_block_number leads to overflow when multiplied with
|
||||
// kv_block_stride
|
||||
const int64_t vphysical_block_number =
|
||||
static_cast<int64_t>(vphysical_blocks[b]);
|
||||
const _B16x8* v_ptrh8b =
|
||||
v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8;
|
||||
// iterate over each head elem (within head_size)
|
||||
for (int b = 0; b < VBLOCKS; b++) {
|
||||
// int32 physical_block_number leads to overflow when multiplied with
|
||||
// kv_block_stride
|
||||
const int64_t vphysical_block_number =
|
||||
static_cast<int64_t>(vphysical_blocks[b]);
|
||||
const _B16x8* v_ptrh8b =
|
||||
v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8;
|
||||
// iterate over each head elem (within head_size)
|
||||
#pragma unroll
|
||||
for (int h = 0; h < VHELOOP; h++) {
|
||||
const int head_size_elem = h * WARP_SIZE + laneid;
|
||||
const _B16x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8;
|
||||
// iterate over all velems within block
|
||||
for (int h = 0; h < VHELOOP; h++) {
|
||||
const int head_size_elem = h * WARP_SIZE + laneid;
|
||||
const _B16x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8;
|
||||
// iterate over all velems within block
|
||||
#pragma unroll
|
||||
for (int d = 0; d < BLOCK_SIZE / 8; d++) {
|
||||
Vlocal[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d];
|
||||
for (int d = 0; d < BLOCK_SIZE / 8; d++) {
|
||||
Vlocal[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d];
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const _B8x8* v_ptrh8 = reinterpret_cast<const _B8x8*>(v_ptr);
|
||||
// iterate over each v block
|
||||
#pragma unroll
|
||||
for (int b = 0; b < VBLOCKS; b++) {
|
||||
// int32 physical_block_number leads to overflow when multiplied with
|
||||
// kv_block_stride
|
||||
const int64_t vphysical_block_number =
|
||||
static_cast<int64_t>(vphysical_blocks[b]);
|
||||
const _B8x8* v_ptrh8b =
|
||||
v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8;
|
||||
// iterate over each head elem (within head_size)
|
||||
#pragma unroll
|
||||
for (int h = 0; h < VHELOOP; h++) {
|
||||
const int head_size_elem = h * WARP_SIZE + laneid;
|
||||
const _B8x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8;
|
||||
// iterate over all velems within block
|
||||
#pragma unroll
|
||||
for (int d = 0; d < BLOCK_SIZE / 8; d++) {
|
||||
// Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d];
|
||||
const _B8x8 Vlocalb8 = v_ptrh8be[d];
|
||||
Vlocal[h][b * BLOCK_SIZE / 8 + d] =
|
||||
scaled_convert_b8x8<scalar_t, KV_DTYPE>(Vlocalb8, v_scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) {
|
||||
#pragma unroll
|
||||
for (int d = 0; d < KHELOOP; d++) {
|
||||
Klocal[d] =
|
||||
scaled_convert_b8x8<scalar_t, KV_DTYPE>(Klocalb8[d], k_scale);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
@ -794,14 +868,16 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
|
||||
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
|
||||
|
||||
template <typename scalar_t, int BLOCK_SIZE, int HEAD_SIZE, int NUM_THREADS,
|
||||
template <typename scalar_t, typename cache_t,
|
||||
vllm::Fp8KVCacheDataType KV_DTYPE, int BLOCK_SIZE, int HEAD_SIZE,
|
||||
int NUM_THREADS,
|
||||
int GQA_RATIO>
|
||||
__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
|
||||
// head_size/x, block_size, x]
|
||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
||||
// head_size, block_size]
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
|
||||
// head_size/x, block_size, x]
|
||||
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
||||
// head_size, block_size]
|
||||
const int num_kv_heads, const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
@ -814,10 +890,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
|
||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
|
||||
// head_size]
|
||||
scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size]
|
||||
#if 0
|
||||
scalar_t* __restrict__ qk_out, // [num_heads, num_seqs, max_ctx_blocks,block_size]
|
||||
#endif
|
||||
int max_ctx_blocks) {
|
||||
int max_ctx_blocks, float k_scale, float v_scale) {
|
||||
UNREACHABLE_CODE
|
||||
}
|
||||
|
||||
@ -839,26 +912,24 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
|
||||
|
||||
#define LAUNCH_CUSTOM_ATTENTION(GQA_RATIO) \
|
||||
paged_attention_ll4mi_QKV_kernel<T, BLOCK_SIZE, HEAD_SIZE, NTHR, GQA_RATIO> \
|
||||
paged_attention_ll4mi_QKV_kernel<T, KVT, KV_DTYPE, BLOCK_SIZE, HEAD_SIZE, \
|
||||
NTHR, GQA_RATIO> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
|
||||
block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \
|
||||
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
|
||||
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks);
|
||||
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \
|
||||
k_scale, v_scale);
|
||||
|
||||
template <typename T, int BLOCK_SIZE, int HEAD_SIZE, int PARTITION_SIZE = 256>
|
||||
template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE,
|
||||
int BLOCK_SIZE, int HEAD_SIZE, int PARTITION_SIZE = 512>
|
||||
void paged_attention_custom_launcher(
|
||||
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
||||
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, const int num_kv_heads, float scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& context_lens,
|
||||
int max_context_len,
|
||||
#if 0
|
||||
torch::Tensor& qk_out,
|
||||
torch::Tensor& softmax_out,
|
||||
#endif
|
||||
const c10::optional<torch::Tensor>& alibi_slopes) {
|
||||
|
||||
int max_context_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
float k_scale, float v_scale) {
|
||||
int num_seqs = query.size(0);
|
||||
int num_heads = query.size(1);
|
||||
int head_size = query.size(2);
|
||||
@ -878,14 +949,10 @@ void paged_attention_custom_launcher(
|
||||
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
|
||||
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
|
||||
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
||||
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
|
||||
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
|
||||
KVT* key_cache_ptr = reinterpret_cast<KVT*>(key_cache.data_ptr());
|
||||
KVT* value_cache_ptr = reinterpret_cast<KVT*>(value_cache.data_ptr());
|
||||
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||
int* context_lens_ptr = context_lens.data_ptr<int>();
|
||||
#if 0
|
||||
T* qk_out_ptr = reinterpret_cast<T*>(qk_out.data_ptr());
|
||||
T* softmax_out_ptr = reinterpret_cast<T*>(softmax_out.data_ptr());
|
||||
#endif
|
||||
|
||||
const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE);
|
||||
const int max_num_partitions =
|
||||
@ -972,32 +1039,32 @@ void paged_attention_custom_launcher(
|
||||
}
|
||||
}
|
||||
|
||||
#define CALL_CUSTOM_LAUNCHER(T, BLK_SIZE, HEAD_SIZE) \
|
||||
paged_attention_custom_launcher<T, BLK_SIZE, HEAD_SIZE>( \
|
||||
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
||||
num_kv_heads, scale, block_tables, context_lens, max_context_len, \
|
||||
alibi_slopes);
|
||||
#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \
|
||||
paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE>( \
|
||||
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
||||
num_kv_heads, scale, block_tables, context_lens, max_context_len, \
|
||||
alibi_slopes, k_scale, v_scale);
|
||||
|
||||
#define CALL_CUSTOM_LAUNCHER_BLK(T, HEAD_SIZE) \
|
||||
#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \
|
||||
switch (block_size) { \
|
||||
case 16: \
|
||||
CALL_CUSTOM_LAUNCHER(T, 16, HEAD_SIZE); \
|
||||
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, 16, HEAD_SIZE); \
|
||||
break; \
|
||||
case 32: \
|
||||
CALL_CUSTOM_LAUNCHER(T, 32, HEAD_SIZE); \
|
||||
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, 32, HEAD_SIZE); \
|
||||
break; \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||
break; \
|
||||
}
|
||||
|
||||
#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T) \
|
||||
#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T, KVT, KV_DTYPE) \
|
||||
switch (head_size) { \
|
||||
case 64: \
|
||||
CALL_CUSTOM_LAUNCHER_BLK(T, 64); \
|
||||
CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 64); \
|
||||
break; \
|
||||
case 128: \
|
||||
CALL_CUSTOM_LAUNCHER_BLK(T, 128); \
|
||||
CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 128); \
|
||||
break; \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported head size: ", head_size); \
|
||||
@ -1020,19 +1087,34 @@ void paged_attention(
|
||||
torch::Tensor& context_lens, // [num_seqs]
|
||||
int64_t block_size, int64_t max_context_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype) {
|
||||
assert(kv_cache_dtype == "auto");
|
||||
const std::string& kv_cache_dtype, double k_scale, double v_scale) {
|
||||
const int head_size = query.size(2);
|
||||
if (query.dtype() == at::ScalarType::Half) {
|
||||
CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16);
|
||||
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16);
|
||||
if (kv_cache_dtype == "auto") {
|
||||
if (query.dtype() == at::ScalarType::Half) {
|
||||
CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, _Float16,
|
||||
vllm::Fp8KVCacheDataType::kAuto);
|
||||
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, __hip_bfloat16,
|
||||
vllm::Fp8KVCacheDataType::kAuto);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||
}
|
||||
} else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") {
|
||||
if (query.dtype() == at::ScalarType::Half) {
|
||||
CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, uint8_t,
|
||||
vllm::Fp8KVCacheDataType::kFp8E4M3);
|
||||
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, uint8_t,
|
||||
vllm::Fp8KVCacheDataType::kFp8E4M3);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||
TORCH_CHECK(false, "Unsupported KV cache dtype: ", kv_cache_dtype);
|
||||
}
|
||||
}
|
||||
|
||||
#undef WARP_SIZE
|
||||
#undef MAX
|
||||
#undef MIN
|
||||
#undef DIVIDE_ROUND_UP
|
||||
#undef DIVIDE_ROUND_UP
|
||||
@ -10,4 +10,5 @@ void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums,
|
||||
torch::Tensor& context_lens, int64_t block_size,
|
||||
int64_t max_context_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype);
|
||||
const std::string& kv_cache_dtype, double k_scale,
|
||||
double v_scale);
|
||||
|
||||
@ -26,7 +26,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
|
||||
" Tensor context_lens, int block_size,"
|
||||
" int max_context_len,"
|
||||
" Tensor? alibi_slopes,"
|
||||
" str kv_cache_dtype) -> ()");
|
||||
" str kv_cache_dtype,"
|
||||
" float k_scale, float v_scale) -> ()");
|
||||
rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention);
|
||||
}
|
||||
|
||||
|
||||
@ -31,8 +31,7 @@ NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
|
||||
|
||||
# FlashAttention forward only supports head dimension at most 128
|
||||
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
|
||||
HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256
|
||||
] if not is_hip() else [64, 80, 96, 112, 128]
|
||||
HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256]
|
||||
|
||||
BLOCK_SIZES = [16, 32]
|
||||
USE_ALIBI = [False, True]
|
||||
@ -114,7 +113,8 @@ def ref_single_query_cached_kv_attention(
|
||||
output[i].copy_(out, non_blocking=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("version", ["v1", "v2"])
|
||||
@pytest.mark.parametrize(
|
||||
"version", ["v1", "v2"] if not is_hip() else ["v1", "v2", "rocm"])
|
||||
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@ -137,7 +137,8 @@ def test_paged_attention(
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
if kv_cache_dtype == "fp8" and head_size % 16:
|
||||
if ((kv_cache_dtype == "fp8" and head_size % 16)
|
||||
or (version == "rocm" and head_size not in (64, 128))):
|
||||
pytest.skip()
|
||||
|
||||
seed_everything(seed)
|
||||
@ -206,7 +207,7 @@ def test_paged_attention(
|
||||
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
|
||||
cond=(head_size == HEAD_SIZES[0]))
|
||||
|
||||
elif version == "v2":
|
||||
elif version in ("v2", "rocm"):
|
||||
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
|
||||
assert PARTITION_SIZE % block_size == 0
|
||||
num_seqs, num_heads, head_size = output.shape
|
||||
@ -219,32 +220,61 @@ def test_paged_attention(
|
||||
dtype=torch.float32,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
ops.paged_attention_v2(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
if version == "v2":
|
||||
ops.paged_attention_v2(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
opcheck(torch.ops._C.paged_attention_v2,
|
||||
(output, exp_sums, max_logits, tmp_output, query, key_cache,
|
||||
value_cache, num_kv_heads, scale, block_tables, seq_lens,
|
||||
block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
|
||||
k_scale, v_scale, 0, 0, 0, 64, 0),
|
||||
cond=(head_size == HEAD_SIZES[0]))
|
||||
opcheck(torch.ops._C.paged_attention_v2,
|
||||
(output, exp_sums, max_logits, tmp_output, query,
|
||||
key_cache, value_cache, num_kv_heads, scale, block_tables,
|
||||
seq_lens, block_size, max_seq_len, alibi_slopes,
|
||||
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
|
||||
cond=(head_size == HEAD_SIZES[0]))
|
||||
|
||||
else:
|
||||
ops.paged_attention_rocm(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
opcheck(torch.ops._rocm_C.paged_attention,
|
||||
(output, exp_sums, max_logits, tmp_output, query,
|
||||
key_cache, value_cache, num_kv_heads, scale, block_tables,
|
||||
seq_lens, block_size, max_seq_len, alibi_slopes,
|
||||
kv_cache_dtype, k_scale, v_scale),
|
||||
cond=(head_size == HEAD_SIZES[0]))
|
||||
|
||||
else:
|
||||
raise AssertionError(f"Unknown version: {version}")
|
||||
@ -328,162 +358,6 @@ def ref_multi_query_kv_attention(
|
||||
return torch.cat(ref_outputs, dim=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("version", ["rocm"])
|
||||
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", [64, 128]) # only test 64 128
|
||||
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("kv_cache_dtype", ["auto"])
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.skipif(not is_hip(), reason="only for rocm")
|
||||
def test_paged_attention_rocm(
|
||||
kv_cache_factory,
|
||||
version: str,
|
||||
num_seqs: int,
|
||||
num_heads: Tuple[int, int],
|
||||
head_size: int,
|
||||
use_alibi: bool,
|
||||
block_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: str,
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
num_query_heads, num_kv_heads = num_heads
|
||||
query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
|
||||
query.uniform_(-scale, scale)
|
||||
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
num_queries_per_kv = num_query_heads // num_kv_heads
|
||||
alibi_slopes = None
|
||||
if use_alibi:
|
||||
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
|
||||
|
||||
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
|
||||
context_lens[-1] = MAX_SEQ_LEN
|
||||
#context_lens = [8192 for _ in range(num_seqs)]
|
||||
max_context_len = max(context_lens)
|
||||
context_lens = torch.tensor(context_lens, dtype=torch.int)
|
||||
#print('>>> ctx lens', context_lens)
|
||||
|
||||
# Create the block tables.
|
||||
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
|
||||
block_tables = []
|
||||
for _ in range(num_seqs):
|
||||
block_table = [
|
||||
random.randint(0, NUM_BLOCKS - 1)
|
||||
for _ in range(max_num_blocks_per_seq)
|
||||
]
|
||||
block_tables.append(block_table)
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int)
|
||||
|
||||
# Create the KV caches.
|
||||
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
|
||||
num_kv_heads, head_size,
|
||||
kv_cache_dtype, dtype, seed,
|
||||
device)
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
|
||||
# TODO(charlifu) enable fp8 kv cache
|
||||
# Using default kv_scale
|
||||
# kv_scale = 1.0
|
||||
|
||||
# Call the paged attention kernel.
|
||||
output = torch.empty_like(query)
|
||||
PARTITION_SIZE_ROCM = 256
|
||||
num_partitions = ((max_context_len + PARTITION_SIZE_ROCM - 1) //
|
||||
PARTITION_SIZE_ROCM)
|
||||
assert PARTITION_SIZE_ROCM % block_size == 0
|
||||
num_seqs, num_heads, head_size = output.shape
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_heads, num_partitions, head_size),
|
||||
dtype=output.dtype,
|
||||
)
|
||||
exp_sums = torch.empty(
|
||||
size=(num_seqs, num_heads, num_partitions),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
if version == "rocm":
|
||||
ops.paged_attention_rocm(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
block_size,
|
||||
max_context_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
)
|
||||
else:
|
||||
raise AssertionError(f"Unknown version: {version}")
|
||||
|
||||
# Run the reference implementation.
|
||||
if kv_cache_dtype == "fp8":
|
||||
# Convert cache data back to dtype.
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
|
||||
block_size, x)
|
||||
dequantized_key_cache = torch.empty(size=key_cache_shape,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
ops.convert_fp8(key_cache, dequantized_key_cache)
|
||||
key_cache = dequantized_key_cache
|
||||
|
||||
value_cache_shape = value_cache.shape
|
||||
dequantized_value_cache = torch.empty(size=value_cache_shape,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
ops.convert_fp8(value_cache, dequantized_value_cache)
|
||||
value_cache = dequantized_value_cache
|
||||
|
||||
ref_output = torch.empty_like(query)
|
||||
ref_single_query_cached_kv_attention(
|
||||
ref_output,
|
||||
query,
|
||||
num_queries_per_kv,
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_tables,
|
||||
context_lens,
|
||||
scale,
|
||||
alibi_slopes,
|
||||
)
|
||||
|
||||
# NOTE(woosuk): Due to the kernel-level differences in the two
|
||||
# implementations, there is a small numerical difference in the two
|
||||
# outputs. Thus, we use a relaxed tolerance for the test.
|
||||
atol = get_default_atol(output) if is_hip() else 1e-3
|
||||
rtol = get_default_rtol(output) if is_hip() else 1e-5
|
||||
|
||||
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
|
||||
# so we use a relaxed tolerance for the test.
|
||||
atol, rtol = 1e-4, 1e-5
|
||||
if dtype == torch.bfloat16:
|
||||
atol, rtol = 2e-4, 1e-5
|
||||
if use_alibi:
|
||||
if dtype == torch.half:
|
||||
atol, rtol = 5e-4, 1e-5
|
||||
if dtype == torch.bfloat16:
|
||||
atol, rtol = 1e-3, 1e-5
|
||||
if kv_cache_dtype == "fp8":
|
||||
atol, rtol = 1e-2, 1e-5
|
||||
assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)
|
||||
|
||||
|
||||
# TODO(woosuk): Add tests for USE_ALIBI=True.
|
||||
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@ -491,7 +365,8 @@ def test_paged_attention_rocm(
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.skipif(is_hip(), reason="skip for rocm")
|
||||
@pytest.mark.skipif(is_hip(),
|
||||
reason="Xformers backend is not supported on ROCm.")
|
||||
@torch.inference_mode()
|
||||
def test_multi_query_kv_attention(
|
||||
num_seqs: int,
|
||||
|
||||
@ -146,12 +146,14 @@ def paged_attention_rocm(
|
||||
max_seq_len: int,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
kv_cache_dtype: str,
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
) -> None:
|
||||
torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
|
||||
key_cache, value_cache, num_kv_heads,
|
||||
scale, block_tables, seq_lens,
|
||||
block_size, max_seq_len, alibi_slopes,
|
||||
kv_cache_dtype)
|
||||
kv_cache_dtype, k_scale, v_scale)
|
||||
|
||||
|
||||
# pos encoding ops
|
||||
|
||||
@ -17,8 +17,8 @@ from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_PARTITION_SIZE = 256
|
||||
ON_NAVI = "gfx1" in torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
_PARTITION_SIZE_ROCM = 512
|
||||
_ON_NAVI = "gfx1" in torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
|
||||
|
||||
class ROCmFlashAttentionBackend(AttentionBackend):
|
||||
@ -489,14 +489,15 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
num_seqs, num_heads, head_size = decode_query.shape
|
||||
block_size = value_cache.shape[3]
|
||||
gqa_ratio = num_heads // self.num_kv_heads
|
||||
use_custom = use_rocm_custom_paged_attention(
|
||||
decode_query.dtype, head_size, block_size, self.kv_cache_dtype,
|
||||
gqa_ratio, decode_meta.max_decode_seq_len)
|
||||
use_custom = _use_rocm_custom_paged_attention(
|
||||
decode_query.dtype, head_size, block_size, gqa_ratio,
|
||||
decode_meta.max_decode_seq_len)
|
||||
if use_custom:
|
||||
max_seq_len = decode_meta.max_decode_seq_len
|
||||
max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
|
||||
_PARTITION_SIZE)
|
||||
assert _PARTITION_SIZE % block_size == 0
|
||||
max_num_partitions = (
|
||||
(max_seq_len + _PARTITION_SIZE_ROCM - 1) //
|
||||
_PARTITION_SIZE_ROCM)
|
||||
assert _PARTITION_SIZE_ROCM % block_size == 0
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
||||
dtype=output.dtype,
|
||||
@ -524,6 +525,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
max_seq_len,
|
||||
self.alibi_slopes,
|
||||
self.kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
else:
|
||||
output[num_prefill_tokens:] = PagedAttention.forward_decode(
|
||||
@ -580,12 +583,11 @@ def _sdpa_attention(
|
||||
return output
|
||||
|
||||
|
||||
def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
|
||||
block_size: int, kv_cache_dtype: str,
|
||||
gqa_ratio: int, max_seq_len: int) -> bool:
|
||||
def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
|
||||
block_size: int, gqa_ratio: int,
|
||||
max_seq_len: int) -> bool:
|
||||
# rocm custom page attention not support on navi (gfx1*)
|
||||
return (not ON_NAVI and (qtype == torch.half or qtype == torch.bfloat16)
|
||||
return (not _ON_NAVI and (qtype == torch.half or qtype == torch.bfloat16)
|
||||
and (head_size == 64 or head_size == 128)
|
||||
and (block_size == 16 or block_size == 32)
|
||||
and kv_cache_dtype == "auto"
|
||||
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user