diff --git a/csrc/ops.h b/csrc/ops.h index 13fbbe41286d..724d7c92b826 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -151,6 +151,14 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X, torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type, int64_t row); +torch::Tensor ggml_moe_a8(torch::Tensor X, torch::Tensor W, + torch::Tensor sorted_token_ids, + torch::Tensor expert_ids, + torch::Tensor num_tokens_post_padded, int64_t type, + int64_t row, int64_t top_k, int64_t tokens); + +int64_t ggml_moe_get_block_size(int64_t type); + #ifndef USE_ROCM void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A, torch::Tensor const& B, torch::Tensor const& A_sf, diff --git a/csrc/quantization/gguf/gguf_kernel.cu b/csrc/quantization/gguf/gguf_kernel.cu index 1150bd8f2258..46b716bbd98d 100644 --- a/csrc/quantization/gguf/gguf_kernel.cu +++ b/csrc/quantization/gguf/gguf_kernel.cu @@ -12,6 +12,7 @@ #include "dequantize.cuh" #include "mmvq.cuh" #include "mmq.cuh" +#include "moe.cuh" // Q8 gemv template @@ -59,10 +60,14 @@ static void quantize_row_q8_1_cuda(const scalar_t* x, void* vy, const int kx, const int64_t kx_padded = (kx + 512 - 1) / 512 * 512; const int block_num_x = (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; - const dim3 num_blocks(block_num_x, ky, 1); - const dim3 block_size(CUDA_DEQUANTIZE_BLOCK_SIZE, 1, 1); - quantize_q8_1 - <<>>(x, vy, kx, kx_padded); + constexpr int MAX_BLOCK_SIZE = 65535; + for (int off = 0; off < ky; off += MAX_BLOCK_SIZE) { + const int num_blocks_y = std::min(ky, off + MAX_BLOCK_SIZE) - off; + const dim3 num_blocks(block_num_x, num_blocks_y, 1); + const dim3 block_size(CUDA_DEQUANTIZE_BLOCK_SIZE, 1, 1); + quantize_q8_1<<>>( + &x[off * kx], (int32_t*)vy + off * (kx_padded / 32 * 9), kx, kx_padded); + } } torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight @@ -263,3 +268,132 @@ torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight }); return Y; } + +torch::Tensor ggml_moe_a8(torch::Tensor X, // input + torch::Tensor W, // expert weights + torch::Tensor sorted_token_ids, + torch::Tensor expert_ids, + torch::Tensor num_tokens_post_padded, int64_t type, + int64_t row, int64_t top_k, int64_t tokens) { + int col = X.sizes()[1]; + int padded = (col + 512 - 1) / 512 * 512; + const at::cuda::OptionalCUDAGuard device_guard(device_of(X)); + auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device()); + at::Tensor Y = torch::empty({tokens * top_k, row}, options); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + options = torch::TensorOptions().dtype(torch::kInt32).device(W.device()); + at::Tensor quant_X = torch::empty({tokens, padded / 32 * 9}, options); + VLLM_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_moe_a8", [&] { + quantize_row_q8_1_cuda((scalar_t*)X.data_ptr(), (void*)quant_X.data_ptr(), + col, tokens, stream); + switch (type) { + case 2: + ggml_moe_q4_0_q8_1_cuda( + (void*)quant_X.data_ptr(), (void*)W.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(), + (int*)expert_ids.data_ptr(), + (int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row, + tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream); + break; + case 3: + ggml_moe_q4_1_q8_1_cuda( + (void*)quant_X.data_ptr(), (void*)W.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(), + (int*)expert_ids.data_ptr(), + (int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row, + tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream); + break; + case 6: + ggml_moe_q5_0_q8_1_cuda( + (void*)quant_X.data_ptr(), (void*)W.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(), + (int*)expert_ids.data_ptr(), + (int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row, + tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream); + break; + case 7: + ggml_moe_q5_1_q8_1_cuda( + (void*)quant_X.data_ptr(), (void*)W.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(), + (int*)expert_ids.data_ptr(), + (int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row, + tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream); + break; + case 8: + ggml_moe_q8_0_q8_1_cuda( + (void*)quant_X.data_ptr(), (void*)W.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(), + (int*)expert_ids.data_ptr(), + (int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row, + tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream); + break; + case 10: + ggml_moe_q2_K_q8_1_cuda( + (void*)quant_X.data_ptr(), (void*)W.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(), + (int*)expert_ids.data_ptr(), + (int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row, + tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream); + break; + case 11: + ggml_moe_q3_K_q8_1_cuda( + (void*)quant_X.data_ptr(), (void*)W.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(), + (int*)expert_ids.data_ptr(), + (int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row, + tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream); + break; + case 12: + ggml_moe_q4_K_q8_1_cuda( + (void*)quant_X.data_ptr(), (void*)W.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(), + (int*)expert_ids.data_ptr(), + (int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row, + tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream); + break; + case 13: + ggml_moe_q5_K_q8_1_cuda( + (void*)quant_X.data_ptr(), (void*)W.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(), + (int*)expert_ids.data_ptr(), + (int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row, + tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream); + break; + case 14: + ggml_moe_q6_K_q8_1_cuda( + (void*)quant_X.data_ptr(), (void*)W.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(), + (int*)expert_ids.data_ptr(), + (int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row, + tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream); + break; + } + }); + return Y; +} + +int64_t ggml_moe_get_block_size(int64_t type) { + switch (type) { + case 2: + return MMQ_X_Q4_0; + case 3: + return MMQ_X_Q4_1; + case 6: + return MMQ_X_Q5_0; + case 7: + return MMQ_X_Q5_1; + case 8: + return MMQ_X_Q8_0; + case 10: + return MMQ_X_Q2_K; + case 11: + return MMQ_X_Q3_K; + case 12: + return MMQ_X_Q4_K; + case 13: + return MMQ_X_Q5_K; + case 14: + return MMQ_X_Q6_K; + } + return 0; +} diff --git a/csrc/quantization/gguf/moe.cuh b/csrc/quantization/gguf/moe.cuh new file mode 100644 index 000000000000..e499f53a2acd --- /dev/null +++ b/csrc/quantization/gguf/moe.cuh @@ -0,0 +1,739 @@ +#include + +/* Adapted from ./csrc/quantization/gguf/mmq.cuh + based on ./vllm/model_executor/layers/fused_moe/fused_moe.py */ +template +static __device__ __forceinline__ void moe_q( + const void* __restrict__ vx, const void* __restrict__ vy, + scalar_t* __restrict__ dst, const int* __restrict__ sorted_token_ids, + const int* __restrict__ expert_ids, + const int* __restrict__ num_tokens_post_padded, const int exp_stride, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, + const int nrows_dst, const int top_k) { + const int blocks_per_row_x = ncols_x / qk; + const int blocks_per_col_y = nrows_y / QK8_1; + const int blocks_per_warp = WARP_SIZE_GGUF / qi; + + const int ncols_dst = ncols_y * top_k; + + const int row_dst_0 = blockIdx.x * mmq_y; + const int& row_x_0 = row_dst_0; + + const int col_dst_0 = blockIdx.y * mmq_x; + + int token_offs[mmq_x / nwarps]; + for (int i = 0; i < mmq_x; i += nwarps) { + token_offs[i / nwarps] = sorted_token_ids[col_dst_0 + threadIdx.y + i]; + } + + const int exp_idx = expert_ids[blockIdx.y]; + if (exp_idx > 255 || exp_idx < 0) return; + if (blockIdx.y * mmq_x > num_tokens_post_padded[0]) return; + + const block_q_t* x = (const block_q_t*)((char*)vx + exp_idx * exp_stride); + const block_q8_1* y = (const block_q8_1*)(vy); + + int* tile_x_ql = nullptr; + half2* tile_x_dm = nullptr; + int* tile_x_qh = nullptr; + int* tile_x_sc = nullptr; + + allocate_tiles(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc); + + __shared__ int tile_y_qs[mmq_x * WARP_SIZE_GGUF]; + __shared__ half2 tile_y_ds[mmq_x * WARP_SIZE_GGUF / QI8_1]; + + float sum[mmq_y / WARP_SIZE_GGUF][mmq_x / nwarps] = {{0.0f}}; + + for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) { + load_tiles(x + row_x_0 * blocks_per_row_x + ib0, tile_x_ql, tile_x_dm, + tile_x_qh, tile_x_sc, threadIdx.y, nrows_x - row_x_0 - 1, + threadIdx.x, blocks_per_row_x); + + const int n_per_r = ((qk * blocks_per_warp) / qr); +#pragma unroll + for (int ir = 0; ir < qr && ib0 * qk + ir * n_per_r < ncols_x; ++ir) { + const int kqs = ir * WARP_SIZE_GGUF + threadIdx.x; + const int kbxd = kqs / QI8_1; + +#pragma unroll + for (int i = 0; i < mmq_x; i += nwarps) { + const int col_y_eff = token_offs[i / nwarps] / top_k; + const int block_x = ib0 * (qk / QK8_1) + kbxd; + if (col_y_eff < ncols_y && block_x < blocks_per_col_y) { + const block_q8_1* by0 = &y[col_y_eff * blocks_per_col_y + block_x]; + const int index_y = + (threadIdx.y + i) * WARP_SIZE_GGUF + kqs % WARP_SIZE_GGUF; + tile_y_qs[index_y] = + get_int_from_int8_aligned(by0->qs, threadIdx.x % QI8_1); + } + } + + if (threadIdx.x < n_per_r / QK8_1) { + const int kby = threadIdx.x % (WARP_SIZE_GGUF / QI8_1); + const int col_y_eff = token_offs[threadIdx.y] / top_k; + const int block_x = + ib0 * (qk / QK8_1) + ir * (WARP_SIZE_GGUF / QI8_1) + kby; + + if (col_y_eff < ncols_y && block_x < blocks_per_col_y) { + const half2* dsi_src = &y[col_y_eff * blocks_per_col_y + block_x].ds; + half2* dsi_dst = + &tile_y_ds[threadIdx.y * (WARP_SIZE_GGUF / QI8_1) + kby]; + + if (need_sum) { + *dsi_dst = *dsi_src; + } else { + float* dfi_dst = (float*)dsi_dst; + *dfi_dst = __low2float(*dsi_src); + } + } + } + __syncthreads(); + + // #pragma unroll // unrolling this loop causes too much register pressure + for (int k = ir * WARP_SIZE_GGUF / qr; k < (ir + 1) * WARP_SIZE_GGUF / qr; + k += vdr) { +#pragma unroll + for (int j = 0; j < mmq_x; j += nwarps) { +#pragma unroll + for (int i = 0; i < mmq_y; i += WARP_SIZE_GGUF) { + sum[i / WARP_SIZE_GGUF][j / nwarps] += + vec_dot(tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, + tile_y_ds, threadIdx.x + i, threadIdx.y + j, k); + } + } + } + __syncthreads(); + } + } + +#pragma unroll + for (int j = 0; j < mmq_x; j += nwarps) { + const int col_dst = token_offs[j / nwarps]; + if (col_dst >= ncols_dst) { + return; + } + +#pragma unroll + for (int i = 0; i < mmq_y; i += WARP_SIZE_GGUF) { + const int row_dst = row_dst_0 + threadIdx.x + i; + if (row_dst >= nrows_dst) { + continue; + } + dst[col_dst * nrows_dst + row_dst] = sum[i / WARP_SIZE_GGUF][j / nwarps]; + } + } +} + +#if defined(USE_ROCM) + #define MMQ_X_Q4_0 64 + #define MMQ_Y_Q4_0 128 + #define NWARPS_Q4_0 8 +#else + #define MMQ_X_Q4_0 4 + #define MMQ_Y_Q4_0 32 + #define NWARPS_Q4_0 4 +#endif + +template +static __global__ void +#if defined(USE_ROCM) +__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q4_0, 2) +#endif + moe_q4_0(const void* __restrict__ vx, const void* __restrict__ vy, + scalar_t* __restrict__ dst, const int* sorted_token_ids, + const int* expert_ids, const int* num_tokens_post_padded, + const int exp_stride, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, + const int top_k) { + const int mmq_x = MMQ_X_Q4_0; + const int mmq_y = MMQ_Y_Q4_0; + const int nwarps = NWARPS_Q4_0; + + moe_q, load_tiles_q4_0, + VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>( + vx, vy, dst, sorted_token_ids, expert_ids, num_tokens_post_padded, + exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k); +} + +template +static void ggml_moe_q4_0_q8_1_cuda( + const void* inp, const void* w, scalar_t* dst, const int* sorted_token_ids, + const int* expert_ids, const int* num_tokens_post_padded, + const int exp_stride, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k, + const int tokens_post_padded, cudaStream_t stream) { + int mmq_x = MMQ_X_Q4_0; + int mmq_y = MMQ_Y_Q4_0; + int nwarps = NWARPS_Q4_0; + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (tokens_post_padded) / mmq_x; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1); + + if (nrows_x % mmq_y == 0) { + constexpr bool need_check = false; + moe_q4_0<<>>( + w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded, + exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k); + } else { + constexpr bool need_check = true; + moe_q4_0<<>>( + w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded, + exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k); + } +} + +#if defined(USE_ROCM) + #define MMQ_X_Q4_1 64 + #define MMQ_Y_Q4_1 128 + #define NWARPS_Q4_1 8 +#else + #define MMQ_X_Q4_1 4 + #define MMQ_Y_Q4_1 32 + #define NWARPS_Q4_1 4 +#endif + +template +static __global__ void +#if defined(USE_ROCM) +__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q4_1, 2) +#endif + moe_q4_1(const void* __restrict__ vx, const void* __restrict__ vy, + scalar_t* __restrict__ dst, const int* sorted_token_ids, + const int* expert_ids, const int* num_tokens_post_padded, + const int exp_stride, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, + const int top_k) { + const int mmq_x = MMQ_X_Q4_1; + const int mmq_y = MMQ_Y_Q4_1; + const int nwarps = NWARPS_Q4_1; + + moe_q, load_tiles_q4_1, + VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>( + vx, vy, dst, sorted_token_ids, expert_ids, num_tokens_post_padded, + exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k); +} + +template +static void ggml_moe_q4_1_q8_1_cuda( + const void* inp, const void* w, scalar_t* dst, const int* sorted_token_ids, + const int* expert_ids, const int* num_tokens_post_padded, + const int exp_stride, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k, + const int tokens_post_padded, cudaStream_t stream) { + int mmq_x = MMQ_X_Q4_1; + int mmq_y = MMQ_Y_Q4_1; + int nwarps = NWARPS_Q4_1; + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (tokens_post_padded) / mmq_x; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1); + + if (nrows_x % mmq_y == 0) { + constexpr bool need_check = false; + moe_q4_1<<>>( + w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded, + exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k); + } else { + constexpr bool need_check = true; + moe_q4_1<<>>( + w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded, + exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k); + } +} + +#if defined(USE_ROCM) + #define MMQ_X_Q5_0 64 + #define MMQ_Y_Q5_0 128 + #define NWARPS_Q5_0 8 +#else + #define MMQ_X_Q5_0 4 + #define MMQ_Y_Q5_0 32 + #define NWARPS_Q5_0 4 +#endif + +template +static __global__ void +#if defined(USE_ROCM) +__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q5_0, 2) +#endif + moe_q5_0(const void* __restrict__ vx, const void* __restrict__ vy, + scalar_t* __restrict__ dst, const int* sorted_token_ids, + const int* expert_ids, const int* num_tokens_post_padded, + const int exp_stride, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, + const int top_k) { + const int mmq_x = MMQ_X_Q5_0; + const int mmq_y = MMQ_Y_Q5_0; + const int nwarps = NWARPS_Q5_0; + + moe_q, load_tiles_q5_0, + VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>( + vx, vy, dst, sorted_token_ids, expert_ids, num_tokens_post_padded, + exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k); +} + +template +static void ggml_moe_q5_0_q8_1_cuda( + const void* inp, const void* w, scalar_t* dst, const int* sorted_token_ids, + const int* expert_ids, const int* num_tokens_post_padded, + const int exp_stride, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k, + const int tokens_post_padded, cudaStream_t stream) { + const int mmq_x = MMQ_X_Q5_0; + const int mmq_y = MMQ_Y_Q5_0; + const int nwarps = NWARPS_Q5_0; + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (tokens_post_padded) / mmq_x; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1); + + if (nrows_x % mmq_y == 0) { + constexpr bool need_check = false; + moe_q5_0<<>>( + w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded, + exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k); + } else { + constexpr bool need_check = true; + moe_q5_0<<>>( + w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded, + exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k); + } +} + +#if defined(USE_ROCM) + #define MMQ_X_Q5_1 64 + #define MMQ_Y_Q5_1 128 + #define NWARPS_Q5_1 8 +#else + #define MMQ_X_Q5_1 4 + #define MMQ_Y_Q5_1 32 + #define NWARPS_Q5_1 4 +#endif + +template +static __global__ void +#if defined(USE_ROCM) +__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q5_1, 2) +#endif + moe_q5_1(const void* __restrict__ vx, const void* __restrict__ vy, + scalar_t* __restrict__ dst, const int* sorted_token_ids, + const int* expert_ids, const int* num_tokens_post_padded, + const int exp_stride, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, + const int top_k) { + const int mmq_x = MMQ_X_Q5_1; + const int mmq_y = MMQ_Y_Q5_1; + const int nwarps = NWARPS_Q5_1; + + moe_q, load_tiles_q5_1, + VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>( + vx, vy, dst, sorted_token_ids, expert_ids, num_tokens_post_padded, + exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k); +} + +template +static void ggml_moe_q5_1_q8_1_cuda( + const void* inp, const void* w, scalar_t* dst, const int* sorted_token_ids, + const int* expert_ids, const int* num_tokens_post_padded, + const int exp_stride, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k, + const int tokens_post_padded, cudaStream_t stream) { + const int mmq_x = MMQ_X_Q5_1; + const int mmq_y = MMQ_Y_Q5_1; + const int nwarps = NWARPS_Q5_1; + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (tokens_post_padded) / mmq_x; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1); + + if (nrows_x % mmq_y == 0) { + constexpr bool need_check = false; + moe_q5_1<<>>( + w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded, + exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k); + } else { + constexpr bool need_check = true; + moe_q5_1<<>>( + w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded, + exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k); + } +} + +#if defined(USE_ROCM) + #define MMQ_X_Q8_0 64 + #define MMQ_Y_Q8_0 128 + #define NWARPS_Q8_0 8 +#else + #define MMQ_X_Q8_0 4 + #define MMQ_Y_Q8_0 32 + #define NWARPS_Q8_0 4 +#endif + +template +static __global__ void +#if defined(USE_ROCM) +__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q8_0, 2) +#endif + moe_q8_0(const void* __restrict__ vx, const void* __restrict__ vy, + scalar_t* __restrict__ dst, const int* sorted_token_ids, + const int* expert_ids, const int* num_tokens_post_padded, + const int exp_stride, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, + const int top_k) { + const int mmq_x = MMQ_X_Q8_0; + const int mmq_y = MMQ_Y_Q8_0; + const int nwarps = NWARPS_Q8_0; + + moe_q, load_tiles_q8_0, + VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>( + vx, vy, dst, sorted_token_ids, expert_ids, num_tokens_post_padded, + exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k); +} + +template +static void ggml_moe_q8_0_q8_1_cuda( + const void* inp, const void* w, scalar_t* dst, const int* sorted_token_ids, + const int* expert_ids, const int* num_tokens_post_padded, + const int exp_stride, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k, + const int tokens_post_padded, cudaStream_t stream) { + const int mmq_x = MMQ_X_Q8_0; + const int mmq_y = MMQ_Y_Q8_0; + const int nwarps = NWARPS_Q8_0; + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (tokens_post_padded) / mmq_x; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1); + + if (nrows_x % mmq_y == 0) { + constexpr bool need_check = false; + moe_q8_0<<>>( + w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded, + exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k); + } else { + constexpr bool need_check = true; + moe_q8_0<<>>( + w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded, + exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k); + } +} + +#if defined(USE_ROCM) + #define MMQ_X_Q2_K 64 + #define MMQ_Y_Q2_K 128 + #define NWARPS_Q2_K 8 +#else + #define MMQ_X_Q2_K 4 + #define MMQ_Y_Q2_K 32 + #define NWARPS_Q2_K 4 +#endif + +template +static __global__ void +#if defined(USE_ROCM) +__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q2_K, 2) +#endif + moe_q2_K(const void* __restrict__ vx, const void* __restrict__ vy, + scalar_t* __restrict__ dst, const int* sorted_token_ids, + const int* expert_ids, const int* num_tokens_post_padded, + const int exp_stride, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, + const int top_k) { + const int mmq_x = MMQ_X_Q2_K; + const int mmq_y = MMQ_Y_Q2_K; + const int nwarps = NWARPS_Q2_K; + + moe_q, load_tiles_q2_K, + VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>( + vx, vy, dst, sorted_token_ids, expert_ids, num_tokens_post_padded, + exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k); +} + +template +static void ggml_moe_q2_K_q8_1_cuda( + const void* inp, const void* w, scalar_t* dst, const int* sorted_token_ids, + const int* expert_ids, const int* num_tokens_post_padded, + const int exp_stride, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k, + const int tokens_post_padded, cudaStream_t stream) { + const int mmq_x = MMQ_X_Q2_K; + const int mmq_y = MMQ_Y_Q2_K; + const int nwarps = NWARPS_Q2_K; + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (tokens_post_padded) / mmq_x; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1); + + if (nrows_x % mmq_y == 0) { + constexpr bool need_check = false; + moe_q2_K<<>>( + w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded, + exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k); + } else { + constexpr bool need_check = true; + moe_q2_K<<>>( + w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded, + exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k); + } +} + +#if defined(USE_ROCM) + #define MMQ_X_Q3_K 64 + #define MMQ_Y_Q3_K 128 + #define NWARPS_Q3_K 8 +#else + #define MMQ_X_Q3_K 4 + #define MMQ_Y_Q3_K 32 + #define NWARPS_Q3_K 4 +#endif + +template +static __global__ void +#if defined(USE_ROCM) +__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q3_K, 2) +#endif + moe_q3_K(const void* __restrict__ vx, const void* __restrict__ vy, + scalar_t* __restrict__ dst, const int* sorted_token_ids, + const int* expert_ids, const int* num_tokens_post_padded, + const int exp_stride, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, + const int top_k) { + + const int mmq_x = MMQ_X_Q3_K; + const int mmq_y = MMQ_Y_Q3_K; + const int nwarps = NWARPS_Q3_K; + + moe_q, load_tiles_q3_K, + VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>( + vx, vy, dst, sorted_token_ids, expert_ids, num_tokens_post_padded, + exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k); +} +template +static void ggml_moe_q3_K_q8_1_cuda( + const void* inp, const void* w, scalar_t* dst, const int* sorted_token_ids, + const int* expert_ids, const int* num_tokens_post_padded, + const int exp_stride, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k, + const int tokens_post_padded, cudaStream_t stream) { + const int mmq_x = MMQ_X_Q3_K; + const int mmq_y = MMQ_Y_Q3_K; + const int nwarps = NWARPS_Q3_K; + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (tokens_post_padded) / mmq_x; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1); + + if (nrows_x % mmq_y == 0) { + constexpr bool need_check = false; + moe_q3_K<<>>( + w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded, + exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k); + } else { + constexpr bool need_check = true; + moe_q3_K<<>>( + w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded, + exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k); + } +} + +#if defined(USE_ROCM) + #define MMQ_X_Q4_K 64 + #define MMQ_Y_Q4_K 128 + #define NWARPS_Q4_K 8 +#else + #define MMQ_X_Q4_K 4 + #define MMQ_Y_Q4_K 32 + #define NWARPS_Q4_K 4 +#endif + +template +static __global__ void +#if defined(USE_ROCM) +__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q4_K, 2) +#endif + moe_q4_K(const void* __restrict__ vx, const void* __restrict__ vy, + scalar_t* __restrict__ dst, const int* sorted_token_ids, + const int* expert_ids, const int* num_tokens_post_padded, + const int exp_stride, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, + const int top_k) { + const int mmq_x = MMQ_X_Q4_K; + const int mmq_y = MMQ_Y_Q4_K; + const int nwarps = NWARPS_Q4_K; + + moe_q, load_tiles_q4_K, + VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>( + vx, vy, dst, sorted_token_ids, expert_ids, num_tokens_post_padded, + exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k); +} + +template +static void ggml_moe_q4_K_q8_1_cuda( + const void* inp, const void* w, scalar_t* dst, const int* sorted_token_ids, + const int* expert_ids, const int* num_tokens_post_padded, + const int exp_stride, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k, + const int tokens_post_padded, cudaStream_t stream) { + const int mmq_x = MMQ_X_Q4_K; + const int mmq_y = MMQ_Y_Q4_K; + const int nwarps = NWARPS_Q4_K; + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (tokens_post_padded) / mmq_x; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1); + + if (nrows_x % mmq_y == 0) { + constexpr bool need_check = false; + moe_q4_K<<>>( + w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded, + exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k); + } else { + constexpr bool need_check = true; + moe_q4_K<<>>( + w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded, + exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k); + } +} + +#if defined(USE_ROCM) + #define MMQ_X_Q5_K 64 + #define MMQ_Y_Q5_K 128 + #define NWARPS_Q5_K 8 +#else + #define MMQ_X_Q5_K 4 + #define MMQ_Y_Q5_K 32 + #define NWARPS_Q5_K 4 +#endif + +template +static __global__ void +#if defined(USE_ROCM) +__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q5_K, 2) +#endif + moe_q5_K(const void* __restrict__ vx, const void* __restrict__ vy, + scalar_t* __restrict__ dst, const int* sorted_token_ids, + const int* expert_ids, const int* num_tokens_post_padded, + const int exp_stride, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, + const int top_k) { + const int mmq_x = MMQ_X_Q5_K; + const int mmq_y = MMQ_Y_Q5_K; + const int nwarps = NWARPS_Q5_K; + + moe_q, load_tiles_q5_K, + VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>( + vx, vy, dst, sorted_token_ids, expert_ids, num_tokens_post_padded, + exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k); +} + +template +static void ggml_moe_q5_K_q8_1_cuda( + const void* inp, const void* w, scalar_t* dst, const int* sorted_token_ids, + const int* expert_ids, const int* num_tokens_post_padded, + const int exp_stride, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k, + const int tokens_post_padded, cudaStream_t stream) { + const int mmq_x = MMQ_X_Q5_K; + const int mmq_y = MMQ_Y_Q5_K; + const int nwarps = NWARPS_Q5_K; + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (tokens_post_padded) / mmq_x; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1); + + if (nrows_x % mmq_y == 0) { + constexpr bool need_check = false; + moe_q5_K<<>>( + w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded, + exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k); + } else { + constexpr bool need_check = true; + moe_q5_K<<>>( + w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded, + exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k); + } +} + +#if defined(USE_ROCM) + #define MMQ_X_Q6_K 64 + #define MMQ_Y_Q6_K 128 + #define NWARPS_Q6_K 8 +#else + #define MMQ_X_Q6_K 4 + #define MMQ_Y_Q6_K 32 + #define NWARPS_Q6_K 4 +#endif + +template +static __global__ void +#if defined(USE_ROCM) +__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q6_K, 2) +#endif + moe_q6_K(const void* __restrict__ vx, const void* __restrict__ vy, + scalar_t* __restrict__ dst, const int* sorted_token_ids, + const int* expert_ids, const int* num_tokens_post_padded, + const int exp_stride, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, + const int top_k) { + const int mmq_x = MMQ_X_Q6_K; + const int mmq_y = MMQ_Y_Q6_K; + const int nwarps = NWARPS_Q6_K; + + moe_q, load_tiles_q6_K, + VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>( + vx, vy, dst, sorted_token_ids, expert_ids, num_tokens_post_padded, + exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k); +} + +template +static void ggml_moe_q6_K_q8_1_cuda( + const void* inp, const void* w, scalar_t* dst, const int* sorted_token_ids, + const int* expert_ids, const int* num_tokens_post_padded, + const int exp_stride, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k, + const int tokens_post_padded, cudaStream_t stream) { + const int mmq_x = MMQ_X_Q6_K; + const int mmq_y = MMQ_Y_Q6_K; + const int nwarps = NWARPS_Q6_K; + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (tokens_post_padded) / mmq_x; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1); + + if (nrows_x % mmq_y == 0) { + constexpr bool need_check = false; + moe_q6_K<<>>( + w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded, + exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k); + } else { + constexpr bool need_check = true; + moe_q6_K<<>>( + w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded, + exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k); + } +} diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index b06b12220793..eac27e648f80 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -305,6 +305,16 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "ggml_mul_mat_a8(Tensor W, Tensor X, int type, SymInt row) -> Tensor"); ops.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8); + // moe kernel for GGML. + ops.def( + "ggml_moe_a8(Tensor X, Tensor W, " + "Tensor sorted_token_ids, Tensor expert_ids, Tensor " + "num_tokens_post_padded, " + "int type, SymInt row, SymInt top_k, SymInt tokens) -> Tensor"); + ops.impl("ggml_moe_a8", torch::kCUDA, &ggml_moe_a8); + + ops.def("ggml_moe_get_block_size", &ggml_moe_get_block_size); + #ifndef USE_ROCM // fp8_marlin Optimized Quantized GEMM for FP8 weight-only. ops.def( diff --git a/tests/kernels/test_ggml.py b/tests/kernels/test_ggml.py index dc728fd4861d..23fa1fdfda17 100644 --- a/tests/kernels/test_ggml.py +++ b/tests/kernels/test_ggml.py @@ -22,3 +22,16 @@ def test_ggml_opcheck(quant_type): (qweight, x, quant_type, qweight.shape[0])) opcheck(torch.ops._C.ggml_mul_mat_vec_a8, (qweight, x, quant_type, qweight.shape[0])) + + shape = [256, 1024, 336] + qweight = torch.randint(0, 100, shape, device='cuda', dtype=torch.uint8) + x = torch.rand((1, 1024), device='cuda', dtype=torch.float16) + sorted_token_ids = torch.arange(776, device='cuda') + expert_ids = torch.randint(0, 256, (194, ), device='cuda') + num_tokens_post_padded = torch.tensor([1], + dtype=torch.int64, + device='cuda') + + opcheck(torch.ops._C.ggml_moe_a8, + (x, qweight, sorted_token_ids, expert_ids, num_tokens_post_padded, + quant_type, qweight.shape[0], 1, x.shape[0])) diff --git a/tests/kernels/test_gguf.py b/tests/kernels/test_gguf.py index dde3741d3c4f..ede941844dc0 100644 --- a/tests/kernels/test_gguf.py +++ b/tests/kernels/test_gguf.py @@ -8,9 +8,13 @@ from gguf import GGMLQuantizationType, GGUFReader, ReaderTensor, dequantize from huggingface_hub import snapshot_download import vllm._custom_ops as ops +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import fused_experts +from vllm.model_executor.layers.quantization.gguf import _fused_moe_gguf from vllm.platforms import current_platform GGUF_SAMPLE = snapshot_download("Isotr0py/test-gguf-sample") +GGUF_SAMPLE_MOE = snapshot_download("SzymonOzog/test-gguf-moe-sample") def get_gguf_sample_tensors( @@ -22,6 +26,15 @@ def get_gguf_sample_tensors( return GGUFReader(sample_file).tensors +def get_gguf_MoE_tensors( + hidden_size: int, + quant_type: GGMLQuantizationType) -> list[ReaderTensor]: + sample_dir = GGUF_SAMPLE_MOE + filename = f"Quant_{quant_type.name}_{hidden_size}.gguf" + sample_file = Path(sample_dir) / filename + return GGUFReader(sample_file).tensors + + DTYPES = [torch.half, torch.bfloat16, torch.float32] # Hidden_size for testing, must match the sample file in HF repo, # we have `hidden_size = 256, 1024` for test in HF repo currently. @@ -132,3 +145,54 @@ def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype, ref_output, atol=atols[dtype], rtol=rtols[dtype]) + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", [512]) +@pytest.mark.parametrize("top_k", [4, 8]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize( + "quant_type", + [ + # k-quants + GGMLQuantizationType.Q2_K, + GGMLQuantizationType.Q3_K, + GGMLQuantizationType.Q4_K, + GGMLQuantizationType.Q5_K, + GGMLQuantizationType.Q6_K, + # standard quants + GGMLQuantizationType.Q4_0, + GGMLQuantizationType.Q5_0, + GGMLQuantizationType.Q8_0, + ]) +@torch.inference_mode() +def test_moe(num_tokens: int, hidden_size: int, dtype: torch.dtype, + quant_type: GGMLQuantizationType, top_k: int): + current_platform.seed_everything(0) + H, E = 1024, 256 + + x = torch.rand((num_tokens, H), dtype=dtype, device="cuda") + + topk_weights = torch.rand(num_tokens, top_k, device="cuda", dtype=dtype) + topk_ids = torch.randint(0, E, (num_tokens, top_k), device="cuda") + + tensors = get_gguf_MoE_tensors(hidden_size, quant_type) + + w13 = tensors[0] + w2 = tensors[1] + + w13_dequant = torch.tensor(dequantize(w13.data, quant_type), + device="cuda").to(dtype) + + w2_dequant = torch.tensor(dequantize(w2.data, quant_type), + device="cuda").to(dtype) + act = SiluAndMul() + + output = _fused_moe_gguf(x, torch.tensor(w13.data, device="cuda"), + torch.tensor(w2.data, + device="cuda"), topk_weights, + topk_ids, quant_type, quant_type, act) + + ref_output = fused_experts(x, w13_dequant, w2_dequant, topk_weights, + topk_ids).reshape(output.shape) + torch.testing.assert_close(output, ref_output, atol=1, rtol=1e-1) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 14cfe751514f..9f5b48714e1f 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -448,6 +448,23 @@ if hasattr(torch.ops._C, "ggml_dequantize"): batch = X.size(0) return torch.empty((batch, row), dtype=X.dtype, device=W.device) + @register_fake("_C::ggml_moe_a8") + def _ggml_moe_a8_fake( + X: torch.Tensor, + W: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + quant_type: int, + row: torch.SymInt, + top_k: torch.SymInt, + tokens: torch.SymInt, + ) -> torch.Tensor: + tokens = X.size(0) + return torch.empty((tokens * top_k, row), + dtype=torch.float16, + device=W.device) + # cutlass def cutlass_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor, @@ -1034,6 +1051,26 @@ def ggml_mul_mat_a8( return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row) +def ggml_moe_a8( + X: torch.Tensor, + W: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + quant_type: int, + row: int, + top_k: int, + tokens: int, +) -> torch.Tensor: + return torch.ops._C.ggml_moe_a8(X, W, sorted_token_ids, expert_ids, + num_tokens_post_padded, quant_type, row, + top_k, tokens) + + +def ggml_moe_get_block_size(quant_type: int) -> int: + return torch.ops._C.ggml_moe_get_block_size(quant_type) + + # mamba def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor, bias_: Optional[torch.Tensor], diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index d4e971776934..5d4c1c6ec893 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -8,7 +8,9 @@ from gguf import GGMLQuantizationType as WeightType from torch.nn.parameter import Parameter, UninitializedParameter from vllm import _custom_ops as ops +from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, FusedMoEMethodBase) from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase @@ -18,6 +20,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.utils import set_weight_attrs +logger = init_logger(__name__) + class GGUFConfig(QuantizationConfig): """Config class for GGUF.""" @@ -119,6 +123,59 @@ def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor, return y +def _fused_moe_gguf( + x: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + qweight_type: int, + qweight_type2: int, + act, +) -> torch.Tensor: + out_hidden_states = torch.empty_like(x) + if qweight_type2 in MMQ_QUANT_TYPES and qweight_type in MMQ_QUANT_TYPES: + num_tokens, _ = x.shape + E, N, _ = w1.shape + top_k = topk_ids.shape[1] + BLOCK_SIZE = ops.ggml_moe_get_block_size(qweight_type) + + sorted_token_ids, expert_ids, num_tokens_post_padded = \ + moe_align_block_size(topk_ids, BLOCK_SIZE, E) + out = ops.ggml_moe_a8(x, w1, sorted_token_ids, expert_ids, + num_tokens_post_padded, qweight_type, N, top_k, + num_tokens) + out = act(out) + out = ops.ggml_moe_a8(out, w2, sorted_token_ids, expert_ids, + num_tokens_post_padded, qweight_type2, + w2.shape[1], 1, num_tokens * top_k) + out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_( + topk_weights.view(num_tokens, top_k, 1)) + ops.moe_sum(out, out_hidden_states) + else: + logger.warning_once("There is no support for fast MoE kernel " + "for current quantization method. " + "Falling back to slow implementation. ") + for tok, (w, idx) in enumerate(zip(topk_weights, topk_ids)): + inp = x[tok].reshape((1, ) + x.shape[1:]) + current_hidden_state = None + for ww, ii in zip(w, idx): + expert_up = w1[ii] + + out = _fuse_mul_mat(inp, expert_up, qweight_type) + out = act(out) + + expert_down = w2[ii] + current_state = _fuse_mul_mat(out, expert_down, + qweight_type2).mul_(ww) + if current_hidden_state is None: + current_hidden_state = current_state + else: + current_hidden_state.add_(current_state) + out_hidden_states[tok] = current_hidden_state + return out_hidden_states + + class GGUFLinearMethod(LinearMethodBase): """Linear method for GGUF. @@ -285,27 +342,10 @@ class GGUFMoEMethod(FusedMoEMethodBase): custom_routing_function=custom_routing_function, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias) - final_hidden_states = torch.empty_like(x) - for tok, (w, idx) in enumerate(zip(topk_weights, topk_ids)): - inp = x[tok].reshape((1, ) + x.shape[1:]) - current_hidden_state = None - for ww, ii in zip(w, idx): - expert_up = layer.w13_qweight[ii] - - out = _fuse_mul_mat(inp, expert_up, - layer.w13_qweight_type.weight_type) - out = self.act(out) - - expert_down = layer.w2_qweight[ii] - current_state = _fuse_mul_mat( - out, expert_down, - layer.w2_qweight_type.weight_type).mul_(ww) - if current_hidden_state is None: - current_hidden_state = current_state - else: - current_hidden_state.add_(current_state) - final_hidden_states[tok] = current_hidden_state - return final_hidden_states + return _fused_moe_gguf(x, layer.w13_qweight, layer.w2_qweight, + topk_weights, topk_ids, + layer.w13_qweight_type.weight_type, + layer.w2_qweight_type.weight_type, self.act) class GGUFEmbeddingMethod(GGUFLinearMethod):