diff --git a/csrc/ops.h b/csrc/ops.h index 59ae0937604c0..4cac278c92cbd 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -178,6 +178,10 @@ torch::Tensor ggml_moe_a8(torch::Tensor X, torch::Tensor W, torch::Tensor num_tokens_post_padded, int64_t type, int64_t row, int64_t top_k, int64_t tokens); +torch::Tensor ggml_moe_a8_vec(torch::Tensor X, torch::Tensor W, + torch::Tensor topk_ids, int64_t top_k, + int64_t type, int64_t row, int64_t tokens); + int64_t ggml_moe_get_block_size(int64_t type); #ifndef USE_ROCM diff --git a/csrc/quantization/gguf/gguf_kernel.cu b/csrc/quantization/gguf/gguf_kernel.cu index 56b78f1834d15..6c146c3fb6fde 100644 --- a/csrc/quantization/gguf/gguf_kernel.cu +++ b/csrc/quantization/gguf/gguf_kernel.cu @@ -13,6 +13,7 @@ #include "mmvq.cuh" #include "mmq.cuh" #include "moe.cuh" +#include "moe_vec.cuh" // Q8 gemv template @@ -377,6 +378,142 @@ torch::Tensor ggml_moe_a8(torch::Tensor X, // input return Y; } +torch::Tensor ggml_moe_a8_vec(torch::Tensor X, // input + torch::Tensor W, // expert weights + torch::Tensor topk_ids, int64_t top_k, + int64_t type, int64_t row, int64_t tokens) { + int col = X.sizes()[1]; + const int padded = (col + 512 - 1) / 512 * 512; + const at::cuda::OptionalCUDAGuard device_guard(device_of(X)); + auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device()); + at::Tensor Y = torch::zeros({tokens * top_k, row}, options); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + options = torch::TensorOptions().dtype(torch::kInt32).device(W.device()); + at::Tensor quant_X = torch::empty({tokens, padded / 32 * 9}, options); + VLLM_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_moe_vec_a8", [&] { + quantize_row_q8_1_cuda((scalar_t*)X.data_ptr(), + (void*)quant_X.data_ptr(), col, tokens, + stream); + switch (type) { + case 2: + moe_vec_q4_0_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 3: + moe_vec_q4_1_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 6: + moe_vec_q5_0_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 7: + moe_vec_q5_1_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 8: + moe_vec_q8_0_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 10: + moe_vec_q2_K_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 11: + moe_vec_q3_K_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 12: + moe_vec_q4_K_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 13: + moe_vec_q5_K_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 14: + moe_vec_q6_K_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 16: + moe_vec_iq2_xxs_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 17: + moe_vec_iq2_xs_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 18: + moe_vec_iq3_xxs_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 19: + moe_vec_iq1_s_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 20: + moe_vec_iq4_nl_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 21: + moe_vec_iq3_s_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 22: + moe_vec_iq2_s_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 23: + moe_vec_iq4_xs_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 29: + moe_vec_iq1_m_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + } + }); + return Y; +} + int64_t ggml_moe_get_block_size(int64_t type) { switch (type) { case 2: diff --git a/csrc/quantization/gguf/moe_vec.cuh b/csrc/quantization/gguf/moe_vec.cuh new file mode 100644 index 0000000000000..60f65a1bfdcba --- /dev/null +++ b/csrc/quantization/gguf/moe_vec.cuh @@ -0,0 +1,338 @@ +// copied and adapted from +// https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmvq.cu +template +static __global__ void moe_vec_q(const void* __restrict__ vx, + const void* __restrict__ vy, + scalar_t* __restrict__ dst, + const int* topk_ids, const int topk, + const int ncols, const int nrows, + const int token_stride) { + const auto row = blockIdx.x * blockDim.y + threadIdx.y; + + const auto token = blockIdx.z / topk; + const auto expert = (topk_ids)[blockIdx.z]; + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / qk; + const int blocks_per_warp = vdr * WARP_SIZE / qi; + + // partial sum for each thread + float tmp = 0.0f; + + const block_q_t* x = ((const block_q_t*)vx) + expert * nrows * blocks_per_row; + const block_q8_1* y = + (const block_q8_1*)(((const int*)vy) + token * token_stride); + + for (auto i = threadIdx.x / (qi / vdr); i < blocks_per_row; + i += blocks_per_warp) { + const int ibx = row * blocks_per_row + i; // x block index + + const int iby = i * (qk / QK8_1); // y block index that aligns with ibx + + const int iqs = + vdr * + (threadIdx.x % + (qi / vdr)); // x block quant index when casting the quants to int + + tmp += vec_dot_q_cuda(&x[ibx], &y[iby], iqs); + } + + // sum up partial sums and write back result +#pragma unroll + for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + tmp += VLLM_SHFL_XOR_SYNC(tmp, mask); + } + + if (threadIdx.x == 0) { + dst[blockIdx.z * nrows + row] = tmp; + } +} + +template +static void moe_vec_q4_0_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q<<>>( + vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); +} + +template +static void moe_vec_q4_1_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q<<>>( + vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); +} + +template +static void moe_vec_q5_0_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q<<>>( + vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); +} + +template +static void moe_vec_q5_1_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q<<>>( + vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); +} + +template +static void moe_vec_q8_0_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q<<>>( + vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); +} + +template +static void moe_vec_q2_K_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q<<>>( + vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); +} + +template +static void moe_vec_q3_K_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q<<>>( + vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); +} + +template +static void moe_vec_q4_K_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q<<>>( + vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); +} + +template +static void moe_vec_q5_K_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q<<>>( + vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); +} + +template +static void moe_vec_q6_K_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q<<>>( + vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); +} + +template +static void moe_vec_iq2_xxs_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q + <<>>(vx, vy, dst, topk_ids, top_k, + ncols, nrows, token_stride); +} + +template +static void moe_vec_iq2_xs_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q + <<>>(vx, vy, dst, topk_ids, top_k, + ncols, nrows, token_stride); +} + +template +static void moe_vec_iq2_s_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q + <<>>(vx, vy, dst, topk_ids, top_k, + ncols, nrows, token_stride); +} + +template +static void moe_vec_iq3_xxs_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q + <<>>(vx, vy, dst, topk_ids, top_k, + ncols, nrows, token_stride); +} + +template +static void moe_vec_iq1_s_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q + <<>>(vx, vy, dst, topk_ids, top_k, + ncols, nrows, token_stride); +} + +template +static void moe_vec_iq1_m_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q + <<>>(vx, vy, dst, topk_ids, top_k, + ncols, nrows, token_stride); +} + +template +static void moe_vec_iq4_nl_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q<<>>( + vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); +} + +template +static void moe_vec_iq4_xs_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q + <<>>(vx, vy, dst, topk_ids, top_k, + ncols, nrows, token_stride); +} + +template +static void moe_vec_iq3_s_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q + <<>>(vx, vy, dst, topk_ids, top_k, + ncols, nrows, token_stride); +} diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index f59b42d88c61c..e50df72e2a8b4 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -337,6 +337,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "int type, SymInt row, SymInt top_k, SymInt tokens) -> Tensor"); ops.impl("ggml_moe_a8", torch::kCUDA, &ggml_moe_a8); + ops.def( + "ggml_moe_a8_vec(Tensor X, Tensor W, " + "Tensor topk_ids, int top_k, " + "int type, SymInt row, SymInt tokens) -> Tensor"); + ops.impl("ggml_moe_a8_vec", torch::kCUDA, &ggml_moe_a8_vec); + ops.def("ggml_moe_get_block_size", &ggml_moe_get_block_size); #ifndef USE_ROCM diff --git a/tests/kernels/quantization/test_ggml.py b/tests/kernels/quantization/test_ggml.py index cc157da518cbf..73697a6d1242d 100644 --- a/tests/kernels/quantization/test_ggml.py +++ b/tests/kernels/quantization/test_ggml.py @@ -36,3 +36,9 @@ def test_ggml_opcheck(quant_type): opcheck(torch.ops._C.ggml_moe_a8, (x, qweight, sorted_token_ids, expert_ids, num_tokens_post_padded, quant_type, qweight.shape[0], 1, x.shape[0])) + + topk_ids = torch.zeros((1, 1), device='cuda', dtype=torch.int32) + + opcheck( + torch.ops._C.ggml_moe_a8_vec, + (x, qweight, topk_ids, 1, quant_type, qweight.shape[0], x.shape[0])) diff --git a/tests/kernels/quantization/test_gguf.py b/tests/kernels/quantization/test_gguf.py index 4c0fae9d9fd75..6cf88604ec65e 100644 --- a/tests/kernels/quantization/test_gguf.py +++ b/tests/kernels/quantization/test_gguf.py @@ -151,20 +151,7 @@ def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype, @pytest.mark.parametrize("hidden_size", [512]) @pytest.mark.parametrize("top_k", [4, 8]) @pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize( - "quant_type", - [ - # k-quants - GGMLQuantizationType.Q2_K, - GGMLQuantizationType.Q3_K, - GGMLQuantizationType.Q4_K, - GGMLQuantizationType.Q5_K, - GGMLQuantizationType.Q6_K, - # standard quants - GGMLQuantizationType.Q4_0, - GGMLQuantizationType.Q5_0, - GGMLQuantizationType.Q8_0, - ]) +@pytest.mark.parametrize("quant_type", QUANT_TYPES) @torch.inference_mode() def test_moe(num_tokens: int, hidden_size: int, dtype: torch.dtype, quant_type: GGMLQuantizationType, top_k: int): @@ -174,7 +161,10 @@ def test_moe(num_tokens: int, hidden_size: int, dtype: torch.dtype, x = torch.rand((num_tokens, H), dtype=dtype, device="cuda") topk_weights = torch.rand(num_tokens, top_k, device="cuda", dtype=dtype) - topk_ids = torch.randint(0, E, (num_tokens, top_k), device="cuda") + topk_ids = torch.randint(0, + E, (num_tokens, top_k), + device="cuda", + dtype=torch.int32) tensors = get_gguf_MoE_tensors(hidden_size, quant_type) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 44377ccb2959f..6f0a5f991908f 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -497,6 +497,24 @@ if hasattr(torch.ops._C, "ggml_dequantize"): device=W.device) +if hasattr(torch.ops._C, "ggml_moe_a8_vec"): + + @register_fake("_C::ggml_moe_a8_vec") + def _ggml_moe_a8_vec_fake( + X: torch.Tensor, + W: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + quant_type: int, + row: torch.SymInt, + tokens: torch.SymInt, + ) -> torch.Tensor: + tokens = X.size(0) + return torch.empty((tokens * top_k, row), + dtype=X.dtype, + device=W.device) + + # cutlass def cutlass_scaled_mm_supports_fp4(cuda_device_capability: int) -> bool: return torch.ops._C.cutlass_scaled_mm_supports_fp4(cuda_device_capability) @@ -1146,6 +1164,19 @@ def ggml_moe_a8( top_k, tokens) +def ggml_moe_a8_vec( + X: torch.Tensor, + W: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + quant_type: int, + row: torch.SymInt, + tokens: torch.SymInt, +) -> torch.Tensor: + return torch.ops._C.ggml_moe_a8_vec(X, W, topk_ids, top_k, quant_type, row, + tokens) + + def ggml_moe_get_block_size(quant_type: int) -> int: return torch.ops._C.ggml_moe_get_block_size(quant_type) diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 05058dfaa7332..c881524549416 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -145,7 +145,9 @@ def _fused_moe_gguf( moe_align_block_size) out_hidden_states = torch.empty_like(x) - if qweight_type2 in MMQ_QUANT_TYPES and qweight_type in MMQ_QUANT_TYPES: + # unless we decent expert reuse we are better off running moe_vec kernel + if (qweight_type2 in MMQ_QUANT_TYPES and qweight_type in MMQ_QUANT_TYPES + and x.shape[0] > 64): num_tokens, _ = x.shape E, N, _ = w1.shape top_k = topk_ids.shape[1] @@ -163,6 +165,20 @@ def _fused_moe_gguf( out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_( topk_weights.view(num_tokens, top_k, 1)) ops.moe_sum(out, out_hidden_states) + elif qweight_type2 in MMVQ_QUANT_TYPES and qweight_type in MMVQ_QUANT_TYPES: + num_tokens, _ = x.shape + E, N, _ = w1.shape + top_k = topk_ids.shape[1] + + out = ops.ggml_moe_a8_vec(x, w1, topk_ids, top_k, qweight_type, N, + num_tokens) + out = act(out) + + out = ops.ggml_moe_a8_vec(out, w2, topk_ids, 1, qweight_type2, + w2.shape[1], num_tokens * top_k) + out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_( + topk_weights.view(num_tokens, top_k, 1)) + ops.moe_sum(out, out_hidden_states) else: logger.warning_once("There is no support for fast MoE kernel " "for current quantization method. "