mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:06:06 +08:00
[Kernel] Add more dtype support for GGUF kernels (#14043)
Signed-off-by: SzymonOzog <szymon.ozog@aleph-alpha.com> Signed-off-by: SzymonOzog <szymon.ozog@gmail.com>
This commit is contained in:
parent
b0746fae3d
commit
89cdaa83e7
@ -5,6 +5,7 @@
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#include "ggml-common.h"
|
||||
#include "vecdotq.cuh"
|
||||
@ -13,7 +14,8 @@
|
||||
#include "mmq.cuh"
|
||||
|
||||
// Q8 gemv
|
||||
static __global__ void quantize_q8_1(const half* __restrict__ x,
|
||||
template <typename scalar_t>
|
||||
static __global__ void quantize_q8_1(const scalar_t* __restrict__ x,
|
||||
void* __restrict__ vy, const int kx,
|
||||
const int kx_padded) {
|
||||
const int ix = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
@ -28,7 +30,7 @@ static __global__ void quantize_q8_1(const half* __restrict__ x,
|
||||
const int ib = i_padded / QK8_1; // block index
|
||||
const int iqs = i_padded % QK8_1; // quant index
|
||||
|
||||
const float xi = ix < kx ? __half2float(x[iy * kx + ix]) : 0.0f;
|
||||
const float xi = ix < kx ? static_cast<float>(x[iy * kx + ix]) : 0.0f;
|
||||
float amax = fabsf(xi);
|
||||
float sum = xi;
|
||||
|
||||
@ -51,14 +53,16 @@ static __global__ void quantize_q8_1(const half* __restrict__ x,
|
||||
y[ib].ds.y = __float2half(sum);
|
||||
}
|
||||
|
||||
static void quantize_row_q8_1_cuda(const half* x, void* vy, const int kx,
|
||||
template <typename scalar_t>
|
||||
static void quantize_row_q8_1_cuda(const scalar_t* x, void* vy, const int kx,
|
||||
const int ky, cudaStream_t stream) {
|
||||
const int64_t kx_padded = (kx + 512 - 1) / 512 * 512;
|
||||
const int block_num_x =
|
||||
(kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
|
||||
const dim3 num_blocks(block_num_x, ky, 1);
|
||||
const dim3 block_size(CUDA_DEQUANTIZE_BLOCK_SIZE, 1, 1);
|
||||
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded);
|
||||
quantize_q8_1<scalar_t>
|
||||
<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded);
|
||||
}
|
||||
|
||||
torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight
|
||||
@ -79,101 +83,112 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight
|
||||
int col = X.sizes()[1];
|
||||
const int padded = (col + 512 - 1) / 512 * 512;
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
|
||||
auto options =
|
||||
torch::TensorOptions().dtype(torch::kFloat16).device(W.device());
|
||||
auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device());
|
||||
at::Tensor Y = torch::empty({1, row}, options);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
options = torch::TensorOptions().dtype(torch::kInt32).device(W.device());
|
||||
at::Tensor quant_X = torch::empty({1, padded / 32 * 9}, options);
|
||||
quantize_row_q8_1_cuda((half*)X.data_ptr(), (void*)quant_X.data_ptr(), col, 1,
|
||||
stream);
|
||||
switch (type) {
|
||||
case 2:
|
||||
mul_mat_vec_q4_0_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 3:
|
||||
mul_mat_vec_q4_1_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 6:
|
||||
mul_mat_vec_q5_0_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 7:
|
||||
mul_mat_vec_q5_1_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 8:
|
||||
mul_mat_vec_q8_0_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 10:
|
||||
mul_mat_vec_q2_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 11:
|
||||
mul_mat_vec_q3_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 12:
|
||||
mul_mat_vec_q4_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 13:
|
||||
mul_mat_vec_q5_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 14:
|
||||
mul_mat_vec_q6_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 16:
|
||||
mul_mat_vec_iq2_xxs_q8_1_cuda((void*)W.data_ptr(),
|
||||
(void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 17:
|
||||
mul_mat_vec_iq2_xs_q8_1_cuda((void*)W.data_ptr(),
|
||||
(void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 18:
|
||||
mul_mat_vec_iq3_xxs_q8_1_cuda((void*)W.data_ptr(),
|
||||
(void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 19:
|
||||
mul_mat_vec_iq1_s_q8_1_cuda((void*)W.data_ptr(),
|
||||
(void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 20:
|
||||
mul_mat_vec_iq4_nl_q8_1_cuda((void*)W.data_ptr(),
|
||||
(void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 21:
|
||||
mul_mat_vec_iq3_s_q8_1_cuda((void*)W.data_ptr(),
|
||||
(void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 22:
|
||||
mul_mat_vec_iq2_s_q8_1_cuda((void*)W.data_ptr(),
|
||||
(void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 23:
|
||||
mul_mat_vec_iq4_xs_q8_1_cuda((void*)W.data_ptr(),
|
||||
(void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 29:
|
||||
mul_mat_vec_iq1_m_q8_1_cuda((void*)W.data_ptr(),
|
||||
(void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
}
|
||||
VLLM_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_mul_mat_vec_a8", [&] {
|
||||
quantize_row_q8_1_cuda<scalar_t>((scalar_t*)X.data_ptr(),
|
||||
(void*)quant_X.data_ptr(), col, 1, stream);
|
||||
switch (type) {
|
||||
case 2:
|
||||
mul_mat_vec_q4_0_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 3:
|
||||
mul_mat_vec_q4_1_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 6:
|
||||
mul_mat_vec_q5_0_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 7:
|
||||
mul_mat_vec_q5_1_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 8:
|
||||
mul_mat_vec_q8_0_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 10:
|
||||
mul_mat_vec_q2_K_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 11:
|
||||
mul_mat_vec_q3_K_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 12:
|
||||
mul_mat_vec_q4_K_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 13:
|
||||
mul_mat_vec_q5_K_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 14:
|
||||
mul_mat_vec_q6_K_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 16:
|
||||
mul_mat_vec_iq2_xxs_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 17:
|
||||
mul_mat_vec_iq2_xs_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 18:
|
||||
mul_mat_vec_iq3_xxs_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 19:
|
||||
mul_mat_vec_iq1_s_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 20:
|
||||
mul_mat_vec_iq4_nl_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 21:
|
||||
mul_mat_vec_iq3_s_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 22:
|
||||
mul_mat_vec_iq2_s_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 23:
|
||||
mul_mat_vec_iq4_xs_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 29:
|
||||
mul_mat_vec_iq1_m_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
}
|
||||
});
|
||||
return Y;
|
||||
}
|
||||
|
||||
@ -184,66 +199,67 @@ torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight
|
||||
int padded = (col + 512 - 1) / 512 * 512;
|
||||
int batch = X.sizes()[0];
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
|
||||
auto options =
|
||||
torch::TensorOptions().dtype(torch::kFloat16).device(W.device());
|
||||
auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device());
|
||||
at::Tensor Y = torch::empty({batch, row}, options);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
options = torch::TensorOptions().dtype(torch::kInt32).device(W.device());
|
||||
at::Tensor quant_X = torch::empty({batch, padded / 32 * 9}, options);
|
||||
quantize_row_q8_1_cuda((half*)X.data_ptr(), (void*)quant_X.data_ptr(), col,
|
||||
batch, stream);
|
||||
VLLM_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_mul_mat_a8", [&] {
|
||||
quantize_row_q8_1_cuda((scalar_t*)X.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
col, batch, stream);
|
||||
|
||||
switch (type) {
|
||||
case 2:
|
||||
ggml_mul_mat_q4_0_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
|
||||
col, row, batch, padded, row, stream);
|
||||
break;
|
||||
case 3:
|
||||
ggml_mul_mat_q4_1_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
|
||||
col, row, batch, padded, row, stream);
|
||||
break;
|
||||
case 6:
|
||||
ggml_mul_mat_q5_0_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
|
||||
col, row, batch, padded, row, stream);
|
||||
break;
|
||||
case 7:
|
||||
ggml_mul_mat_q5_1_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
|
||||
col, row, batch, padded, row, stream);
|
||||
break;
|
||||
case 8:
|
||||
ggml_mul_mat_q8_0_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
|
||||
col, row, batch, padded, row, stream);
|
||||
break;
|
||||
case 10:
|
||||
ggml_mul_mat_q2_K_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
|
||||
col, row, batch, padded, row, stream);
|
||||
break;
|
||||
case 11:
|
||||
ggml_mul_mat_q3_K_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
|
||||
col, row, batch, padded, row, stream);
|
||||
break;
|
||||
case 12:
|
||||
ggml_mul_mat_q4_K_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
|
||||
col, row, batch, padded, row, stream);
|
||||
break;
|
||||
case 13:
|
||||
ggml_mul_mat_q5_K_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
|
||||
col, row, batch, padded, row, stream);
|
||||
break;
|
||||
case 14:
|
||||
ggml_mul_mat_q6_K_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
|
||||
col, row, batch, padded, row, stream);
|
||||
break;
|
||||
}
|
||||
switch (type) {
|
||||
case 2:
|
||||
ggml_mul_mat_q4_0_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream);
|
||||
break;
|
||||
case 3:
|
||||
ggml_mul_mat_q4_1_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream);
|
||||
break;
|
||||
case 6:
|
||||
ggml_mul_mat_q5_0_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream);
|
||||
break;
|
||||
case 7:
|
||||
ggml_mul_mat_q5_1_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream);
|
||||
break;
|
||||
case 8:
|
||||
ggml_mul_mat_q8_0_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream);
|
||||
break;
|
||||
case 10:
|
||||
ggml_mul_mat_q2_K_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream);
|
||||
break;
|
||||
case 11:
|
||||
ggml_mul_mat_q3_K_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream);
|
||||
break;
|
||||
case 12:
|
||||
ggml_mul_mat_q4_K_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream);
|
||||
break;
|
||||
case 13:
|
||||
ggml_mul_mat_q5_K_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream);
|
||||
break;
|
||||
case 14:
|
||||
ggml_mul_mat_q6_K_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream);
|
||||
break;
|
||||
}
|
||||
});
|
||||
return Y;
|
||||
}
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
// copied from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmq.cu
|
||||
template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
|
||||
template <typename scalar_t, int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
|
||||
allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
|
||||
static __device__ __forceinline__ void mul_mat_q(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst,
|
||||
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
const block_q_t * x = (const block_q_t *) vx;
|
||||
@ -38,7 +38,7 @@ static __device__ __forceinline__ void mul_mat_q(
|
||||
threadIdx.y, nrows_x-row_x_0-1, threadIdx.x, blocks_per_row_x);
|
||||
|
||||
#pragma unroll
|
||||
for (int ir = 0; ir < qr; ++ir) {
|
||||
for (int ir = 0; ir < qr && ib0 + ir * blocks_per_warp/qr < blocks_per_row_x; ++ir) {
|
||||
const int kqs = ir*WARP_SIZE_GGUF + threadIdx.x;
|
||||
const int kbxd = kqs / QI8_1;
|
||||
|
||||
@ -98,7 +98,7 @@ static __device__ __forceinline__ void mul_mat_q(
|
||||
if (row_dst >= nrows_dst) {
|
||||
continue;
|
||||
}
|
||||
dst[col_dst*nrows_dst + row_dst] = __float2half(sum[i/WARP_SIZE_GGUF][j/nwarps]);
|
||||
dst[col_dst*nrows_dst + row_dst] = sum[i/WARP_SIZE_GGUF][j/nwarps];
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -113,24 +113,25 @@ static __device__ __forceinline__ void mul_mat_q(
|
||||
#define NWARPS_Q4_0 4
|
||||
#endif
|
||||
|
||||
template <bool need_check> static __global__ void
|
||||
template<typename scalar_t, bool need_check> static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q4_0, 2)
|
||||
#endif
|
||||
mul_mat_q4_0(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst,
|
||||
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
||||
const int mmq_x = MMQ_X_Q4_0;
|
||||
const int mmq_y = MMQ_Y_Q4_0;
|
||||
const int nwarps = NWARPS_Q4_0;
|
||||
|
||||
mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps, allocate_tiles_q4_0<mmq_y>,
|
||||
mul_mat_q<scalar_t, QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps, allocate_tiles_q4_0<mmq_y>,
|
||||
load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
static void ggml_mul_mat_q4_0_q8_1_cuda(
|
||||
const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
|
||||
const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||
|
||||
int mmq_x = MMQ_X_Q4_0;
|
||||
@ -144,11 +145,11 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
const bool need_check = false;
|
||||
mul_mat_q4_0<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
mul_mat_q4_0<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
} else {
|
||||
const bool need_check = true;
|
||||
mul_mat_q4_0<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
mul_mat_q4_0<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
}
|
||||
@ -163,24 +164,25 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(
|
||||
#define NWARPS_Q4_1 4
|
||||
#endif
|
||||
|
||||
template <bool need_check> static __global__ void
|
||||
template<typename scalar_t, bool need_check> static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q4_1, 2)
|
||||
#endif
|
||||
mul_mat_q4_1(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst,
|
||||
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
||||
const int mmq_x = MMQ_X_Q4_1;
|
||||
const int mmq_y = MMQ_Y_Q4_1;
|
||||
const int nwarps = NWARPS_Q4_1;
|
||||
|
||||
mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps, allocate_tiles_q4_1<mmq_y>,
|
||||
mul_mat_q<scalar_t, QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps, allocate_tiles_q4_1<mmq_y>,
|
||||
load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
static void ggml_mul_mat_q4_1_q8_1_cuda(
|
||||
const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
|
||||
const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||
|
||||
int mmq_x = MMQ_X_Q4_1;
|
||||
@ -194,11 +196,11 @@ static void ggml_mul_mat_q4_1_q8_1_cuda(
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
const bool need_check = false;
|
||||
mul_mat_q4_1<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
mul_mat_q4_1<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
} else {
|
||||
const bool need_check = true;
|
||||
mul_mat_q4_1<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
mul_mat_q4_1<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
}
|
||||
@ -213,24 +215,25 @@ static void ggml_mul_mat_q4_1_q8_1_cuda(
|
||||
#define NWARPS_Q5_0 4
|
||||
#endif
|
||||
|
||||
template <bool need_check> static __global__ void
|
||||
template<typename scalar_t, bool need_check> static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q5_0, 2)
|
||||
#endif
|
||||
mul_mat_q5_0(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst,
|
||||
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
||||
const int mmq_x = MMQ_X_Q5_0;
|
||||
const int mmq_y = MMQ_Y_Q5_0;
|
||||
const int nwarps = NWARPS_Q5_0;
|
||||
|
||||
mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps, allocate_tiles_q5_0<mmq_y>,
|
||||
mul_mat_q<scalar_t, QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps, allocate_tiles_q5_0<mmq_y>,
|
||||
load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
static void ggml_mul_mat_q5_0_q8_1_cuda(
|
||||
const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
|
||||
const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||
|
||||
const int mmq_x = MMQ_X_Q5_0;
|
||||
@ -244,11 +247,11 @@ static void ggml_mul_mat_q5_0_q8_1_cuda(
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
const bool need_check = false;
|
||||
mul_mat_q5_0<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
mul_mat_q5_0<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
} else {
|
||||
const bool need_check = true;
|
||||
mul_mat_q5_0<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
mul_mat_q5_0<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
}
|
||||
@ -263,24 +266,25 @@ static void ggml_mul_mat_q5_0_q8_1_cuda(
|
||||
#define NWARPS_Q5_1 4
|
||||
#endif
|
||||
|
||||
template <bool need_check> static __global__ void
|
||||
template<typename scalar_t, bool need_check> static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q5_1, 2)
|
||||
#endif
|
||||
mul_mat_q5_1(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst,
|
||||
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
||||
const int mmq_x = MMQ_X_Q5_1;
|
||||
const int mmq_y = MMQ_Y_Q5_1;
|
||||
const int nwarps = NWARPS_Q5_1;
|
||||
|
||||
mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps, allocate_tiles_q5_1<mmq_y>,
|
||||
mul_mat_q<scalar_t, QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps, allocate_tiles_q5_1<mmq_y>,
|
||||
load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
static void ggml_mul_mat_q5_1_q8_1_cuda(
|
||||
const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
|
||||
const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const int mmq_x = MMQ_X_Q5_1;
|
||||
const int mmq_y = MMQ_Y_Q5_1;
|
||||
@ -293,11 +297,11 @@ static void ggml_mul_mat_q5_1_q8_1_cuda(
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
const bool need_check = false;
|
||||
mul_mat_q5_1<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
mul_mat_q5_1<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
} else {
|
||||
const bool need_check = true;
|
||||
mul_mat_q5_1<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
mul_mat_q5_1<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
}
|
||||
@ -312,24 +316,25 @@ static void ggml_mul_mat_q5_1_q8_1_cuda(
|
||||
#define NWARPS_Q8_0 4
|
||||
#endif
|
||||
|
||||
template <bool need_check> static __global__ void
|
||||
template<typename scalar_t, bool need_check> static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q8_0, 2)
|
||||
#endif
|
||||
mul_mat_q8_0(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst,
|
||||
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
||||
const int mmq_x = MMQ_X_Q8_0;
|
||||
const int mmq_y = MMQ_Y_Q8_0;
|
||||
const int nwarps = NWARPS_Q8_0;
|
||||
|
||||
mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps, allocate_tiles_q8_0<mmq_y>,
|
||||
mul_mat_q<scalar_t, QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps, allocate_tiles_q8_0<mmq_y>,
|
||||
load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
static void ggml_mul_mat_q8_0_q8_1_cuda(
|
||||
const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
|
||||
const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const int mmq_x = MMQ_X_Q8_0;
|
||||
const int mmq_y = MMQ_Y_Q8_0;
|
||||
@ -342,11 +347,11 @@ static void ggml_mul_mat_q8_0_q8_1_cuda(
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
const bool need_check = false;
|
||||
mul_mat_q8_0<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
mul_mat_q8_0<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
} else {
|
||||
const bool need_check = true;
|
||||
mul_mat_q8_0<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
mul_mat_q8_0<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
}
|
||||
@ -361,24 +366,25 @@ static void ggml_mul_mat_q8_0_q8_1_cuda(
|
||||
#define NWARPS_Q2_K 4
|
||||
#endif
|
||||
|
||||
template <bool need_check> static __global__ void
|
||||
template<typename scalar_t, bool need_check> static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q2_K, 2)
|
||||
#endif
|
||||
mul_mat_q2_K(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst,
|
||||
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
||||
const int mmq_x = MMQ_X_Q2_K;
|
||||
const int mmq_y = MMQ_Y_Q2_K;
|
||||
const int nwarps = NWARPS_Q2_K;
|
||||
|
||||
mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps, allocate_tiles_q2_K<mmq_y>,
|
||||
mul_mat_q<scalar_t, QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps, allocate_tiles_q2_K<mmq_y>,
|
||||
load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
static void ggml_mul_mat_q2_K_q8_1_cuda(
|
||||
const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
|
||||
const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const int mmq_x = MMQ_X_Q2_K;
|
||||
const int mmq_y = MMQ_Y_Q2_K;
|
||||
@ -391,11 +397,11 @@ static void ggml_mul_mat_q2_K_q8_1_cuda(
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
const bool need_check = false;
|
||||
mul_mat_q2_K<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
mul_mat_q2_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
} else {
|
||||
const bool need_check = true;
|
||||
mul_mat_q2_K<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
mul_mat_q2_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
}
|
||||
@ -410,25 +416,26 @@ static void ggml_mul_mat_q2_K_q8_1_cuda(
|
||||
#define NWARPS_Q3_K 4
|
||||
#endif
|
||||
|
||||
template <bool need_check> static __global__ void
|
||||
template<typename scalar_t, bool need_check> static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q3_K, 2)
|
||||
#endif
|
||||
mul_mat_q3_K(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst,
|
||||
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
const int mmq_x = MMQ_X_Q3_K;
|
||||
const int mmq_y = MMQ_Y_Q3_K;
|
||||
const int nwarps = NWARPS_Q3_K;
|
||||
|
||||
mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps, allocate_tiles_q3_K<mmq_y>,
|
||||
mul_mat_q<scalar_t, QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps, allocate_tiles_q3_K<mmq_y>,
|
||||
load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
static void ggml_mul_mat_q3_K_q8_1_cuda(
|
||||
const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
|
||||
const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||
|
||||
const int mmq_x = MMQ_X_Q3_K;
|
||||
@ -442,11 +449,11 @@ static void ggml_mul_mat_q3_K_q8_1_cuda(
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
const bool need_check = false;
|
||||
mul_mat_q3_K<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
mul_mat_q3_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
} else {
|
||||
const bool need_check = true;
|
||||
mul_mat_q3_K<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
mul_mat_q3_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
}
|
||||
@ -461,24 +468,25 @@ static void ggml_mul_mat_q3_K_q8_1_cuda(
|
||||
#define NWARPS_Q4_K 4
|
||||
#endif
|
||||
|
||||
template <bool need_check> static __global__ void
|
||||
template<typename scalar_t, bool need_check> static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q4_K, 2)
|
||||
#endif
|
||||
mul_mat_q4_K(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst,
|
||||
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
||||
const int mmq_x = MMQ_X_Q4_K;
|
||||
const int mmq_y = MMQ_Y_Q4_K;
|
||||
const int nwarps = NWARPS_Q4_K;
|
||||
|
||||
mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps, allocate_tiles_q4_K<mmq_y>,
|
||||
mul_mat_q<scalar_t, QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps, allocate_tiles_q4_K<mmq_y>,
|
||||
load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
static void ggml_mul_mat_q4_K_q8_1_cuda(
|
||||
const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
|
||||
const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const int mmq_x = MMQ_X_Q4_K;
|
||||
const int mmq_y = MMQ_Y_Q4_K;
|
||||
@ -491,11 +499,11 @@ static void ggml_mul_mat_q4_K_q8_1_cuda(
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
const bool need_check = false;
|
||||
mul_mat_q4_K<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
mul_mat_q4_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
} else {
|
||||
const bool need_check = true;
|
||||
mul_mat_q4_K<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
mul_mat_q4_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
}
|
||||
@ -510,24 +518,25 @@ static void ggml_mul_mat_q4_K_q8_1_cuda(
|
||||
#define NWARPS_Q5_K 4
|
||||
#endif
|
||||
|
||||
template <bool need_check> static __global__ void
|
||||
template<typename scalar_t, bool need_check> static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q5_K, 2)
|
||||
#endif
|
||||
mul_mat_q5_K(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst,
|
||||
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
||||
const int mmq_x = MMQ_X_Q5_K;
|
||||
const int mmq_y = MMQ_Y_Q5_K;
|
||||
const int nwarps = NWARPS_Q5_K;
|
||||
|
||||
mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps, allocate_tiles_q5_K<mmq_y>,
|
||||
mul_mat_q<scalar_t, QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps, allocate_tiles_q5_K<mmq_y>,
|
||||
load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
static void ggml_mul_mat_q5_K_q8_1_cuda(
|
||||
const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
|
||||
const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||
|
||||
const int mmq_x = MMQ_X_Q5_K;
|
||||
@ -541,11 +550,11 @@ static void ggml_mul_mat_q5_K_q8_1_cuda(
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
const bool need_check = false;
|
||||
mul_mat_q5_K<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
mul_mat_q5_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
} else {
|
||||
const bool need_check = true;
|
||||
mul_mat_q5_K<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
mul_mat_q5_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
}
|
||||
@ -560,24 +569,25 @@ static void ggml_mul_mat_q5_K_q8_1_cuda(
|
||||
#define NWARPS_Q6_K 4
|
||||
#endif
|
||||
|
||||
template <bool need_check> static __global__ void
|
||||
template<typename scalar_t, bool need_check> static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q6_K, 2)
|
||||
#endif
|
||||
mul_mat_q6_K(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst,
|
||||
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
||||
const int mmq_x = MMQ_X_Q6_K;
|
||||
const int mmq_y = MMQ_Y_Q6_K;
|
||||
const int nwarps = NWARPS_Q6_K;
|
||||
|
||||
mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps, allocate_tiles_q6_K<mmq_y>,
|
||||
mul_mat_q<scalar_t, QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps, allocate_tiles_q6_K<mmq_y>,
|
||||
load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
static void ggml_mul_mat_q6_K_q8_1_cuda(
|
||||
const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
|
||||
const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const int mmq_x = MMQ_X_Q6_K;
|
||||
const int mmq_y = MMQ_Y_Q6_K;
|
||||
@ -590,11 +600,11 @@ static void ggml_mul_mat_q6_K_q8_1_cuda(
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
const bool need_check = false;
|
||||
mul_mat_q6_K<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
mul_mat_q6_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
} else {
|
||||
const bool need_check = true;
|
||||
mul_mat_q6_K<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
mul_mat_q6_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
// copied and adapted from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmvq.cu
|
||||
template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
|
||||
static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, const int ncols, const int nrows) {
|
||||
template <typename scalar_t, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
|
||||
static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst, const int ncols, const int nrows) {
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
|
||||
if (row >= nrows) {
|
||||
@ -33,158 +33,177 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void *
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
dst[row] = __float2half(tmp);
|
||||
dst[row] = tmp;
|
||||
}
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
|
||||
mul_mat_vec_q<scalar_t, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
|
||||
mul_mat_vec_q<scalar_t, QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
|
||||
mul_mat_vec_q<scalar_t, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
|
||||
mul_mat_vec_q<scalar_t, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
|
||||
mul_mat_vec_q<scalar_t, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
|
||||
mul_mat_vec_q<scalar_t, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
|
||||
mul_mat_vec_q<scalar_t, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
|
||||
mul_mat_vec_q<scalar_t, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
|
||||
mul_mat_vec_q<scalar_t, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
|
||||
mul_mat_vec_q<scalar_t, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq2_xxs_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_iq2_xxs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
|
||||
mul_mat_vec_q<scalar_t, QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq2_xs_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_iq2_xs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1>
|
||||
mul_mat_vec_q<scalar_t, QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq2_s_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_iq2_s_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI2_S, block_iq2_s, 1, vec_dot_iq2_s_q8_1>
|
||||
mul_mat_vec_q<scalar_t, QK_K, QI2_S, block_iq2_s, 1, vec_dot_iq2_s_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq3_xxs_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_iq3_xxs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
|
||||
mul_mat_vec_q<scalar_t, QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq1_s_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_iq1_s_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI1_S, block_iq1_s, 1, vec_dot_iq1_s_q8_1>
|
||||
mul_mat_vec_q<scalar_t, QK_K, QI1_S, block_iq1_s, 1, vec_dot_iq1_s_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq1_m_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_iq1_m_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI1_M, block_iq1_m, 1, vec_dot_iq1_m_q8_1>
|
||||
mul_mat_vec_q<scalar_t, QK_K, QI1_M, block_iq1_m, 1, vec_dot_iq1_m_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq4_nl_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_iq4_nl_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK4_NL, QI4_NL, block_iq4_nl, VDR_Q4_0_Q8_1_MMVQ, vec_dot_iq4_nl_q8_1>
|
||||
mul_mat_vec_q<scalar_t, QK4_NL, QI4_NL, block_iq4_nl, VDR_Q4_0_Q8_1_MMVQ, vec_dot_iq4_nl_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq4_xs_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_iq4_xs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI4_XS, block_iq4_xs, 1, vec_dot_iq4_xs_q8_1>
|
||||
mul_mat_vec_q<scalar_t, QK_K, QI4_XS, block_iq4_xs, 1, vec_dot_iq4_xs_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq3_s_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_iq3_s_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI3_XS, block_iq3_s, 1, vec_dot_iq3_s_q8_1>
|
||||
mul_mat_vec_q<scalar_t, QK_K, QI3_XS, block_iq3_s, 1, vec_dot_iq3_s_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
@ -22,7 +22,7 @@ def get_gguf_sample_tensors(
|
||||
return GGUFReader(sample_file).tensors
|
||||
|
||||
|
||||
DTYPES = [torch.half]
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float32]
|
||||
# Hidden_size for testing, must match the sample file in HF repo,
|
||||
# we have `hidden_size = 256, 1024` for test in HF repo currently.
|
||||
HIDDEN_SIZES = [256, 1024]
|
||||
@ -52,7 +52,7 @@ QUANT_TYPES = [
|
||||
|
||||
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("dtype", [torch.half])
|
||||
@pytest.mark.parametrize("quant_type", QUANT_TYPES)
|
||||
@torch.inference_mode()
|
||||
def test_dequantize(hidden_size: int, dtype: torch.dtype,
|
||||
@ -122,7 +122,13 @@ def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype,
|
||||
ref_output = x @ weight.T
|
||||
|
||||
qweight = torch.tensor(tensor.data, device="cuda")
|
||||
output = ops.ggml_mul_mat_a8(qweight, x, quant_type,
|
||||
qweight.shape[0]).to(dtype)
|
||||
|
||||
torch.testing.assert_close(output, ref_output, atol=1, rtol=1e-1)
|
||||
output = ops.ggml_mul_mat_a8(qweight, x, quant_type, qweight.shape[0])
|
||||
atols = {torch.half: 1, torch.bfloat16: 1.5, torch.float: 1.2}
|
||||
# test matrix has inputs centered around 0 and lower precision from
|
||||
# bfloat16 tends to accumulate and can greatly inflate rtol
|
||||
# since outputs are also very close to 0
|
||||
rtols = {torch.half: 1e-1, torch.bfloat16: 1e4, torch.float: 2e1}
|
||||
torch.testing.assert_close(output,
|
||||
ref_output,
|
||||
atol=atols[dtype],
|
||||
rtol=rtols[dtype])
|
||||
|
||||
@ -436,7 +436,7 @@ if hasattr(torch.ops._C, "ggml_dequantize"):
|
||||
quant_type: int,
|
||||
row: torch.SymInt,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty((1, row), dtype=torch.float16, device=W.device)
|
||||
return torch.empty((1, row), dtype=X.dtype, device=W.device)
|
||||
|
||||
@register_fake("_C::ggml_mul_mat_a8")
|
||||
def _ggml_mul_mat_a8_fake(
|
||||
@ -446,7 +446,7 @@ if hasattr(torch.ops._C, "ggml_dequantize"):
|
||||
row: torch.SymInt,
|
||||
) -> torch.Tensor:
|
||||
batch = X.size(0)
|
||||
return torch.empty((batch, row), dtype=torch.float16, device=W.device)
|
||||
return torch.empty((batch, row), dtype=X.dtype, device=W.device)
|
||||
|
||||
|
||||
# cutlass
|
||||
|
||||
@ -32,7 +32,7 @@ class GGUFConfig(QuantizationConfig):
|
||||
return "gguf"
|
||||
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
return [torch.half]
|
||||
return [torch.half, torch.bfloat16, torch.float32]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
@ -134,6 +134,7 @@ class GGUFLinearMethod(LinearMethodBase):
|
||||
output_partition_sizes: List[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
self.params_dtype = params_dtype
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
|
||||
tensor_shape = (output_size_per_partition, input_size_per_partition)
|
||||
@ -326,7 +327,7 @@ class GGUFEmbeddingMethod(GGUFLinearMethod):
|
||||
x_flat = x.flatten()
|
||||
quant = torch.index_select(qweight, dim=0, index=x_flat)
|
||||
dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size,
|
||||
x_flat.shape[0])
|
||||
x_flat.shape[0]).to(self.params_dtype)
|
||||
return dequant.view(*x.shape, hidden_size)
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user