mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 07:34:59 +08:00
Fp8 paged attention update (#22222)
Signed-off-by: Xiao Yu <xiao.yu@amd.com> Signed-off-by: xiao-llm <xiao.yu.dc@outlook.com> Co-authored-by: Xiao Yu <xiao.yu@metamaterial.com> Co-authored-by: Xiao Yu <xiao.yu@amd.com> Co-authored-by: Bowen Bao <bowenbao@amd.com>
This commit is contained in:
parent
0e219cd50b
commit
01413e0cf5
@ -30,6 +30,10 @@
|
|||||||
#define __HIP__GFX9__
|
#define __HIP__GFX9__
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if defined(__HIPCC__) && (defined(__gfx942__) || defined(__gfx950__))
|
||||||
|
#define __HIP__FP8MFMA__
|
||||||
|
#endif
|
||||||
|
|
||||||
#if defined(__HIPCC__) && (defined(__gfx1100__) || defined(__gfx1101__))
|
#if defined(__HIPCC__) && (defined(__gfx1100__) || defined(__gfx1101__))
|
||||||
#define __HIP__GFX11__
|
#define __HIP__GFX11__
|
||||||
#endif
|
#endif
|
||||||
@ -51,6 +55,12 @@
|
|||||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||||
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
|
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
|
||||||
|
|
||||||
|
enum class MFMAType {
|
||||||
|
F16 = 0,
|
||||||
|
Fp8 = 1,
|
||||||
|
Fp4 = 2,
|
||||||
|
};
|
||||||
|
|
||||||
#if defined(__HIP__GFX9__)
|
#if defined(__HIP__GFX9__)
|
||||||
|
|
||||||
#define GCN_MFMA_INSTR1 __builtin_amdgcn_mfma_f32_16x16x4f32
|
#define GCN_MFMA_INSTR1 __builtin_amdgcn_mfma_f32_16x16x4f32
|
||||||
@ -112,6 +122,21 @@ __device__ __forceinline__ floatx4 gcn_mfma16x16x16_instr(const _B16x4& inpA,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T, int absz, int cbid, int blgp>
|
||||||
|
__device__ __forceinline__ floatx4 gcn_mfma16x16x32_instr(const long& inpA,
|
||||||
|
const long& inpB,
|
||||||
|
const floatx4& inpC) {
|
||||||
|
if constexpr (std::is_same<T, __hip_fp8_e4m3>::value) {
|
||||||
|
return __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(inpA, inpB, inpC, absz,
|
||||||
|
cbid, blgp);
|
||||||
|
} else if constexpr (std::is_same<T, __hip_fp8_e5m2>::value) {
|
||||||
|
return __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(inpA, inpB, inpC, absz,
|
||||||
|
cbid, blgp);
|
||||||
|
} else {
|
||||||
|
static_assert(false, "unsupported 8b dtype");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ __forceinline__ float to_float(const T& inp) {
|
__device__ __forceinline__ float to_float(const T& inp) {
|
||||||
if constexpr (std::is_same<T, _Float16>::value) {
|
if constexpr (std::is_same<T, _Float16>::value) {
|
||||||
@ -256,12 +281,44 @@ __device__ __forceinline__ _B16x8 convert_b8x8_custom(const _B8x8 input) {
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
typedef union u64_cvt {
|
||||||
|
half f16x4[4];
|
||||||
|
int16_t b16x4[4];
|
||||||
|
_B8x8 b8x8;
|
||||||
|
_B16x4 b64;
|
||||||
|
int64_t i64;
|
||||||
|
} _T8x8;
|
||||||
|
|
||||||
|
__device__ __forceinline__ _B8x8 convert_b16x8(const _B16x8& input,
|
||||||
|
_T8x8& Mtemp) {
|
||||||
|
_T8x8 Qtmp8x8;
|
||||||
|
|
||||||
|
for (int i = 0; i < 2; i++) {
|
||||||
|
floatx4 q_out = {0, 0, 0, 0};
|
||||||
|
q_out = gcn_mfma16x16x16_instr<_Float16, 0, 0, 0>(Mtemp.b64, input.xy[i],
|
||||||
|
q_out);
|
||||||
|
Qtmp8x8.b16x4[i * 2] =
|
||||||
|
__builtin_amdgcn_cvt_pk_fp8_f32(q_out[0], q_out[1], 0, false);
|
||||||
|
Qtmp8x8.b16x4[i * 2 + 1] =
|
||||||
|
__builtin_amdgcn_cvt_pk_fp8_f32(q_out[2], q_out[3], 0, false);
|
||||||
|
}
|
||||||
|
return Qtmp8x8.b8x8;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ float warpReduceMax(float val) {
|
||||||
|
for (int offset = warpSize / 2; offset > 0; offset /= 2) {
|
||||||
|
val = max(
|
||||||
|
val, __shfl_down(val, offset, WARP_SIZE)); // Using max() for reduction
|
||||||
|
}
|
||||||
|
return val;
|
||||||
|
}
|
||||||
|
|
||||||
// grid (num_seqs, num_partitions,num_kv_heads)
|
// grid (num_seqs, num_partitions,num_kv_heads)
|
||||||
// block (256)
|
// block (256)
|
||||||
// clang-format off
|
// clang-format off
|
||||||
template <typename scalar_t, typename cache_t,
|
template <typename scalar_t, typename cache_t,
|
||||||
vllm::Fp8KVCacheDataType KV_DTYPE, typename OUTT, int BLOCK_SIZE,
|
vllm::Fp8KVCacheDataType KV_DTYPE, typename OUTT, int BLOCK_SIZE,
|
||||||
int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED, int GQA_RATIO>
|
int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED, int GQA_RATIO, MFMAType MFMA_TYPE>
|
||||||
__global__
|
__global__
|
||||||
__launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
__launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||||
@ -367,6 +424,10 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
|||||||
const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq;
|
const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq;
|
||||||
|
|
||||||
int kphysical_block_number[TLOOP];
|
int kphysical_block_number[TLOOP];
|
||||||
|
#if defined(__HIP__FP8MFMA__)
|
||||||
|
float q_max = 0;
|
||||||
|
float q_scale = 1.0;
|
||||||
|
#endif
|
||||||
|
|
||||||
// fetch k physical block numbers
|
// fetch k physical block numbers
|
||||||
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
|
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
|
||||||
@ -416,6 +477,15 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
|||||||
Qlocal[qkhe_depth][qkratio].xy[i] =
|
Qlocal[qkhe_depth][qkratio].xy[i] =
|
||||||
shared_logits[qkhe_depth][rowid][lane16id % GQA_RATIO]
|
shared_logits[qkhe_depth][rowid][lane16id % GQA_RATIO]
|
||||||
[2 * qkratio + i];
|
[2 * qkratio + i];
|
||||||
|
#if defined(__HIP__FP8MFMA__)
|
||||||
|
if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto &&
|
||||||
|
MFMA_TYPE == MFMAType::Fp8) {
|
||||||
|
scalar_t* qptr =
|
||||||
|
reinterpret_cast<scalar_t*>(&Qlocal[qkhe_depth][qkratio].xy[i]);
|
||||||
|
for (int k = 0; k < 4; k++)
|
||||||
|
q_max = fmax(fabs(to_float<scalar_t>(qptr[k])), q_max);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -515,6 +585,14 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
|||||||
if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) {
|
if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) {
|
||||||
// multiply by k_scale if fp8 kv cache
|
// multiply by k_scale if fp8 kv cache
|
||||||
scale2 *= *k_scale;
|
scale2 *= *k_scale;
|
||||||
|
#if defined(__HIP__FP8MFMA__)
|
||||||
|
q_max = warpReduceMax(q_max);
|
||||||
|
constexpr float FP8_E4M3_SCALE_TARGET = 224.0f;
|
||||||
|
if constexpr (MFMA_TYPE == MFMAType::Fp8) {
|
||||||
|
q_scale = q_max > 0 ? FP8_E4M3_SCALE_TARGET / q_max : 1.0f;
|
||||||
|
scale2 /= q_scale;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
floatx4 d_out[TLOOP];
|
floatx4 d_out[TLOOP];
|
||||||
@ -534,12 +612,41 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
|||||||
auto Ktmp = Klocal[token_depth][qkhe_depth];
|
auto Ktmp = Klocal[token_depth][qkhe_depth];
|
||||||
_B8x16 Ktmp8x16 = *reinterpret_cast<_B8x16*>(&Ktmp);
|
_B8x16 Ktmp8x16 = *reinterpret_cast<_B8x16*>(&Ktmp);
|
||||||
for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) {
|
for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) {
|
||||||
_B8x8 Ktmp8x8 = Ktmp8x16.xy[qkratio];
|
if constexpr (MFMA_TYPE == MFMAType::F16) {
|
||||||
_B16x8 Klocaltmp = convert_b8x8_custom<scalar_t>(Ktmp8x8);
|
_B8x8 Ktmp8x8 = Ktmp8x16.xy[qkratio];
|
||||||
for (int i = 0; i < 2; i++) {
|
_B16x8 Klocaltmp = convert_b8x8_custom<scalar_t>(Ktmp8x8);
|
||||||
d_out[token_depth] = gcn_mfma16x16x16_instr<scalar_t, 0, 0, 0>(
|
for (int i = 0; i < 2; i++) {
|
||||||
Klocaltmp.xy[i], Qlocal[qkhe_depth][qkratio].xy[i],
|
d_out[token_depth] = gcn_mfma16x16x16_instr<scalar_t, 0, 0, 0>(
|
||||||
d_out[token_depth]);
|
Klocaltmp.xy[i], Qlocal[qkhe_depth][qkratio].xy[i],
|
||||||
|
d_out[token_depth]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
#if defined(__HIP__FP8MFMA__)
|
||||||
|
_T8x8 Ktmp8x8, Qtmp8x8;
|
||||||
|
Ktmp8x8.b8x8 = Ktmp8x16.xy[qkratio];
|
||||||
|
|
||||||
|
for (int n = 0; n < 2; n++) {
|
||||||
|
scalar_t* qptr = reinterpret_cast<scalar_t*>(
|
||||||
|
&Qlocal[qkhe_depth][qkratio].xy[n]);
|
||||||
|
|
||||||
|
Qtmp8x8.b16x4[n * 2] =
|
||||||
|
vllm::fp8::scaled_vec_conversion<uint16_t, float2>(
|
||||||
|
make_float2(to_float<scalar_t>(qptr[0]),
|
||||||
|
to_float<scalar_t>(qptr[1])),
|
||||||
|
q_scale);
|
||||||
|
Qtmp8x8.b16x4[n * 2 + 1] =
|
||||||
|
vllm::fp8::scaled_vec_conversion<uint16_t, float2>(
|
||||||
|
make_float2(to_float<scalar_t>(qptr[2]),
|
||||||
|
to_float<scalar_t>(qptr[3])),
|
||||||
|
q_scale);
|
||||||
|
}
|
||||||
|
|
||||||
|
d_out[token_depth] =
|
||||||
|
gcn_mfma16x16x32_instr<__hip_fp8_e4m3, 0, 0, 0>(
|
||||||
|
Ktmp8x8.i64, Qtmp8x8.i64, d_out[token_depth]);
|
||||||
|
#else
|
||||||
|
UNREACHABLE_CODE
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -629,17 +736,36 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
|||||||
// disable rtz conversion due to its impact on accuracy.
|
// disable rtz conversion due to its impact on accuracy.
|
||||||
constexpr bool LOGITS_RTZ_CONVERSION = false;
|
constexpr bool LOGITS_RTZ_CONVERSION = false;
|
||||||
|
|
||||||
|
#if defined(__HIP__FP8MFMA__)
|
||||||
|
int rowid_8x8 = rowid / 2;
|
||||||
|
int offset = rowid % 2;
|
||||||
|
#endif
|
||||||
|
|
||||||
// write logits to shared mem
|
// write logits to shared mem
|
||||||
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
|
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
|
||||||
d_out[token_depth] *= inv_sum_scale;
|
d_out[token_depth] *= inv_sum_scale;
|
||||||
if constexpr (LOGITS_RTZ_CONVERSION) {
|
if constexpr (MFMA_TYPE != MFMAType::Fp8) {
|
||||||
// use rtz conversion for better performance, with negligible impact on
|
if constexpr (LOGITS_RTZ_CONVERSION) {
|
||||||
// accuracy
|
// use rtz conversion for better performance, with negligible impact on
|
||||||
shared_logits[warpid][token_depth][lane16id][rowid] =
|
// accuracy
|
||||||
from_floatx4_rtz<scalar_t>(d_out[token_depth]);
|
shared_logits[warpid][token_depth][lane16id][rowid] =
|
||||||
|
from_floatx4_rtz<scalar_t>(d_out[token_depth]);
|
||||||
|
} else {
|
||||||
|
shared_logits[warpid][token_depth][lane16id][rowid] =
|
||||||
|
from_floatx4<scalar_t>(d_out[token_depth]);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
shared_logits[warpid][token_depth][lane16id][rowid] =
|
#if defined(__HIP__FP8MFMA__)
|
||||||
from_floatx4<scalar_t>(d_out[token_depth]);
|
// cast _B16x4* to _B8x8*
|
||||||
|
_T8x8& logits_8x8 = *reinterpret_cast<_T8x8*>(
|
||||||
|
&shared_logits[warpid][token_depth][lane16id][rowid_8x8]);
|
||||||
|
logits_8x8.b16x4[offset * 2] = __builtin_amdgcn_cvt_pk_fp8_f32(
|
||||||
|
d_out[token_depth][0], d_out[token_depth][1], 0, false);
|
||||||
|
logits_8x8.b16x4[offset * 2 + 1] = __builtin_amdgcn_cvt_pk_fp8_f32(
|
||||||
|
d_out[token_depth][2], d_out[token_depth][3], 0, false);
|
||||||
|
#else
|
||||||
|
UNREACHABLE_CODE
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -692,19 +818,42 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
|||||||
_B8x16 Vtmp8x16 = *reinterpret_cast<_B8x16*>(&Vtmp);
|
_B8x16 Vtmp8x16 = *reinterpret_cast<_B8x16*>(&Vtmp);
|
||||||
for (int j = 0; j < ELEMS16_ELEMS8_RATIO; j++) {
|
for (int j = 0; j < ELEMS16_ELEMS8_RATIO; j++) {
|
||||||
_B8x8 Vtmp8x8 = Vtmp8x16.xy[j];
|
_B8x8 Vtmp8x8 = Vtmp8x16.xy[j];
|
||||||
_B16x8 Vlocaltmp = convert_b8x8_custom<scalar_t>(Vtmp8x8);
|
if constexpr (MFMA_TYPE == MFMAType::F16) {
|
||||||
for (int i = 0; i < ELEMS8_ELEMS4_RATIO; i++) {
|
_B16x8 Vlocaltmp = convert_b8x8_custom<scalar_t>(Vtmp8x8);
|
||||||
const int offset =
|
for (int i = 0; i < ELEMS8_ELEMS4_RATIO; i++) {
|
||||||
rowid * ELEMS16_ELEMS8_RATIO * ELEMS8_ELEMS4_RATIO +
|
const int offset =
|
||||||
j * ELEMS8_ELEMS4_RATIO + i;
|
rowid * ELEMS16_ELEMS8_RATIO * ELEMS8_ELEMS4_RATIO +
|
||||||
const int offset1 = offset % ROWS_PER_WARP;
|
j * ELEMS8_ELEMS4_RATIO + i;
|
||||||
const int offset2 = offset / ROWS_PER_WARP;
|
const int offset1 = offset % ROWS_PER_WARP;
|
||||||
// output format is 16 qheads across 16 lanes, 16 head elems
|
const int offset2 = offset / ROWS_PER_WARP;
|
||||||
// spread across 4 rows
|
// output format is 16 qheads across 16 lanes, 16 head elems
|
||||||
tmp_out = gcn_mfma16x16x16_instr<scalar_t, 0, 0, 0>(
|
// spread across 4 rows
|
||||||
Vlocaltmp.xy[i],
|
tmp_out = gcn_mfma16x16x16_instr<scalar_t, 0, 0, 0>(
|
||||||
shared_logits[vtoken_depth][offset2][lane16id][offset1],
|
Vlocaltmp.xy[i],
|
||||||
tmp_out);
|
shared_logits[vtoken_depth][offset2][lane16id][offset1],
|
||||||
|
tmp_out);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
#if defined(__HIP__FP8MFMA__)
|
||||||
|
for (int i = 0; i < ELEMS8_ELEMS4_RATIO / 2; i++) {
|
||||||
|
const int offset =
|
||||||
|
rowid * ELEMS16_ELEMS8_RATIO * ELEMS8_ELEMS4_RATIO +
|
||||||
|
j * ELEMS8_ELEMS4_RATIO + i;
|
||||||
|
const int offset1 = (offset % ROWS_PER_WARP) / 2;
|
||||||
|
const int offset2 = offset / ROWS_PER_WARP;
|
||||||
|
// output format is 16 qheads across 16 lanes, 16 head elems
|
||||||
|
// spread across 4 rows
|
||||||
|
tmp_out = gcn_mfma16x16x32_instr<__hip_fp8_e4m3, 0, 0, 0>(
|
||||||
|
reinterpret_cast<_T8x8*>(&Vtmp8x8)->i64,
|
||||||
|
reinterpret_cast<_T8x8*>(
|
||||||
|
&shared_logits[vtoken_depth][offset2][lane16id]
|
||||||
|
[offset1])
|
||||||
|
->i64,
|
||||||
|
tmp_out);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
UNREACHABLE_CODE
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1570,7 +1719,8 @@ __device__ __forceinline__ _B16x8 from_floatx8(const floatx8& inp) {
|
|||||||
// clang-format off
|
// clang-format off
|
||||||
template <typename scalar_t, typename cache_t,
|
template <typename scalar_t, typename cache_t,
|
||||||
vllm::Fp8KVCacheDataType KV_DTYPE, typename OUTT, int BLOCK_SIZE,
|
vllm::Fp8KVCacheDataType KV_DTYPE, typename OUTT, int BLOCK_SIZE,
|
||||||
int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED, int GQA_RATIO>
|
int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED, int GQA_RATIO,
|
||||||
|
MFMAType MFMA_TYPE>
|
||||||
__global__
|
__global__
|
||||||
__launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
__launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||||
@ -2337,7 +2487,8 @@ __device__ __forceinline__ _B16x8 from_floatx8(const floatx8& inp) {
|
|||||||
// clang-format off
|
// clang-format off
|
||||||
template <typename scalar_t, typename cache_t,
|
template <typename scalar_t, typename cache_t,
|
||||||
vllm::Fp8KVCacheDataType KV_DTYPE, typename OUTT, int BLOCK_SIZE,
|
vllm::Fp8KVCacheDataType KV_DTYPE, typename OUTT, int BLOCK_SIZE,
|
||||||
int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED, int GQA_RATIO>
|
int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED, int GQA_RATIO,
|
||||||
|
MFMAType MFMA_TYPE>
|
||||||
__global__
|
__global__
|
||||||
__launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
__launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||||
@ -2969,7 +3120,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
|||||||
template <typename scalar_t, typename cache_t,
|
template <typename scalar_t, typename cache_t,
|
||||||
vllm::Fp8KVCacheDataType KV_DTYPE, typename OUTT, int BLOCK_SIZE,
|
vllm::Fp8KVCacheDataType KV_DTYPE, typename OUTT, int BLOCK_SIZE,
|
||||||
int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED,
|
int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED,
|
||||||
int GQA_RATIO>
|
int GQA_RATIO, MFMAType MFMA_TYPE>
|
||||||
__global__
|
__global__
|
||||||
__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||||
@ -3041,7 +3192,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
|||||||
#define LAUNCH_CUSTOM_ATTENTION_MFMA16(GQA_RATIO) \
|
#define LAUNCH_CUSTOM_ATTENTION_MFMA16(GQA_RATIO) \
|
||||||
paged_attention_ll4mi_QKV_mfma16_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \
|
paged_attention_ll4mi_QKV_mfma16_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \
|
||||||
HEAD_SIZE, NTHR, ALIBI_ENABLED, \
|
HEAD_SIZE, NTHR, ALIBI_ENABLED, \
|
||||||
GQA_RATIO> \
|
GQA_RATIO, MFMA_TYPE> \
|
||||||
<<<grid, block, 0, stream>>>( \
|
<<<grid, block, 0, stream>>>( \
|
||||||
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
|
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
|
||||||
block_tables_ptr, seq_lens_ptr, query_start_loc_ptr, \
|
block_tables_ptr, seq_lens_ptr, query_start_loc_ptr, \
|
||||||
@ -3069,7 +3220,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
|||||||
|
|
||||||
template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE,
|
template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE,
|
||||||
int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, int PARTITION_SIZE_OLD,
|
int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, int PARTITION_SIZE_OLD,
|
||||||
bool ALIBI_ENABLED>
|
bool ALIBI_ENABLED, MFMAType MFMA_TYPE>
|
||||||
void paged_attention_custom_launcher(
|
void paged_attention_custom_launcher(
|
||||||
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
||||||
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||||
@ -3225,7 +3376,7 @@ void paged_attention_custom_launcher(
|
|||||||
|
|
||||||
template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE,
|
template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE,
|
||||||
int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, int PARTITION_SIZE_OLD,
|
int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, int PARTITION_SIZE_OLD,
|
||||||
bool ALIBI_ENABLED>
|
bool ALIBI_ENABLED, MFMAType MFMA_TYPE>
|
||||||
void paged_attention_custom_launcher_navi(
|
void paged_attention_custom_launcher_navi(
|
||||||
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
||||||
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||||
@ -3397,74 +3548,77 @@ void paged_attention_custom_launcher_navi(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \
|
#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \
|
||||||
PSIZE, ALIBI_ENABLED) \
|
PSIZE, ALIBI_ENABLED, MFMA_TYPE) \
|
||||||
if (!is_navi) { \
|
if (!is_navi) { \
|
||||||
paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
|
paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
|
||||||
OUTT, PSIZE, ALIBI_ENABLED>( \
|
OUTT, PSIZE, ALIBI_ENABLED, MFMA_TYPE>( \
|
||||||
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
||||||
num_kv_heads, scale, block_tables, seq_lens, query_start_loc, \
|
num_kv_heads, scale, block_tables, seq_lens, query_start_loc, \
|
||||||
max_seq_len, alibi_slopes, k_scale, v_scale, fp8_out_scale); \
|
max_seq_len, alibi_slopes, k_scale, v_scale, fp8_out_scale); \
|
||||||
} else { \
|
} else { \
|
||||||
paged_attention_custom_launcher_navi< \
|
paged_attention_custom_launcher_navi<T, KVT, KV_DTYPE, BLK_SIZE, \
|
||||||
T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, ALIBI_ENABLED>( \
|
HEAD_SIZE, OUTT, PSIZE, \
|
||||||
|
ALIBI_ENABLED, MFMA_TYPE>( \
|
||||||
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
||||||
num_kv_heads, scale, block_tables, seq_lens, query_start_loc, \
|
num_kv_heads, scale, block_tables, seq_lens, query_start_loc, \
|
||||||
max_seq_len, alibi_slopes, k_scale, v_scale); \
|
max_seq_len, alibi_slopes, k_scale, v_scale); \
|
||||||
}
|
}
|
||||||
|
|
||||||
#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
|
#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
|
||||||
OUTT, PSIZE) \
|
OUTT, PSIZE, MFMA_TYPE) \
|
||||||
if (alibi_slopes) { \
|
if (alibi_slopes) { \
|
||||||
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \
|
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \
|
||||||
true); \
|
true, MFMA_TYPE); \
|
||||||
} else { \
|
} else { \
|
||||||
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \
|
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \
|
||||||
false); \
|
false, MFMA_TYPE); \
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(__HIPCC__) && defined(__gfx90a__)
|
#if defined(__HIPCC__) && defined(__gfx90a__)
|
||||||
#define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \
|
#define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
|
||||||
|
MFMA_TYPE) \
|
||||||
if (fp8_out_scale) { \
|
if (fp8_out_scale) { \
|
||||||
TORCH_CHECK(false, "fp8 out scale unsupported for gfx90a"); \
|
TORCH_CHECK(false, "fp8 out scale unsupported for gfx90a"); \
|
||||||
} else { \
|
} else { \
|
||||||
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \
|
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \
|
||||||
256); \
|
256, MFMA_TYPE); \
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
#define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \
|
#define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
|
||||||
|
MFMA_TYPE) \
|
||||||
if (fp8_out_scale) { \
|
if (fp8_out_scale) { \
|
||||||
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
|
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
|
||||||
uint8_t, 256); \
|
uint8_t, 256, MFMA_TYPE); \
|
||||||
} else { \
|
} else { \
|
||||||
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \
|
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \
|
||||||
256); \
|
256, MFMA_TYPE); \
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \
|
#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE, MFMA_TYPE) \
|
||||||
switch (block_size) { \
|
switch (block_size) { \
|
||||||
case 16: \
|
case 16: \
|
||||||
CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 16, HEAD_SIZE); \
|
CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 16, HEAD_SIZE, MFMA_TYPE); \
|
||||||
break; \
|
break; \
|
||||||
case 32: \
|
case 32: \
|
||||||
CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 32, HEAD_SIZE); \
|
CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 32, HEAD_SIZE, MFMA_TYPE); \
|
||||||
break; \
|
break; \
|
||||||
default: \
|
default: \
|
||||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||||
break; \
|
break; \
|
||||||
}
|
}
|
||||||
|
|
||||||
#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T, KVT, KV_DTYPE) \
|
#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T, KVT, KV_DTYPE, MFMA_TYPE) \
|
||||||
switch (head_size) { \
|
switch (head_size) { \
|
||||||
case 64: \
|
case 64: \
|
||||||
CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 64); \
|
CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 64, MFMA_TYPE); \
|
||||||
break; \
|
break; \
|
||||||
case 128: \
|
case 128: \
|
||||||
CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 128); \
|
CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 128, MFMA_TYPE); \
|
||||||
break; \
|
break; \
|
||||||
default: \
|
default: \
|
||||||
TORCH_CHECK(false, "Unsupported head size: ", head_size); \
|
TORCH_CHECK(false, "Unsupported head size: ", head_size); \
|
||||||
break; \
|
break; \
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_navi_gpu() {
|
bool is_navi_gpu() {
|
||||||
@ -3503,28 +3657,43 @@ void paged_attention(
|
|||||||
const std::optional<torch::Tensor>& alibi_slopes,
|
const std::optional<torch::Tensor>& alibi_slopes,
|
||||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||||
torch::Tensor& v_scale,
|
torch::Tensor& v_scale,
|
||||||
const std::optional<torch::Tensor>& fp8_out_scale) {
|
const std::optional<torch::Tensor>& fp8_out_scale,
|
||||||
|
const std::string& mfma_type) {
|
||||||
// clang-format on
|
// clang-format on
|
||||||
bool is_navi = is_navi_gpu();
|
bool is_navi = is_navi_gpu();
|
||||||
|
|
||||||
const int head_size = query.size(2);
|
const int head_size = query.size(2);
|
||||||
if (kv_cache_dtype == "auto") {
|
if (kv_cache_dtype == "auto") {
|
||||||
if (query.dtype() == at::ScalarType::Half) {
|
if (query.dtype() == at::ScalarType::Half) {
|
||||||
CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, _Float16,
|
CALL_CUSTOM_LAUNCHER_BLK_HEAD(
|
||||||
vllm::Fp8KVCacheDataType::kAuto);
|
_Float16, _Float16, vllm::Fp8KVCacheDataType::kAuto, MFMAType::F16);
|
||||||
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||||
CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, __hip_bfloat16,
|
CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, __hip_bfloat16,
|
||||||
vllm::Fp8KVCacheDataType::kAuto);
|
vllm::Fp8KVCacheDataType::kAuto,
|
||||||
|
MFMAType::F16);
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||||
}
|
}
|
||||||
} else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") {
|
} else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") {
|
||||||
if (query.dtype() == at::ScalarType::Half) {
|
if (query.dtype() == at::ScalarType::Half) {
|
||||||
CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, uint8_t,
|
if (mfma_type == "fp8") {
|
||||||
vllm::Fp8KVCacheDataType::kFp8E4M3);
|
CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, uint8_t,
|
||||||
|
vllm::Fp8KVCacheDataType::kFp8E4M3,
|
||||||
|
MFMAType::Fp8);
|
||||||
|
} else {
|
||||||
|
CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, uint8_t,
|
||||||
|
vllm::Fp8KVCacheDataType::kFp8E4M3,
|
||||||
|
MFMAType::F16);
|
||||||
|
}
|
||||||
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||||
CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, uint8_t,
|
if (mfma_type == "fp8") {
|
||||||
vllm::Fp8KVCacheDataType::kFp8E4M3);
|
CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, uint8_t,
|
||||||
|
vllm::Fp8KVCacheDataType::kFp8E4M3,
|
||||||
|
MFMAType::Fp8);
|
||||||
|
} else {
|
||||||
|
CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, uint8_t,
|
||||||
|
vllm::Fp8KVCacheDataType::kFp8E4M3,
|
||||||
|
MFMAType::F16);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||||
}
|
}
|
||||||
|
|||||||
@ -19,4 +19,5 @@ void paged_attention(
|
|||||||
const std::optional<torch::Tensor>& query_start_loc, int64_t block_size,
|
const std::optional<torch::Tensor>& query_start_loc, int64_t block_size,
|
||||||
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
|
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
|
||||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||||
torch::Tensor& v_scale, const std::optional<torch::Tensor>& fp8_out_scale);
|
torch::Tensor& v_scale, const std::optional<torch::Tensor>& fp8_out_scale,
|
||||||
|
const std::string& mfma_type);
|
||||||
|
|||||||
@ -48,7 +48,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
|
|||||||
" Tensor? alibi_slopes,"
|
" Tensor? alibi_slopes,"
|
||||||
" str kv_cache_dtype,"
|
" str kv_cache_dtype,"
|
||||||
" Tensor k_scale, Tensor v_scale,"
|
" Tensor k_scale, Tensor v_scale,"
|
||||||
" Tensor? fp8_out_scale) -> ()");
|
" Tensor? fp8_out_scale,"
|
||||||
|
" str mfma_type) -> ()");
|
||||||
rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention);
|
rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -117,13 +117,14 @@ def paged_attention_rocm(
|
|||||||
k_scale: torch.Tensor,
|
k_scale: torch.Tensor,
|
||||||
v_scale: torch.Tensor,
|
v_scale: torch.Tensor,
|
||||||
fp8_out_scale: Optional[torch.Tensor] = None,
|
fp8_out_scale: Optional[torch.Tensor] = None,
|
||||||
|
mfma_type: str = "fp8" if envs.VLLM_ROCM_FP8_MFMA_PAGE_ATTN else "f16",
|
||||||
) -> None:
|
) -> None:
|
||||||
torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
|
torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
|
||||||
key_cache, value_cache, num_kv_heads,
|
key_cache, value_cache, num_kv_heads,
|
||||||
scale, block_tables, seq_lens,
|
scale, block_tables, seq_lens,
|
||||||
query_start_loc, block_size, max_seq_len,
|
query_start_loc, block_size, max_seq_len,
|
||||||
alibi_slopes, kv_cache_dtype, k_scale,
|
alibi_slopes, kv_cache_dtype, k_scale,
|
||||||
v_scale, fp8_out_scale)
|
v_scale, fp8_out_scale, mfma_type)
|
||||||
|
|
||||||
|
|
||||||
def mla_decode_kvcache_cpu(
|
def mla_decode_kvcache_cpu(
|
||||||
|
|||||||
@ -167,6 +167,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_HAS_FLASHINFER_CUBIN: bool = False
|
VLLM_HAS_FLASHINFER_CUBIN: bool = False
|
||||||
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False
|
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False
|
||||||
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False
|
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False
|
||||||
|
VLLM_ROCM_FP8_MFMA_PAGE_ATTN: bool = False
|
||||||
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS: bool = False
|
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS: bool = False
|
||||||
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False
|
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False
|
||||||
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
|
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
|
||||||
@ -1219,6 +1220,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_ENABLE_RESPONSES_API_STORE":
|
"VLLM_ENABLE_RESPONSES_API_STORE":
|
||||||
lambda: bool(int(os.getenv("VLLM_ENABLE_RESPONSES_API_STORE", "0"))),
|
lambda: bool(int(os.getenv("VLLM_ENABLE_RESPONSES_API_STORE", "0"))),
|
||||||
|
|
||||||
|
# If set, use the fp8 mfma in rocm paged attention.
|
||||||
|
"VLLM_ROCM_FP8_MFMA_PAGE_ATTN":
|
||||||
|
lambda: bool(int(os.getenv("VLLM_ROCM_FP8_MFMA_PAGE_ATTN", "0"))),
|
||||||
|
|
||||||
# Whether to use pytorch symmetric memory for allreduce
|
# Whether to use pytorch symmetric memory for allreduce
|
||||||
"VLLM_ALLREDUCE_USE_SYMM_MEM":
|
"VLLM_ALLREDUCE_USE_SYMM_MEM":
|
||||||
lambda: bool(int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "0"))),
|
lambda: bool(int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "0"))),
|
||||||
@ -1340,6 +1345,7 @@ def compute_hash() -> str:
|
|||||||
"VLLM_ROCM_QUICK_REDUCE_QUANTIZATION",
|
"VLLM_ROCM_QUICK_REDUCE_QUANTIZATION",
|
||||||
"VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16",
|
"VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16",
|
||||||
"VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB",
|
"VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB",
|
||||||
|
"VLLM_ROCM_FP8_MFMA_PAGE_ATTN",
|
||||||
]
|
]
|
||||||
for key in environment_variables_to_hash:
|
for key in environment_variables_to_hash:
|
||||||
# if this goes out of sync with environment_variables,
|
# if this goes out of sync with environment_variables,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user