mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 20:15:01 +08:00
[Kernel] GGUF MoeVec kernel (#16780)
Signed-off-by: SzymonOzog <szymon.ozog@aleph-alpha.com> Signed-off-by: SzymonOzog <szymon.ozog@gmail.com> Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
c3e9d5060e
commit
1a45a61387
@ -178,6 +178,10 @@ torch::Tensor ggml_moe_a8(torch::Tensor X, torch::Tensor W,
|
||||
torch::Tensor num_tokens_post_padded, int64_t type,
|
||||
int64_t row, int64_t top_k, int64_t tokens);
|
||||
|
||||
torch::Tensor ggml_moe_a8_vec(torch::Tensor X, torch::Tensor W,
|
||||
torch::Tensor topk_ids, int64_t top_k,
|
||||
int64_t type, int64_t row, int64_t tokens);
|
||||
|
||||
int64_t ggml_moe_get_block_size(int64_t type);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
#include "mmvq.cuh"
|
||||
#include "mmq.cuh"
|
||||
#include "moe.cuh"
|
||||
#include "moe_vec.cuh"
|
||||
|
||||
// Q8 gemv
|
||||
template <typename scalar_t>
|
||||
@ -377,6 +378,142 @@ torch::Tensor ggml_moe_a8(torch::Tensor X, // input
|
||||
return Y;
|
||||
}
|
||||
|
||||
torch::Tensor ggml_moe_a8_vec(torch::Tensor X, // input
|
||||
torch::Tensor W, // expert weights
|
||||
torch::Tensor topk_ids, int64_t top_k,
|
||||
int64_t type, int64_t row, int64_t tokens) {
|
||||
int col = X.sizes()[1];
|
||||
const int padded = (col + 512 - 1) / 512 * 512;
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
|
||||
auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device());
|
||||
at::Tensor Y = torch::zeros({tokens * top_k, row}, options);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
options = torch::TensorOptions().dtype(torch::kInt32).device(W.device());
|
||||
at::Tensor quant_X = torch::empty({tokens, padded / 32 * 9}, options);
|
||||
VLLM_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_moe_vec_a8", [&] {
|
||||
quantize_row_q8_1_cuda<scalar_t>((scalar_t*)X.data_ptr(),
|
||||
(void*)quant_X.data_ptr(), col, tokens,
|
||||
stream);
|
||||
switch (type) {
|
||||
case 2:
|
||||
moe_vec_q4_0_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
|
||||
col, row, quant_X.stride(0), stream);
|
||||
break;
|
||||
case 3:
|
||||
moe_vec_q4_1_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
|
||||
col, row, quant_X.stride(0), stream);
|
||||
break;
|
||||
case 6:
|
||||
moe_vec_q5_0_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
|
||||
col, row, quant_X.stride(0), stream);
|
||||
break;
|
||||
case 7:
|
||||
moe_vec_q5_1_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
|
||||
col, row, quant_X.stride(0), stream);
|
||||
break;
|
||||
case 8:
|
||||
moe_vec_q8_0_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
|
||||
col, row, quant_X.stride(0), stream);
|
||||
break;
|
||||
case 10:
|
||||
moe_vec_q2_K_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
|
||||
col, row, quant_X.stride(0), stream);
|
||||
break;
|
||||
case 11:
|
||||
moe_vec_q3_K_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
|
||||
col, row, quant_X.stride(0), stream);
|
||||
break;
|
||||
case 12:
|
||||
moe_vec_q4_K_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
|
||||
col, row, quant_X.stride(0), stream);
|
||||
break;
|
||||
case 13:
|
||||
moe_vec_q5_K_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
|
||||
col, row, quant_X.stride(0), stream);
|
||||
break;
|
||||
case 14:
|
||||
moe_vec_q6_K_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
|
||||
col, row, quant_X.stride(0), stream);
|
||||
break;
|
||||
case 16:
|
||||
moe_vec_iq2_xxs_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
|
||||
col, row, quant_X.stride(0), stream);
|
||||
break;
|
||||
case 17:
|
||||
moe_vec_iq2_xs_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
|
||||
col, row, quant_X.stride(0), stream);
|
||||
break;
|
||||
case 18:
|
||||
moe_vec_iq3_xxs_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
|
||||
col, row, quant_X.stride(0), stream);
|
||||
break;
|
||||
case 19:
|
||||
moe_vec_iq1_s_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
|
||||
col, row, quant_X.stride(0), stream);
|
||||
break;
|
||||
case 20:
|
||||
moe_vec_iq4_nl_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
|
||||
col, row, quant_X.stride(0), stream);
|
||||
break;
|
||||
case 21:
|
||||
moe_vec_iq3_s_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
|
||||
col, row, quant_X.stride(0), stream);
|
||||
break;
|
||||
case 22:
|
||||
moe_vec_iq2_s_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
|
||||
col, row, quant_X.stride(0), stream);
|
||||
break;
|
||||
case 23:
|
||||
moe_vec_iq4_xs_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
|
||||
col, row, quant_X.stride(0), stream);
|
||||
break;
|
||||
case 29:
|
||||
moe_vec_iq1_m_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
|
||||
col, row, quant_X.stride(0), stream);
|
||||
break;
|
||||
}
|
||||
});
|
||||
return Y;
|
||||
}
|
||||
|
||||
int64_t ggml_moe_get_block_size(int64_t type) {
|
||||
switch (type) {
|
||||
case 2:
|
||||
|
||||
338
csrc/quantization/gguf/moe_vec.cuh
Normal file
338
csrc/quantization/gguf/moe_vec.cuh
Normal file
@ -0,0 +1,338 @@
|
||||
// copied and adapted from
|
||||
// https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmvq.cu
|
||||
template <typename scalar_t, int qk, int qi, typename block_q_t, int vdr,
|
||||
vec_dot_q_cuda_t vec_dot_q_cuda>
|
||||
static __global__ void moe_vec_q(const void* __restrict__ vx,
|
||||
const void* __restrict__ vy,
|
||||
scalar_t* __restrict__ dst,
|
||||
const int* topk_ids, const int topk,
|
||||
const int ncols, const int nrows,
|
||||
const int token_stride) {
|
||||
const auto row = blockIdx.x * blockDim.y + threadIdx.y;
|
||||
|
||||
const auto token = blockIdx.z / topk;
|
||||
const auto expert = (topk_ids)[blockIdx.z];
|
||||
|
||||
if (row >= nrows) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int blocks_per_row = ncols / qk;
|
||||
const int blocks_per_warp = vdr * WARP_SIZE / qi;
|
||||
|
||||
// partial sum for each thread
|
||||
float tmp = 0.0f;
|
||||
|
||||
const block_q_t* x = ((const block_q_t*)vx) + expert * nrows * blocks_per_row;
|
||||
const block_q8_1* y =
|
||||
(const block_q8_1*)(((const int*)vy) + token * token_stride);
|
||||
|
||||
for (auto i = threadIdx.x / (qi / vdr); i < blocks_per_row;
|
||||
i += blocks_per_warp) {
|
||||
const int ibx = row * blocks_per_row + i; // x block index
|
||||
|
||||
const int iby = i * (qk / QK8_1); // y block index that aligns with ibx
|
||||
|
||||
const int iqs =
|
||||
vdr *
|
||||
(threadIdx.x %
|
||||
(qi / vdr)); // x block quant index when casting the quants to int
|
||||
|
||||
tmp += vec_dot_q_cuda(&x[ibx], &y[iby], iqs);
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||
tmp += VLLM_SHFL_XOR_SYNC(tmp, mask);
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
dst[blockIdx.z * nrows + row] = tmp;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void moe_vec_q4_0_q8_1_cuda(const void* vx, const void* vy,
|
||||
scalar_t* dst, const int* topk_ids,
|
||||
const int top_k, const int tokens,
|
||||
const int ncols, const int nrows,
|
||||
const int token_stride,
|
||||
cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, tokens * top_k);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
moe_vec_q<scalar_t, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ,
|
||||
vec_dot_q4_0_q8_1><<<block_nums, block_dims, 0, stream>>>(
|
||||
vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void moe_vec_q4_1_q8_1_cuda(const void* vx, const void* vy,
|
||||
scalar_t* dst, const int* topk_ids,
|
||||
const int top_k, const int tokens,
|
||||
const int ncols, const int nrows,
|
||||
const int token_stride,
|
||||
cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, tokens * top_k);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
moe_vec_q<scalar_t, QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ,
|
||||
vec_dot_q4_1_q8_1><<<block_nums, block_dims, 0, stream>>>(
|
||||
vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void moe_vec_q5_0_q8_1_cuda(const void* vx, const void* vy,
|
||||
scalar_t* dst, const int* topk_ids,
|
||||
const int top_k, const int tokens,
|
||||
const int ncols, const int nrows,
|
||||
const int token_stride,
|
||||
cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, tokens * top_k);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
moe_vec_q<scalar_t, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ,
|
||||
vec_dot_q5_0_q8_1><<<block_nums, block_dims, 0, stream>>>(
|
||||
vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void moe_vec_q5_1_q8_1_cuda(const void* vx, const void* vy,
|
||||
scalar_t* dst, const int* topk_ids,
|
||||
const int top_k, const int tokens,
|
||||
const int ncols, const int nrows,
|
||||
const int token_stride,
|
||||
cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, tokens * top_k);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
moe_vec_q<scalar_t, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ,
|
||||
vec_dot_q5_1_q8_1><<<block_nums, block_dims, 0, stream>>>(
|
||||
vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void moe_vec_q8_0_q8_1_cuda(const void* vx, const void* vy,
|
||||
scalar_t* dst, const int* topk_ids,
|
||||
const int top_k, const int tokens,
|
||||
const int ncols, const int nrows,
|
||||
const int token_stride,
|
||||
cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, tokens * top_k);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
moe_vec_q<scalar_t, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ,
|
||||
vec_dot_q8_0_q8_1><<<block_nums, block_dims, 0, stream>>>(
|
||||
vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void moe_vec_q2_K_q8_1_cuda(const void* vx, const void* vy,
|
||||
scalar_t* dst, const int* topk_ids,
|
||||
const int top_k, const int tokens,
|
||||
const int ncols, const int nrows,
|
||||
const int token_stride,
|
||||
cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, tokens * top_k);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
moe_vec_q<scalar_t, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ,
|
||||
vec_dot_q2_K_q8_1><<<block_nums, block_dims, 0, stream>>>(
|
||||
vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void moe_vec_q3_K_q8_1_cuda(const void* vx, const void* vy,
|
||||
scalar_t* dst, const int* topk_ids,
|
||||
const int top_k, const int tokens,
|
||||
const int ncols, const int nrows,
|
||||
const int token_stride,
|
||||
cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, tokens * top_k);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
moe_vec_q<scalar_t, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ,
|
||||
vec_dot_q3_K_q8_1><<<block_nums, block_dims, 0, stream>>>(
|
||||
vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void moe_vec_q4_K_q8_1_cuda(const void* vx, const void* vy,
|
||||
scalar_t* dst, const int* topk_ids,
|
||||
const int top_k, const int tokens,
|
||||
const int ncols, const int nrows,
|
||||
const int token_stride,
|
||||
cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, tokens * top_k);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
moe_vec_q<scalar_t, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ,
|
||||
vec_dot_q4_K_q8_1><<<block_nums, block_dims, 0, stream>>>(
|
||||
vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void moe_vec_q5_K_q8_1_cuda(const void* vx, const void* vy,
|
||||
scalar_t* dst, const int* topk_ids,
|
||||
const int top_k, const int tokens,
|
||||
const int ncols, const int nrows,
|
||||
const int token_stride,
|
||||
cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, tokens * top_k);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
moe_vec_q<scalar_t, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ,
|
||||
vec_dot_q5_K_q8_1><<<block_nums, block_dims, 0, stream>>>(
|
||||
vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void moe_vec_q6_K_q8_1_cuda(const void* vx, const void* vy,
|
||||
scalar_t* dst, const int* topk_ids,
|
||||
const int top_k, const int tokens,
|
||||
const int ncols, const int nrows,
|
||||
const int token_stride,
|
||||
cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, tokens * top_k);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
moe_vec_q<scalar_t, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ,
|
||||
vec_dot_q6_K_q8_1><<<block_nums, block_dims, 0, stream>>>(
|
||||
vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void moe_vec_iq2_xxs_q8_1_cuda(const void* vx, const void* vy,
|
||||
scalar_t* dst, const int* topk_ids,
|
||||
const int top_k, const int tokens,
|
||||
const int ncols, const int nrows,
|
||||
const int token_stride,
|
||||
cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, tokens * top_k);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
moe_vec_q<scalar_t, QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k,
|
||||
ncols, nrows, token_stride);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void moe_vec_iq2_xs_q8_1_cuda(const void* vx, const void* vy,
|
||||
scalar_t* dst, const int* topk_ids,
|
||||
const int top_k, const int tokens,
|
||||
const int ncols, const int nrows,
|
||||
const int token_stride,
|
||||
cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, tokens * top_k);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
moe_vec_q<scalar_t, QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k,
|
||||
ncols, nrows, token_stride);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void moe_vec_iq2_s_q8_1_cuda(const void* vx, const void* vy,
|
||||
scalar_t* dst, const int* topk_ids,
|
||||
const int top_k, const int tokens,
|
||||
const int ncols, const int nrows,
|
||||
const int token_stride,
|
||||
cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, tokens * top_k);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
moe_vec_q<scalar_t, QK_K, QI2_S, block_iq2_s, 1, vec_dot_iq2_s_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k,
|
||||
ncols, nrows, token_stride);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void moe_vec_iq3_xxs_q8_1_cuda(const void* vx, const void* vy,
|
||||
scalar_t* dst, const int* topk_ids,
|
||||
const int top_k, const int tokens,
|
||||
const int ncols, const int nrows,
|
||||
const int token_stride,
|
||||
cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, tokens * top_k);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
moe_vec_q<scalar_t, QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k,
|
||||
ncols, nrows, token_stride);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void moe_vec_iq1_s_q8_1_cuda(const void* vx, const void* vy,
|
||||
scalar_t* dst, const int* topk_ids,
|
||||
const int top_k, const int tokens,
|
||||
const int ncols, const int nrows,
|
||||
const int token_stride,
|
||||
cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, tokens * top_k);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
moe_vec_q<scalar_t, QK_K, QI1_S, block_iq1_s, 1, vec_dot_iq1_s_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k,
|
||||
ncols, nrows, token_stride);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void moe_vec_iq1_m_q8_1_cuda(const void* vx, const void* vy,
|
||||
scalar_t* dst, const int* topk_ids,
|
||||
const int top_k, const int tokens,
|
||||
const int ncols, const int nrows,
|
||||
const int token_stride,
|
||||
cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, tokens * top_k);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
moe_vec_q<scalar_t, QK_K, QI1_M, block_iq1_m, 1, vec_dot_iq1_m_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k,
|
||||
ncols, nrows, token_stride);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void moe_vec_iq4_nl_q8_1_cuda(const void* vx, const void* vy,
|
||||
scalar_t* dst, const int* topk_ids,
|
||||
const int top_k, const int tokens,
|
||||
const int ncols, const int nrows,
|
||||
const int token_stride,
|
||||
cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, tokens * top_k);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
moe_vec_q<scalar_t, QK4_NL, QI4_NL, block_iq4_nl, VDR_Q4_0_Q8_1_MMVQ,
|
||||
vec_dot_iq4_nl_q8_1><<<block_nums, block_dims, 0, stream>>>(
|
||||
vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void moe_vec_iq4_xs_q8_1_cuda(const void* vx, const void* vy,
|
||||
scalar_t* dst, const int* topk_ids,
|
||||
const int top_k, const int tokens,
|
||||
const int ncols, const int nrows,
|
||||
const int token_stride,
|
||||
cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, tokens * top_k);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
moe_vec_q<scalar_t, QK_K, QI4_XS, block_iq4_xs, 1, vec_dot_iq4_xs_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k,
|
||||
ncols, nrows, token_stride);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void moe_vec_iq3_s_q8_1_cuda(const void* vx, const void* vy,
|
||||
scalar_t* dst, const int* topk_ids,
|
||||
const int top_k, const int tokens,
|
||||
const int ncols, const int nrows,
|
||||
const int token_stride,
|
||||
cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, tokens * top_k);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
moe_vec_q<scalar_t, QK_K, QI3_XS, block_iq3_s, 1, vec_dot_iq3_s_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k,
|
||||
ncols, nrows, token_stride);
|
||||
}
|
||||
@ -337,6 +337,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"int type, SymInt row, SymInt top_k, SymInt tokens) -> Tensor");
|
||||
ops.impl("ggml_moe_a8", torch::kCUDA, &ggml_moe_a8);
|
||||
|
||||
ops.def(
|
||||
"ggml_moe_a8_vec(Tensor X, Tensor W, "
|
||||
"Tensor topk_ids, int top_k, "
|
||||
"int type, SymInt row, SymInt tokens) -> Tensor");
|
||||
ops.impl("ggml_moe_a8_vec", torch::kCUDA, &ggml_moe_a8_vec);
|
||||
|
||||
ops.def("ggml_moe_get_block_size", &ggml_moe_get_block_size);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
|
||||
@ -36,3 +36,9 @@ def test_ggml_opcheck(quant_type):
|
||||
opcheck(torch.ops._C.ggml_moe_a8,
|
||||
(x, qweight, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
quant_type, qweight.shape[0], 1, x.shape[0]))
|
||||
|
||||
topk_ids = torch.zeros((1, 1), device='cuda', dtype=torch.int32)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C.ggml_moe_a8_vec,
|
||||
(x, qweight, topk_ids, 1, quant_type, qweight.shape[0], x.shape[0]))
|
||||
|
||||
@ -151,20 +151,7 @@ def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype,
|
||||
@pytest.mark.parametrize("hidden_size", [512])
|
||||
@pytest.mark.parametrize("top_k", [4, 8])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize(
|
||||
"quant_type",
|
||||
[
|
||||
# k-quants
|
||||
GGMLQuantizationType.Q2_K,
|
||||
GGMLQuantizationType.Q3_K,
|
||||
GGMLQuantizationType.Q4_K,
|
||||
GGMLQuantizationType.Q5_K,
|
||||
GGMLQuantizationType.Q6_K,
|
||||
# standard quants
|
||||
GGMLQuantizationType.Q4_0,
|
||||
GGMLQuantizationType.Q5_0,
|
||||
GGMLQuantizationType.Q8_0,
|
||||
])
|
||||
@pytest.mark.parametrize("quant_type", QUANT_TYPES)
|
||||
@torch.inference_mode()
|
||||
def test_moe(num_tokens: int, hidden_size: int, dtype: torch.dtype,
|
||||
quant_type: GGMLQuantizationType, top_k: int):
|
||||
@ -174,7 +161,10 @@ def test_moe(num_tokens: int, hidden_size: int, dtype: torch.dtype,
|
||||
x = torch.rand((num_tokens, H), dtype=dtype, device="cuda")
|
||||
|
||||
topk_weights = torch.rand(num_tokens, top_k, device="cuda", dtype=dtype)
|
||||
topk_ids = torch.randint(0, E, (num_tokens, top_k), device="cuda")
|
||||
topk_ids = torch.randint(0,
|
||||
E, (num_tokens, top_k),
|
||||
device="cuda",
|
||||
dtype=torch.int32)
|
||||
|
||||
tensors = get_gguf_MoE_tensors(hidden_size, quant_type)
|
||||
|
||||
|
||||
@ -497,6 +497,24 @@ if hasattr(torch.ops._C, "ggml_dequantize"):
|
||||
device=W.device)
|
||||
|
||||
|
||||
if hasattr(torch.ops._C, "ggml_moe_a8_vec"):
|
||||
|
||||
@register_fake("_C::ggml_moe_a8_vec")
|
||||
def _ggml_moe_a8_vec_fake(
|
||||
X: torch.Tensor,
|
||||
W: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
top_k: int,
|
||||
quant_type: int,
|
||||
row: torch.SymInt,
|
||||
tokens: torch.SymInt,
|
||||
) -> torch.Tensor:
|
||||
tokens = X.size(0)
|
||||
return torch.empty((tokens * top_k, row),
|
||||
dtype=X.dtype,
|
||||
device=W.device)
|
||||
|
||||
|
||||
# cutlass
|
||||
def cutlass_scaled_mm_supports_fp4(cuda_device_capability: int) -> bool:
|
||||
return torch.ops._C.cutlass_scaled_mm_supports_fp4(cuda_device_capability)
|
||||
@ -1146,6 +1164,19 @@ def ggml_moe_a8(
|
||||
top_k, tokens)
|
||||
|
||||
|
||||
def ggml_moe_a8_vec(
|
||||
X: torch.Tensor,
|
||||
W: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
top_k: int,
|
||||
quant_type: int,
|
||||
row: torch.SymInt,
|
||||
tokens: torch.SymInt,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops._C.ggml_moe_a8_vec(X, W, topk_ids, top_k, quant_type, row,
|
||||
tokens)
|
||||
|
||||
|
||||
def ggml_moe_get_block_size(quant_type: int) -> int:
|
||||
return torch.ops._C.ggml_moe_get_block_size(quant_type)
|
||||
|
||||
|
||||
@ -145,7 +145,9 @@ def _fused_moe_gguf(
|
||||
moe_align_block_size)
|
||||
|
||||
out_hidden_states = torch.empty_like(x)
|
||||
if qweight_type2 in MMQ_QUANT_TYPES and qweight_type in MMQ_QUANT_TYPES:
|
||||
# unless we decent expert reuse we are better off running moe_vec kernel
|
||||
if (qweight_type2 in MMQ_QUANT_TYPES and qweight_type in MMQ_QUANT_TYPES
|
||||
and x.shape[0] > 64):
|
||||
num_tokens, _ = x.shape
|
||||
E, N, _ = w1.shape
|
||||
top_k = topk_ids.shape[1]
|
||||
@ -163,6 +165,20 @@ def _fused_moe_gguf(
|
||||
out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_(
|
||||
topk_weights.view(num_tokens, top_k, 1))
|
||||
ops.moe_sum(out, out_hidden_states)
|
||||
elif qweight_type2 in MMVQ_QUANT_TYPES and qweight_type in MMVQ_QUANT_TYPES:
|
||||
num_tokens, _ = x.shape
|
||||
E, N, _ = w1.shape
|
||||
top_k = topk_ids.shape[1]
|
||||
|
||||
out = ops.ggml_moe_a8_vec(x, w1, topk_ids, top_k, qweight_type, N,
|
||||
num_tokens)
|
||||
out = act(out)
|
||||
|
||||
out = ops.ggml_moe_a8_vec(out, w2, topk_ids, 1, qweight_type2,
|
||||
w2.shape[1], num_tokens * top_k)
|
||||
out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_(
|
||||
topk_weights.view(num_tokens, top_k, 1))
|
||||
ops.moe_sum(out, out_hidden_states)
|
||||
else:
|
||||
logger.warning_once("There is no support for fast MoE kernel "
|
||||
"for current quantization method. "
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user