diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index e3a0e15f5304..dac9df6048f2 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -30,6 +30,10 @@ #define __HIP__GFX9__ #endif +#if defined(__HIPCC__) && (defined(__gfx942__) || defined(__gfx950__)) + #define __HIP__FP8MFMA__ +#endif + #if defined(__HIPCC__) && (defined(__gfx1100__) || defined(__gfx1101__)) #define __HIP__GFX11__ #endif @@ -51,6 +55,12 @@ #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) +enum class MFMAType { + F16 = 0, + Fp8 = 1, + Fp4 = 2, +}; + #if defined(__HIP__GFX9__) #define GCN_MFMA_INSTR1 __builtin_amdgcn_mfma_f32_16x16x4f32 @@ -112,6 +122,21 @@ __device__ __forceinline__ floatx4 gcn_mfma16x16x16_instr(const _B16x4& inpA, } } +template +__device__ __forceinline__ floatx4 gcn_mfma16x16x32_instr(const long& inpA, + const long& inpB, + const floatx4& inpC) { + if constexpr (std::is_same::value) { + return __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(inpA, inpB, inpC, absz, + cbid, blgp); + } else if constexpr (std::is_same::value) { + return __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(inpA, inpB, inpC, absz, + cbid, blgp); + } else { + static_assert(false, "unsupported 8b dtype"); + } +} + template __device__ __forceinline__ float to_float(const T& inp) { if constexpr (std::is_same::value) { @@ -256,12 +281,44 @@ __device__ __forceinline__ _B16x8 convert_b8x8_custom(const _B8x8 input) { return ret; } +typedef union u64_cvt { + half f16x4[4]; + int16_t b16x4[4]; + _B8x8 b8x8; + _B16x4 b64; + int64_t i64; +} _T8x8; + +__device__ __forceinline__ _B8x8 convert_b16x8(const _B16x8& input, + _T8x8& Mtemp) { + _T8x8 Qtmp8x8; + + for (int i = 0; i < 2; i++) { + floatx4 q_out = {0, 0, 0, 0}; + q_out = gcn_mfma16x16x16_instr<_Float16, 0, 0, 0>(Mtemp.b64, input.xy[i], + q_out); + Qtmp8x8.b16x4[i * 2] = + __builtin_amdgcn_cvt_pk_fp8_f32(q_out[0], q_out[1], 0, false); + Qtmp8x8.b16x4[i * 2 + 1] = + __builtin_amdgcn_cvt_pk_fp8_f32(q_out[2], q_out[3], 0, false); + } + return Qtmp8x8.b8x8; +} + +__device__ float warpReduceMax(float val) { + for (int offset = warpSize / 2; offset > 0; offset /= 2) { + val = max( + val, __shfl_down(val, offset, WARP_SIZE)); // Using max() for reduction + } + return val; +} + // grid (num_seqs, num_partitions,num_kv_heads) // block (256) // clang-format off template + int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED, int GQA_RATIO, MFMAType MFMA_TYPE> __global__ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] @@ -367,6 +424,10 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq; int kphysical_block_number[TLOOP]; + #if defined(__HIP__FP8MFMA__) + float q_max = 0; + float q_scale = 1.0; + #endif // fetch k physical block numbers for (int token_depth = 0; token_depth < TLOOP; token_depth++) { @@ -416,6 +477,15 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( Qlocal[qkhe_depth][qkratio].xy[i] = shared_logits[qkhe_depth][rowid][lane16id % GQA_RATIO] [2 * qkratio + i]; + #if defined(__HIP__FP8MFMA__) + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto && + MFMA_TYPE == MFMAType::Fp8) { + scalar_t* qptr = + reinterpret_cast(&Qlocal[qkhe_depth][qkratio].xy[i]); + for (int k = 0; k < 4; k++) + q_max = fmax(fabs(to_float(qptr[k])), q_max); + } + #endif } } } @@ -515,6 +585,14 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { // multiply by k_scale if fp8 kv cache scale2 *= *k_scale; + #if defined(__HIP__FP8MFMA__) + q_max = warpReduceMax(q_max); + constexpr float FP8_E4M3_SCALE_TARGET = 224.0f; + if constexpr (MFMA_TYPE == MFMAType::Fp8) { + q_scale = q_max > 0 ? FP8_E4M3_SCALE_TARGET / q_max : 1.0f; + scale2 /= q_scale; + } + #endif } floatx4 d_out[TLOOP]; @@ -534,12 +612,41 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( auto Ktmp = Klocal[token_depth][qkhe_depth]; _B8x16 Ktmp8x16 = *reinterpret_cast<_B8x16*>(&Ktmp); for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) { - _B8x8 Ktmp8x8 = Ktmp8x16.xy[qkratio]; - _B16x8 Klocaltmp = convert_b8x8_custom(Ktmp8x8); - for (int i = 0; i < 2; i++) { - d_out[token_depth] = gcn_mfma16x16x16_instr( - Klocaltmp.xy[i], Qlocal[qkhe_depth][qkratio].xy[i], - d_out[token_depth]); + if constexpr (MFMA_TYPE == MFMAType::F16) { + _B8x8 Ktmp8x8 = Ktmp8x16.xy[qkratio]; + _B16x8 Klocaltmp = convert_b8x8_custom(Ktmp8x8); + for (int i = 0; i < 2; i++) { + d_out[token_depth] = gcn_mfma16x16x16_instr( + Klocaltmp.xy[i], Qlocal[qkhe_depth][qkratio].xy[i], + d_out[token_depth]); + } + } else { + #if defined(__HIP__FP8MFMA__) + _T8x8 Ktmp8x8, Qtmp8x8; + Ktmp8x8.b8x8 = Ktmp8x16.xy[qkratio]; + + for (int n = 0; n < 2; n++) { + scalar_t* qptr = reinterpret_cast( + &Qlocal[qkhe_depth][qkratio].xy[n]); + + Qtmp8x8.b16x4[n * 2] = + vllm::fp8::scaled_vec_conversion( + make_float2(to_float(qptr[0]), + to_float(qptr[1])), + q_scale); + Qtmp8x8.b16x4[n * 2 + 1] = + vllm::fp8::scaled_vec_conversion( + make_float2(to_float(qptr[2]), + to_float(qptr[3])), + q_scale); + } + + d_out[token_depth] = + gcn_mfma16x16x32_instr<__hip_fp8_e4m3, 0, 0, 0>( + Ktmp8x8.i64, Qtmp8x8.i64, d_out[token_depth]); + #else + UNREACHABLE_CODE + #endif } } } @@ -629,17 +736,36 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( // disable rtz conversion due to its impact on accuracy. constexpr bool LOGITS_RTZ_CONVERSION = false; + #if defined(__HIP__FP8MFMA__) + int rowid_8x8 = rowid / 2; + int offset = rowid % 2; + #endif + // write logits to shared mem for (int token_depth = 0; token_depth < TLOOP; token_depth++) { d_out[token_depth] *= inv_sum_scale; - if constexpr (LOGITS_RTZ_CONVERSION) { - // use rtz conversion for better performance, with negligible impact on - // accuracy - shared_logits[warpid][token_depth][lane16id][rowid] = - from_floatx4_rtz(d_out[token_depth]); + if constexpr (MFMA_TYPE != MFMAType::Fp8) { + if constexpr (LOGITS_RTZ_CONVERSION) { + // use rtz conversion for better performance, with negligible impact on + // accuracy + shared_logits[warpid][token_depth][lane16id][rowid] = + from_floatx4_rtz(d_out[token_depth]); + } else { + shared_logits[warpid][token_depth][lane16id][rowid] = + from_floatx4(d_out[token_depth]); + } } else { - shared_logits[warpid][token_depth][lane16id][rowid] = - from_floatx4(d_out[token_depth]); + #if defined(__HIP__FP8MFMA__) + // cast _B16x4* to _B8x8* + _T8x8& logits_8x8 = *reinterpret_cast<_T8x8*>( + &shared_logits[warpid][token_depth][lane16id][rowid_8x8]); + logits_8x8.b16x4[offset * 2] = __builtin_amdgcn_cvt_pk_fp8_f32( + d_out[token_depth][0], d_out[token_depth][1], 0, false); + logits_8x8.b16x4[offset * 2 + 1] = __builtin_amdgcn_cvt_pk_fp8_f32( + d_out[token_depth][2], d_out[token_depth][3], 0, false); + #else + UNREACHABLE_CODE + #endif } } @@ -692,19 +818,42 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( _B8x16 Vtmp8x16 = *reinterpret_cast<_B8x16*>(&Vtmp); for (int j = 0; j < ELEMS16_ELEMS8_RATIO; j++) { _B8x8 Vtmp8x8 = Vtmp8x16.xy[j]; - _B16x8 Vlocaltmp = convert_b8x8_custom(Vtmp8x8); - for (int i = 0; i < ELEMS8_ELEMS4_RATIO; i++) { - const int offset = - rowid * ELEMS16_ELEMS8_RATIO * ELEMS8_ELEMS4_RATIO + - j * ELEMS8_ELEMS4_RATIO + i; - const int offset1 = offset % ROWS_PER_WARP; - const int offset2 = offset / ROWS_PER_WARP; - // output format is 16 qheads across 16 lanes, 16 head elems - // spread across 4 rows - tmp_out = gcn_mfma16x16x16_instr( - Vlocaltmp.xy[i], - shared_logits[vtoken_depth][offset2][lane16id][offset1], - tmp_out); + if constexpr (MFMA_TYPE == MFMAType::F16) { + _B16x8 Vlocaltmp = convert_b8x8_custom(Vtmp8x8); + for (int i = 0; i < ELEMS8_ELEMS4_RATIO; i++) { + const int offset = + rowid * ELEMS16_ELEMS8_RATIO * ELEMS8_ELEMS4_RATIO + + j * ELEMS8_ELEMS4_RATIO + i; + const int offset1 = offset % ROWS_PER_WARP; + const int offset2 = offset / ROWS_PER_WARP; + // output format is 16 qheads across 16 lanes, 16 head elems + // spread across 4 rows + tmp_out = gcn_mfma16x16x16_instr( + Vlocaltmp.xy[i], + shared_logits[vtoken_depth][offset2][lane16id][offset1], + tmp_out); + } + } else { + #if defined(__HIP__FP8MFMA__) + for (int i = 0; i < ELEMS8_ELEMS4_RATIO / 2; i++) { + const int offset = + rowid * ELEMS16_ELEMS8_RATIO * ELEMS8_ELEMS4_RATIO + + j * ELEMS8_ELEMS4_RATIO + i; + const int offset1 = (offset % ROWS_PER_WARP) / 2; + const int offset2 = offset / ROWS_PER_WARP; + // output format is 16 qheads across 16 lanes, 16 head elems + // spread across 4 rows + tmp_out = gcn_mfma16x16x32_instr<__hip_fp8_e4m3, 0, 0, 0>( + reinterpret_cast<_T8x8*>(&Vtmp8x8)->i64, + reinterpret_cast<_T8x8*>( + &shared_logits[vtoken_depth][offset2][lane16id] + [offset1]) + ->i64, + tmp_out); + } + #else + UNREACHABLE_CODE + #endif } } } @@ -1570,7 +1719,8 @@ __device__ __forceinline__ _B16x8 from_floatx8(const floatx8& inp) { // clang-format off template + int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED, int GQA_RATIO, + MFMAType MFMA_TYPE> __global__ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] @@ -2337,7 +2487,8 @@ __device__ __forceinline__ _B16x8 from_floatx8(const floatx8& inp) { // clang-format off template + int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED, int GQA_RATIO, + MFMAType MFMA_TYPE> __global__ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] @@ -2969,7 +3120,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( template + int GQA_RATIO, MFMAType MFMA_TYPE> __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] @@ -3041,7 +3192,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( #define LAUNCH_CUSTOM_ATTENTION_MFMA16(GQA_RATIO) \ paged_attention_ll4mi_QKV_mfma16_kernel \ + GQA_RATIO, MFMA_TYPE> \ <<>>( \ query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ block_tables_ptr, seq_lens_ptr, query_start_loc_ptr, \ @@ -3069,7 +3220,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( template + bool ALIBI_ENABLED, MFMAType MFMA_TYPE> void paged_attention_custom_launcher( torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, @@ -3225,7 +3376,7 @@ void paged_attention_custom_launcher( template + bool ALIBI_ENABLED, MFMAType MFMA_TYPE> void paged_attention_custom_launcher_navi( torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, @@ -3397,74 +3548,77 @@ void paged_attention_custom_launcher_navi( } #define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \ - PSIZE, ALIBI_ENABLED) \ + PSIZE, ALIBI_ENABLED, MFMA_TYPE) \ if (!is_navi) { \ paged_attention_custom_launcher( \ + OUTT, PSIZE, ALIBI_ENABLED, MFMA_TYPE>( \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ num_kv_heads, scale, block_tables, seq_lens, query_start_loc, \ max_seq_len, alibi_slopes, k_scale, v_scale, fp8_out_scale); \ } else { \ - paged_attention_custom_launcher_navi< \ - T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, ALIBI_ENABLED>( \ + paged_attention_custom_launcher_navi( \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ num_kv_heads, scale, block_tables, seq_lens, query_start_loc, \ max_seq_len, alibi_slopes, k_scale, v_scale); \ } #define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ - OUTT, PSIZE) \ + OUTT, PSIZE, MFMA_TYPE) \ if (alibi_slopes) { \ CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \ - true); \ + true, MFMA_TYPE); \ } else { \ CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \ - false); \ + false, MFMA_TYPE); \ } #if defined(__HIPCC__) && defined(__gfx90a__) - #define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \ + #define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ + MFMA_TYPE) \ if (fp8_out_scale) { \ TORCH_CHECK(false, "fp8 out scale unsupported for gfx90a"); \ } else { \ CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \ - 256); \ + 256, MFMA_TYPE); \ } #else - #define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \ + #define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ + MFMA_TYPE) \ if (fp8_out_scale) { \ CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ - uint8_t, 256); \ + uint8_t, 256, MFMA_TYPE); \ } else { \ CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \ - 256); \ + 256, MFMA_TYPE); \ } #endif -#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \ - switch (block_size) { \ - case 16: \ - CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 16, HEAD_SIZE); \ - break; \ - case 32: \ - CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 32, HEAD_SIZE); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ +#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE, MFMA_TYPE) \ + switch (block_size) { \ + case 16: \ + CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 16, HEAD_SIZE, MFMA_TYPE); \ + break; \ + case 32: \ + CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 32, HEAD_SIZE, MFMA_TYPE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ } -#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T, KVT, KV_DTYPE) \ - switch (head_size) { \ - case 64: \ - CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 64); \ - break; \ - case 128: \ - CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 128); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported head size: ", head_size); \ - break; \ +#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T, KVT, KV_DTYPE, MFMA_TYPE) \ + switch (head_size) { \ + case 64: \ + CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 64, MFMA_TYPE); \ + break; \ + case 128: \ + CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 128, MFMA_TYPE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported head size: ", head_size); \ + break; \ } bool is_navi_gpu() { @@ -3503,28 +3657,43 @@ void paged_attention( const std::optional& alibi_slopes, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale, - const std::optional& fp8_out_scale) { + const std::optional& fp8_out_scale, + const std::string& mfma_type) { // clang-format on bool is_navi = is_navi_gpu(); - const int head_size = query.size(2); if (kv_cache_dtype == "auto") { if (query.dtype() == at::ScalarType::Half) { - CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, _Float16, - vllm::Fp8KVCacheDataType::kAuto); + CALL_CUSTOM_LAUNCHER_BLK_HEAD( + _Float16, _Float16, vllm::Fp8KVCacheDataType::kAuto, MFMAType::F16); } else if (query.dtype() == at::ScalarType::BFloat16) { CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, __hip_bfloat16, - vllm::Fp8KVCacheDataType::kAuto); + vllm::Fp8KVCacheDataType::kAuto, + MFMAType::F16); } else { TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); } } else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") { if (query.dtype() == at::ScalarType::Half) { - CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, uint8_t, - vllm::Fp8KVCacheDataType::kFp8E4M3); + if (mfma_type == "fp8") { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, uint8_t, + vllm::Fp8KVCacheDataType::kFp8E4M3, + MFMAType::Fp8); + } else { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, uint8_t, + vllm::Fp8KVCacheDataType::kFp8E4M3, + MFMAType::F16); + } } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, uint8_t, - vllm::Fp8KVCacheDataType::kFp8E4M3); + if (mfma_type == "fp8") { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, uint8_t, + vllm::Fp8KVCacheDataType::kFp8E4M3, + MFMAType::Fp8); + } else { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, uint8_t, + vllm::Fp8KVCacheDataType::kFp8E4M3, + MFMAType::F16); + } } else { TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); } diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h index 34dcc9401aae..b6ee2656746c 100644 --- a/csrc/rocm/ops.h +++ b/csrc/rocm/ops.h @@ -19,4 +19,5 @@ void paged_attention( const std::optional& query_start_loc, int64_t block_size, int64_t max_seq_len, const std::optional& alibi_slopes, const std::string& kv_cache_dtype, torch::Tensor& k_scale, - torch::Tensor& v_scale, const std::optional& fp8_out_scale); + torch::Tensor& v_scale, const std::optional& fp8_out_scale, + const std::string& mfma_type); diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp index 66bdc448da3c..c0c4daef64f0 100644 --- a/csrc/rocm/torch_bindings.cpp +++ b/csrc/rocm/torch_bindings.cpp @@ -48,7 +48,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { " Tensor? alibi_slopes," " str kv_cache_dtype," " Tensor k_scale, Tensor v_scale," - " Tensor? fp8_out_scale) -> ()"); + " Tensor? fp8_out_scale," + " str mfma_type) -> ()"); rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention); } diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 93b4f87ed260..f3e5fe73d4e2 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -117,13 +117,14 @@ def paged_attention_rocm( k_scale: torch.Tensor, v_scale: torch.Tensor, fp8_out_scale: Optional[torch.Tensor] = None, + mfma_type: str = "fp8" if envs.VLLM_ROCM_FP8_MFMA_PAGE_ATTN else "f16", ) -> None: torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, query_start_loc, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, k_scale, - v_scale, fp8_out_scale) + v_scale, fp8_out_scale, mfma_type) def mla_decode_kvcache_cpu( diff --git a/vllm/envs.py b/vllm/envs.py index bb10c7cc2ac2..7ea18064c944 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -167,6 +167,7 @@ if TYPE_CHECKING: VLLM_HAS_FLASHINFER_CUBIN: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False + VLLM_ROCM_FP8_MFMA_PAGE_ATTN: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS: bool = False VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None @@ -1219,6 +1220,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_ENABLE_RESPONSES_API_STORE": lambda: bool(int(os.getenv("VLLM_ENABLE_RESPONSES_API_STORE", "0"))), + # If set, use the fp8 mfma in rocm paged attention. + "VLLM_ROCM_FP8_MFMA_PAGE_ATTN": + lambda: bool(int(os.getenv("VLLM_ROCM_FP8_MFMA_PAGE_ATTN", "0"))), + # Whether to use pytorch symmetric memory for allreduce "VLLM_ALLREDUCE_USE_SYMM_MEM": lambda: bool(int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "0"))), @@ -1340,6 +1345,7 @@ def compute_hash() -> str: "VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", "VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16", "VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB", + "VLLM_ROCM_FP8_MFMA_PAGE_ATTN", ] for key in environment_variables_to_hash: # if this goes out of sync with environment_variables,