[Kernel] Fix conflicting macro names for gguf kernels (#15456)

Signed-off-by: SzymonOzog <szymon.ozog@gmail.com>
This commit is contained in:
Szymon Ożóg 2025-03-25 14:50:49 +01:00 committed by GitHub
parent 3f04a7fbf2
commit a608160027
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 90 additions and 90 deletions

View File

@ -375,25 +375,25 @@ torch::Tensor ggml_moe_a8(torch::Tensor X, // input
int64_t ggml_moe_get_block_size(int64_t type) { int64_t ggml_moe_get_block_size(int64_t type) {
switch (type) { switch (type) {
case 2: case 2:
return MMQ_X_Q4_0; return MOE_X_Q4_0;
case 3: case 3:
return MMQ_X_Q4_1; return MOE_X_Q4_1;
case 6: case 6:
return MMQ_X_Q5_0; return MOE_X_Q5_0;
case 7: case 7:
return MMQ_X_Q5_1; return MOE_X_Q5_1;
case 8: case 8:
return MMQ_X_Q8_0; return MOE_X_Q8_0;
case 10: case 10:
return MMQ_X_Q2_K; return MOE_X_Q2_K;
case 11: case 11:
return MMQ_X_Q3_K; return MOE_X_Q3_K;
case 12: case 12:
return MMQ_X_Q4_K; return MOE_X_Q4_K;
case 13: case 13:
return MMQ_X_Q5_K; return MOE_X_Q5_K;
case 14: case 14:
return MMQ_X_Q6_K; return MOE_X_Q6_K;
} }
return 0; return 0;
} }

View File

@ -129,12 +129,12 @@ static __device__ __forceinline__ void moe_q(
} }
#if defined(USE_ROCM) #if defined(USE_ROCM)
#define MMQ_X_Q4_0 64 #define MOE_X_Q4_0 64
#define MMQ_Y_Q4_0 128 #define MOE_Y_Q4_0 128
#define NWARPS_Q4_0 8 #define NWARPS_Q4_0 8
#else #else
#define MMQ_X_Q4_0 4 #define MOE_X_Q4_0 4
#define MMQ_Y_Q4_0 32 #define MOE_Y_Q4_0 32
#define NWARPS_Q4_0 4 #define NWARPS_Q4_0 4
#endif #endif
@ -149,8 +149,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q4_0, 2)
const int exp_stride, const int ncols_x, const int nrows_x, const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int ncols_y, const int nrows_y, const int nrows_dst,
const int top_k) { const int top_k) {
const int mmq_x = MMQ_X_Q4_0; const int mmq_x = MOE_X_Q4_0;
const int mmq_y = MMQ_Y_Q4_0; const int mmq_y = MOE_Y_Q4_0;
const int nwarps = NWARPS_Q4_0; const int nwarps = NWARPS_Q4_0;
moe_q<scalar_t, QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps, moe_q<scalar_t, QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps,
@ -167,8 +167,8 @@ static void ggml_moe_q4_0_q8_1_cuda(
const int exp_stride, const int ncols_x, const int nrows_x, const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k, const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
const int tokens_post_padded, cudaStream_t stream) { const int tokens_post_padded, cudaStream_t stream) {
int mmq_x = MMQ_X_Q4_0; int mmq_x = MOE_X_Q4_0;
int mmq_y = MMQ_Y_Q4_0; int mmq_y = MOE_Y_Q4_0;
int nwarps = NWARPS_Q4_0; int nwarps = NWARPS_Q4_0;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
@ -190,12 +190,12 @@ static void ggml_moe_q4_0_q8_1_cuda(
} }
#if defined(USE_ROCM) #if defined(USE_ROCM)
#define MMQ_X_Q4_1 64 #define MOE_X_Q4_1 64
#define MMQ_Y_Q4_1 128 #define MOE_Y_Q4_1 128
#define NWARPS_Q4_1 8 #define NWARPS_Q4_1 8
#else #else
#define MMQ_X_Q4_1 4 #define MOE_X_Q4_1 4
#define MMQ_Y_Q4_1 32 #define MOE_Y_Q4_1 32
#define NWARPS_Q4_1 4 #define NWARPS_Q4_1 4
#endif #endif
@ -210,8 +210,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q4_1, 2)
const int exp_stride, const int ncols_x, const int nrows_x, const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int ncols_y, const int nrows_y, const int nrows_dst,
const int top_k) { const int top_k) {
const int mmq_x = MMQ_X_Q4_1; const int mmq_x = MOE_X_Q4_1;
const int mmq_y = MMQ_Y_Q4_1; const int mmq_y = MOE_Y_Q4_1;
const int nwarps = NWARPS_Q4_1; const int nwarps = NWARPS_Q4_1;
moe_q<scalar_t, QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps, moe_q<scalar_t, QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps,
@ -228,8 +228,8 @@ static void ggml_moe_q4_1_q8_1_cuda(
const int exp_stride, const int ncols_x, const int nrows_x, const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k, const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
const int tokens_post_padded, cudaStream_t stream) { const int tokens_post_padded, cudaStream_t stream) {
int mmq_x = MMQ_X_Q4_1; int mmq_x = MOE_X_Q4_1;
int mmq_y = MMQ_Y_Q4_1; int mmq_y = MOE_Y_Q4_1;
int nwarps = NWARPS_Q4_1; int nwarps = NWARPS_Q4_1;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
@ -251,12 +251,12 @@ static void ggml_moe_q4_1_q8_1_cuda(
} }
#if defined(USE_ROCM) #if defined(USE_ROCM)
#define MMQ_X_Q5_0 64 #define MOE_X_Q5_0 64
#define MMQ_Y_Q5_0 128 #define MOE_Y_Q5_0 128
#define NWARPS_Q5_0 8 #define NWARPS_Q5_0 8
#else #else
#define MMQ_X_Q5_0 4 #define MOE_X_Q5_0 4
#define MMQ_Y_Q5_0 32 #define MOE_Y_Q5_0 32
#define NWARPS_Q5_0 4 #define NWARPS_Q5_0 4
#endif #endif
@ -271,8 +271,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q5_0, 2)
const int exp_stride, const int ncols_x, const int nrows_x, const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int ncols_y, const int nrows_y, const int nrows_dst,
const int top_k) { const int top_k) {
const int mmq_x = MMQ_X_Q5_0; const int mmq_x = MOE_X_Q5_0;
const int mmq_y = MMQ_Y_Q5_0; const int mmq_y = MOE_Y_Q5_0;
const int nwarps = NWARPS_Q5_0; const int nwarps = NWARPS_Q5_0;
moe_q<scalar_t, QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps, moe_q<scalar_t, QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps,
@ -289,8 +289,8 @@ static void ggml_moe_q5_0_q8_1_cuda(
const int exp_stride, const int ncols_x, const int nrows_x, const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k, const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
const int tokens_post_padded, cudaStream_t stream) { const int tokens_post_padded, cudaStream_t stream) {
const int mmq_x = MMQ_X_Q5_0; const int mmq_x = MOE_X_Q5_0;
const int mmq_y = MMQ_Y_Q5_0; const int mmq_y = MOE_Y_Q5_0;
const int nwarps = NWARPS_Q5_0; const int nwarps = NWARPS_Q5_0;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
@ -312,12 +312,12 @@ static void ggml_moe_q5_0_q8_1_cuda(
} }
#if defined(USE_ROCM) #if defined(USE_ROCM)
#define MMQ_X_Q5_1 64 #define MOE_X_Q5_1 64
#define MMQ_Y_Q5_1 128 #define MOE_Y_Q5_1 128
#define NWARPS_Q5_1 8 #define NWARPS_Q5_1 8
#else #else
#define MMQ_X_Q5_1 4 #define MOE_X_Q5_1 4
#define MMQ_Y_Q5_1 32 #define MOE_Y_Q5_1 32
#define NWARPS_Q5_1 4 #define NWARPS_Q5_1 4
#endif #endif
@ -332,8 +332,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q5_1, 2)
const int exp_stride, const int ncols_x, const int nrows_x, const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int ncols_y, const int nrows_y, const int nrows_dst,
const int top_k) { const int top_k) {
const int mmq_x = MMQ_X_Q5_1; const int mmq_x = MOE_X_Q5_1;
const int mmq_y = MMQ_Y_Q5_1; const int mmq_y = MOE_Y_Q5_1;
const int nwarps = NWARPS_Q5_1; const int nwarps = NWARPS_Q5_1;
moe_q<scalar_t, QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps, moe_q<scalar_t, QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps,
@ -350,8 +350,8 @@ static void ggml_moe_q5_1_q8_1_cuda(
const int exp_stride, const int ncols_x, const int nrows_x, const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k, const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
const int tokens_post_padded, cudaStream_t stream) { const int tokens_post_padded, cudaStream_t stream) {
const int mmq_x = MMQ_X_Q5_1; const int mmq_x = MOE_X_Q5_1;
const int mmq_y = MMQ_Y_Q5_1; const int mmq_y = MOE_Y_Q5_1;
const int nwarps = NWARPS_Q5_1; const int nwarps = NWARPS_Q5_1;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
@ -373,12 +373,12 @@ static void ggml_moe_q5_1_q8_1_cuda(
} }
#if defined(USE_ROCM) #if defined(USE_ROCM)
#define MMQ_X_Q8_0 64 #define MOE_X_Q8_0 64
#define MMQ_Y_Q8_0 128 #define MOE_Y_Q8_0 128
#define NWARPS_Q8_0 8 #define NWARPS_Q8_0 8
#else #else
#define MMQ_X_Q8_0 4 #define MOE_X_Q8_0 4
#define MMQ_Y_Q8_0 32 #define MOE_Y_Q8_0 32
#define NWARPS_Q8_0 4 #define NWARPS_Q8_0 4
#endif #endif
@ -393,8 +393,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q8_0, 2)
const int exp_stride, const int ncols_x, const int nrows_x, const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int ncols_y, const int nrows_y, const int nrows_dst,
const int top_k) { const int top_k) {
const int mmq_x = MMQ_X_Q8_0; const int mmq_x = MOE_X_Q8_0;
const int mmq_y = MMQ_Y_Q8_0; const int mmq_y = MOE_Y_Q8_0;
const int nwarps = NWARPS_Q8_0; const int nwarps = NWARPS_Q8_0;
moe_q<scalar_t, QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps, moe_q<scalar_t, QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps,
@ -411,8 +411,8 @@ static void ggml_moe_q8_0_q8_1_cuda(
const int exp_stride, const int ncols_x, const int nrows_x, const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k, const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
const int tokens_post_padded, cudaStream_t stream) { const int tokens_post_padded, cudaStream_t stream) {
const int mmq_x = MMQ_X_Q8_0; const int mmq_x = MOE_X_Q8_0;
const int mmq_y = MMQ_Y_Q8_0; const int mmq_y = MOE_Y_Q8_0;
const int nwarps = NWARPS_Q8_0; const int nwarps = NWARPS_Q8_0;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
@ -434,12 +434,12 @@ static void ggml_moe_q8_0_q8_1_cuda(
} }
#if defined(USE_ROCM) #if defined(USE_ROCM)
#define MMQ_X_Q2_K 64 #define MOE_X_Q2_K 64
#define MMQ_Y_Q2_K 128 #define MOE_Y_Q2_K 128
#define NWARPS_Q2_K 8 #define NWARPS_Q2_K 8
#else #else
#define MMQ_X_Q2_K 4 #define MOE_X_Q2_K 4
#define MMQ_Y_Q2_K 32 #define MOE_Y_Q2_K 32
#define NWARPS_Q2_K 4 #define NWARPS_Q2_K 4
#endif #endif
@ -454,8 +454,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q2_K, 2)
const int exp_stride, const int ncols_x, const int nrows_x, const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int ncols_y, const int nrows_y, const int nrows_dst,
const int top_k) { const int top_k) {
const int mmq_x = MMQ_X_Q2_K; const int mmq_x = MOE_X_Q2_K;
const int mmq_y = MMQ_Y_Q2_K; const int mmq_y = MOE_Y_Q2_K;
const int nwarps = NWARPS_Q2_K; const int nwarps = NWARPS_Q2_K;
moe_q<scalar_t, QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps, moe_q<scalar_t, QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps,
@ -472,8 +472,8 @@ static void ggml_moe_q2_K_q8_1_cuda(
const int exp_stride, const int ncols_x, const int nrows_x, const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k, const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
const int tokens_post_padded, cudaStream_t stream) { const int tokens_post_padded, cudaStream_t stream) {
const int mmq_x = MMQ_X_Q2_K; const int mmq_x = MOE_X_Q2_K;
const int mmq_y = MMQ_Y_Q2_K; const int mmq_y = MOE_Y_Q2_K;
const int nwarps = NWARPS_Q2_K; const int nwarps = NWARPS_Q2_K;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
@ -495,12 +495,12 @@ static void ggml_moe_q2_K_q8_1_cuda(
} }
#if defined(USE_ROCM) #if defined(USE_ROCM)
#define MMQ_X_Q3_K 64 #define MOE_X_Q3_K 64
#define MMQ_Y_Q3_K 128 #define MOE_Y_Q3_K 128
#define NWARPS_Q3_K 8 #define NWARPS_Q3_K 8
#else #else
#define MMQ_X_Q3_K 4 #define MOE_X_Q3_K 4
#define MMQ_Y_Q3_K 32 #define MOE_Y_Q3_K 32
#define NWARPS_Q3_K 4 #define NWARPS_Q3_K 4
#endif #endif
@ -516,8 +516,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q3_K, 2)
const int ncols_y, const int nrows_y, const int nrows_dst, const int ncols_y, const int nrows_y, const int nrows_dst,
const int top_k) { const int top_k) {
const int mmq_x = MMQ_X_Q3_K; const int mmq_x = MOE_X_Q3_K;
const int mmq_y = MMQ_Y_Q3_K; const int mmq_y = MOE_Y_Q3_K;
const int nwarps = NWARPS_Q3_K; const int nwarps = NWARPS_Q3_K;
moe_q<scalar_t, QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps, moe_q<scalar_t, QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps,
@ -533,8 +533,8 @@ static void ggml_moe_q3_K_q8_1_cuda(
const int exp_stride, const int ncols_x, const int nrows_x, const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k, const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
const int tokens_post_padded, cudaStream_t stream) { const int tokens_post_padded, cudaStream_t stream) {
const int mmq_x = MMQ_X_Q3_K; const int mmq_x = MOE_X_Q3_K;
const int mmq_y = MMQ_Y_Q3_K; const int mmq_y = MOE_Y_Q3_K;
const int nwarps = NWARPS_Q3_K; const int nwarps = NWARPS_Q3_K;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
@ -556,12 +556,12 @@ static void ggml_moe_q3_K_q8_1_cuda(
} }
#if defined(USE_ROCM) #if defined(USE_ROCM)
#define MMQ_X_Q4_K 64 #define MOE_X_Q4_K 64
#define MMQ_Y_Q4_K 128 #define MOE_Y_Q4_K 128
#define NWARPS_Q4_K 8 #define NWARPS_Q4_K 8
#else #else
#define MMQ_X_Q4_K 4 #define MOE_X_Q4_K 4
#define MMQ_Y_Q4_K 32 #define MOE_Y_Q4_K 32
#define NWARPS_Q4_K 4 #define NWARPS_Q4_K 4
#endif #endif
@ -576,8 +576,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q4_K, 2)
const int exp_stride, const int ncols_x, const int nrows_x, const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int ncols_y, const int nrows_y, const int nrows_dst,
const int top_k) { const int top_k) {
const int mmq_x = MMQ_X_Q4_K; const int mmq_x = MOE_X_Q4_K;
const int mmq_y = MMQ_Y_Q4_K; const int mmq_y = MOE_Y_Q4_K;
const int nwarps = NWARPS_Q4_K; const int nwarps = NWARPS_Q4_K;
moe_q<scalar_t, QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps, moe_q<scalar_t, QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps,
@ -594,8 +594,8 @@ static void ggml_moe_q4_K_q8_1_cuda(
const int exp_stride, const int ncols_x, const int nrows_x, const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k, const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
const int tokens_post_padded, cudaStream_t stream) { const int tokens_post_padded, cudaStream_t stream) {
const int mmq_x = MMQ_X_Q4_K; const int mmq_x = MOE_X_Q4_K;
const int mmq_y = MMQ_Y_Q4_K; const int mmq_y = MOE_Y_Q4_K;
const int nwarps = NWARPS_Q4_K; const int nwarps = NWARPS_Q4_K;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
@ -617,12 +617,12 @@ static void ggml_moe_q4_K_q8_1_cuda(
} }
#if defined(USE_ROCM) #if defined(USE_ROCM)
#define MMQ_X_Q5_K 64 #define MOE_X_Q5_K 64
#define MMQ_Y_Q5_K 128 #define MOE_Y_Q5_K 128
#define NWARPS_Q5_K 8 #define NWARPS_Q5_K 8
#else #else
#define MMQ_X_Q5_K 4 #define MOE_X_Q5_K 4
#define MMQ_Y_Q5_K 32 #define MOE_Y_Q5_K 32
#define NWARPS_Q5_K 4 #define NWARPS_Q5_K 4
#endif #endif
@ -637,8 +637,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q5_K, 2)
const int exp_stride, const int ncols_x, const int nrows_x, const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int ncols_y, const int nrows_y, const int nrows_dst,
const int top_k) { const int top_k) {
const int mmq_x = MMQ_X_Q5_K; const int mmq_x = MOE_X_Q5_K;
const int mmq_y = MMQ_Y_Q5_K; const int mmq_y = MOE_Y_Q5_K;
const int nwarps = NWARPS_Q5_K; const int nwarps = NWARPS_Q5_K;
moe_q<scalar_t, QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps, moe_q<scalar_t, QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps,
@ -655,8 +655,8 @@ static void ggml_moe_q5_K_q8_1_cuda(
const int exp_stride, const int ncols_x, const int nrows_x, const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k, const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
const int tokens_post_padded, cudaStream_t stream) { const int tokens_post_padded, cudaStream_t stream) {
const int mmq_x = MMQ_X_Q5_K; const int mmq_x = MOE_X_Q5_K;
const int mmq_y = MMQ_Y_Q5_K; const int mmq_y = MOE_Y_Q5_K;
const int nwarps = NWARPS_Q5_K; const int nwarps = NWARPS_Q5_K;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
@ -678,12 +678,12 @@ static void ggml_moe_q5_K_q8_1_cuda(
} }
#if defined(USE_ROCM) #if defined(USE_ROCM)
#define MMQ_X_Q6_K 64 #define MOE_X_Q6_K 64
#define MMQ_Y_Q6_K 128 #define MOE_Y_Q6_K 128
#define NWARPS_Q6_K 8 #define NWARPS_Q6_K 8
#else #else
#define MMQ_X_Q6_K 4 #define MOE_X_Q6_K 4
#define MMQ_Y_Q6_K 32 #define MOE_Y_Q6_K 32
#define NWARPS_Q6_K 4 #define NWARPS_Q6_K 4
#endif #endif
@ -698,8 +698,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q6_K, 2)
const int exp_stride, const int ncols_x, const int nrows_x, const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int ncols_y, const int nrows_y, const int nrows_dst,
const int top_k) { const int top_k) {
const int mmq_x = MMQ_X_Q6_K; const int mmq_x = MOE_X_Q6_K;
const int mmq_y = MMQ_Y_Q6_K; const int mmq_y = MOE_Y_Q6_K;
const int nwarps = NWARPS_Q6_K; const int nwarps = NWARPS_Q6_K;
moe_q<scalar_t, QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps, moe_q<scalar_t, QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps,
@ -716,8 +716,8 @@ static void ggml_moe_q6_K_q8_1_cuda(
const int exp_stride, const int ncols_x, const int nrows_x, const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k, const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
const int tokens_post_padded, cudaStream_t stream) { const int tokens_post_padded, cudaStream_t stream) {
const int mmq_x = MMQ_X_Q6_K; const int mmq_x = MOE_X_Q6_K;
const int mmq_y = MMQ_Y_Q6_K; const int mmq_y = MOE_Y_Q6_K;
const int nwarps = NWARPS_Q6_K; const int nwarps = NWARPS_Q6_K;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;