mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-19 01:17:18 +08:00
[Kernel] GGUF MMVQ kernel for multiple input vectors (#18754)
Signed-off-by: SzymonOzog <szymon.ozog@gmail.com>
This commit is contained in:
parent
8d120701fd
commit
dec66d253b
@ -92,111 +92,112 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight
|
||||
torch::Tensor X, // input
|
||||
int64_t type, int64_t row) {
|
||||
int col = X.sizes()[1];
|
||||
int vecs = X.sizes()[0];
|
||||
const int padded = (col + 512 - 1) / 512 * 512;
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
|
||||
auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device());
|
||||
at::Tensor Y = torch::empty({1, row}, options);
|
||||
at::Tensor Y = torch::empty({vecs, row}, options);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
options = torch::TensorOptions().dtype(torch::kInt32).device(W.device());
|
||||
at::Tensor quant_X = torch::empty({1, padded / 32 * 9}, options);
|
||||
at::Tensor quant_X = torch::empty({vecs, padded / 32 * 9}, options);
|
||||
VLLM_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_mul_mat_vec_a8", [&] {
|
||||
quantize_row_q8_1_cuda<scalar_t>((scalar_t*)X.data_ptr(),
|
||||
(void*)quant_X.data_ptr(), col, 1, stream);
|
||||
quantize_row_q8_1_cuda<scalar_t>(
|
||||
(scalar_t*)X.data_ptr(), (void*)quant_X.data_ptr(), col, vecs, stream);
|
||||
switch (type) {
|
||||
case 2:
|
||||
mul_mat_vec_q4_0_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 3:
|
||||
mul_mat_vec_q4_1_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 6:
|
||||
mul_mat_vec_q5_0_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 7:
|
||||
mul_mat_vec_q5_1_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 8:
|
||||
mul_mat_vec_q8_0_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 10:
|
||||
mul_mat_vec_q2_K_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 11:
|
||||
mul_mat_vec_q3_K_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 12:
|
||||
mul_mat_vec_q4_K_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 13:
|
||||
mul_mat_vec_q5_K_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 14:
|
||||
mul_mat_vec_q6_K_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 16:
|
||||
mul_mat_vec_iq2_xxs_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 17:
|
||||
mul_mat_vec_iq2_xs_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 18:
|
||||
mul_mat_vec_iq3_xxs_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 19:
|
||||
mul_mat_vec_iq1_s_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 20:
|
||||
mul_mat_vec_iq4_nl_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 21:
|
||||
mul_mat_vec_iq3_s_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 22:
|
||||
mul_mat_vec_iq2_s_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 23:
|
||||
mul_mat_vec_iq4_xs_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 29:
|
||||
mul_mat_vec_iq1_m_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
}
|
||||
});
|
||||
|
||||
@ -1,16 +1,19 @@
|
||||
// copied and adapted from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmvq.cu
|
||||
template <typename scalar_t, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
|
||||
static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst, const int ncols, const int nrows) {
|
||||
static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst, const int ncols, const int nrows, const int nvecs) {
|
||||
const auto row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
const auto vec = blockIdx.y;
|
||||
|
||||
if (row >= nrows) {
|
||||
if (row >= nrows || vec >= nvecs) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int blocks_per_row = ncols / qk;
|
||||
const int blocks_per_warp = vdr * WARP_SIZE / qi;
|
||||
const int nrows_y = (ncols + 512 - 1) / 512 * 512;
|
||||
|
||||
// partial sum for each thread
|
||||
|
||||
// partial sum for each thread
|
||||
float tmp = 0.0f;
|
||||
|
||||
const block_q_t * x = (const block_q_t *) vx;
|
||||
@ -19,7 +22,7 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void *
|
||||
for (auto i = threadIdx.x / (qi/vdr); i < blocks_per_row; i += blocks_per_warp) {
|
||||
const int ibx = row*blocks_per_row + i; // x block index
|
||||
|
||||
const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
|
||||
const int iby = vec*(nrows_y/QK8_1) + i * (qk/QK8_1); // y block index that aligns with ibx
|
||||
|
||||
const int iqs = vdr * (threadIdx.x % (qi/vdr)); // x block quant index when casting the quants to int
|
||||
|
||||
@ -33,177 +36,177 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void *
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
dst[row] = tmp;
|
||||
dst[vec*nrows + row] = tmp;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_nums(block_num_y, nvecs, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<scalar_t, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_nums(block_num_y, nvecs, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<scalar_t, QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_nums(block_num_y, nvecs, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<scalar_t, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_nums(block_num_y, nvecs, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<scalar_t, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_nums(block_num_y, nvecs, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<scalar_t, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_nums(block_num_y, nvecs, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<scalar_t, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_nums(block_num_y, nvecs, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<scalar_t, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_nums(block_num_y, nvecs, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<scalar_t, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_nums(block_num_y, nvecs, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<scalar_t, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_nums(block_num_y, nvecs, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<scalar_t, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_iq2_xxs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
static void mul_mat_vec_iq2_xxs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_nums(block_num_y, nvecs, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<scalar_t, QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_iq2_xs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
static void mul_mat_vec_iq2_xs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_nums(block_num_y, nvecs, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<scalar_t, QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_iq2_s_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
static void mul_mat_vec_iq2_s_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_nums(block_num_y, nvecs, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<scalar_t, QK_K, QI2_S, block_iq2_s, 1, vec_dot_iq2_s_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_iq3_xxs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
static void mul_mat_vec_iq3_xxs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_nums(block_num_y, nvecs, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<scalar_t, QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_iq1_s_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
static void mul_mat_vec_iq1_s_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_nums(block_num_y, nvecs, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<scalar_t, QK_K, QI1_S, block_iq1_s, 1, vec_dot_iq1_s_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_iq1_m_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
static void mul_mat_vec_iq1_m_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_nums(block_num_y, nvecs, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<scalar_t, QK_K, QI1_M, block_iq1_m, 1, vec_dot_iq1_m_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_iq4_nl_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
static void mul_mat_vec_iq4_nl_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_nums(block_num_y, nvecs, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<scalar_t, QK4_NL, QI4_NL, block_iq4_nl, VDR_Q4_0_Q8_1_MMVQ, vec_dot_iq4_nl_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_iq4_xs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
static void mul_mat_vec_iq4_xs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_nums(block_num_y, nvecs, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<scalar_t, QK_K, QI4_XS, block_iq4_xs, 1, vec_dot_iq4_xs_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_iq3_s_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
static void mul_mat_vec_iq3_s_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_nums(block_num_y, nvecs, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<scalar_t, QK_K, QI3_XS, block_iq3_s, 1, vec_dot_iq3_s_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
|
||||
}
|
||||
|
||||
@ -594,7 +594,7 @@ if hasattr(torch.ops._C, "ggml_dequantize"):
|
||||
quant_type: int,
|
||||
row: torch.SymInt,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty((1, row), dtype=X.dtype, device=W.device)
|
||||
return torch.empty((X.shape[0], row), dtype=X.dtype, device=W.device)
|
||||
|
||||
@register_fake("_C::ggml_mul_mat_a8")
|
||||
def _ggml_mul_mat_a8_fake(
|
||||
|
||||
@ -99,6 +99,10 @@ MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES
|
||||
|
||||
def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor,
|
||||
qweight_type: int) -> torch.Tensor:
|
||||
if qweight_type in IMATRIX_QUANT_TYPES:
|
||||
mmvq_safe = 8 if qweight.shape[0] > 5120 else 16
|
||||
else:
|
||||
mmvq_safe = 2 if qweight.shape[0] > 5120 else 6
|
||||
# HACK: when doing chunked prefill we don't generate output tokens
|
||||
# so input to logits generator is empty which causes invalid parameter
|
||||
if x.shape[0] == 0:
|
||||
@ -110,7 +114,7 @@ def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor,
|
||||
if qweight_type in UNQUANTIZED_TYPES:
|
||||
return x @ qweight.T
|
||||
# enable MMVQ in contiguous batching with batch_size=1
|
||||
if x.shape[0] == 1 and qweight_type in MMVQ_QUANT_TYPES:
|
||||
if x.shape[0] <= mmvq_safe and qweight_type in MMVQ_QUANT_TYPES:
|
||||
y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
|
||||
# Use MMQ Kernel if it's available (standard + k-quants)
|
||||
elif qweight_type in MMQ_QUANT_TYPES:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user