Fp8 paged attention update (#22222)

Signed-off-by: Xiao Yu <xiao.yu@amd.com>
Signed-off-by: xiao-llm <xiao.yu.dc@outlook.com>
Co-authored-by: Xiao Yu <xiao.yu@metamaterial.com>
Co-authored-by: Xiao Yu <xiao.yu@amd.com>
Co-authored-by: Bowen Bao <bowenbao@amd.com>
This commit is contained in:
xiao-llm 2025-09-15 10:43:26 -04:00 committed by GitHub
parent 0e219cd50b
commit 01413e0cf5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 257 additions and 79 deletions

View File

@ -30,6 +30,10 @@
#define __HIP__GFX9__
#endif
#if defined(__HIPCC__) && (defined(__gfx942__) || defined(__gfx950__))
#define __HIP__FP8MFMA__
#endif
#if defined(__HIPCC__) && (defined(__gfx1100__) || defined(__gfx1101__))
#define __HIP__GFX11__
#endif
@ -51,6 +55,12 @@
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
enum class MFMAType {
F16 = 0,
Fp8 = 1,
Fp4 = 2,
};
#if defined(__HIP__GFX9__)
#define GCN_MFMA_INSTR1 __builtin_amdgcn_mfma_f32_16x16x4f32
@ -112,6 +122,21 @@ __device__ __forceinline__ floatx4 gcn_mfma16x16x16_instr(const _B16x4& inpA,
}
}
template <typename T, int absz, int cbid, int blgp>
__device__ __forceinline__ floatx4 gcn_mfma16x16x32_instr(const long& inpA,
const long& inpB,
const floatx4& inpC) {
if constexpr (std::is_same<T, __hip_fp8_e4m3>::value) {
return __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(inpA, inpB, inpC, absz,
cbid, blgp);
} else if constexpr (std::is_same<T, __hip_fp8_e5m2>::value) {
return __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(inpA, inpB, inpC, absz,
cbid, blgp);
} else {
static_assert(false, "unsupported 8b dtype");
}
}
template <typename T>
__device__ __forceinline__ float to_float(const T& inp) {
if constexpr (std::is_same<T, _Float16>::value) {
@ -256,12 +281,44 @@ __device__ __forceinline__ _B16x8 convert_b8x8_custom(const _B8x8 input) {
return ret;
}
typedef union u64_cvt {
half f16x4[4];
int16_t b16x4[4];
_B8x8 b8x8;
_B16x4 b64;
int64_t i64;
} _T8x8;
__device__ __forceinline__ _B8x8 convert_b16x8(const _B16x8& input,
_T8x8& Mtemp) {
_T8x8 Qtmp8x8;
for (int i = 0; i < 2; i++) {
floatx4 q_out = {0, 0, 0, 0};
q_out = gcn_mfma16x16x16_instr<_Float16, 0, 0, 0>(Mtemp.b64, input.xy[i],
q_out);
Qtmp8x8.b16x4[i * 2] =
__builtin_amdgcn_cvt_pk_fp8_f32(q_out[0], q_out[1], 0, false);
Qtmp8x8.b16x4[i * 2 + 1] =
__builtin_amdgcn_cvt_pk_fp8_f32(q_out[2], q_out[3], 0, false);
}
return Qtmp8x8.b8x8;
}
__device__ float warpReduceMax(float val) {
for (int offset = warpSize / 2; offset > 0; offset /= 2) {
val = max(
val, __shfl_down(val, offset, WARP_SIZE)); // Using max() for reduction
}
return val;
}
// grid (num_seqs, num_partitions,num_kv_heads)
// block (256)
// clang-format off
template <typename scalar_t, typename cache_t,
vllm::Fp8KVCacheDataType KV_DTYPE, typename OUTT, int BLOCK_SIZE,
int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED, int GQA_RATIO>
int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED, int GQA_RATIO, MFMAType MFMA_TYPE>
__global__
__launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
@ -367,6 +424,10 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq;
int kphysical_block_number[TLOOP];
#if defined(__HIP__FP8MFMA__)
float q_max = 0;
float q_scale = 1.0;
#endif
// fetch k physical block numbers
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
@ -416,6 +477,15 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
Qlocal[qkhe_depth][qkratio].xy[i] =
shared_logits[qkhe_depth][rowid][lane16id % GQA_RATIO]
[2 * qkratio + i];
#if defined(__HIP__FP8MFMA__)
if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto &&
MFMA_TYPE == MFMAType::Fp8) {
scalar_t* qptr =
reinterpret_cast<scalar_t*>(&Qlocal[qkhe_depth][qkratio].xy[i]);
for (int k = 0; k < 4; k++)
q_max = fmax(fabs(to_float<scalar_t>(qptr[k])), q_max);
}
#endif
}
}
}
@ -515,6 +585,14 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) {
// multiply by k_scale if fp8 kv cache
scale2 *= *k_scale;
#if defined(__HIP__FP8MFMA__)
q_max = warpReduceMax(q_max);
constexpr float FP8_E4M3_SCALE_TARGET = 224.0f;
if constexpr (MFMA_TYPE == MFMAType::Fp8) {
q_scale = q_max > 0 ? FP8_E4M3_SCALE_TARGET / q_max : 1.0f;
scale2 /= q_scale;
}
#endif
}
floatx4 d_out[TLOOP];
@ -534,12 +612,41 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
auto Ktmp = Klocal[token_depth][qkhe_depth];
_B8x16 Ktmp8x16 = *reinterpret_cast<_B8x16*>(&Ktmp);
for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) {
_B8x8 Ktmp8x8 = Ktmp8x16.xy[qkratio];
_B16x8 Klocaltmp = convert_b8x8_custom<scalar_t>(Ktmp8x8);
for (int i = 0; i < 2; i++) {
d_out[token_depth] = gcn_mfma16x16x16_instr<scalar_t, 0, 0, 0>(
Klocaltmp.xy[i], Qlocal[qkhe_depth][qkratio].xy[i],
d_out[token_depth]);
if constexpr (MFMA_TYPE == MFMAType::F16) {
_B8x8 Ktmp8x8 = Ktmp8x16.xy[qkratio];
_B16x8 Klocaltmp = convert_b8x8_custom<scalar_t>(Ktmp8x8);
for (int i = 0; i < 2; i++) {
d_out[token_depth] = gcn_mfma16x16x16_instr<scalar_t, 0, 0, 0>(
Klocaltmp.xy[i], Qlocal[qkhe_depth][qkratio].xy[i],
d_out[token_depth]);
}
} else {
#if defined(__HIP__FP8MFMA__)
_T8x8 Ktmp8x8, Qtmp8x8;
Ktmp8x8.b8x8 = Ktmp8x16.xy[qkratio];
for (int n = 0; n < 2; n++) {
scalar_t* qptr = reinterpret_cast<scalar_t*>(
&Qlocal[qkhe_depth][qkratio].xy[n]);
Qtmp8x8.b16x4[n * 2] =
vllm::fp8::scaled_vec_conversion<uint16_t, float2>(
make_float2(to_float<scalar_t>(qptr[0]),
to_float<scalar_t>(qptr[1])),
q_scale);
Qtmp8x8.b16x4[n * 2 + 1] =
vllm::fp8::scaled_vec_conversion<uint16_t, float2>(
make_float2(to_float<scalar_t>(qptr[2]),
to_float<scalar_t>(qptr[3])),
q_scale);
}
d_out[token_depth] =
gcn_mfma16x16x32_instr<__hip_fp8_e4m3, 0, 0, 0>(
Ktmp8x8.i64, Qtmp8x8.i64, d_out[token_depth]);
#else
UNREACHABLE_CODE
#endif
}
}
}
@ -629,17 +736,36 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
// disable rtz conversion due to its impact on accuracy.
constexpr bool LOGITS_RTZ_CONVERSION = false;
#if defined(__HIP__FP8MFMA__)
int rowid_8x8 = rowid / 2;
int offset = rowid % 2;
#endif
// write logits to shared mem
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
d_out[token_depth] *= inv_sum_scale;
if constexpr (LOGITS_RTZ_CONVERSION) {
// use rtz conversion for better performance, with negligible impact on
// accuracy
shared_logits[warpid][token_depth][lane16id][rowid] =
from_floatx4_rtz<scalar_t>(d_out[token_depth]);
if constexpr (MFMA_TYPE != MFMAType::Fp8) {
if constexpr (LOGITS_RTZ_CONVERSION) {
// use rtz conversion for better performance, with negligible impact on
// accuracy
shared_logits[warpid][token_depth][lane16id][rowid] =
from_floatx4_rtz<scalar_t>(d_out[token_depth]);
} else {
shared_logits[warpid][token_depth][lane16id][rowid] =
from_floatx4<scalar_t>(d_out[token_depth]);
}
} else {
shared_logits[warpid][token_depth][lane16id][rowid] =
from_floatx4<scalar_t>(d_out[token_depth]);
#if defined(__HIP__FP8MFMA__)
// cast _B16x4* to _B8x8*
_T8x8& logits_8x8 = *reinterpret_cast<_T8x8*>(
&shared_logits[warpid][token_depth][lane16id][rowid_8x8]);
logits_8x8.b16x4[offset * 2] = __builtin_amdgcn_cvt_pk_fp8_f32(
d_out[token_depth][0], d_out[token_depth][1], 0, false);
logits_8x8.b16x4[offset * 2 + 1] = __builtin_amdgcn_cvt_pk_fp8_f32(
d_out[token_depth][2], d_out[token_depth][3], 0, false);
#else
UNREACHABLE_CODE
#endif
}
}
@ -692,19 +818,42 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
_B8x16 Vtmp8x16 = *reinterpret_cast<_B8x16*>(&Vtmp);
for (int j = 0; j < ELEMS16_ELEMS8_RATIO; j++) {
_B8x8 Vtmp8x8 = Vtmp8x16.xy[j];
_B16x8 Vlocaltmp = convert_b8x8_custom<scalar_t>(Vtmp8x8);
for (int i = 0; i < ELEMS8_ELEMS4_RATIO; i++) {
const int offset =
rowid * ELEMS16_ELEMS8_RATIO * ELEMS8_ELEMS4_RATIO +
j * ELEMS8_ELEMS4_RATIO + i;
const int offset1 = offset % ROWS_PER_WARP;
const int offset2 = offset / ROWS_PER_WARP;
// output format is 16 qheads across 16 lanes, 16 head elems
// spread across 4 rows
tmp_out = gcn_mfma16x16x16_instr<scalar_t, 0, 0, 0>(
Vlocaltmp.xy[i],
shared_logits[vtoken_depth][offset2][lane16id][offset1],
tmp_out);
if constexpr (MFMA_TYPE == MFMAType::F16) {
_B16x8 Vlocaltmp = convert_b8x8_custom<scalar_t>(Vtmp8x8);
for (int i = 0; i < ELEMS8_ELEMS4_RATIO; i++) {
const int offset =
rowid * ELEMS16_ELEMS8_RATIO * ELEMS8_ELEMS4_RATIO +
j * ELEMS8_ELEMS4_RATIO + i;
const int offset1 = offset % ROWS_PER_WARP;
const int offset2 = offset / ROWS_PER_WARP;
// output format is 16 qheads across 16 lanes, 16 head elems
// spread across 4 rows
tmp_out = gcn_mfma16x16x16_instr<scalar_t, 0, 0, 0>(
Vlocaltmp.xy[i],
shared_logits[vtoken_depth][offset2][lane16id][offset1],
tmp_out);
}
} else {
#if defined(__HIP__FP8MFMA__)
for (int i = 0; i < ELEMS8_ELEMS4_RATIO / 2; i++) {
const int offset =
rowid * ELEMS16_ELEMS8_RATIO * ELEMS8_ELEMS4_RATIO +
j * ELEMS8_ELEMS4_RATIO + i;
const int offset1 = (offset % ROWS_PER_WARP) / 2;
const int offset2 = offset / ROWS_PER_WARP;
// output format is 16 qheads across 16 lanes, 16 head elems
// spread across 4 rows
tmp_out = gcn_mfma16x16x32_instr<__hip_fp8_e4m3, 0, 0, 0>(
reinterpret_cast<_T8x8*>(&Vtmp8x8)->i64,
reinterpret_cast<_T8x8*>(
&shared_logits[vtoken_depth][offset2][lane16id]
[offset1])
->i64,
tmp_out);
}
#else
UNREACHABLE_CODE
#endif
}
}
}
@ -1570,7 +1719,8 @@ __device__ __forceinline__ _B16x8 from_floatx8(const floatx8& inp) {
// clang-format off
template <typename scalar_t, typename cache_t,
vllm::Fp8KVCacheDataType KV_DTYPE, typename OUTT, int BLOCK_SIZE,
int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED, int GQA_RATIO>
int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED, int GQA_RATIO,
MFMAType MFMA_TYPE>
__global__
__launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
@ -2337,7 +2487,8 @@ __device__ __forceinline__ _B16x8 from_floatx8(const floatx8& inp) {
// clang-format off
template <typename scalar_t, typename cache_t,
vllm::Fp8KVCacheDataType KV_DTYPE, typename OUTT, int BLOCK_SIZE,
int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED, int GQA_RATIO>
int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED, int GQA_RATIO,
MFMAType MFMA_TYPE>
__global__
__launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
@ -2969,7 +3120,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
template <typename scalar_t, typename cache_t,
vllm::Fp8KVCacheDataType KV_DTYPE, typename OUTT, int BLOCK_SIZE,
int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED,
int GQA_RATIO>
int GQA_RATIO, MFMAType MFMA_TYPE>
__global__
__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel(
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
@ -3041,7 +3192,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
#define LAUNCH_CUSTOM_ATTENTION_MFMA16(GQA_RATIO) \
paged_attention_ll4mi_QKV_mfma16_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \
HEAD_SIZE, NTHR, ALIBI_ENABLED, \
GQA_RATIO> \
GQA_RATIO, MFMA_TYPE> \
<<<grid, block, 0, stream>>>( \
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr, seq_lens_ptr, query_start_loc_ptr, \
@ -3069,7 +3220,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE,
int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, int PARTITION_SIZE_OLD,
bool ALIBI_ENABLED>
bool ALIBI_ENABLED, MFMAType MFMA_TYPE>
void paged_attention_custom_launcher(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
@ -3225,7 +3376,7 @@ void paged_attention_custom_launcher(
template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE,
int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, int PARTITION_SIZE_OLD,
bool ALIBI_ENABLED>
bool ALIBI_ENABLED, MFMAType MFMA_TYPE>
void paged_attention_custom_launcher_navi(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
@ -3397,74 +3548,77 @@ void paged_attention_custom_launcher_navi(
}
#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \
PSIZE, ALIBI_ENABLED) \
PSIZE, ALIBI_ENABLED, MFMA_TYPE) \
if (!is_navi) { \
paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
OUTT, PSIZE, ALIBI_ENABLED>( \
OUTT, PSIZE, ALIBI_ENABLED, MFMA_TYPE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, seq_lens, query_start_loc, \
max_seq_len, alibi_slopes, k_scale, v_scale, fp8_out_scale); \
} else { \
paged_attention_custom_launcher_navi< \
T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, ALIBI_ENABLED>( \
paged_attention_custom_launcher_navi<T, KVT, KV_DTYPE, BLK_SIZE, \
HEAD_SIZE, OUTT, PSIZE, \
ALIBI_ENABLED, MFMA_TYPE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, seq_lens, query_start_loc, \
max_seq_len, alibi_slopes, k_scale, v_scale); \
}
#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
OUTT, PSIZE) \
OUTT, PSIZE, MFMA_TYPE) \
if (alibi_slopes) { \
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \
true); \
true, MFMA_TYPE); \
} else { \
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \
false); \
false, MFMA_TYPE); \
}
#if defined(__HIPCC__) && defined(__gfx90a__)
#define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \
#define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
MFMA_TYPE) \
if (fp8_out_scale) { \
TORCH_CHECK(false, "fp8 out scale unsupported for gfx90a"); \
} else { \
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \
256); \
256, MFMA_TYPE); \
}
#else
#define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \
#define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
MFMA_TYPE) \
if (fp8_out_scale) { \
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
uint8_t, 256); \
uint8_t, 256, MFMA_TYPE); \
} else { \
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \
256); \
256, MFMA_TYPE); \
}
#endif
#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \
switch (block_size) { \
case 16: \
CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 16, HEAD_SIZE); \
break; \
case 32: \
CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 32, HEAD_SIZE); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE, MFMA_TYPE) \
switch (block_size) { \
case 16: \
CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 16, HEAD_SIZE, MFMA_TYPE); \
break; \
case 32: \
CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 32, HEAD_SIZE, MFMA_TYPE); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T, KVT, KV_DTYPE) \
switch (head_size) { \
case 64: \
CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 64); \
break; \
case 128: \
CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 128); \
break; \
default: \
TORCH_CHECK(false, "Unsupported head size: ", head_size); \
break; \
#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T, KVT, KV_DTYPE, MFMA_TYPE) \
switch (head_size) { \
case 64: \
CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 64, MFMA_TYPE); \
break; \
case 128: \
CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 128, MFMA_TYPE); \
break; \
default: \
TORCH_CHECK(false, "Unsupported head size: ", head_size); \
break; \
}
bool is_navi_gpu() {
@ -3503,28 +3657,43 @@ void paged_attention(
const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale,
const std::optional<torch::Tensor>& fp8_out_scale) {
const std::optional<torch::Tensor>& fp8_out_scale,
const std::string& mfma_type) {
// clang-format on
bool is_navi = is_navi_gpu();
const int head_size = query.size(2);
if (kv_cache_dtype == "auto") {
if (query.dtype() == at::ScalarType::Half) {
CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, _Float16,
vllm::Fp8KVCacheDataType::kAuto);
CALL_CUSTOM_LAUNCHER_BLK_HEAD(
_Float16, _Float16, vllm::Fp8KVCacheDataType::kAuto, MFMAType::F16);
} else if (query.dtype() == at::ScalarType::BFloat16) {
CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, __hip_bfloat16,
vllm::Fp8KVCacheDataType::kAuto);
vllm::Fp8KVCacheDataType::kAuto,
MFMAType::F16);
} else {
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
}
} else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") {
if (query.dtype() == at::ScalarType::Half) {
CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, uint8_t,
vllm::Fp8KVCacheDataType::kFp8E4M3);
if (mfma_type == "fp8") {
CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, uint8_t,
vllm::Fp8KVCacheDataType::kFp8E4M3,
MFMAType::Fp8);
} else {
CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, uint8_t,
vllm::Fp8KVCacheDataType::kFp8E4M3,
MFMAType::F16);
}
} else if (query.dtype() == at::ScalarType::BFloat16) {
CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, uint8_t,
vllm::Fp8KVCacheDataType::kFp8E4M3);
if (mfma_type == "fp8") {
CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, uint8_t,
vllm::Fp8KVCacheDataType::kFp8E4M3,
MFMAType::Fp8);
} else {
CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, uint8_t,
vllm::Fp8KVCacheDataType::kFp8E4M3,
MFMAType::F16);
}
} else {
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
}

View File

@ -19,4 +19,5 @@ void paged_attention(
const std::optional<torch::Tensor>& query_start_loc, int64_t block_size,
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale, const std::optional<torch::Tensor>& fp8_out_scale);
torch::Tensor& v_scale, const std::optional<torch::Tensor>& fp8_out_scale,
const std::string& mfma_type);

View File

@ -48,7 +48,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
" Tensor? alibi_slopes,"
" str kv_cache_dtype,"
" Tensor k_scale, Tensor v_scale,"
" Tensor? fp8_out_scale) -> ()");
" Tensor? fp8_out_scale,"
" str mfma_type) -> ()");
rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention);
}

View File

@ -117,13 +117,14 @@ def paged_attention_rocm(
k_scale: torch.Tensor,
v_scale: torch.Tensor,
fp8_out_scale: Optional[torch.Tensor] = None,
mfma_type: str = "fp8" if envs.VLLM_ROCM_FP8_MFMA_PAGE_ATTN else "f16",
) -> None:
torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
key_cache, value_cache, num_kv_heads,
scale, block_tables, seq_lens,
query_start_loc, block_size, max_seq_len,
alibi_slopes, kv_cache_dtype, k_scale,
v_scale, fp8_out_scale)
v_scale, fp8_out_scale, mfma_type)
def mla_decode_kvcache_cpu(

View File

@ -167,6 +167,7 @@ if TYPE_CHECKING:
VLLM_HAS_FLASHINFER_CUBIN: bool = False
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False
VLLM_ROCM_FP8_MFMA_PAGE_ATTN: bool = False
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS: bool = False
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
@ -1219,6 +1220,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ENABLE_RESPONSES_API_STORE":
lambda: bool(int(os.getenv("VLLM_ENABLE_RESPONSES_API_STORE", "0"))),
# If set, use the fp8 mfma in rocm paged attention.
"VLLM_ROCM_FP8_MFMA_PAGE_ATTN":
lambda: bool(int(os.getenv("VLLM_ROCM_FP8_MFMA_PAGE_ATTN", "0"))),
# Whether to use pytorch symmetric memory for allreduce
"VLLM_ALLREDUCE_USE_SYMM_MEM":
lambda: bool(int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "0"))),
@ -1340,6 +1345,7 @@ def compute_hash() -> str:
"VLLM_ROCM_QUICK_REDUCE_QUANTIZATION",
"VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16",
"VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB",
"VLLM_ROCM_FP8_MFMA_PAGE_ATTN",
]
for key in environment_variables_to_hash:
# if this goes out of sync with environment_variables,