mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 13:37:23 +08:00
[Core] Support loading GGUF model (#5191)
Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
parent
ef527be06c
commit
360bd67cf0
5
.github/workflows/clang-format.yml
vendored
5
.github/workflows/clang-format.yml
vendored
@ -30,6 +30,11 @@ jobs:
|
||||
run: |
|
||||
EXCLUDES=(
|
||||
'csrc/moe/topk_softmax_kernels.cu'
|
||||
'csrc/quantization/gguf/ggml-common.h'
|
||||
'csrc/quantization/gguf/dequantize.cuh'
|
||||
'csrc/quantization/gguf/vecdotq.cuh'
|
||||
'csrc/quantization/gguf/mmq.cuh'
|
||||
'csrc/quantization/gguf/mmvq.cuh'
|
||||
)
|
||||
find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \
|
||||
| grep -vFf <(printf "%s\n" "${EXCLUDES[@]}") \
|
||||
|
||||
@ -208,6 +208,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
|
||||
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
|
||||
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu"
|
||||
"csrc/quantization/gguf/gguf_kernel.cu"
|
||||
"csrc/quantization/fp8/fp8_marlin.cu"
|
||||
"csrc/custom_all_reduce.cu"
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
|
||||
|
||||
@ -107,6 +107,15 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
|
||||
int64_t size_n, int64_t num_bits);
|
||||
|
||||
torch::Tensor ggml_dequantize(torch::Tensor W, int8_t type, int64_t m,
|
||||
int64_t n);
|
||||
|
||||
torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X, int8_t type,
|
||||
int64_t row);
|
||||
|
||||
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int8_t type,
|
||||
int64_t row);
|
||||
|
||||
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_scales, torch::Tensor& workspace,
|
||||
int64_t num_bits, int64_t size_m, int64_t size_n,
|
||||
|
||||
531
csrc/quantization/gguf/dequantize.cuh
Normal file
531
csrc/quantization/gguf/dequantize.cuh
Normal file
@ -0,0 +1,531 @@
|
||||
// copied and adapted from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/convert.cu
|
||||
// Dequant functions
|
||||
static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||
const block_q4_0 * x = (const block_q4_0 *) vx;
|
||||
|
||||
const dfloat d = x[ib].d;
|
||||
|
||||
const int vui = x[ib].qs[iqs];
|
||||
|
||||
v.x = __int2half_rn(vui & 0xF);
|
||||
v.y = __int2half_rn(vui >> 4);
|
||||
|
||||
v = __hsub2(v, __floats2half2_rn(8.0f, 8.0f));
|
||||
v = __hmul2(v, {d, d});
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||
const block_q4_1 * x = (const block_q4_1 *) vx;
|
||||
|
||||
const dfloat d = __low2half(x[ib].dm);
|
||||
const dfloat m = __high2half(x[ib].dm);
|
||||
|
||||
const int vui = x[ib].qs[iqs];
|
||||
|
||||
v.x = __int2half_rn(vui & 0xF);
|
||||
v.y = __int2half_rn(vui >> 4);
|
||||
|
||||
v = __hmul2(v, {d, d});
|
||||
v = __hadd2(v, {m, m});
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||
const block_q5_0 * x = (const block_q5_0 *) vx;
|
||||
|
||||
const dfloat d = x[ib].d;
|
||||
|
||||
uint32_t qh;
|
||||
memcpy(&qh, x[ib].qh, sizeof(qh));
|
||||
|
||||
const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
|
||||
const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
|
||||
|
||||
v.x = __int2half_rn((x[ib].qs[iqs] & 0xf) | xh_0);
|
||||
v.y = __int2half_rn((x[ib].qs[iqs] >> 4) | xh_1);
|
||||
|
||||
v = __hsub2(v, __floats2half2_rn(16.0f, 16.0f));
|
||||
v = __hmul2(v, {d, d});
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||
const block_q5_1 * x = (const block_q5_1 *) vx;
|
||||
|
||||
const dfloat d = __low2half(x[ib].dm);
|
||||
const dfloat m = __high2half(x[ib].dm);
|
||||
|
||||
uint32_t qh;
|
||||
memcpy(&qh, x[ib].qh, sizeof(qh));
|
||||
|
||||
const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
|
||||
const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
|
||||
|
||||
v.x = __int2half_rn((x[ib].qs[iqs] & 0xf) | xh_0);
|
||||
v.y = __int2half_rn((x[ib].qs[iqs] >> 4) | xh_1);
|
||||
|
||||
v = __hmul2(v, {d, d});
|
||||
v = __hadd2(v, {m, m});
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||
const block_q8_0 * x = (const block_q8_0 *) vx;
|
||||
|
||||
const dfloat d = x[ib].d;
|
||||
|
||||
v.x = __int2half_rn(x[ib].qs[iqs + 0]);
|
||||
v.y = __int2half_rn(x[ib].qs[iqs + 1]);
|
||||
|
||||
v = __hmul2(v, {d, d});
|
||||
}
|
||||
|
||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||
static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
|
||||
const int i = 2*(blockDim.x*blockIdx.x + threadIdx.x);
|
||||
|
||||
if (i >= k) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int ib = i/qk; // block index
|
||||
const int iqs = (i%qk)/qr; // quant index
|
||||
const int iybs = i - i%qk; // y block start index
|
||||
const int y_offset = qr == 1 ? 1 : qk/2;
|
||||
|
||||
// dequantize
|
||||
dfloat2 v;
|
||||
dequantize_kernel(vx, ib, iqs, v);
|
||||
|
||||
y[iybs + iqs + 0] = v.x;
|
||||
y[iybs + iqs + y_offset] = v.y;
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
|
||||
const int i = blockIdx.x;
|
||||
const block_q2_K * x = (const block_q2_K *) vx;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const int n = tid/32;
|
||||
const int l = tid - 32*n;
|
||||
const int is = 8*n + l/16;
|
||||
|
||||
const uint8_t q = x[i].qs[32*n + l];
|
||||
dst_t * y = yy + i*QK_K + 128*n;
|
||||
|
||||
half dall = __low2half(x[i].dm);
|
||||
half dmin = __high2half(x[i].dm);
|
||||
y[l+ 0] = __hsub(__hmul(dall, __int2half_rn((x[i].scales[is+0] & 0xF) * ((q >> 0) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+0] >> 4)));
|
||||
y[l+32] = __hsub(__hmul(dall, __int2half_rn((x[i].scales[is+2] & 0xF) * ((q >> 2) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+2] >> 4)));
|
||||
y[l+64] = __hsub(__hmul(dall, __int2half_rn((x[i].scales[is+4] & 0xF) * ((q >> 4) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+4] >> 4)));
|
||||
y[l+96] = __hsub(__hmul(dall, __int2half_rn((x[i].scales[is+6] & 0xF) * ((q >> 6) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+6] >> 4)));
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
|
||||
const int i = blockIdx.x;
|
||||
const block_q3_K * x = (const block_q3_K *) vx;
|
||||
|
||||
const int r = threadIdx.x/4;
|
||||
const int tid = r/2;
|
||||
const int is0 = r%2;
|
||||
const int l0 = 16*is0 + 4*(threadIdx.x%4);
|
||||
const int n = tid / 4;
|
||||
const int j = tid - 4*n;
|
||||
|
||||
uint8_t m = 1 << (4*n + j);
|
||||
int is = 8*n + 2*j + is0;
|
||||
int shift = 2*j;
|
||||
|
||||
int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :
|
||||
is < 8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) :
|
||||
is < 12 ? (x[i].scales[is-8] >> 4) | (((x[i].scales[is+0] >> 4) & 3) << 4) :
|
||||
(x[i].scales[is-8] >> 4) | (((x[i].scales[is-4] >> 6) & 3) << 4);
|
||||
half d_all = x[i].d;
|
||||
half dl = __hmul(d_all, __int2half_rn(us - 32));
|
||||
|
||||
dst_t * y = yy + i*QK_K + 128*n + 32*j;
|
||||
const uint8_t * q = x[i].qs + 32*n;
|
||||
const uint8_t * hm = x[i].hmask;
|
||||
|
||||
for (int l = l0; l < l0+4; ++l) y[l] = __hmul(dl, __int2half_rn((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4)));
|
||||
}
|
||||
|
||||
static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
|
||||
if (j < 4) {
|
||||
d = q[j] & 63; m = q[j + 4] & 63;
|
||||
} else {
|
||||
d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
|
||||
m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
const block_q4_K * x = (const block_q4_K *) vx;
|
||||
|
||||
const int i = blockIdx.x;
|
||||
|
||||
// assume 32 threads
|
||||
const int tid = threadIdx.x;
|
||||
const int il = tid/8;
|
||||
const int ir = tid%8;
|
||||
const int is = 2*il;
|
||||
const int n = 4;
|
||||
|
||||
dst_t * y = yy + i*QK_K + 64*il + n*ir;
|
||||
|
||||
const half dall = __low2half(x[i].dm);
|
||||
const half dmin = __high2half(x[i].dm);
|
||||
|
||||
const uint8_t * q = x[i].qs + 32*il + n*ir;
|
||||
|
||||
uint8_t sc, m;
|
||||
get_scale_min_k4(is + 0, x[i].scales, sc, m);
|
||||
const half d1 = __hmul(dall, __int2half_rn(sc));
|
||||
const half m1 = __hmul(dmin, __int2half_rn(m));
|
||||
get_scale_min_k4(is + 1, x[i].scales, sc, m);
|
||||
const half d2 = __hmul(dall, __int2half_rn(sc));
|
||||
const half m2 = __hmul(dmin, __int2half_rn(m));
|
||||
for (int l = 0; l < n; ++l) {
|
||||
y[l + 0] = __hsub(__hmul(d1, __int2half_rn(q[l] & 0xF)), m1);
|
||||
y[l +32] = __hsub(__hmul(d2, __int2half_rn(q[l] >> 4)), m2);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
const block_q5_K * x = (const block_q5_K *) vx;
|
||||
|
||||
const int i = blockIdx.x;
|
||||
|
||||
// assume 64 threads - this is very slightly better than the one below
|
||||
const int tid = threadIdx.x;
|
||||
const int il = tid/16; // il is in 0...3
|
||||
const int ir = tid%16; // ir is in 0...15
|
||||
const int is = 2*il; // is is in 0...6
|
||||
|
||||
dst_t * y = yy + i*QK_K + 64*il + 2*ir;
|
||||
|
||||
const half dall = __low2half(x[i].dm);
|
||||
const half dmin = __high2half(x[i].dm);
|
||||
|
||||
const uint8_t * ql = x[i].qs + 32*il + 2*ir;
|
||||
const uint8_t * qh = x[i].qh + 2*ir;
|
||||
|
||||
uint8_t sc, m;
|
||||
get_scale_min_k4(is + 0, x[i].scales, sc, m);
|
||||
const half d1 = __hmul(dall, __int2half_rn(sc)); const half m1 = __hmul(dmin, __int2half_rn(m));
|
||||
get_scale_min_k4(is + 1, x[i].scales, sc, m);
|
||||
const half d2 = __hmul(dall, __int2half_rn(sc)); const half m2 = __hmul(dmin, __int2half_rn(m));
|
||||
|
||||
uint8_t hm = 1 << (2*il);
|
||||
y[ 0] = __hsub(__hmul(d1, __int2half_rn((ql[0] & 0xF) + (qh[0] & hm ? 16 : 0))), m1);
|
||||
y[ 1] = __hsub(__hmul(d1, __int2half_rn((ql[1] & 0xF) + (qh[1] & hm ? 16 : 0))), m1);
|
||||
hm <<= 1;
|
||||
y[32] = __hsub(__hmul(d2, __int2half_rn((ql[0] >> 4) + (qh[0] & hm ? 16 : 0))), m2);
|
||||
y[33] = __hsub(__hmul(d2, __int2half_rn((ql[1] >> 4) + (qh[1] & hm ? 16 : 0))), m2);
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
const block_q6_K * x = (const block_q6_K *) vx;
|
||||
|
||||
const int i = blockIdx.x;
|
||||
|
||||
// assume 64 threads - this is very slightly better than the one below
|
||||
const int tid = threadIdx.x;
|
||||
const int ip = tid/32; // ip is 0 or 1
|
||||
const int il = tid - 32*ip; // 0...32
|
||||
const int is = 8*ip + il/16;
|
||||
|
||||
dst_t * y = yy + i*QK_K + 128*ip + il;
|
||||
|
||||
const half d = x[i].d;
|
||||
|
||||
const uint8_t * ql = x[i].ql + 64*ip + il;
|
||||
const uint8_t qh = x[i].qh[32*ip + il];
|
||||
const int8_t * sc = x[i].scales + is;
|
||||
|
||||
y[ 0] = __hmul(d, __int2half_rn(sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32)));
|
||||
y[32] = __hmul(d, __int2half_rn(sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32)));
|
||||
y[64] = __hmul(d, __int2half_rn(sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32)));
|
||||
y[96] = __hmul(d, __int2half_rn(sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32)));
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
|
||||
const int i = blockIdx.x;
|
||||
const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const int il = tid/8; // 0...3
|
||||
const int ib = tid%8; // 0...7
|
||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||
const uint16_t * q2 = x[i].qs + 4*ib;
|
||||
const uint8_t * aux8 = (const uint8_t *)q2;
|
||||
const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[il]);
|
||||
const uint32_t aux32 = q2[2] | (q2[3] << 16);
|
||||
const float d = __half2float(x[i].d) * (0.5f + (aux32 >> 28)) * 0.25f;
|
||||
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
|
||||
for (int j = 0; j < 8; ++j) y[j] = __float2half(d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f));
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
|
||||
const int i = blockIdx.x;
|
||||
const block_iq2_xs * x = (const block_iq2_xs *) vx;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const int il = tid/8; // 0...3
|
||||
const int ib = tid%8; // 0...7
|
||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||
const uint16_t * q2 = x[i].qs + 4*ib;
|
||||
const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
|
||||
const float d = __half2float(x[i].d) * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
|
||||
const uint8_t signs = ksigns_iq2xs[q2[il] >> 9];
|
||||
for (int j = 0; j < 8; ++j) y[j] = __float2half(d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f));
|
||||
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
|
||||
const int i = blockIdx.x;
|
||||
const block_iq2_s * x = (const block_iq2_s *) vx;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const int il = tid/8; // 0...3
|
||||
const int ib = tid%8; // 0...7
|
||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||
const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
|
||||
const float d = __half2float(x[i].d) * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
|
||||
const uint8_t signs = x[i].qs[QK_K/8+4*ib+il];
|
||||
for (int j = 0; j < 8; ++j) y[j] = __float2half(d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f));
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
|
||||
const int i = blockIdx.x;
|
||||
const block_iq3_xxs * x = (const block_iq3_xxs *) vx;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const int il = tid/8; // 0...3
|
||||
const int ib = tid%8; // 0...7
|
||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||
const uint8_t * q3 = x[i].qs + 8*ib;
|
||||
const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib;
|
||||
const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*il+0]);
|
||||
const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*il+1]);
|
||||
const uint32_t aux32 = gas[0] | (gas[1] << 16);
|
||||
const float d = __half2float(x[i].d) * (0.5f + (aux32 >> 28)) * 0.5f;
|
||||
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
y[j+0] = __float2half(d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f));
|
||||
y[j+4] = __float2half(d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f));
|
||||
}
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
|
||||
const int i = blockIdx.x;
|
||||
const block_iq3_s * x = (const block_iq3_s *) vx;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const int il = tid/8; // 0...3
|
||||
const int ib = tid%8; // 0...7
|
||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||
const uint8_t * qs = x[i].qs + 8*ib;
|
||||
const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
|
||||
const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)));
|
||||
const float d = __half2float(x[i].d) * (0.5f + ((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf)) * 0.5f;
|
||||
const uint8_t signs = x[i].signs[4*ib + il];
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
y[j+0] = __float2half(d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f));
|
||||
y[j+4] = __float2half(d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f));
|
||||
}
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
|
||||
const int i = blockIdx.x;
|
||||
const block_iq1_s * x = (const block_iq1_s *) vx;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const int il = tid/8; // 0...3
|
||||
const int ib = tid%8; // 0...7
|
||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||
const int i8 = 4*ib+il;
|
||||
uint8_t h = x[i].scales[i8/2] >> 4*(i8%2);
|
||||
const int8_t * grid = (const int8_t *)(iq1s_grid + (x[i].qs[i8] | ((h & 8) << 5)));
|
||||
const float d = __half2float(x[i].d) * (2*(h & 7) + 1);
|
||||
for (int j = 0; j < 8; ++j) y[j] = __float2half(d * grid[j]);
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
|
||||
const int i = blockIdx.x;
|
||||
const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const int il = tid/8; // 0...3
|
||||
const int ib = tid%8; // 0...7
|
||||
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
|
||||
const uint8_t * q4 = x[ib].qs + 4*il;
|
||||
const float d = __half2float(x[ib].d);
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
y[j+ 0] = __float2half(d * kvalues_iq4nl[q4[j] & 0xf]);
|
||||
y[j+16] = __float2half(d * kvalues_iq4nl[q4[j] >> 4]);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
const int i = blockIdx.x;
|
||||
const block_iq4_xs * x = (const block_iq4_xs *)vx;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const int il = tid/8; // 0...3
|
||||
const int ib = tid%8; // 0...7
|
||||
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
|
||||
const uint8_t * q4 = x[i].qs + 16*ib + 4*il;
|
||||
const float d = __half2float(x[i].d) * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
y[j+ 0] = __float2half(d * kvalues_iq4nl[q4[j] & 0xf]);
|
||||
y[j+16] = __float2half(d * kvalues_iq4nl[q4[j] >> 4]);
|
||||
}
|
||||
}
|
||||
|
||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||
static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE);
|
||||
dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||
const int nb = k / QK_K;
|
||||
dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||
const int nb = k / QK_K;
|
||||
dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||
const int nb = k / QK_K;
|
||||
dequantize_block_q4_K<<<nb, 32, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||
const int nb = k / QK_K;
|
||||
dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||
const int nb = k / QK_K;
|
||||
dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||
const int nb = k / QK_K;
|
||||
dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||
const int nb = k / QK_K;
|
||||
dequantize_block_iq2_xs<<<nb, 32, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static void dequantize_row_iq2_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||
const int nb = k / QK_K;
|
||||
dequantize_block_iq2_s<<<nb, 32, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||
const int nb = k / QK_K;
|
||||
dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||
const int nb = k / QK_K;
|
||||
dequantize_block_iq3_s<<<nb, 32, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||
const int nb = k / QK_K;
|
||||
dequantize_block_iq1_s<<<nb, 32, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||
const int nb = (k + QK_K - 1) / QK_K;
|
||||
dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||
const int nb = (k + QK_K - 1) / QK_K;
|
||||
dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
static to_fp16_cuda_t ggml_get_to_fp16_cuda(int type) {
|
||||
switch (type) {
|
||||
case 2:
|
||||
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
|
||||
case 3:
|
||||
return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
|
||||
case 6:
|
||||
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
|
||||
case 7:
|
||||
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
|
||||
case 8:
|
||||
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
|
||||
case 10:
|
||||
return dequantize_row_q2_K_cuda;
|
||||
case 11:
|
||||
return dequantize_row_q3_K_cuda;
|
||||
case 12:
|
||||
return dequantize_row_q4_K_cuda;
|
||||
case 13:
|
||||
return dequantize_row_q5_K_cuda;
|
||||
case 14:
|
||||
return dequantize_row_q6_K_cuda;
|
||||
case 16:
|
||||
return dequantize_row_iq2_xxs_cuda;
|
||||
case 17:
|
||||
return dequantize_row_iq2_xs_cuda;
|
||||
case 18:
|
||||
return dequantize_row_iq3_xxs_cuda;
|
||||
case 19:
|
||||
return dequantize_row_iq1_s_cuda;
|
||||
case 20:
|
||||
return dequantize_row_iq4_nl_cuda;
|
||||
case 21:
|
||||
return dequantize_row_iq3_s_cuda;
|
||||
case 22:
|
||||
return dequantize_row_iq2_s_cuda;
|
||||
case 23:
|
||||
return dequantize_row_iq4_xs_cuda;
|
||||
default:
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
969
csrc/quantization/gguf/ggml-common.h
Normal file
969
csrc/quantization/gguf/ggml-common.h
Normal file
@ -0,0 +1,969 @@
|
||||
// copied from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-common.h
|
||||
#define QK_K 256
|
||||
#define K_QUANTS_PER_ITERATION 2
|
||||
#define WARP_SIZE 32
|
||||
#define K_SCALE_SIZE 12
|
||||
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
|
||||
#define CUDA_QUANTIZE_BLOCK_SIZE 256
|
||||
#define GGML_CUDA_DMMV_X 32
|
||||
#define GGML_CUDA_MMV_Y 1
|
||||
|
||||
|
||||
// Data Structures
|
||||
// QK = number of values after dequantization
|
||||
// QR = QK / number of values before dequantization
|
||||
// QI = number of 32 bit integers before dequantization
|
||||
|
||||
#define QK4_0 32
|
||||
#define QR4_0 2
|
||||
#define QI4_0 (QK4_0 / (4 * QR4_0))
|
||||
typedef struct {
|
||||
half d; // delta
|
||||
uint8_t qs[QK4_0 / 2]; // nibbles / quants
|
||||
} block_q4_0;
|
||||
|
||||
#define QK4_1 32
|
||||
#define QR4_1 2
|
||||
#define QI4_1 (QK4_1 / (4 * QR4_1))
|
||||
typedef struct {
|
||||
half2 dm; // dm.x = delta, dm.y = min
|
||||
uint8_t qs[QK4_1 / 2]; // nibbles / quants
|
||||
} block_q4_1;
|
||||
|
||||
#define QK5_0 32
|
||||
#define QR5_0 2
|
||||
#define QI5_0 (QK5_0 / (4 * QR5_0))
|
||||
typedef struct {
|
||||
half d; // delta
|
||||
uint8_t qh[4]; // 5-th bit of quants
|
||||
uint8_t qs[QK5_0 / 2]; // nibbles / quants
|
||||
} block_q5_0;
|
||||
|
||||
#define QK5_1 32
|
||||
#define QR5_1 2
|
||||
#define QI5_1 (QK5_1 / (4 * QR5_1))
|
||||
typedef struct {
|
||||
half2 dm; // dm.x = delta, dm.y = min
|
||||
uint8_t qh[4]; // 5-th bit of quants
|
||||
uint8_t qs[QK5_1 / 2]; // nibbles / quants
|
||||
} block_q5_1;
|
||||
|
||||
#define QK8_0 32
|
||||
#define QR8_0 1
|
||||
#define QI8_0 (QK8_0 / (4 * QR8_0))
|
||||
typedef struct {
|
||||
half d; // delta
|
||||
int8_t qs[QK8_0]; // quants
|
||||
} block_q8_0;
|
||||
|
||||
#define QK8_1 32
|
||||
#define QR8_1 1
|
||||
#define QI8_1 (QK8_1 / (4 * QR8_1))
|
||||
typedef struct {
|
||||
half2 ds; // ds.x = delta, ds.y = sum
|
||||
int8_t qs[QK8_0]; // quants
|
||||
} block_q8_1;
|
||||
|
||||
#define QR2_K 4
|
||||
#define QI2_K (QK_K / (4*QR2_K))
|
||||
typedef struct {
|
||||
uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
|
||||
uint8_t qs[QK_K/4]; // quants
|
||||
half2 dm; // super-block scale for quantized scales/mins
|
||||
} block_q2_K;
|
||||
|
||||
#define QR3_K 4
|
||||
#define QI3_K (QK_K / (4*QR3_K))
|
||||
typedef struct {
|
||||
uint8_t hmask[QK_K/8]; // quants - high bit
|
||||
uint8_t qs[QK_K/4]; // quants - low 2 bits
|
||||
uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
|
||||
half d; // super-block scale
|
||||
} block_q3_K;
|
||||
|
||||
#define QR4_K 2
|
||||
#define QI4_K (QK_K / (4*QR4_K))
|
||||
typedef struct {
|
||||
half2 dm; // super-block scale for quantized scales/mins
|
||||
uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
|
||||
uint8_t qs[QK_K/2]; // 4--bit quants
|
||||
} block_q4_K;
|
||||
|
||||
#define QR5_K 2
|
||||
#define QI5_K (QK_K / (4*QR5_K))
|
||||
typedef struct {
|
||||
half2 dm; // super-block scale for quantized scales/mins
|
||||
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
|
||||
uint8_t qh[QK_K/8]; // quants, high bit
|
||||
uint8_t qs[QK_K/2]; // quants, low 4 bits
|
||||
} block_q5_K;
|
||||
|
||||
#define QR6_K 2
|
||||
#define QI6_K (QK_K / (4*QR6_K))
|
||||
typedef struct {
|
||||
uint8_t ql[QK_K/2]; // quants, lower 4 bits
|
||||
uint8_t qh[QK_K/4]; // quants, upper 2 bits
|
||||
int8_t scales[QK_K/16]; // scales
|
||||
half d; // delta
|
||||
} block_q6_K;
|
||||
|
||||
#define QR2_XXS 8
|
||||
#define QI2_XXS (QK_K / (4*QR2_XXS))
|
||||
typedef struct {
|
||||
half d;
|
||||
uint16_t qs[QK_K/8];
|
||||
} block_iq2_xxs;
|
||||
|
||||
#define QR2_XS 8
|
||||
#define QI2_XS (QK_K / (4*QR2_XS))
|
||||
typedef struct {
|
||||
half d;
|
||||
uint16_t qs[QK_K/8];
|
||||
uint8_t scales[QK_K/32];
|
||||
} block_iq2_xs;
|
||||
|
||||
#define QR2_S 8
|
||||
#define QI2_S (QK_K / (4*QR2_S))
|
||||
typedef struct {
|
||||
half d;
|
||||
uint8_t qs[QK_K/4];
|
||||
uint8_t qh[QK_K/32];
|
||||
uint8_t scales[QK_K/32];
|
||||
} block_iq2_s;
|
||||
|
||||
#define QR3_XXS 8
|
||||
#define QI3_XXS (QK_K / (4*QR3_XXS))
|
||||
typedef struct {
|
||||
half d;
|
||||
uint8_t qs[3*(QK_K/8)];
|
||||
} block_iq3_xxs;
|
||||
|
||||
#define QR3_XS 8
|
||||
#define QI3_XS (QK_K / (4*QR3_XS))
|
||||
#define IQ3S_N_SCALE QK_K/64
|
||||
typedef struct {
|
||||
half d;
|
||||
uint8_t qs[QK_K/4];
|
||||
uint8_t qh[QK_K/32];
|
||||
uint8_t signs[QK_K/8];
|
||||
uint8_t scales[IQ3S_N_SCALE];
|
||||
} block_iq3_s;
|
||||
|
||||
#define QR1_S 8
|
||||
#define QI1_S (QK_K / (4*QR1_S))
|
||||
typedef struct {
|
||||
half d;
|
||||
uint8_t qs[QK_K/8];
|
||||
uint8_t scales[QK_K/16];
|
||||
} block_iq1_s;
|
||||
|
||||
#define QK4_NL 32
|
||||
#define QR4_NL 2
|
||||
#define QI4_NL (QK4_NL / (4*QR4_NL))
|
||||
typedef struct {
|
||||
half d;
|
||||
uint8_t qs[QK4_NL/2];
|
||||
} block_iq4_nl;
|
||||
|
||||
#define QR4_XS 8
|
||||
#define QI4_XS (QK_K / (4*QR4_XS))
|
||||
typedef struct {
|
||||
half d;
|
||||
uint16_t scales_h;
|
||||
uint8_t scales_l[QK_K/64];
|
||||
uint8_t qs[QK_K/2];
|
||||
} block_iq4_xs;
|
||||
|
||||
static const __device__ uint64_t iq2xxs_grid[256] = {
|
||||
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
|
||||
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
|
||||
0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
|
||||
0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819,
|
||||
0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b,
|
||||
0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808,
|
||||
0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08,
|
||||
0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b,
|
||||
0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819,
|
||||
0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08,
|
||||
0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808,
|
||||
0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08,
|
||||
0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808,
|
||||
0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808,
|
||||
0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919,
|
||||
0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819,
|
||||
0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08,
|
||||
0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908,
|
||||
0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819,
|
||||
0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808,
|
||||
0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808,
|
||||
0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908,
|
||||
0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808,
|
||||
0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08,
|
||||
0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819,
|
||||
0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819,
|
||||
0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819,
|
||||
0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908,
|
||||
0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19,
|
||||
0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819,
|
||||
0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b,
|
||||
0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808,
|
||||
0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908,
|
||||
0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08,
|
||||
0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08,
|
||||
0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908,
|
||||
0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819,
|
||||
0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808,
|
||||
0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808,
|
||||
0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19,
|
||||
0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819,
|
||||
0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919,
|
||||
0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b,
|
||||
0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08,
|
||||
0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808,
|
||||
0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908,
|
||||
0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b,
|
||||
0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819,
|
||||
0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08,
|
||||
0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08,
|
||||
0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808,
|
||||
0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b,
|
||||
0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b,
|
||||
0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908,
|
||||
0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819,
|
||||
0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808,
|
||||
0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908,
|
||||
0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b,
|
||||
0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808,
|
||||
0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b,
|
||||
0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b,
|
||||
0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808,
|
||||
0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19,
|
||||
0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
|
||||
};
|
||||
|
||||
static const __device__ uint64_t iq2xs_grid[512] = {
|
||||
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
|
||||
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
|
||||
0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
|
||||
0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
|
||||
0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
|
||||
0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808,
|
||||
0x080808082b08082b, 0x080808082b081919, 0x080808082b082b08, 0x080808082b190819,
|
||||
0x080808082b191908, 0x080808082b192b19, 0x080808082b2b0808, 0x0808081908080819,
|
||||
0x0808081908081908, 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808,
|
||||
0x080808190819082b, 0x0808081908191919, 0x0808081908192b08, 0x0808081908192b2b,
|
||||
0x08080819082b0819, 0x08080819082b1908, 0x0808081919080808, 0x080808191908082b,
|
||||
0x0808081919081919, 0x0808081919082b08, 0x0808081919190819, 0x0808081919191908,
|
||||
0x08080819192b0808, 0x08080819192b2b08, 0x080808192b080819, 0x080808192b081908,
|
||||
0x080808192b190808, 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b08081919,
|
||||
0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808,
|
||||
0x0808082b19080819, 0x0808082b19081908, 0x0808082b19190808, 0x0808082b19191919,
|
||||
0x0808082b2b080808, 0x0808082b2b082b2b, 0x0808190808080819, 0x0808190808081908,
|
||||
0x080819080808192b, 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b,
|
||||
0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908,
|
||||
0x0808190819080808, 0x080819081908082b, 0x0808190819081919, 0x0808190819082b08,
|
||||
0x0808190819190819, 0x0808190819191908, 0x080819081919192b, 0x08081908192b0808,
|
||||
0x080819082b080819, 0x080819082b081908, 0x080819082b190808, 0x0808191908080808,
|
||||
0x080819190808082b, 0x0808191908081919, 0x0808191908082b08, 0x0808191908190819,
|
||||
0x0808191908191908, 0x08081919082b0808, 0x0808191919080819, 0x0808191919081908,
|
||||
0x0808191919190808, 0x08081919192b0819, 0x080819192b080808, 0x0808192b08080819,
|
||||
0x0808192b08081908, 0x0808192b08190808, 0x0808192b082b192b, 0x0808192b19080808,
|
||||
0x0808192b1908082b, 0x0808192b2b081908, 0x08082b0808080808, 0x08082b080808082b,
|
||||
0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808082b2b, 0x08082b0808190819,
|
||||
0x08082b0808191908, 0x08082b08082b0808, 0x08082b08082b1919, 0x08082b0819080819,
|
||||
0x08082b0819081908, 0x08082b0819190808, 0x08082b0819192b08, 0x08082b082b080808,
|
||||
0x08082b082b2b0808, 0x08082b082b2b2b2b, 0x08082b1908080819, 0x08082b1908081908,
|
||||
0x08082b1908190808, 0x08082b1919080808, 0x08082b192b080819, 0x08082b192b082b19,
|
||||
0x08082b2b08080808, 0x08082b2b082b0808, 0x08082b2b082b2b08, 0x08082b2b2b19192b,
|
||||
0x08082b2b2b2b0808, 0x0819080808080819, 0x0819080808081908, 0x081908080808192b,
|
||||
0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, 0x0819080808191919,
|
||||
0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, 0x0819080819080808,
|
||||
0x081908081908082b, 0x0819080819081919, 0x0819080819082b08, 0x0819080819190819,
|
||||
0x0819080819191908, 0x08190808192b0808, 0x08190808192b2b2b, 0x081908082b080819,
|
||||
0x081908082b081908, 0x081908082b190808, 0x0819081908080808, 0x081908190808082b,
|
||||
0x0819081908081919, 0x0819081908082b08, 0x0819081908190819, 0x0819081908191908,
|
||||
0x08190819082b0808, 0x0819081919080819, 0x0819081919081908, 0x0819081919190808,
|
||||
0x081908192b080808, 0x081908192b191908, 0x081908192b19192b, 0x0819082b08080819,
|
||||
0x0819082b08081908, 0x0819082b0808192b, 0x0819082b08190808, 0x0819082b19080808,
|
||||
0x0819082b192b0808, 0x0819190808080808, 0x081919080808082b, 0x0819190808081919,
|
||||
0x0819190808082b08, 0x0819190808190819, 0x0819190808191908, 0x08191908082b0808,
|
||||
0x0819190819080819, 0x0819190819081908, 0x0819190819082b19, 0x0819190819190808,
|
||||
0x08191908192b1908, 0x081919082b080808, 0x0819191908080819, 0x0819191908081908,
|
||||
0x0819191908190808, 0x0819191919080808, 0x0819192b08080808, 0x0819192b08191908,
|
||||
0x0819192b19082b19, 0x08192b0808080819, 0x08192b0808081908, 0x08192b0808190808,
|
||||
0x08192b080819082b, 0x08192b0819080808, 0x08192b0819191908, 0x08192b082b08192b,
|
||||
0x08192b1908080808, 0x08192b1908081919, 0x08192b19192b192b, 0x08192b2b19190819,
|
||||
0x08192b2b2b2b2b19, 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919,
|
||||
0x082b080808082b08, 0x082b080808082b2b, 0x082b080808190819, 0x082b080808191908,
|
||||
0x082b0808082b0808, 0x082b080819080819, 0x082b080819081908, 0x082b080819190808,
|
||||
0x082b08082b080808, 0x082b08082b2b0808, 0x082b081908080819, 0x082b081908081908,
|
||||
0x082b081908190808, 0x082b081919080808, 0x082b081919082b08, 0x082b0819192b1919,
|
||||
0x082b082b08080808, 0x082b082b082b082b, 0x082b082b2b080808, 0x082b082b2b2b2b08,
|
||||
0x082b190808080819, 0x082b190808081908, 0x082b190808190808, 0x082b1908082b2b19,
|
||||
0x082b190819080808, 0x082b191908080808, 0x082b191919080819, 0x082b19191919082b,
|
||||
0x082b19192b192b19, 0x082b192b08080819, 0x082b192b08192b2b, 0x082b192b2b2b192b,
|
||||
0x082b2b0808080808, 0x082b2b0808082b08, 0x082b2b0808082b2b, 0x082b2b08082b0808,
|
||||
0x082b2b0819191919, 0x082b2b082b082b08, 0x082b2b082b2b082b, 0x082b2b19192b2b08,
|
||||
0x082b2b192b190808, 0x082b2b2b08082b08, 0x082b2b2b082b0808, 0x082b2b2b2b08082b,
|
||||
0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, 0x1908080808081908,
|
||||
0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, 0x190808080819082b,
|
||||
0x1908080808191919, 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908,
|
||||
0x1908080819080808, 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08,
|
||||
0x1908080819082b2b, 0x1908080819190819, 0x1908080819191908, 0x19080808192b0808,
|
||||
0x19080808192b1919, 0x190808082b080819, 0x190808082b081908, 0x190808082b190808,
|
||||
0x1908081908080808, 0x190808190808082b, 0x1908081908081919, 0x1908081908082b08,
|
||||
0x1908081908190819, 0x1908081908191908, 0x19080819082b0808, 0x1908081919080819,
|
||||
0x1908081919081908, 0x1908081919190808, 0x190808192b080808, 0x190808192b081919,
|
||||
0x190808192b2b082b, 0x1908082b08080819, 0x1908082b08081908, 0x1908082b08190808,
|
||||
0x1908082b0819082b, 0x1908082b082b2b19, 0x1908082b19080808, 0x1908190808080808,
|
||||
0x190819080808082b, 0x1908190808081919, 0x1908190808082b08, 0x1908190808190819,
|
||||
0x1908190808191908, 0x1908190808192b19, 0x19081908082b0808, 0x1908190819080819,
|
||||
0x1908190819081908, 0x1908190819190808, 0x190819082b080808, 0x190819082b191908,
|
||||
0x1908191908080819, 0x1908191908081908, 0x1908191908190808, 0x19081919082b1908,
|
||||
0x1908191919080808, 0x190819192b192b2b, 0x1908192b08080808, 0x1908192b08082b2b,
|
||||
0x1908192b19081908, 0x1908192b19190808, 0x19082b0808080819, 0x19082b0808081908,
|
||||
0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, 0x19082b0819191908,
|
||||
0x19082b08192b082b, 0x19082b1908080808, 0x19082b1908190819, 0x19082b1919081908,
|
||||
0x19082b1919190808, 0x19082b19192b2b19, 0x19082b2b08081908, 0x1919080808080808,
|
||||
0x191908080808082b, 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819,
|
||||
0x1919080808191908, 0x19190808082b0808, 0x19190808082b2b08, 0x1919080819080819,
|
||||
0x1919080819081908, 0x1919080819190808, 0x191908082b080808, 0x1919081908080819,
|
||||
0x1919081908081908, 0x1919081908190808, 0x1919081908191919, 0x1919081919080808,
|
||||
0x191908191908082b, 0x1919082b08080808, 0x1919082b19081908, 0x1919082b2b2b2b2b,
|
||||
0x1919190808080819, 0x1919190808081908, 0x1919190808190808, 0x19191908082b0819,
|
||||
0x1919190819080808, 0x19191908192b0808, 0x191919082b080819, 0x191919082b2b0819,
|
||||
0x1919191908080808, 0x1919191908082b08, 0x191919192b080808, 0x191919192b082b08,
|
||||
0x1919192b082b0819, 0x1919192b192b2b08, 0x1919192b2b2b0819, 0x19192b0808080808,
|
||||
0x19192b0808191908, 0x19192b0819080819, 0x19192b0819190808, 0x19192b082b192b19,
|
||||
0x19192b1908192b2b, 0x19192b1919080808, 0x19192b191908082b, 0x19192b2b2b081919,
|
||||
0x192b080808080819, 0x192b080808081908, 0x192b080808190808, 0x192b080819080808,
|
||||
0x192b080819191908, 0x192b0808192b082b, 0x192b08082b08192b, 0x192b08082b2b2b19,
|
||||
0x192b081908080808, 0x192b082b082b1908, 0x192b082b19082b2b, 0x192b082b2b19082b,
|
||||
0x192b190808080808, 0x192b19080819192b, 0x192b191908190808, 0x192b191919080808,
|
||||
0x192b191919081919, 0x192b19192b2b1908, 0x192b2b0808080819, 0x192b2b08192b2b2b,
|
||||
0x192b2b19082b1919, 0x192b2b2b0808192b, 0x192b2b2b19191908, 0x192b2b2b192b082b,
|
||||
0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08,
|
||||
0x2b08080808190819, 0x2b08080808191908, 0x2b080808082b0808, 0x2b080808082b2b2b,
|
||||
0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808082b080808,
|
||||
0x2b0808082b08082b, 0x2b0808082b2b2b08, 0x2b0808082b2b2b2b, 0x2b08081908080819,
|
||||
0x2b08081908081908, 0x2b0808190808192b, 0x2b08081908190808, 0x2b08081919080808,
|
||||
0x2b08081919190819, 0x2b08081919192b19, 0x2b08082b08080808, 0x2b08082b082b0808,
|
||||
0x2b08082b2b080808, 0x2b08082b2b08082b, 0x2b08082b2b2b0808, 0x2b08082b2b2b2b08,
|
||||
0x2b08190808080819, 0x2b08190808081908, 0x2b08190808190808, 0x2b0819080819082b,
|
||||
0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808, 0x2b0819082b082b19,
|
||||
0x2b08191908080808, 0x2b08191919081908, 0x2b0819192b2b1919, 0x2b08192b08192b08,
|
||||
0x2b08192b192b2b2b, 0x2b082b0808080808, 0x2b082b0808082b08, 0x2b082b08082b1919,
|
||||
0x2b082b0819192b2b, 0x2b082b082b080808, 0x2b082b082b08082b, 0x2b082b082b2b2b08,
|
||||
0x2b082b190808192b, 0x2b082b2b082b082b, 0x2b082b2b2b080808, 0x2b082b2b2b082b08,
|
||||
0x2b082b2b2b19192b, 0x2b082b2b2b2b2b08, 0x2b19080808080819, 0x2b19080808081908,
|
||||
0x2b19080808190808, 0x2b19080819080808, 0x2b1908081919192b, 0x2b1908082b081908,
|
||||
0x2b19081908080808, 0x2b190819082b082b, 0x2b190819192b1908, 0x2b19082b1919192b,
|
||||
0x2b19082b2b082b19, 0x2b19190808080808, 0x2b19190808081919, 0x2b19190819081908,
|
||||
0x2b19190819190808, 0x2b19190819192b08, 0x2b191919082b2b19, 0x2b1919192b190808,
|
||||
0x2b1919192b19082b, 0x2b19192b19080819, 0x2b192b0819190819, 0x2b192b082b2b192b,
|
||||
0x2b192b1919082b19, 0x2b192b2b08191919, 0x2b192b2b192b0808, 0x2b2b080808080808,
|
||||
0x2b2b08080808082b, 0x2b2b080808082b08, 0x2b2b080808082b2b, 0x2b2b0808082b0808,
|
||||
0x2b2b0808082b2b2b, 0x2b2b08082b2b0808, 0x2b2b081919190819, 0x2b2b081919192b19,
|
||||
0x2b2b08192b2b192b, 0x2b2b082b08080808, 0x2b2b082b0808082b, 0x2b2b082b08082b08,
|
||||
0x2b2b082b082b2b2b, 0x2b2b082b2b080808, 0x2b2b082b2b2b0808, 0x2b2b190819080808,
|
||||
0x2b2b19082b191919, 0x2b2b192b192b1919, 0x2b2b192b2b192b08, 0x2b2b2b0808082b2b,
|
||||
0x2b2b2b08082b0808, 0x2b2b2b08082b082b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b0808,
|
||||
0x2b2b2b082b2b2b08, 0x2b2b2b1908081908, 0x2b2b2b192b081908, 0x2b2b2b192b08192b,
|
||||
0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
|
||||
};
|
||||
|
||||
static const __device__ uint64_t iq2s_grid[1024] = {
|
||||
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
|
||||
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
|
||||
0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
|
||||
0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
|
||||
0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
|
||||
0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x08080808192b192b,
|
||||
0x08080808192b2b19, 0x080808082b080808, 0x080808082b08082b, 0x080808082b081919,
|
||||
0x080808082b082b08, 0x080808082b190819, 0x080808082b191908, 0x080808082b2b0808,
|
||||
0x080808082b2b1919, 0x080808082b2b2b2b, 0x0808081908080819, 0x0808081908081908,
|
||||
0x080808190808192b, 0x0808081908082b19, 0x0808081908190808, 0x080808190819082b,
|
||||
0x0808081908191919, 0x0808081908192b08, 0x08080819082b0819, 0x08080819082b1908,
|
||||
0x0808081919080808, 0x080808191908082b, 0x0808081919081919, 0x0808081919082b08,
|
||||
0x0808081919190819, 0x0808081919191908, 0x080808191919192b, 0x0808081919192b19,
|
||||
0x08080819192b0808, 0x08080819192b1919, 0x08080819192b2b08, 0x080808192b080819,
|
||||
0x080808192b081908, 0x080808192b190808, 0x080808192b19082b, 0x080808192b191919,
|
||||
0x080808192b2b0819, 0x080808192b2b1908, 0x0808082b08080808, 0x0808082b0808082b,
|
||||
0x0808082b08081919, 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908,
|
||||
0x0808082b082b0808, 0x0808082b082b2b2b, 0x0808082b19080819, 0x0808082b19081908,
|
||||
0x0808082b1908192b, 0x0808082b19082b19, 0x0808082b19190808, 0x0808082b19191919,
|
||||
0x0808082b2b080808, 0x0808082b2b081919, 0x0808082b2b082b2b, 0x0808082b2b191908,
|
||||
0x0808082b2b2b082b, 0x0808190808080819, 0x0808190808081908, 0x080819080808192b,
|
||||
0x0808190808082b19, 0x0808190808190808, 0x080819080819082b, 0x0808190808191919,
|
||||
0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908, 0x08081908082b192b,
|
||||
0x08081908082b2b19, 0x0808190819080808, 0x080819081908082b, 0x0808190819081919,
|
||||
0x0808190819082b08, 0x0808190819082b2b, 0x0808190819190819, 0x0808190819191908,
|
||||
0x080819081919192b, 0x0808190819192b19, 0x08081908192b0808, 0x08081908192b082b,
|
||||
0x08081908192b1919, 0x080819082b080819, 0x080819082b081908, 0x080819082b08192b,
|
||||
0x080819082b082b19, 0x080819082b190808, 0x080819082b191919, 0x080819082b192b08,
|
||||
0x080819082b2b0819, 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b,
|
||||
0x0808191908081919, 0x0808191908082b08, 0x0808191908082b2b, 0x0808191908190819,
|
||||
0x0808191908191908, 0x080819190819192b, 0x0808191908192b19, 0x08081919082b0808,
|
||||
0x08081919082b1919, 0x08081919082b2b08, 0x0808191919080819, 0x0808191919081908,
|
||||
0x080819191908192b, 0x0808191919082b19, 0x0808191919190808, 0x080819191919082b,
|
||||
0x0808191919191919, 0x0808191919192b08, 0x08081919192b0819, 0x08081919192b1908,
|
||||
0x080819192b080808, 0x080819192b08082b, 0x080819192b081919, 0x080819192b082b08,
|
||||
0x080819192b190819, 0x080819192b191908, 0x080819192b2b0808, 0x0808192b08080819,
|
||||
0x0808192b08081908, 0x0808192b0808192b, 0x0808192b08082b19, 0x0808192b08190808,
|
||||
0x0808192b08191919, 0x0808192b19080808, 0x0808192b19081919, 0x0808192b19082b08,
|
||||
0x0808192b19190819, 0x0808192b19191908, 0x0808192b192b0808, 0x0808192b2b080819,
|
||||
0x0808192b2b081908, 0x0808192b2b190808, 0x08082b0808080808, 0x08082b080808082b,
|
||||
0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808190819, 0x08082b0808191908,
|
||||
0x08082b080819192b, 0x08082b0808192b19, 0x08082b08082b0808, 0x08082b08082b1919,
|
||||
0x08082b08082b2b2b, 0x08082b0819080819, 0x08082b0819081908, 0x08082b081908192b,
|
||||
0x08082b0819082b19, 0x08082b0819190808, 0x08082b081919082b, 0x08082b0819191919,
|
||||
0x08082b0819192b08, 0x08082b08192b0819, 0x08082b08192b1908, 0x08082b082b080808,
|
||||
0x08082b082b081919, 0x08082b082b191908, 0x08082b082b2b2b2b, 0x08082b1908080819,
|
||||
0x08082b1908081908, 0x08082b1908190808, 0x08082b190819082b, 0x08082b1908191919,
|
||||
0x08082b1908192b08, 0x08082b19082b0819, 0x08082b1919080808, 0x08082b1919081919,
|
||||
0x08082b1919082b08, 0x08082b1919190819, 0x08082b1919191908, 0x08082b19192b0808,
|
||||
0x08082b192b080819, 0x08082b192b190808, 0x08082b2b08080808, 0x08082b2b08190819,
|
||||
0x08082b2b08191908, 0x08082b2b082b082b, 0x08082b2b082b2b08, 0x08082b2b082b2b2b,
|
||||
0x08082b2b19190808, 0x08082b2b2b192b19, 0x0819080808080819, 0x0819080808081908,
|
||||
0x081908080808192b, 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b,
|
||||
0x0819080808191919, 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908,
|
||||
0x08190808082b192b, 0x0819080819080808, 0x081908081908082b, 0x0819080819081919,
|
||||
0x0819080819082b08, 0x0819080819190819, 0x0819080819191908, 0x081908081919192b,
|
||||
0x0819080819192b19, 0x08190808192b0808, 0x08190808192b082b, 0x08190808192b1919,
|
||||
0x08190808192b2b08, 0x081908082b080819, 0x081908082b081908, 0x081908082b08192b,
|
||||
0x081908082b190808, 0x081908082b191919, 0x081908082b192b08, 0x081908082b2b0819,
|
||||
0x081908082b2b1908, 0x0819081908080808, 0x081908190808082b, 0x0819081908081919,
|
||||
0x0819081908082b08, 0x0819081908082b2b, 0x0819081908190819, 0x0819081908191908,
|
||||
0x081908190819192b, 0x0819081908192b19, 0x08190819082b0808, 0x08190819082b082b,
|
||||
0x08190819082b1919, 0x08190819082b2b08, 0x0819081919080819, 0x0819081919081908,
|
||||
0x081908191908192b, 0x0819081919082b19, 0x0819081919190808, 0x081908191919082b,
|
||||
0x0819081919191919, 0x0819081919192b08, 0x08190819192b0819, 0x08190819192b1908,
|
||||
0x081908192b080808, 0x081908192b08082b, 0x081908192b081919, 0x081908192b082b08,
|
||||
0x081908192b190819, 0x081908192b191908, 0x0819082b08080819, 0x0819082b08081908,
|
||||
0x0819082b08082b19, 0x0819082b08190808, 0x0819082b08191919, 0x0819082b082b0819,
|
||||
0x0819082b082b1908, 0x0819082b19080808, 0x0819082b19081919, 0x0819082b19190819,
|
||||
0x0819082b19191908, 0x0819082b2b080819, 0x0819082b2b081908, 0x0819082b2b190808,
|
||||
0x0819190808080808, 0x081919080808082b, 0x0819190808081919, 0x0819190808082b08,
|
||||
0x0819190808190819, 0x0819190808191908, 0x081919080819192b, 0x0819190808192b19,
|
||||
0x08191908082b0808, 0x08191908082b1919, 0x08191908082b2b08, 0x0819190819080819,
|
||||
0x0819190819081908, 0x081919081908192b, 0x0819190819082b19, 0x0819190819190808,
|
||||
0x081919081919082b, 0x0819190819191919, 0x0819190819192b08, 0x08191908192b0819,
|
||||
0x08191908192b1908, 0x081919082b080808, 0x081919082b08082b, 0x081919082b081919,
|
||||
0x081919082b082b08, 0x081919082b190819, 0x081919082b191908, 0x081919082b2b0808,
|
||||
0x0819191908080819, 0x0819191908081908, 0x081919190808192b, 0x0819191908082b19,
|
||||
0x0819191908190808, 0x081919190819082b, 0x0819191908191919, 0x0819191908192b08,
|
||||
0x08191919082b0819, 0x08191919082b1908, 0x0819191919080808, 0x081919191908082b,
|
||||
0x0819191919081919, 0x0819191919082b08, 0x0819191919190819, 0x0819191919191908,
|
||||
0x08191919192b0808, 0x081919192b080819, 0x081919192b081908, 0x081919192b190808,
|
||||
0x0819192b08080808, 0x0819192b08081919, 0x0819192b08082b08, 0x0819192b08190819,
|
||||
0x0819192b08191908, 0x0819192b082b0808, 0x0819192b19080819, 0x0819192b19081908,
|
||||
0x0819192b19190808, 0x0819192b2b080808, 0x0819192b2b2b2b2b, 0x08192b0808080819,
|
||||
0x08192b0808081908, 0x08192b080808192b, 0x08192b0808082b19, 0x08192b0808190808,
|
||||
0x08192b0808191919, 0x08192b0808192b08, 0x08192b08082b0819, 0x08192b0819080808,
|
||||
0x08192b081908082b, 0x08192b0819081919, 0x08192b0819082b08, 0x08192b0819190819,
|
||||
0x08192b0819191908, 0x08192b08192b0808, 0x08192b082b080819, 0x08192b082b081908,
|
||||
0x08192b1908080808, 0x08192b190808082b, 0x08192b1908081919, 0x08192b1908082b08,
|
||||
0x08192b1908190819, 0x08192b1908191908, 0x08192b19082b0808, 0x08192b1919080819,
|
||||
0x08192b1919081908, 0x08192b1919190808, 0x08192b19192b2b19, 0x08192b192b2b082b,
|
||||
0x08192b2b08081908, 0x08192b2b08190808, 0x08192b2b19080808, 0x08192b2b1919192b,
|
||||
0x082b080808080808, 0x082b08080808082b, 0x082b080808081919, 0x082b080808082b08,
|
||||
0x082b080808190819, 0x082b080808191908, 0x082b08080819192b, 0x082b080808192b19,
|
||||
0x082b0808082b0808, 0x082b0808082b1919, 0x082b0808082b2b2b, 0x082b080819080819,
|
||||
0x082b080819081908, 0x082b080819190808, 0x082b08081919082b, 0x082b080819191919,
|
||||
0x082b0808192b1908, 0x082b08082b080808, 0x082b08082b082b2b, 0x082b08082b191908,
|
||||
0x082b08082b2b2b2b, 0x082b081908080819, 0x082b081908081908, 0x082b081908190808,
|
||||
0x082b08190819082b, 0x082b081908191919, 0x082b0819082b0819, 0x082b081919080808,
|
||||
0x082b08191908082b, 0x082b081919081919, 0x082b081919190819, 0x082b081919191908,
|
||||
0x082b0819192b0808, 0x082b08192b080819, 0x082b08192b081908, 0x082b08192b190808,
|
||||
0x082b082b08080808, 0x082b082b08082b2b, 0x082b082b082b082b, 0x082b082b082b2b08,
|
||||
0x082b082b082b2b2b, 0x082b082b19081908, 0x082b082b19190808, 0x082b082b2b082b08,
|
||||
0x082b082b2b082b2b, 0x082b082b2b2b2b08, 0x082b190808080819, 0x082b190808081908,
|
||||
0x082b19080808192b, 0x082b190808082b19, 0x082b190808190808, 0x082b190808191919,
|
||||
0x082b190808192b08, 0x082b1908082b0819, 0x082b1908082b1908, 0x082b190819080808,
|
||||
0x082b19081908082b, 0x082b190819081919, 0x082b190819082b08, 0x082b190819190819,
|
||||
0x082b190819191908, 0x082b1908192b0808, 0x082b19082b080819, 0x082b19082b081908,
|
||||
0x082b19082b190808, 0x082b191908080808, 0x082b191908081919, 0x082b191908082b08,
|
||||
0x082b191908190819, 0x082b191908191908, 0x082b1919082b0808, 0x082b191919080819,
|
||||
0x082b191919081908, 0x082b191919190808, 0x082b1919192b192b, 0x082b19192b080808,
|
||||
0x082b192b08080819, 0x082b192b08081908, 0x082b192b08190808, 0x082b192b19080808,
|
||||
0x082b192b19192b19, 0x082b2b0808080808, 0x082b2b0808081919, 0x082b2b0808190819,
|
||||
0x082b2b0808191908, 0x082b2b0819080819, 0x082b2b0819081908, 0x082b2b0819190808,
|
||||
0x082b2b082b082b2b, 0x082b2b082b2b2b2b, 0x082b2b1908080819, 0x082b2b1908081908,
|
||||
0x082b2b1908190808, 0x082b2b192b191919, 0x082b2b2b08082b2b, 0x082b2b2b082b082b,
|
||||
0x082b2b2b192b1908, 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819,
|
||||
0x1908080808081908, 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808,
|
||||
0x190808080819082b, 0x1908080808191919, 0x1908080808192b08, 0x1908080808192b2b,
|
||||
0x19080808082b0819, 0x19080808082b1908, 0x19080808082b192b, 0x1908080819080808,
|
||||
0x190808081908082b, 0x1908080819081919, 0x1908080819082b08, 0x1908080819082b2b,
|
||||
0x1908080819190819, 0x1908080819191908, 0x190808081919192b, 0x1908080819192b19,
|
||||
0x19080808192b0808, 0x19080808192b082b, 0x19080808192b1919, 0x190808082b080819,
|
||||
0x190808082b081908, 0x190808082b190808, 0x190808082b191919, 0x190808082b192b08,
|
||||
0x190808082b2b0819, 0x190808082b2b1908, 0x1908081908080808, 0x190808190808082b,
|
||||
0x1908081908081919, 0x1908081908082b08, 0x1908081908190819, 0x1908081908191908,
|
||||
0x190808190819192b, 0x1908081908192b19, 0x19080819082b0808, 0x19080819082b082b,
|
||||
0x19080819082b1919, 0x1908081919080819, 0x1908081919081908, 0x190808191908192b,
|
||||
0x1908081919082b19, 0x1908081919190808, 0x190808191919082b, 0x1908081919191919,
|
||||
0x1908081919192b08, 0x19080819192b0819, 0x19080819192b1908, 0x190808192b080808,
|
||||
0x190808192b08082b, 0x190808192b081919, 0x190808192b082b08, 0x190808192b190819,
|
||||
0x190808192b191908, 0x190808192b2b0808, 0x1908082b08080819, 0x1908082b08081908,
|
||||
0x1908082b08190808, 0x1908082b0819082b, 0x1908082b08191919, 0x1908082b08192b08,
|
||||
0x1908082b082b1908, 0x1908082b19080808, 0x1908082b19081919, 0x1908082b19082b08,
|
||||
0x1908082b19190819, 0x1908082b19191908, 0x1908082b192b0808, 0x1908082b2b080819,
|
||||
0x1908082b2b081908, 0x1908190808080808, 0x190819080808082b, 0x1908190808081919,
|
||||
0x1908190808082b08, 0x1908190808082b2b, 0x1908190808190819, 0x1908190808191908,
|
||||
0x190819080819192b, 0x1908190808192b19, 0x19081908082b0808, 0x19081908082b082b,
|
||||
0x19081908082b1919, 0x19081908082b2b08, 0x1908190819080819, 0x1908190819081908,
|
||||
0x190819081908192b, 0x1908190819082b19, 0x1908190819190808, 0x190819081919082b,
|
||||
0x1908190819191919, 0x1908190819192b08, 0x19081908192b0819, 0x19081908192b1908,
|
||||
0x190819082b080808, 0x190819082b08082b, 0x190819082b081919, 0x190819082b082b08,
|
||||
0x190819082b190819, 0x190819082b191908, 0x190819082b2b0808, 0x1908191908080819,
|
||||
0x1908191908081908, 0x190819190808192b, 0x1908191908082b19, 0x1908191908190808,
|
||||
0x190819190819082b, 0x1908191908191919, 0x1908191908192b08, 0x19081919082b0819,
|
||||
0x19081919082b1908, 0x1908191919080808, 0x190819191908082b, 0x1908191919081919,
|
||||
0x1908191919082b08, 0x1908191919190819, 0x1908191919191908, 0x19081919192b0808,
|
||||
0x19081919192b2b2b, 0x190819192b080819, 0x190819192b081908, 0x190819192b190808,
|
||||
0x1908192b08080808, 0x1908192b0808082b, 0x1908192b08081919, 0x1908192b08082b08,
|
||||
0x1908192b08190819, 0x1908192b08191908, 0x1908192b082b0808, 0x1908192b19080819,
|
||||
0x1908192b19081908, 0x1908192b19190808, 0x1908192b2b080808, 0x1908192b2b2b1919,
|
||||
0x19082b0808080819, 0x19082b0808081908, 0x19082b0808082b19, 0x19082b0808190808,
|
||||
0x19082b080819082b, 0x19082b0808191919, 0x19082b0808192b08, 0x19082b08082b0819,
|
||||
0x19082b08082b1908, 0x19082b0819080808, 0x19082b081908082b, 0x19082b0819081919,
|
||||
0x19082b0819082b08, 0x19082b0819190819, 0x19082b0819191908, 0x19082b08192b0808,
|
||||
0x19082b082b081908, 0x19082b082b190808, 0x19082b1908080808, 0x19082b190808082b,
|
||||
0x19082b1908081919, 0x19082b1908082b08, 0x19082b1908190819, 0x19082b1908191908,
|
||||
0x19082b19082b0808, 0x19082b1919080819, 0x19082b1919081908, 0x19082b1919190808,
|
||||
0x19082b192b080808, 0x19082b192b19192b, 0x19082b2b08080819, 0x19082b2b08081908,
|
||||
0x19082b2b08190808, 0x19082b2b19080808, 0x1919080808080808, 0x191908080808082b,
|
||||
0x1919080808081919, 0x1919080808082b08, 0x1919080808190819, 0x1919080808191908,
|
||||
0x191908080819192b, 0x1919080808192b19, 0x19190808082b0808, 0x19190808082b082b,
|
||||
0x19190808082b1919, 0x19190808082b2b08, 0x1919080819080819, 0x1919080819081908,
|
||||
0x191908081908192b, 0x1919080819082b19, 0x1919080819190808, 0x191908081919082b,
|
||||
0x1919080819191919, 0x1919080819192b08, 0x19190808192b0819, 0x19190808192b1908,
|
||||
0x191908082b080808, 0x191908082b08082b, 0x191908082b081919, 0x191908082b082b08,
|
||||
0x191908082b190819, 0x191908082b191908, 0x1919081908080819, 0x1919081908081908,
|
||||
0x191908190808192b, 0x1919081908082b19, 0x1919081908190808, 0x191908190819082b,
|
||||
0x1919081908191919, 0x1919081908192b08, 0x19190819082b0819, 0x19190819082b1908,
|
||||
0x1919081919080808, 0x191908191908082b, 0x1919081919081919, 0x1919081919082b08,
|
||||
0x1919081919190819, 0x1919081919191908, 0x19190819192b0808, 0x191908192b080819,
|
||||
0x191908192b081908, 0x191908192b190808, 0x1919082b08080808, 0x1919082b08081919,
|
||||
0x1919082b08082b08, 0x1919082b08190819, 0x1919082b08191908, 0x1919082b082b0808,
|
||||
0x1919082b19080819, 0x1919082b19081908, 0x1919082b19190808, 0x1919082b192b2b19,
|
||||
0x1919082b2b080808, 0x1919190808080819, 0x1919190808081908, 0x191919080808192b,
|
||||
0x1919190808082b19, 0x1919190808190808, 0x191919080819082b, 0x1919190808191919,
|
||||
0x1919190808192b08, 0x19191908082b0819, 0x19191908082b1908, 0x1919190819080808,
|
||||
0x191919081908082b, 0x1919190819081919, 0x1919190819082b08, 0x1919190819190819,
|
||||
0x1919190819191908, 0x19191908192b0808, 0x191919082b080819, 0x191919082b081908,
|
||||
0x191919082b190808, 0x1919191908080808, 0x191919190808082b, 0x1919191908081919,
|
||||
0x1919191908082b08, 0x1919191908190819, 0x1919191908191908, 0x19191919082b0808,
|
||||
0x1919191919080819, 0x1919191919081908, 0x1919191919190808, 0x191919192b080808,
|
||||
0x1919192b08080819, 0x1919192b08081908, 0x1919192b08190808, 0x1919192b082b192b,
|
||||
0x1919192b19080808, 0x19192b0808080808, 0x19192b080808082b, 0x19192b0808081919,
|
||||
0x19192b0808082b08, 0x19192b0808190819, 0x19192b0808191908, 0x19192b08082b0808,
|
||||
0x19192b0819080819, 0x19192b0819081908, 0x19192b0819190808, 0x19192b0819192b2b,
|
||||
0x19192b082b080808, 0x19192b1908080819, 0x19192b1908081908, 0x19192b1908190808,
|
||||
0x19192b1919080808, 0x19192b2b08080808, 0x19192b2b08192b19, 0x19192b2b2b081919,
|
||||
0x19192b2b2b2b2b08, 0x192b080808080819, 0x192b080808081908, 0x192b08080808192b,
|
||||
0x192b080808190808, 0x192b08080819082b, 0x192b080808191919, 0x192b080808192b08,
|
||||
0x192b0808082b0819, 0x192b0808082b1908, 0x192b080819080808, 0x192b080819081919,
|
||||
0x192b080819082b08, 0x192b080819190819, 0x192b080819191908, 0x192b0808192b0808,
|
||||
0x192b08082b081908, 0x192b08082b190808, 0x192b081908080808, 0x192b08190808082b,
|
||||
0x192b081908081919, 0x192b081908082b08, 0x192b081908190819, 0x192b081908191908,
|
||||
0x192b0819082b0808, 0x192b081919080819, 0x192b081919081908, 0x192b081919190808,
|
||||
0x192b08192b080808, 0x192b08192b192b19, 0x192b082b08081908, 0x192b082b08190808,
|
||||
0x192b082b19080808, 0x192b082b1919192b, 0x192b082b2b2b0819, 0x192b190808080808,
|
||||
0x192b190808081919, 0x192b190808082b08, 0x192b190808190819, 0x192b190808191908,
|
||||
0x192b1908082b0808, 0x192b190819080819, 0x192b190819081908, 0x192b190819190808,
|
||||
0x192b19082b080808, 0x192b191908080819, 0x192b191908081908, 0x192b191908190808,
|
||||
0x192b191919080808, 0x192b191919082b2b, 0x192b1919192b2b08, 0x192b19192b19082b,
|
||||
0x192b192b08080808, 0x192b192b2b191908, 0x192b2b0808080819, 0x192b2b0808081908,
|
||||
0x192b2b0808190808, 0x192b2b08192b1919, 0x192b2b082b192b08, 0x192b2b1908080808,
|
||||
0x192b2b19082b2b2b, 0x192b2b2b1908082b, 0x192b2b2b2b2b0819, 0x2b08080808080808,
|
||||
0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08, 0x2b08080808190819,
|
||||
0x2b08080808191908, 0x2b08080808192b19, 0x2b080808082b0808, 0x2b080808082b1919,
|
||||
0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808081919082b,
|
||||
0x2b08080819191919, 0x2b08080819192b08, 0x2b080808192b0819, 0x2b0808082b080808,
|
||||
0x2b0808082b081919, 0x2b0808082b190819, 0x2b0808082b191908, 0x2b08081908080819,
|
||||
0x2b08081908081908, 0x2b08081908082b19, 0x2b08081908190808, 0x2b0808190819082b,
|
||||
0x2b08081908191919, 0x2b08081908192b08, 0x2b080819082b0819, 0x2b080819082b1908,
|
||||
0x2b08081919080808, 0x2b0808191908082b, 0x2b08081919081919, 0x2b08081919082b08,
|
||||
0x2b08081919190819, 0x2b08081919191908, 0x2b0808192b080819, 0x2b0808192b081908,
|
||||
0x2b0808192b190808, 0x2b0808192b2b2b19, 0x2b08082b08080808, 0x2b08082b08081919,
|
||||
0x2b08082b08082b2b, 0x2b08082b08190819, 0x2b08082b08191908, 0x2b08082b19080819,
|
||||
0x2b08082b19081908, 0x2b08082b19190808, 0x2b08190808080819, 0x2b08190808081908,
|
||||
0x2b0819080808192b, 0x2b08190808082b19, 0x2b08190808190808, 0x2b0819080819082b,
|
||||
0x2b08190808191919, 0x2b08190808192b08, 0x2b081908082b0819, 0x2b08190819080808,
|
||||
0x2b0819081908082b, 0x2b08190819081919, 0x2b08190819082b08, 0x2b08190819190819,
|
||||
0x2b08190819191908, 0x2b081908192b0808, 0x2b0819082b080819, 0x2b0819082b081908,
|
||||
0x2b0819082b190808, 0x2b08191908080808, 0x2b0819190808082b, 0x2b08191908081919,
|
||||
0x2b08191908082b08, 0x2b08191908190819, 0x2b08191908191908, 0x2b081919082b0808,
|
||||
0x2b08191919080819, 0x2b08191919081908, 0x2b08191919190808, 0x2b0819192b080808,
|
||||
0x2b0819192b082b2b, 0x2b08192b08080819, 0x2b08192b08081908, 0x2b08192b08190808,
|
||||
0x2b08192b082b2b19, 0x2b08192b19080808, 0x2b082b0808080808, 0x2b082b0808081919,
|
||||
0x2b082b0808190819, 0x2b082b0808191908, 0x2b082b0819080819, 0x2b082b0819081908,
|
||||
0x2b082b0819190808, 0x2b082b082b2b082b, 0x2b082b1908080819, 0x2b082b1908081908,
|
||||
0x2b082b1919080808, 0x2b082b19192b1919, 0x2b082b2b082b082b, 0x2b082b2b19192b08,
|
||||
0x2b082b2b19192b2b, 0x2b082b2b2b08082b, 0x2b082b2b2b2b082b, 0x2b19080808080819,
|
||||
0x2b19080808081908, 0x2b19080808082b19, 0x2b19080808190808, 0x2b1908080819082b,
|
||||
0x2b19080808191919, 0x2b19080808192b08, 0x2b190808082b1908, 0x2b19080819080808,
|
||||
0x2b1908081908082b, 0x2b19080819081919, 0x2b19080819082b08, 0x2b19080819190819,
|
||||
0x2b19080819191908, 0x2b190808192b0808, 0x2b1908082b080819, 0x2b1908082b081908,
|
||||
0x2b1908082b190808, 0x2b19081908080808, 0x2b19081908081919, 0x2b19081908190819,
|
||||
0x2b19081908191908, 0x2b19081919080819, 0x2b19081919081908, 0x2b19081919190808,
|
||||
0x2b19081919192b2b, 0x2b19082b08080819, 0x2b19082b08081908, 0x2b19082b08190808,
|
||||
0x2b19082b19080808, 0x2b19082b2b2b192b, 0x2b19190808080808, 0x2b1919080808082b,
|
||||
0x2b19190808081919, 0x2b19190808082b08, 0x2b19190808190819, 0x2b19190808191908,
|
||||
0x2b191908082b0808, 0x2b19190819080819, 0x2b19190819081908, 0x2b19190819190808,
|
||||
0x2b1919082b080808, 0x2b1919082b19192b, 0x2b19191908080819, 0x2b19191908081908,
|
||||
0x2b19191908190808, 0x2b19191919080808, 0x2b1919192b192b08, 0x2b1919192b2b0819,
|
||||
0x2b19192b08080808, 0x2b19192b1908192b, 0x2b19192b192b1908, 0x2b192b0808080819,
|
||||
0x2b192b0808081908, 0x2b192b0808190808, 0x2b192b08082b192b, 0x2b192b0819080808,
|
||||
0x2b192b082b2b2b19, 0x2b192b1908080808, 0x2b192b1919082b19, 0x2b192b191919082b,
|
||||
0x2b192b2b2b190808, 0x2b2b080808080808, 0x2b2b080808081919, 0x2b2b080808082b2b,
|
||||
0x2b2b080808191908, 0x2b2b0808082b082b, 0x2b2b0808082b2b2b, 0x2b2b080819080819,
|
||||
0x2b2b080819081908, 0x2b2b080819190808, 0x2b2b08082b2b082b, 0x2b2b08082b2b2b2b,
|
||||
0x2b2b081919080808, 0x2b2b0819192b1919, 0x2b2b082b0808082b, 0x2b2b082b08082b2b,
|
||||
0x2b2b082b082b082b, 0x2b2b082b082b2b08, 0x2b2b082b082b2b2b, 0x2b2b082b2b08082b,
|
||||
0x2b2b082b2b082b08, 0x2b2b082b2b082b2b, 0x2b2b082b2b2b2b08, 0x2b2b190808080819,
|
||||
0x2b2b190808081908, 0x2b2b190808190808, 0x2b2b190819080808, 0x2b2b19082b082b19,
|
||||
0x2b2b19082b2b1908, 0x2b2b191908080808, 0x2b2b191908192b19, 0x2b2b192b19190819,
|
||||
0x2b2b2b0808082b2b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b082b, 0x2b2b2b1919191908,
|
||||
0x2b2b2b192b08192b, 0x2b2b2b2b08082b08, 0x2b2b2b2b08082b2b, 0x2b2b2b2b082b0808,
|
||||
0x2b2b2b2b082b082b, 0x2b2b2b2b082b2b08, 0x2b2b2b2b2b082b08, 0x2b2b2b2b2b2b2b2b,
|
||||
};
|
||||
|
||||
static const __device__ uint32_t iq3xxs_grid[256] = {
|
||||
0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
|
||||
0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
|
||||
0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404,
|
||||
0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e,
|
||||
0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c,
|
||||
0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c,
|
||||
0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34,
|
||||
0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c,
|
||||
0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c,
|
||||
0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04,
|
||||
0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c,
|
||||
0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414,
|
||||
0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434,
|
||||
0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c,
|
||||
0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e,
|
||||
0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24,
|
||||
0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24,
|
||||
0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c,
|
||||
0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c,
|
||||
0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14,
|
||||
0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414,
|
||||
0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e,
|
||||
0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404,
|
||||
0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c,
|
||||
0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c,
|
||||
0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14,
|
||||
0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c,
|
||||
0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c,
|
||||
0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14,
|
||||
0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14,
|
||||
0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c,
|
||||
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
|
||||
};
|
||||
|
||||
static const __device__ uint32_t iq3xs_grid[512] = {
|
||||
0x04040404, 0x0404040c, 0x04040414, 0x0404042c, 0x0404043e, 0x04040c04, 0x04040c0c, 0x04040c14,
|
||||
0x04040c24, 0x04040c34, 0x04041404, 0x0404140c, 0x0404142c, 0x04041c1c, 0x04042404, 0x04042414,
|
||||
0x0404242c, 0x0404243e, 0x04042c0c, 0x04042c1c, 0x04043404, 0x04043414, 0x04043e0c, 0x04043e24,
|
||||
0x04043e3e, 0x040c0404, 0x040c040c, 0x040c0414, 0x040c0424, 0x040c0c04, 0x040c0c0c, 0x040c0c2c,
|
||||
0x040c1404, 0x040c141c, 0x040c143e, 0x040c1c0c, 0x040c1c2c, 0x040c2424, 0x040c340c, 0x040c342c,
|
||||
0x040c3e14, 0x04140404, 0x0414040c, 0x0414042c, 0x0414043e, 0x04140c04, 0x04140c1c, 0x04140c34,
|
||||
0x0414140c, 0x0414142c, 0x04141c04, 0x04141c24, 0x04142414, 0x0414242c, 0x0414243e, 0x04142c0c,
|
||||
0x04142c1c, 0x04143e04, 0x04143e1c, 0x041c041c, 0x041c0c0c, 0x041c0c2c, 0x041c1404, 0x041c1414,
|
||||
0x041c1c0c, 0x041c1c1c, 0x041c1c34, 0x041c2424, 0x041c2c04, 0x041c2c14, 0x041c343e, 0x041c3e0c,
|
||||
0x041c3e2c, 0x04240404, 0x04240c1c, 0x04240c3e, 0x0424140c, 0x04241424, 0x04241c14, 0x04242404,
|
||||
0x0424241c, 0x04242c0c, 0x04243e04, 0x042c0414, 0x042c0424, 0x042c1404, 0x042c1414, 0x042c1434,
|
||||
0x042c1c1c, 0x042c240c, 0x042c242c, 0x042c243e, 0x042c3434, 0x042c3e1c, 0x04340434, 0x04340c0c,
|
||||
0x04340c1c, 0x04341c0c, 0x04342c14, 0x04343e0c, 0x043e0404, 0x043e0414, 0x043e0424, 0x043e1404,
|
||||
0x043e1414, 0x043e1434, 0x043e1c1c, 0x043e2c04, 0x043e2c24, 0x0c040404, 0x0c04040c, 0x0c040414,
|
||||
0x0c040424, 0x0c040c04, 0x0c040c0c, 0x0c040c1c, 0x0c040c2c, 0x0c040c3e, 0x0c041404, 0x0c041414,
|
||||
0x0c041c0c, 0x0c041c24, 0x0c041c34, 0x0c042c24, 0x0c042c34, 0x0c04340c, 0x0c043e14, 0x0c0c0404,
|
||||
0x0c0c040c, 0x0c0c041c, 0x0c0c0434, 0x0c0c0c04, 0x0c0c0c24, 0x0c0c140c, 0x0c0c1c04, 0x0c0c1c1c,
|
||||
0x0c0c240c, 0x0c0c2c04, 0x0c0c2c14, 0x0c0c3e04, 0x0c0c3e34, 0x0c140404, 0x0c140c14, 0x0c140c2c,
|
||||
0x0c140c3e, 0x0c141404, 0x0c141424, 0x0c141c14, 0x0c142404, 0x0c14241c, 0x0c142c2c, 0x0c143404,
|
||||
0x0c143e14, 0x0c1c040c, 0x0c1c0424, 0x0c1c043e, 0x0c1c0c04, 0x0c1c0c1c, 0x0c1c140c, 0x0c1c143e,
|
||||
0x0c1c1c04, 0x0c1c1c24, 0x0c1c240c, 0x0c1c3414, 0x0c1c3e04, 0x0c24041c, 0x0c24042c, 0x0c240c14,
|
||||
0x0c240c24, 0x0c241c0c, 0x0c241c1c, 0x0c242414, 0x0c242434, 0x0c242c04, 0x0c242c24, 0x0c2c040c,
|
||||
0x0c2c0c04, 0x0c2c0c1c, 0x0c2c140c, 0x0c2c1c04, 0x0c2c1c14, 0x0c2c2c0c, 0x0c341404, 0x0c341424,
|
||||
0x0c34143e, 0x0c342424, 0x0c342434, 0x0c3e040c, 0x0c3e041c, 0x0c3e0c04, 0x0c3e0c14, 0x0c3e140c,
|
||||
0x0c3e1c2c, 0x0c3e240c, 0x0c3e3414, 0x0c3e3e04, 0x14040404, 0x1404040c, 0x1404041c, 0x1404042c,
|
||||
0x1404043e, 0x14040c04, 0x14040c14, 0x14040c24, 0x14040c34, 0x1404140c, 0x1404141c, 0x1404143e,
|
||||
0x14041c04, 0x14041c14, 0x1404240c, 0x1404241c, 0x1404242c, 0x14042c04, 0x14042c14, 0x1404343e,
|
||||
0x14043e04, 0x14043e1c, 0x14043e2c, 0x140c0404, 0x140c0414, 0x140c0c04, 0x140c0c1c, 0x140c0c3e,
|
||||
0x140c1414, 0x140c142c, 0x140c1c0c, 0x140c1c24, 0x140c2414, 0x140c2c0c, 0x1414040c, 0x14140424,
|
||||
0x1414043e, 0x1414140c, 0x1414141c, 0x14141c04, 0x14141c3e, 0x1414240c, 0x14142c1c, 0x14142c3e,
|
||||
0x14143e0c, 0x14143e24, 0x141c0404, 0x141c0414, 0x141c042c, 0x141c0c0c, 0x141c1414, 0x141c1424,
|
||||
0x141c1c0c, 0x141c1c1c, 0x141c2414, 0x141c2c04, 0x141c3434, 0x1424040c, 0x1424043e, 0x14241404,
|
||||
0x1424141c, 0x14241c14, 0x14241c2c, 0x1424240c, 0x14243e14, 0x14243e2c, 0x142c0424, 0x142c0c0c,
|
||||
0x142c1414, 0x142c1c3e, 0x142c2404, 0x142c2c1c, 0x142c3e04, 0x14340404, 0x14340414, 0x1434043e,
|
||||
0x1434140c, 0x14342c2c, 0x1434340c, 0x143e042c, 0x143e0c0c, 0x143e1434, 0x143e1c04, 0x143e241c,
|
||||
0x143e2c04, 0x1c040414, 0x1c040c0c, 0x1c040c1c, 0x1c040c2c, 0x1c040c3e, 0x1c041414, 0x1c041c0c,
|
||||
0x1c041c1c, 0x1c041c2c, 0x1c042414, 0x1c042424, 0x1c04243e, 0x1c042c0c, 0x1c04341c, 0x1c043e0c,
|
||||
0x1c0c040c, 0x1c0c041c, 0x1c0c042c, 0x1c0c0c24, 0x1c0c140c, 0x1c0c141c, 0x1c0c2404, 0x1c0c3404,
|
||||
0x1c0c3e14, 0x1c0c3e34, 0x1c140404, 0x1c140c14, 0x1c141404, 0x1c141c14, 0x1c141c24, 0x1c142c04,
|
||||
0x1c1c040c, 0x1c1c0c04, 0x1c1c0c24, 0x1c1c140c, 0x1c1c141c, 0x1c1c143e, 0x1c1c1c04, 0x1c1c240c,
|
||||
0x1c1c241c, 0x1c1c243e, 0x1c1c2c2c, 0x1c1c3e1c, 0x1c24041c, 0x1c240c0c, 0x1c240c34, 0x1c241414,
|
||||
0x1c241c0c, 0x1c242c14, 0x1c243404, 0x1c243424, 0x1c2c040c, 0x1c2c0c04, 0x1c2c0c14, 0x1c2c142c,
|
||||
0x1c2c1c14, 0x1c2c2424, 0x1c2c2c34, 0x1c2c3e1c, 0x1c340c34, 0x1c34240c, 0x1c3e040c, 0x1c3e041c,
|
||||
0x1c3e1404, 0x1c3e1414, 0x1c3e1c2c, 0x24040404, 0x24040424, 0x24040c14, 0x24041404, 0x24041424,
|
||||
0x2404143e, 0x24041c14, 0x2404240c, 0x24042c04, 0x24043e04, 0x240c0414, 0x240c043e, 0x240c0c0c,
|
||||
0x240c0c1c, 0x240c1414, 0x240c1c04, 0x240c1c2c, 0x240c241c, 0x240c2c0c, 0x240c2c2c, 0x2414040c,
|
||||
0x2414041c, 0x24140c04, 0x24140c2c, 0x2414140c, 0x24141c1c, 0x24142404, 0x24142c3e, 0x24143414,
|
||||
0x24143e04, 0x241c0424, 0x241c0c0c, 0x241c0c1c, 0x241c1404, 0x241c1414, 0x241c1c0c, 0x241c1c2c,
|
||||
0x24240404, 0x24240414, 0x24241424, 0x24241c3e, 0x24242404, 0x24243e0c, 0x242c042c, 0x242c043e,
|
||||
0x242c140c, 0x242c3414, 0x24340c1c, 0x24341c24, 0x24343404, 0x243e0c04, 0x243e0c2c, 0x243e1c04,
|
||||
0x243e241c, 0x243e2c0c, 0x2c040414, 0x2c040c04, 0x2c040c24, 0x2c041414, 0x2c042404, 0x2c042424,
|
||||
0x2c04243e, 0x2c042c14, 0x2c043434, 0x2c043e24, 0x2c0c040c, 0x2c0c041c, 0x2c0c042c, 0x2c0c0c14,
|
||||
0x2c0c140c, 0x2c0c1c14, 0x2c0c3e14, 0x2c140404, 0x2c140c0c, 0x2c14141c, 0x2c141c04, 0x2c141c34,
|
||||
0x2c142c1c, 0x2c1c0414, 0x2c1c043e, 0x2c1c0c04, 0x2c1c143e, 0x2c1c2424, 0x2c1c2c0c, 0x2c1c342c,
|
||||
0x2c1c3e1c, 0x2c24040c, 0x2c240424, 0x2c241404, 0x2c241c14, 0x2c242434, 0x2c2c0c14, 0x2c2c1434,
|
||||
0x2c2c2c0c, 0x2c2c2c1c, 0x2c342414, 0x2c3e0414, 0x2c3e0424, 0x2c3e1414, 0x34040c0c, 0x34040c1c,
|
||||
0x34040c2c, 0x34041c0c, 0x34041c1c, 0x34043404, 0x340c0404, 0x340c1404, 0x340c143e, 0x340c3424,
|
||||
0x34140c14, 0x34141c24, 0x34142414, 0x34142c2c, 0x34143414, 0x34143e04, 0x341c0404, 0x341c0c24,
|
||||
0x341c140c, 0x341c2404, 0x3424142c, 0x3424241c, 0x34243414, 0x342c0404, 0x342c041c, 0x342c1c24,
|
||||
0x342c3404, 0x3434042c, 0x34342404, 0x343e0c0c, 0x343e0c1c, 0x3e040404, 0x3e040424, 0x3e04043e,
|
||||
0x3e041404, 0x3e041414, 0x3e041c34, 0x3e042404, 0x3e042c24, 0x3e043414, 0x3e0c0414, 0x3e0c0c0c,
|
||||
0x3e0c1424, 0x3e0c241c, 0x3e0c242c, 0x3e14040c, 0x3e140424, 0x3e140c04, 0x3e140c34, 0x3e14140c,
|
||||
0x3e141c04, 0x3e142c0c, 0x3e1c0414, 0x3e1c1c14, 0x3e1c1c2c, 0x3e1c2c1c, 0x3e24040c, 0x3e24042c,
|
||||
0x3e240c1c, 0x3e241404, 0x3e242c04, 0x3e2c1414, 0x3e2c2414, 0x3e340414, 0x3e341c0c, 0x3e3e0404,
|
||||
};
|
||||
|
||||
static const __device__ uint64_t iq1s_grid[512] = {
|
||||
0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
|
||||
0xffffffff01ff00ff, 0xffffffff01ff0001, 0xffffffff0101ffff, 0xffffffff0101ff01,
|
||||
0xffffff00ff000000, 0xffffff000000ff00, 0xffffff00000000ff, 0xffffff0000000100,
|
||||
0xffffff0000010000, 0xffffff0001000000, 0xffffff01ffff00ff, 0xffffff01ff01ff00,
|
||||
0xffffff01ff010100, 0xffffff0100000001, 0xffffff0101ffff00, 0xffffff0101ff0101,
|
||||
0xffffff0101010100, 0xffff00ffff00ff01, 0xffff00ffff0000ff, 0xffff00ff00ff0100,
|
||||
0xffff00ff0100ff00, 0xffff00ff010001ff, 0xffff0000ff0101ff, 0xffff000000ffff00,
|
||||
0xffff000000000000, 0xffff00000001ff01, 0xffff000001000101, 0xffff0000010100ff,
|
||||
0xffff0001ffff0100, 0xffff00010000ff00, 0xffff000100010101, 0xffff000101000000,
|
||||
0xffff01ffffff0000, 0xffff01ffff01ffff, 0xffff01ffff010100, 0xffff01ff00000000,
|
||||
0xffff01ff01ffffff, 0xffff01ff01ff0001, 0xffff01ff0101ffff, 0xffff01ff01010001,
|
||||
0xffff0100ffffff01, 0xffff01000000ffff, 0xffff010000000100, 0xffff010001ff01ff,
|
||||
0xffff010001000000, 0xffff0101ff000000, 0xffff0101000101ff, 0xffff010101ffff01,
|
||||
0xffff01010101ff00, 0xff00ffffff000000, 0xff00ffff00ffff00, 0xff00ffff00000001,
|
||||
0xff00ffff000001ff, 0xff00ffff01010000, 0xff00ff00ffff0000, 0xff00ff00ff00ff00,
|
||||
0xff00ff00ff0000ff, 0xff00ff00ff000100, 0xff00ff00ff010001, 0xff00ff0000ff0001,
|
||||
0xff00ff000000ffff, 0xff00ff0000000000, 0xff00ff000001ff00, 0xff00ff0000010100,
|
||||
0xff00ff0001ff0000, 0xff00ff000100ff00, 0xff00ff0001000100, 0xff00ff01ff000000,
|
||||
0xff00ff0100ff0000, 0xff00ff01000001ff, 0xff00ff0101010001, 0xff0000ff00000000,
|
||||
0xff0000ff0001ff00, 0xff0000ff00010100, 0xff000000ffff0101, 0xff000000ff000000,
|
||||
0xff000000ff01ff00, 0xff00000000ff0000, 0xff0000000000ff00, 0xff000000000000ff,
|
||||
0xff00000000000000, 0xff00000000000001, 0xff00000000000100, 0xff0000000001ffff,
|
||||
0xff00000000010000, 0xff00000001000000, 0xff00000001010100, 0xff000001ff00ff01,
|
||||
0xff000001ff0100ff, 0xff00000100000000, 0xff0000010001ff00, 0xff00000101ff0100,
|
||||
0xff0000010100ff00, 0xff0001ff00ff00ff, 0xff0001ff00000101, 0xff0001ff000100ff,
|
||||
0xff0001ff01000000, 0xff000100ff0001ff, 0xff0001000000ff01, 0xff00010000000000,
|
||||
0xff00010000010001, 0xff00010000010100, 0xff00010001ffff00, 0xff00010001ff0101,
|
||||
0xff00010001010000, 0xff000101ffffffff, 0xff000101ff000101, 0xff00010101ff00ff,
|
||||
0xff00010101000001, 0xff000101010100ff, 0xff01ffffff000101, 0xff01ffffff01ffff,
|
||||
0xff01ffffff01ff01, 0xff01ffffff0101ff, 0xff01ffff00000000, 0xff01ffff01ff0001,
|
||||
0xff01ffff0101ff01, 0xff01ff00ff000000, 0xff01ff0000ff0100, 0xff01ff000000ff01,
|
||||
0xff01ff0000010000, 0xff01ff00010000ff, 0xff01ff01ff01ff00, 0xff01ff0100000101,
|
||||
0xff0100ffffff0000, 0xff0100ffff010000, 0xff0100ff01ff00ff, 0xff0100ff01000100,
|
||||
0xff0100ff010100ff, 0xff010000ffffff01, 0xff01000000000000, 0xff0100000101ff00,
|
||||
0xff010001ffff00ff, 0xff010001ff000100, 0xff01000100ffff00, 0xff01000100010001,
|
||||
0xff01000101ff0001, 0xff010001010001ff, 0xff0101ffffffffff, 0xff0101ffff01ffff,
|
||||
0xff0101ffff010101, 0xff0101ff0000ff00, 0xff0101ff01010001, 0xff010100ff000000,
|
||||
0xff010100ff01ff01, 0xff01010000ff0001, 0xff01010000000100, 0xff01010001000000,
|
||||
0xff0101010100ffff, 0x00ffffff0000ff01, 0x00ffffff000000ff, 0x00ffffff00000100,
|
||||
0x00ffffff00010000, 0x00ffff00ffff0001, 0x00ffff00ff0000ff, 0x00ffff00ff000100,
|
||||
0x00ffff0000000000, 0x00ffff0001000100, 0x00ffff0001010001, 0x00ffff01ff00ff01,
|
||||
0x00ffff0100ff0100, 0x00ffff010000ff00, 0x00ffff01000100ff, 0x00ffff0101ff00ff,
|
||||
0x00ffff010101ff00, 0x00ff00ffffffffff, 0x00ff00ffffff01ff, 0x00ff00ffff000101,
|
||||
0x00ff00ff00000000, 0x00ff00ff000101ff, 0x00ff00ff01010101, 0x00ff0000ff000000,
|
||||
0x00ff0000ff01ffff, 0x00ff000000ff0000, 0x00ff00000000ff00, 0x00ff0000000000ff,
|
||||
0x00ff000000000000, 0x00ff000000000001, 0x00ff000000000100, 0x00ff000000010000,
|
||||
0x00ff000001ffff01, 0x00ff000001000000, 0x00ff0001ff000101, 0x00ff000100ffffff,
|
||||
0x00ff000100000000, 0x00ff0001010001ff, 0x00ff01ffff000000, 0x00ff01ff0001ff00,
|
||||
0x00ff01ff01ff0100, 0x00ff0100ff01ff01, 0x00ff010000ff00ff, 0x00ff010000ff0101,
|
||||
0x00ff010000000000, 0x00ff010000010101, 0x00ff01000100ff00, 0x00ff010001010000,
|
||||
0x00ff0101ffffff00, 0x00ff01010000ff01, 0x00ff010100000100, 0x00ff010101ff0000,
|
||||
0x0000ffffffff0100, 0x0000ffffff00ff00, 0x0000ffffff0000ff, 0x0000ffffff010000,
|
||||
0x0000ffff00000000, 0x0000ffff00010101, 0x0000ffff01ffff01, 0x0000ffff01000100,
|
||||
0x0000ff00ff000000, 0x0000ff00ff01ff00, 0x0000ff00ff0101ff, 0x0000ff0000ff0000,
|
||||
0x0000ff000000ff00, 0x0000ff00000000ff, 0x0000ff0000000000, 0x0000ff0000000001,
|
||||
0x0000ff0000000100, 0x0000ff0000010000, 0x0000ff0001ffffff, 0x0000ff0001ff01ff,
|
||||
0x0000ff0001000000, 0x0000ff000101ffff, 0x0000ff01ffff0101, 0x0000ff01ff010000,
|
||||
0x0000ff0100000000, 0x0000ff0101000101, 0x000000ffffff0001, 0x000000ffff000000,
|
||||
0x000000ff00ff0000, 0x000000ff0000ff00, 0x000000ff000000ff, 0x000000ff00000000,
|
||||
0x000000ff00000001, 0x000000ff00000100, 0x000000ff00010000, 0x000000ff01000000,
|
||||
0x000000ff0101ff00, 0x00000000ffff0000, 0x00000000ff00ff00, 0x00000000ff0000ff,
|
||||
0x00000000ff000000, 0x00000000ff000001, 0x00000000ff000100, 0x00000000ff010000,
|
||||
0x0000000000ffff00, 0x0000000000ff00ff, 0x0000000000ff0000, 0x0000000000ff0001,
|
||||
0x0000000000ff0100, 0x000000000000ffff, 0x000000000000ff00, 0x000000000000ff01,
|
||||
0x00000000000000ff, 0x0000000000000001, 0x00000000000001ff, 0x0000000000000100,
|
||||
0x0000000000000101, 0x000000000001ff00, 0x00000000000100ff, 0x0000000000010000,
|
||||
0x0000000000010001, 0x0000000000010100, 0x0000000001ff0000, 0x000000000100ff00,
|
||||
0x00000000010000ff, 0x0000000001000000, 0x0000000001000001, 0x0000000001000100,
|
||||
0x0000000001010000, 0x00000001ffff01ff, 0x00000001ff000000, 0x0000000100ff0000,
|
||||
0x000000010000ff00, 0x00000001000000ff, 0x0000000100000000, 0x0000000100000001,
|
||||
0x0000000100000100, 0x0000000100010000, 0x0000000101000000, 0x000001ffff00ff00,
|
||||
0x000001ffff010001, 0x000001ffff0101ff, 0x000001ff00ffff01, 0x000001ff0000ffff,
|
||||
0x000001ff00000000, 0x000001ff010000ff, 0x000001ff01010100, 0x00000100ffff0100,
|
||||
0x00000100ff000000, 0x0000010000ff0000, 0x000001000000ff00, 0x00000100000000ff,
|
||||
0x0000010000000000, 0x0000010000000001, 0x0000010000000100, 0x0000010000010000,
|
||||
0x0000010001000000, 0x000001000101ff01, 0x00000101ffff0001, 0x00000101ff01ffff,
|
||||
0x0000010100000000, 0x0000010101010100, 0x0001ffffff000000, 0x0001ffff00ffffff,
|
||||
0x0001ffff00000100, 0x0001ffff0001ff00, 0x0001ffff01000000, 0x0001ff00ffffff00,
|
||||
0x0001ff00ffff01ff, 0x0001ff00ff010000, 0x0001ff0000000000, 0x0001ff0000010001,
|
||||
0x0001ff0001ff0000, 0x0001ff0001010100, 0x0001ff01ff0000ff, 0x0001ff01ff000001,
|
||||
0x0001ff0100ffffff, 0x0001ff010001ffff, 0x0001ff01000101ff, 0x0001ff010100ff01,
|
||||
0x000100ffff00ffff, 0x000100ffff00ff01, 0x000100ffff000100, 0x000100ff00000000,
|
||||
0x000100ff000101ff, 0x000100ff01ff0101, 0x000100ff0100ffff, 0x000100ff01010101,
|
||||
0x00010000ff000000, 0x00010000ff010100, 0x0001000000ff0000, 0x000100000000ff00,
|
||||
0x00010000000000ff, 0x0001000000000000, 0x0001000000000001, 0x0001000000000100,
|
||||
0x0001000000010000, 0x0001000001ffff01, 0x0001000001000000, 0x0001000100ff0101,
|
||||
0x0001000100000000, 0x00010001010100ff, 0x000101ffffff01ff, 0x000101ffffff0101,
|
||||
0x000101ff00010000, 0x000101ff01ff0000, 0x000101ff0100ff01, 0x00010100ffff0000,
|
||||
0x0001010000000000, 0x000101000001ffff, 0x0001010000010101, 0x00010100010001ff,
|
||||
0x00010101ff00ff00, 0x00010101ff010001, 0x0001010100ffffff, 0x0001010100ff01ff,
|
||||
0x00010101000101ff, 0x0001010101ff0000, 0x000101010100ff01, 0x0001010101000101,
|
||||
0x01ffffffffff0101, 0x01ffffffff01ffff, 0x01ffffffff01ff01, 0x01ffffffff0101ff,
|
||||
0x01ffffffff010101, 0x01ffffff00000000, 0x01ffffff01ff01ff, 0x01ffffff01000101,
|
||||
0x01ffffff0101ff01, 0x01ffffff010100ff, 0x01ffff000000ff00, 0x01ffff0000000001,
|
||||
0x01ffff00000001ff, 0x01ffff0000010000, 0x01ffff0001ff0000, 0x01ffff01ffffffff,
|
||||
0x01ffff01ffff01ff, 0x01ffff01ff000000, 0x01ffff01ff01ffff, 0x01ffff01ff0101ff,
|
||||
0x01ffff010100ffff, 0x01ff00ffffff0000, 0x01ff00ffff010000, 0x01ff00ff00ffff01,
|
||||
0x01ff0000ff0000ff, 0x01ff000000000000, 0x01ff00000001ff01, 0x01ff000001ffffff,
|
||||
0x01ff000001010100, 0x01ff0001ffffff01, 0x01ff0001ff010001, 0x01ff000101ff0100,
|
||||
0x01ff000101000001, 0x01ff0001010100ff, 0x01ff01ffff00ffff, 0x01ff01ff00010001,
|
||||
0x01ff01ff01000000, 0x01ff01ff010101ff, 0x01ff0100ff000001, 0x01ff010000ffff00,
|
||||
0x01ff010000000100, 0x01ff010001ff01ff, 0x01ff01000101ffff, 0x01ff0101ffff00ff,
|
||||
0x01ff0101ffff0101, 0x01ff0101ff0101ff, 0x01ff010100010000, 0x0100ffff00ff00ff,
|
||||
0x0100ffff00ff0001, 0x0100ffff00000100, 0x0100ffff0100ff00, 0x0100ff00ffff0000,
|
||||
0x0100ff00ff00ffff, 0x0100ff00ff00ff01, 0x0100ff00ff000100, 0x0100ff00ff010000,
|
||||
0x0100ff0000000000, 0x0100ff00000100ff, 0x0100ff0001ff0101, 0x0100ff0001010101,
|
||||
0x0100ff0100ff00ff, 0x0100ff0100ff0001, 0x0100ff0100000100, 0x0100ff0100010001,
|
||||
0x0100ff0101000000, 0x010000ffff00ff00, 0x010000ff0000ffff, 0x010000ff00000000,
|
||||
0x010000ff010001ff, 0x010000ff01010001, 0x01000000ffffff00, 0x01000000ffff0101,
|
||||
0x01000000ff000000, 0x01000000ff0100ff, 0x01000000ff010101, 0x0100000000ff0000,
|
||||
0x010000000000ff00, 0x01000000000000ff, 0x0100000000000000, 0x0100000000000001,
|
||||
0x0100000000000100, 0x0100000000010000, 0x0100000001000000, 0x0100000100000000,
|
||||
0x01000001000101ff, 0x0100000101ffff01, 0x010001ffff000101, 0x010001ff00ff0100,
|
||||
0x010001ff0000ff00, 0x010001ff000100ff, 0x010001ff01ffffff, 0x01000100ffff0000,
|
||||
0x01000100ff0001ff, 0x0100010000000000, 0x010001000001ff00, 0x0100010001ff0000,
|
||||
0x01000100010000ff, 0x0100010001000101, 0x01000101ff00ff01, 0x0100010100ff0100,
|
||||
0x010001010000ffff, 0x0100010101010001, 0x0101ffffffff0101, 0x0101ffffff0001ff,
|
||||
0x0101ffffff01ffff, 0x0101ffffff010101, 0x0101ffff00000000, 0x0101ffff0101ffff,
|
||||
0x0101ffff010101ff, 0x0101ff00ff000000, 0x0101ff0000ff0100, 0x0101ff000000ff00,
|
||||
0x0101ff0000010000, 0x0101ff00010000ff, 0x0101ff0001000001, 0x0101ff01ff010101,
|
||||
0x0101ff0100000000, 0x0101ff010101ff00, 0x010100ffffff0000, 0x010100ffff010000,
|
||||
0x010100ff00ff01ff, 0x010100ff000000ff, 0x010100ff00000101, 0x010100ff01ffff00,
|
||||
0x01010000ffffff01, 0x01010000ff000100, 0x01010000ff01ff01, 0x0101000000000000,
|
||||
0x01010000000100ff, 0x010100000101ff01, 0x01010001ffff0000, 0x01010001ff00ffff,
|
||||
0x01010001ff010000, 0x0101000101ffffff, 0x0101000101ff01ff, 0x0101000101010101,
|
||||
0x010101ffff01ffff, 0x010101ff00000000, 0x010101ff0001ff01, 0x010101ff0101ffff,
|
||||
0x010101ff010101ff, 0x01010100ffffffff, 0x01010100ff000001, 0x010101000000ff00,
|
||||
0x0101010001010000, 0x0101010100ff0001, 0x010101010001ff01, 0x010101010101ffff,
|
||||
};
|
||||
|
||||
static const __device__ uint8_t ksigns_iq2xs[128] = {
|
||||
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
|
||||
144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
|
||||
160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175,
|
||||
48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63,
|
||||
192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207,
|
||||
80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95,
|
||||
96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,
|
||||
240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
|
||||
};
|
||||
|
||||
static const __device__ uint64_t ksigns64[128] = {
|
||||
0x0000000000000000, 0xff000000000000ff, 0xff0000000000ff00, 0x000000000000ffff,
|
||||
0xff00000000ff0000, 0x0000000000ff00ff, 0x0000000000ffff00, 0xff00000000ffffff,
|
||||
0xff000000ff000000, 0x00000000ff0000ff, 0x00000000ff00ff00, 0xff000000ff00ffff,
|
||||
0x00000000ffff0000, 0xff000000ffff00ff, 0xff000000ffffff00, 0x00000000ffffffff,
|
||||
0xff0000ff00000000, 0x000000ff000000ff, 0x000000ff0000ff00, 0xff0000ff0000ffff,
|
||||
0x000000ff00ff0000, 0xff0000ff00ff00ff, 0xff0000ff00ffff00, 0x000000ff00ffffff,
|
||||
0x000000ffff000000, 0xff0000ffff0000ff, 0xff0000ffff00ff00, 0x000000ffff00ffff,
|
||||
0xff0000ffffff0000, 0x000000ffffff00ff, 0x000000ffffffff00, 0xff0000ffffffffff,
|
||||
0xff00ff0000000000, 0x0000ff00000000ff, 0x0000ff000000ff00, 0xff00ff000000ffff,
|
||||
0x0000ff0000ff0000, 0xff00ff0000ff00ff, 0xff00ff0000ffff00, 0x0000ff0000ffffff,
|
||||
0x0000ff00ff000000, 0xff00ff00ff0000ff, 0xff00ff00ff00ff00, 0x0000ff00ff00ffff,
|
||||
0xff00ff00ffff0000, 0x0000ff00ffff00ff, 0x0000ff00ffffff00, 0xff00ff00ffffffff,
|
||||
0x0000ffff00000000, 0xff00ffff000000ff, 0xff00ffff0000ff00, 0x0000ffff0000ffff,
|
||||
0xff00ffff00ff0000, 0x0000ffff00ff00ff, 0x0000ffff00ffff00, 0xff00ffff00ffffff,
|
||||
0xff00ffffff000000, 0x0000ffffff0000ff, 0x0000ffffff00ff00, 0xff00ffffff00ffff,
|
||||
0x0000ffffffff0000, 0xff00ffffffff00ff, 0xff00ffffffffff00, 0x0000ffffffffffff,
|
||||
0xffff000000000000, 0x00ff0000000000ff, 0x00ff00000000ff00, 0xffff00000000ffff,
|
||||
0x00ff000000ff0000, 0xffff000000ff00ff, 0xffff000000ffff00, 0x00ff000000ffffff,
|
||||
0x00ff0000ff000000, 0xffff0000ff0000ff, 0xffff0000ff00ff00, 0x00ff0000ff00ffff,
|
||||
0xffff0000ffff0000, 0x00ff0000ffff00ff, 0x00ff0000ffffff00, 0xffff0000ffffffff,
|
||||
0x00ff00ff00000000, 0xffff00ff000000ff, 0xffff00ff0000ff00, 0x00ff00ff0000ffff,
|
||||
0xffff00ff00ff0000, 0x00ff00ff00ff00ff, 0x00ff00ff00ffff00, 0xffff00ff00ffffff,
|
||||
0xffff00ffff000000, 0x00ff00ffff0000ff, 0x00ff00ffff00ff00, 0xffff00ffff00ffff,
|
||||
0x00ff00ffffff0000, 0xffff00ffffff00ff, 0xffff00ffffffff00, 0x00ff00ffffffffff,
|
||||
0x00ffff0000000000, 0xffffff00000000ff, 0xffffff000000ff00, 0x00ffff000000ffff,
|
||||
0xffffff0000ff0000, 0x00ffff0000ff00ff, 0x00ffff0000ffff00, 0xffffff0000ffffff,
|
||||
0xffffff00ff000000, 0x00ffff00ff0000ff, 0x00ffff00ff00ff00, 0xffffff00ff00ffff,
|
||||
0x00ffff00ffff0000, 0xffffff00ffff00ff, 0xffffff00ffffff00, 0x00ffff00ffffffff,
|
||||
0xffffffff00000000, 0x00ffffff000000ff, 0x00ffffff0000ff00, 0xffffffff0000ffff,
|
||||
0x00ffffff00ff0000, 0xffffffff00ff00ff, 0xffffffff00ffff00, 0x00ffffff00ffffff,
|
||||
0x00ffffffff000000, 0xffffffffff0000ff, 0xffffffffff00ff00, 0x00ffffffff00ffff,
|
||||
0xffffffffffff0000, 0x00ffffffffff00ff, 0x00ffffffffffff00, 0xffffffffffffffff,
|
||||
};
|
||||
|
||||
static const __device__ uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
|
||||
static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
|
||||
|
||||
|
||||
typedef half dfloat; // dequantize float
|
||||
typedef half2 dfloat2;
|
||||
typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
|
||||
typedef void (*to_fp16_cuda_t)(const void * __restrict__ x, dfloat * __restrict__ y, int k, cudaStream_t stream);
|
||||
typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs);
|
||||
typedef void (*allocate_tiles_cuda_t)(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc);
|
||||
typedef void (*load_tiles_cuda_t)(
|
||||
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
||||
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row);
|
||||
typedef float (*vec_dot_q_mul_mat_cuda_t)(
|
||||
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
||||
const int * __restrict__ y_qs, const half2 * __restrict__ y_ms, const int & i, const int & j, const int & k);
|
||||
|
||||
// Utility function
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
|
||||
#ifndef __has_builtin
|
||||
#define __has_builtin(x) 0
|
||||
#endif
|
||||
|
||||
typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
|
||||
static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
|
||||
const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
|
||||
const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
|
||||
#if __has_builtin(__builtin_elementwise_sub_sat)
|
||||
const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);
|
||||
return reinterpret_cast<const int &>(c);
|
||||
#else
|
||||
int8x4_t c;
|
||||
int16_t tmp;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
tmp = va[i] - vb[i];
|
||||
if(tmp > std::numeric_limits<int8_t>::max()) tmp = std::numeric_limits<int8_t>::max();
|
||||
if(tmp < std::numeric_limits<int8_t>::min()) tmp = std::numeric_limits<int8_t>::min();
|
||||
c[i] = tmp;
|
||||
}
|
||||
return reinterpret_cast<int &>(c);
|
||||
#endif // __has_builtin(__builtin_elementwise_sub_sat)
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
|
||||
#if __has_builtin(__builtin_amdgcn_sdot4)
|
||||
c = __builtin_amdgcn_sdot4(a, b, c, false);
|
||||
#else
|
||||
const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
|
||||
const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
|
||||
c += va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2] + va[3] * vb[3];
|
||||
#endif
|
||||
return c;
|
||||
}
|
||||
#endif // defined(USE_ROCM)
|
||||
242
csrc/quantization/gguf/gguf_kernel.cu
Normal file
242
csrc/quantization/gguf/gguf_kernel.cu
Normal file
@ -0,0 +1,242 @@
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <torch/all.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "ggml-common.h"
|
||||
#include "vecdotq.cuh"
|
||||
#include "dequantize.cuh"
|
||||
#include "mmvq.cuh"
|
||||
#include "mmq.cuh"
|
||||
|
||||
// Q8 gemv
|
||||
static __global__ void quantize_q8_1(const half* __restrict__ x,
|
||||
void* __restrict__ vy, const int kx,
|
||||
const int kx_padded) {
|
||||
const int ix = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (ix >= kx_padded) {
|
||||
return;
|
||||
}
|
||||
const int iy = blockDim.y * blockIdx.y + threadIdx.y;
|
||||
const int i_padded = iy * kx_padded + ix;
|
||||
|
||||
block_q8_1* y = (block_q8_1*)vy;
|
||||
|
||||
const int ib = i_padded / QK8_1; // block index
|
||||
const int iqs = i_padded % QK8_1; // quant index
|
||||
|
||||
const float xi = ix < kx ? __half2float(x[iy * kx + ix]) : 0.0f;
|
||||
float amax = fabsf(xi);
|
||||
float sum = xi;
|
||||
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, mask, 32));
|
||||
sum += __shfl_xor_sync(0xffffffff, sum, mask, 32);
|
||||
}
|
||||
|
||||
const float d = amax / 127;
|
||||
const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
|
||||
|
||||
y[ib].qs[iqs] = q;
|
||||
|
||||
if (iqs > 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
y[ib].ds.x = __float2half(d);
|
||||
y[ib].ds.y = __float2half(sum);
|
||||
}
|
||||
|
||||
static void quantize_row_q8_1_cuda(const half* x, void* vy, const int kx,
|
||||
const int ky, cudaStream_t stream) {
|
||||
const int64_t kx_padded = (kx + 512 - 1) / 512 * 512;
|
||||
const int block_num_x =
|
||||
(kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
|
||||
const dim3 num_blocks(block_num_x, ky, 1);
|
||||
const dim3 block_size(CUDA_DEQUANTIZE_BLOCK_SIZE, 1, 1);
|
||||
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded);
|
||||
}
|
||||
|
||||
torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight
|
||||
int8_t type, int64_t m, int64_t n) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(W));
|
||||
auto options =
|
||||
torch::TensorOptions().dtype(torch::kFloat16).device(W.device());
|
||||
at::Tensor DW = torch::empty({m, n}, options);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(type);
|
||||
to_fp16_cuda((void*)W.data_ptr(), (half*)DW.data_ptr(), m * n, stream);
|
||||
return DW;
|
||||
}
|
||||
|
||||
torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight
|
||||
torch::Tensor X, // input
|
||||
int8_t type, int64_t row) {
|
||||
int col = X.sizes()[1];
|
||||
const int padded = (col + 512 - 1) / 512 * 512;
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
|
||||
auto options =
|
||||
torch::TensorOptions().dtype(torch::kFloat16).device(W.device());
|
||||
at::Tensor Y = torch::empty({1, row}, options);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
options = torch::TensorOptions().dtype(torch::kInt32).device(W.device());
|
||||
at::Tensor quant_X = torch::empty({1, padded / 32 * 9}, options);
|
||||
quantize_row_q8_1_cuda((half*)X.data_ptr(), (void*)quant_X.data_ptr(), col, 1,
|
||||
stream);
|
||||
switch (type) {
|
||||
case 2:
|
||||
mul_mat_vec_q4_0_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 3:
|
||||
mul_mat_vec_q4_1_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 6:
|
||||
mul_mat_vec_q5_0_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 7:
|
||||
mul_mat_vec_q5_1_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 8:
|
||||
mul_mat_vec_q8_0_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 10:
|
||||
mul_mat_vec_q2_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 11:
|
||||
mul_mat_vec_q3_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 12:
|
||||
mul_mat_vec_q4_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 13:
|
||||
mul_mat_vec_q5_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 14:
|
||||
mul_mat_vec_q6_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 16:
|
||||
mul_mat_vec_iq2_xxs_q8_1_cuda((void*)W.data_ptr(),
|
||||
(void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 17:
|
||||
mul_mat_vec_iq2_xs_q8_1_cuda((void*)W.data_ptr(),
|
||||
(void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 18:
|
||||
mul_mat_vec_iq3_xxs_q8_1_cuda((void*)W.data_ptr(),
|
||||
(void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 19:
|
||||
mul_mat_vec_iq1_s_q8_1_cuda((void*)W.data_ptr(),
|
||||
(void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 20:
|
||||
mul_mat_vec_iq4_nl_q8_1_cuda((void*)W.data_ptr(),
|
||||
(void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 21:
|
||||
mul_mat_vec_iq3_s_q8_1_cuda((void*)W.data_ptr(),
|
||||
(void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 22:
|
||||
mul_mat_vec_iq2_s_q8_1_cuda((void*)W.data_ptr(),
|
||||
(void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 23:
|
||||
mul_mat_vec_iq4_xs_q8_1_cuda((void*)W.data_ptr(),
|
||||
(void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
}
|
||||
return Y;
|
||||
}
|
||||
|
||||
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight
|
||||
torch::Tensor X, // input
|
||||
int8_t type, int64_t row) {
|
||||
int col = X.sizes()[1];
|
||||
int padded = (col + 512 - 1) / 512 * 512;
|
||||
int batch = X.sizes()[0];
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
|
||||
auto options =
|
||||
torch::TensorOptions().dtype(torch::kFloat16).device(W.device());
|
||||
at::Tensor Y = torch::empty({batch, row}, options);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
options = torch::TensorOptions().dtype(torch::kInt32).device(W.device());
|
||||
at::Tensor quant_X = torch::empty({batch, padded / 32 * 9}, options);
|
||||
quantize_row_q8_1_cuda((half*)X.data_ptr(), (void*)quant_X.data_ptr(), col,
|
||||
batch, stream);
|
||||
|
||||
switch (type) {
|
||||
case 2:
|
||||
ggml_mul_mat_q4_0_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
|
||||
col, row, batch, padded, row, stream);
|
||||
break;
|
||||
case 3:
|
||||
ggml_mul_mat_q4_1_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
|
||||
col, row, batch, padded, row, stream);
|
||||
break;
|
||||
case 6:
|
||||
ggml_mul_mat_q5_0_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
|
||||
col, row, batch, padded, row, stream);
|
||||
break;
|
||||
case 7:
|
||||
ggml_mul_mat_q5_1_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
|
||||
col, row, batch, padded, row, stream);
|
||||
break;
|
||||
case 8:
|
||||
ggml_mul_mat_q8_0_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
|
||||
col, row, batch, padded, row, stream);
|
||||
break;
|
||||
case 10:
|
||||
ggml_mul_mat_q2_K_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
|
||||
col, row, batch, padded, row, stream);
|
||||
break;
|
||||
case 11:
|
||||
ggml_mul_mat_q3_K_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
|
||||
col, row, batch, padded, row, stream);
|
||||
break;
|
||||
case 12:
|
||||
ggml_mul_mat_q4_K_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
|
||||
col, row, batch, padded, row, stream);
|
||||
break;
|
||||
case 13:
|
||||
ggml_mul_mat_q5_K_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
|
||||
col, row, batch, padded, row, stream);
|
||||
break;
|
||||
case 14:
|
||||
ggml_mul_mat_q6_K_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
|
||||
col, row, batch, padded, row, stream);
|
||||
break;
|
||||
}
|
||||
return Y;
|
||||
}
|
||||
600
csrc/quantization/gguf/mmq.cuh
Normal file
600
csrc/quantization/gguf/mmq.cuh
Normal file
@ -0,0 +1,600 @@
|
||||
// copied from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmq.cu
|
||||
template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
|
||||
allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
|
||||
static __device__ __forceinline__ void mul_mat_q(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
|
||||
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
const block_q_t * x = (const block_q_t *) vx;
|
||||
const block_q8_1 * y = (const block_q8_1 *) vy;
|
||||
|
||||
const int blocks_per_row_x = ncols_x / qk;
|
||||
const int blocks_per_col_y = nrows_y / QK8_1;
|
||||
const int blocks_per_warp = WARP_SIZE / qi;
|
||||
|
||||
const int & ncols_dst = ncols_y;
|
||||
|
||||
const int row_dst_0 = blockIdx.x*mmq_y;
|
||||
const int & row_x_0 = row_dst_0;
|
||||
|
||||
const int col_dst_0 = blockIdx.y*mmq_x;
|
||||
const int & col_y_0 = col_dst_0;
|
||||
|
||||
int * tile_x_ql = nullptr;
|
||||
half2 * tile_x_dm = nullptr;
|
||||
int * tile_x_qh = nullptr;
|
||||
int * tile_x_sc = nullptr;
|
||||
|
||||
allocate_tiles(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc);
|
||||
|
||||
__shared__ int tile_y_qs[mmq_x * WARP_SIZE];
|
||||
__shared__ half2 tile_y_ds[mmq_x * WARP_SIZE/QI8_1];
|
||||
|
||||
float sum[mmq_y/WARP_SIZE][mmq_x/nwarps] = {{0.0f}};
|
||||
|
||||
for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) {
|
||||
|
||||
load_tiles(x + row_x_0*blocks_per_row_x + ib0, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc,
|
||||
threadIdx.y, nrows_x-row_x_0-1, threadIdx.x, blocks_per_row_x);
|
||||
|
||||
#pragma unroll
|
||||
for (int ir = 0; ir < qr; ++ir) {
|
||||
const int kqs = ir*WARP_SIZE + threadIdx.x;
|
||||
const int kbxd = kqs / QI8_1;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < mmq_x; i += nwarps) {
|
||||
const int col_y_eff = min(col_y_0 + threadIdx.y + i, ncols_y-1); // to prevent out-of-bounds memory accesses
|
||||
const block_q8_1 * by0 = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kbxd];
|
||||
const int index_y = (threadIdx.y + i) * WARP_SIZE + kqs % WARP_SIZE;
|
||||
tile_y_qs[index_y] = get_int_from_int8_aligned(by0->qs, threadIdx.x % QI8_1);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) {
|
||||
const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE/QI8_1)) % mmq_x;
|
||||
const int kby = threadIdx.x % (WARP_SIZE/QI8_1);
|
||||
const int col_y_eff = min(col_y_0 + ids, ncols_y-1);
|
||||
|
||||
// if the sum is not needed it's faster to transform the scale to f32 ahead of time
|
||||
const half2 * dsi_src = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + ir*(WARP_SIZE/QI8_1) + kby].ds;
|
||||
half2 * dsi_dst = &tile_y_ds[ids * (WARP_SIZE/QI8_1) + kby];
|
||||
if (need_sum) {
|
||||
*dsi_dst = *dsi_src;
|
||||
} else {
|
||||
float * dfi_dst = (float *) dsi_dst;
|
||||
*dfi_dst = __low2float(*dsi_src);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// #pragma unroll // unrolling this loop causes too much register pressure
|
||||
for (int k = ir*WARP_SIZE/qr; k < (ir+1)*WARP_SIZE/qr; k += vdr) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < mmq_x; j += nwarps) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < mmq_y; i += WARP_SIZE) {
|
||||
sum[i/WARP_SIZE][j/nwarps] += vec_dot(
|
||||
tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds,
|
||||
threadIdx.x + i, threadIdx.y + j, k);
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < mmq_x; j += nwarps) {
|
||||
const int col_dst = col_dst_0 + j + threadIdx.y;
|
||||
if (col_dst >= ncols_dst) {
|
||||
return;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < mmq_y; i += WARP_SIZE) {
|
||||
const int row_dst = row_dst_0 + threadIdx.x + i;
|
||||
if (row_dst >= nrows_dst) {
|
||||
continue;
|
||||
}
|
||||
dst[col_dst*nrows_dst + row_dst] = __float2half(sum[i/WARP_SIZE][j/nwarps]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MMQ_X_Q4_0 64
|
||||
#define MMQ_Y_Q4_0 128
|
||||
#define NWARPS_Q4_0 8
|
||||
#else
|
||||
#define MMQ_X_Q4_0 4
|
||||
#define MMQ_Y_Q4_0 32
|
||||
#define NWARPS_Q4_0 4
|
||||
#endif
|
||||
|
||||
template <bool need_check> static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE*NWARPS_Q4_0, 2)
|
||||
#endif
|
||||
mul_mat_q4_0(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
|
||||
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
||||
const int mmq_x = MMQ_X_Q4_0;
|
||||
const int mmq_y = MMQ_Y_Q4_0;
|
||||
const int nwarps = NWARPS_Q4_0;
|
||||
|
||||
mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps, allocate_tiles_q4_0<mmq_y>,
|
||||
load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
static void ggml_mul_mat_q4_0_q8_1_cuda(
|
||||
const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||
|
||||
int mmq_x = MMQ_X_Q4_0;
|
||||
int mmq_y = MMQ_Y_Q4_0;
|
||||
int nwarps = NWARPS_Q4_0;
|
||||
|
||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, nwarps, 1);
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
const bool need_check = false;
|
||||
mul_mat_q4_0<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
} else {
|
||||
const bool need_check = true;
|
||||
mul_mat_q4_0<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MMQ_X_Q4_1 64
|
||||
#define MMQ_Y_Q4_1 128
|
||||
#define NWARPS_Q4_1 8
|
||||
#else
|
||||
#define MMQ_X_Q4_1 4
|
||||
#define MMQ_Y_Q4_1 32
|
||||
#define NWARPS_Q4_1 4
|
||||
#endif
|
||||
|
||||
template <bool need_check> static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE*NWARPS_Q4_1, 2)
|
||||
#endif
|
||||
mul_mat_q4_1(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
|
||||
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
||||
const int mmq_x = MMQ_X_Q4_1;
|
||||
const int mmq_y = MMQ_Y_Q4_1;
|
||||
const int nwarps = NWARPS_Q4_1;
|
||||
|
||||
mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps, allocate_tiles_q4_1<mmq_y>,
|
||||
load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
static void ggml_mul_mat_q4_1_q8_1_cuda(
|
||||
const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||
|
||||
int mmq_x = MMQ_X_Q4_1;
|
||||
int mmq_y = MMQ_Y_Q4_1;
|
||||
int nwarps = NWARPS_Q4_1;
|
||||
|
||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, nwarps, 1);
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
const bool need_check = false;
|
||||
mul_mat_q4_1<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
} else {
|
||||
const bool need_check = true;
|
||||
mul_mat_q4_1<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MMQ_X_Q5_0 64
|
||||
#define MMQ_Y_Q5_0 128
|
||||
#define NWARPS_Q5_0 8
|
||||
#else
|
||||
#define MMQ_X_Q5_0 4
|
||||
#define MMQ_Y_Q5_0 32
|
||||
#define NWARPS_Q5_0 4
|
||||
#endif
|
||||
|
||||
template <bool need_check> static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE*NWARPS_Q5_0, 2)
|
||||
#endif
|
||||
mul_mat_q5_0(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
|
||||
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
||||
const int mmq_x = MMQ_X_Q5_0;
|
||||
const int mmq_y = MMQ_Y_Q5_0;
|
||||
const int nwarps = NWARPS_Q5_0;
|
||||
|
||||
mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps, allocate_tiles_q5_0<mmq_y>,
|
||||
load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
static void ggml_mul_mat_q5_0_q8_1_cuda(
|
||||
const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||
|
||||
const int mmq_x = MMQ_X_Q5_0;
|
||||
const int mmq_y = MMQ_Y_Q5_0;
|
||||
const int nwarps = NWARPS_Q5_0;
|
||||
|
||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, nwarps, 1);
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
const bool need_check = false;
|
||||
mul_mat_q5_0<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
} else {
|
||||
const bool need_check = true;
|
||||
mul_mat_q5_0<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MMQ_X_Q5_1 64
|
||||
#define MMQ_Y_Q5_1 128
|
||||
#define NWARPS_Q5_1 8
|
||||
#else
|
||||
#define MMQ_X_Q5_1 4
|
||||
#define MMQ_Y_Q5_1 32
|
||||
#define NWARPS_Q5_1 4
|
||||
#endif
|
||||
|
||||
template <bool need_check> static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE*NWARPS_Q5_1, 2)
|
||||
#endif
|
||||
mul_mat_q5_1(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
|
||||
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
||||
const int mmq_x = MMQ_X_Q5_1;
|
||||
const int mmq_y = MMQ_Y_Q5_1;
|
||||
const int nwarps = NWARPS_Q5_1;
|
||||
|
||||
mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps, allocate_tiles_q5_1<mmq_y>,
|
||||
load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
static void ggml_mul_mat_q5_1_q8_1_cuda(
|
||||
const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const int mmq_x = MMQ_X_Q5_1;
|
||||
const int mmq_y = MMQ_Y_Q5_1;
|
||||
const int nwarps = NWARPS_Q5_1;
|
||||
|
||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, nwarps, 1);
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
const bool need_check = false;
|
||||
mul_mat_q5_1<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
} else {
|
||||
const bool need_check = true;
|
||||
mul_mat_q5_1<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MMQ_X_Q8_0 64
|
||||
#define MMQ_Y_Q8_0 128
|
||||
#define NWARPS_Q8_0 8
|
||||
#else
|
||||
#define MMQ_X_Q8_0 4
|
||||
#define MMQ_Y_Q8_0 32
|
||||
#define NWARPS_Q8_0 4
|
||||
#endif
|
||||
|
||||
template <bool need_check> static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE*NWARPS_Q8_0, 2)
|
||||
#endif
|
||||
mul_mat_q8_0(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
|
||||
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
||||
const int mmq_x = MMQ_X_Q8_0;
|
||||
const int mmq_y = MMQ_Y_Q8_0;
|
||||
const int nwarps = NWARPS_Q8_0;
|
||||
|
||||
mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps, allocate_tiles_q8_0<mmq_y>,
|
||||
load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
static void ggml_mul_mat_q8_0_q8_1_cuda(
|
||||
const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const int mmq_x = MMQ_X_Q8_0;
|
||||
const int mmq_y = MMQ_Y_Q8_0;
|
||||
const int nwarps = NWARPS_Q8_0;
|
||||
|
||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, nwarps, 1);
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
const bool need_check = false;
|
||||
mul_mat_q8_0<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
} else {
|
||||
const bool need_check = true;
|
||||
mul_mat_q8_0<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MMQ_X_Q2_K 64
|
||||
#define MMQ_Y_Q2_K 128
|
||||
#define NWARPS_Q2_K 8
|
||||
#else
|
||||
#define MMQ_X_Q2_K 4
|
||||
#define MMQ_Y_Q2_K 32
|
||||
#define NWARPS_Q2_K 4
|
||||
#endif
|
||||
|
||||
template <bool need_check> static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE*NWARPS_Q2_K, 2)
|
||||
#endif
|
||||
mul_mat_q2_K(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
|
||||
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
||||
const int mmq_x = MMQ_X_Q2_K;
|
||||
const int mmq_y = MMQ_Y_Q2_K;
|
||||
const int nwarps = NWARPS_Q2_K;
|
||||
|
||||
mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps, allocate_tiles_q2_K<mmq_y>,
|
||||
load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
static void ggml_mul_mat_q2_K_q8_1_cuda(
|
||||
const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const int mmq_x = MMQ_X_Q2_K;
|
||||
const int mmq_y = MMQ_Y_Q2_K;
|
||||
const int nwarps = NWARPS_Q2_K;
|
||||
|
||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, nwarps, 1);
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
const bool need_check = false;
|
||||
mul_mat_q2_K<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
} else {
|
||||
const bool need_check = true;
|
||||
mul_mat_q2_K<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MMQ_X_Q3_K 64
|
||||
#define MMQ_Y_Q3_K 128
|
||||
#define NWARPS_Q3_K 8
|
||||
#else
|
||||
#define MMQ_X_Q3_K 4
|
||||
#define MMQ_Y_Q3_K 32
|
||||
#define NWARPS_Q3_K 4
|
||||
#endif
|
||||
|
||||
template <bool need_check> static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE*NWARPS_Q3_K, 2)
|
||||
#endif
|
||||
mul_mat_q3_K(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
|
||||
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
const int mmq_x = MMQ_X_Q3_K;
|
||||
const int mmq_y = MMQ_Y_Q3_K;
|
||||
const int nwarps = NWARPS_Q3_K;
|
||||
|
||||
mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps, allocate_tiles_q3_K<mmq_y>,
|
||||
load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
static void ggml_mul_mat_q3_K_q8_1_cuda(
|
||||
const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||
|
||||
const int mmq_x = MMQ_X_Q3_K;
|
||||
const int mmq_y = MMQ_Y_Q3_K;
|
||||
const int nwarps = NWARPS_Q3_K;
|
||||
|
||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, nwarps, 1);
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
const bool need_check = false;
|
||||
mul_mat_q3_K<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
} else {
|
||||
const bool need_check = true;
|
||||
mul_mat_q3_K<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MMQ_X_Q4_K 64
|
||||
#define MMQ_Y_Q4_K 128
|
||||
#define NWARPS_Q4_K 8
|
||||
#else
|
||||
#define MMQ_X_Q4_K 4
|
||||
#define MMQ_Y_Q4_K 32
|
||||
#define NWARPS_Q4_K 4
|
||||
#endif
|
||||
|
||||
template <bool need_check> static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE*NWARPS_Q4_K, 2)
|
||||
#endif
|
||||
mul_mat_q4_K(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
|
||||
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
||||
const int mmq_x = MMQ_X_Q4_K;
|
||||
const int mmq_y = MMQ_Y_Q4_K;
|
||||
const int nwarps = NWARPS_Q4_K;
|
||||
|
||||
mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps, allocate_tiles_q4_K<mmq_y>,
|
||||
load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
static void ggml_mul_mat_q4_K_q8_1_cuda(
|
||||
const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const int mmq_x = MMQ_X_Q4_K;
|
||||
const int mmq_y = MMQ_Y_Q4_K;
|
||||
const int nwarps = NWARPS_Q4_K;
|
||||
|
||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, nwarps, 1);
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
const bool need_check = false;
|
||||
mul_mat_q4_K<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
} else {
|
||||
const bool need_check = true;
|
||||
mul_mat_q4_K<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MMQ_X_Q5_K 64
|
||||
#define MMQ_Y_Q5_K 128
|
||||
#define NWARPS_Q5_K 8
|
||||
#else
|
||||
#define MMQ_X_Q5_K 4
|
||||
#define MMQ_Y_Q5_K 32
|
||||
#define NWARPS_Q5_K 4
|
||||
#endif
|
||||
|
||||
template <bool need_check> static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE*NWARPS_Q5_K, 2)
|
||||
#endif
|
||||
mul_mat_q5_K(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
|
||||
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
||||
const int mmq_x = MMQ_X_Q5_K;
|
||||
const int mmq_y = MMQ_Y_Q5_K;
|
||||
const int nwarps = NWARPS_Q5_K;
|
||||
|
||||
mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps, allocate_tiles_q5_K<mmq_y>,
|
||||
load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
static void ggml_mul_mat_q5_K_q8_1_cuda(
|
||||
const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||
|
||||
const int mmq_x = MMQ_X_Q5_K;
|
||||
const int mmq_y = MMQ_Y_Q5_K;
|
||||
const int nwarps = NWARPS_Q5_K;
|
||||
|
||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, nwarps, 1);
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
const bool need_check = false;
|
||||
mul_mat_q5_K<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
} else {
|
||||
const bool need_check = true;
|
||||
mul_mat_q5_K<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MMQ_X_Q6_K 64
|
||||
#define MMQ_Y_Q6_K 128
|
||||
#define NWARPS_Q6_K 8
|
||||
#else
|
||||
#define MMQ_X_Q6_K 4
|
||||
#define MMQ_Y_Q6_K 32
|
||||
#define NWARPS_Q6_K 4
|
||||
#endif
|
||||
|
||||
template <bool need_check> static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE*NWARPS_Q6_K, 2)
|
||||
#endif
|
||||
mul_mat_q6_K(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
|
||||
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
||||
const int mmq_x = MMQ_X_Q6_K;
|
||||
const int mmq_y = MMQ_Y_Q6_K;
|
||||
const int nwarps = NWARPS_Q6_K;
|
||||
|
||||
mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps, allocate_tiles_q6_K<mmq_y>,
|
||||
load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
static void ggml_mul_mat_q6_K_q8_1_cuda(
|
||||
const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const int mmq_x = MMQ_X_Q6_K;
|
||||
const int mmq_y = MMQ_Y_Q6_K;
|
||||
const int nwarps = NWARPS_Q6_K;
|
||||
|
||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, nwarps, 1);
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
const bool need_check = false;
|
||||
mul_mat_q6_K<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
} else {
|
||||
const bool need_check = true;
|
||||
mul_mat_q6_K<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
}
|
||||
182
csrc/quantization/gguf/mmvq.cuh
Normal file
182
csrc/quantization/gguf/mmvq.cuh
Normal file
@ -0,0 +1,182 @@
|
||||
// copied and adapted from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmvq.cu
|
||||
template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
|
||||
static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, const int ncols, const int nrows) {
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
|
||||
if (row >= nrows) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int blocks_per_row = ncols / qk;
|
||||
const int blocks_per_warp = vdr * WARP_SIZE / qi;
|
||||
|
||||
// partial sum for each thread
|
||||
float tmp = 0.0f;
|
||||
|
||||
const block_q_t * x = (const block_q_t *) vx;
|
||||
const block_q8_1 * y = (const block_q8_1 *) vy;
|
||||
|
||||
for (int i = threadIdx.x / (qi/vdr); i < blocks_per_row; i += blocks_per_warp) {
|
||||
const int ibx = row*blocks_per_row + i; // x block index
|
||||
|
||||
const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
|
||||
|
||||
const int iqs = vdr * (threadIdx.x % (qi/vdr)); // x block quant index when casting the quants to int
|
||||
|
||||
tmp += vec_dot_q_cuda(&x[ibx], &y[iby], iqs);
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
dst[row] = __float2half(tmp);
|
||||
}
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq2_xxs_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq2_xs_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq2_s_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI2_S, block_iq2_s, 1, vec_dot_iq2_s_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq3_xxs_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq1_s_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI1_S, block_iq1_s, 1, vec_dot_iq1_s_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq4_nl_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK4_NL, QI4_NL, block_iq4_nl, VDR_Q4_0_Q8_1_MMVQ, vec_dot_iq4_nl_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq4_xs_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI4_XS, block_iq4_xs, 1, vec_dot_iq4_xs_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq3_s_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI3_XS, block_iq3_s, 1, vec_dot_iq3_s_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
1745
csrc/quantization/gguf/vecdotq.cuh
Normal file
1745
csrc/quantization/gguf/vecdotq.cuh
Normal file
File diff suppressed because it is too large
Load Diff
@ -145,6 +145,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def("awq_marlin_repack", &awq_marlin_repack);
|
||||
ops.impl("awq_marlin_repack", torch::kCUDA, &awq_marlin_repack);
|
||||
|
||||
// Dequantization for GGML.
|
||||
ops.def("ggml_dequantize", &ggml_dequantize);
|
||||
ops.impl("ggml_dequantize", torch::kCUDA, &ggml_dequantize);
|
||||
|
||||
// mmvq kernel for GGML.
|
||||
ops.def("ggml_mul_mat_vec_a8", &ggml_mul_mat_vec_a8);
|
||||
ops.impl("ggml_mul_mat_vec_a8", torch::kCUDA, &ggml_mul_mat_vec_a8);
|
||||
|
||||
// mmq kernel for GGML.
|
||||
ops.def("ggml_mul_mat_a8", &ggml_mul_mat_a8);
|
||||
ops.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8);
|
||||
|
||||
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
|
||||
ops.def("fp8_marlin_gemm", &fp8_marlin_gemm);
|
||||
ops.impl("fp8_marlin_gemm", torch::kCUDA, &fp8_marlin_gemm);
|
||||
|
||||
38
examples/gguf_inference.py
Normal file
38
examples/gguf_inference.py
Normal file
@ -0,0 +1,38 @@
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
|
||||
def run_gguf_inference(model_path):
|
||||
PROMPT_TEMPLATE = "<|system|>\n{system_message}</s>\n<|user|>\n{prompt}</s>\n<|assistant|>\n" # noqa: E501
|
||||
system_message = "You are a friendly chatbot who always responds in the style of a pirate." # noqa: E501
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"How many helicopters can a human eat in one sitting?",
|
||||
"What's the future of AI?",
|
||||
]
|
||||
prompts = [
|
||||
PROMPT_TEMPLATE.format(system_message=system_message, prompt=prompt)
|
||||
for prompt in prompts
|
||||
]
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0, max_tokens=128)
|
||||
|
||||
# Create an LLM.
|
||||
llm = LLM(model=model_path,
|
||||
tokenizer="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
gpu_memory_utilization=0.95)
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
repo_id = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF"
|
||||
filename = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf"
|
||||
model = hf_hub_download(repo_id, filename=filename)
|
||||
run_gguf_inference(model)
|
||||
@ -242,6 +242,11 @@ echo 'vLLM isort: Done'
|
||||
# NOTE: Keep up to date with .github/workflows/clang-format.yml
|
||||
CLANG_FORMAT_EXCLUDES=(
|
||||
'csrc/moe/topk_softmax_kernels.cu'
|
||||
'csrc/quantization/gguf/ggml-common.h'
|
||||
'csrc/quantization/gguf/dequantize.cuh'
|
||||
'csrc/quantization/gguf/vecdotq.cuh'
|
||||
'csrc/quantization/gguf/mmq.cuh'
|
||||
'csrc/quantization/gguf/mmvq.cuh'
|
||||
)
|
||||
|
||||
# Format specified files with clang-format
|
||||
|
||||
@ -22,3 +22,4 @@ outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0
|
||||
typing_extensions
|
||||
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
|
||||
pyzmq
|
||||
gguf == 0.9.1
|
||||
|
||||
76
tests/models/test_gguf.py
Normal file
76
tests/models/test_gguf.py
Normal file
@ -0,0 +1,76 @@
|
||||
"""
|
||||
Tests gguf models against unquantized models generations
|
||||
Note: To pass the test, quantization higher than Q4 should be used
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
|
||||
from .utils import check_logprobs_close
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||
|
||||
MAX_MODEL_LEN = 1024
|
||||
|
||||
# FIXME: Move this to confest
|
||||
MODELS = [
|
||||
("TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
hf_hub_download("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
|
||||
filename="tinyllama-1.1b-chat-v1.0.Q4_0.gguf")),
|
||||
("TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
hf_hub_download("duyntnet/TinyLlama-1.1B-Chat-v1.0-imatrix-GGUF",
|
||||
filename="TinyLlama-1.1B-Chat-v1.0-IQ4_XS.gguf")),
|
||||
("Qwen/Qwen2-1.5B-Instruct",
|
||||
hf_hub_download("Qwen/Qwen2-1.5B-Instruct-GGUF",
|
||||
filename="qwen2-1_5b-instruct-q4_k_m.gguf")),
|
||||
("Qwen/Qwen2-1.5B-Instruct",
|
||||
hf_hub_download("legraphista/Qwen2-1.5B-Instruct-IMat-GGUF",
|
||||
filename="Qwen2-1.5B-Instruct.IQ4_XS.gguf")),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("gguf"),
|
||||
reason="gguf is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [32])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_models(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
original_model, gguf_model = model
|
||||
|
||||
# Run unquantized model.
|
||||
with vllm_runner(model_name=original_model,
|
||||
dtype=dtype,
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
enforce_eager=True,
|
||||
tensor_parallel_size=1) as original_model:
|
||||
|
||||
original_outputs = original_model.generate_greedy_logprobs(
|
||||
example_prompts[:-1], max_tokens, num_logprobs)
|
||||
|
||||
# Run gguf model.
|
||||
with vllm_runner(model_name=gguf_model,
|
||||
dtype=dtype,
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
enforce_eager=True,
|
||||
tensor_parallel_size=1) as gguf_model:
|
||||
gguf_outputs = gguf_model.generate_greedy_logprobs(
|
||||
example_prompts[:-1], max_tokens, num_logprobs)
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=original_outputs,
|
||||
outputs_1_lst=gguf_outputs,
|
||||
name_0="original",
|
||||
name_1="gguf",
|
||||
)
|
||||
@ -7,11 +7,12 @@ from typing import Tuple
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
UnquantizedEmbeddingMethod)
|
||||
|
||||
PROMPT = "On the surface of Mars, we found"
|
||||
|
||||
@ -37,7 +38,8 @@ def test_lm_head(
|
||||
lm_head_layer.linear_method,
|
||||
(GPTQLinearMethod, GPTQMarlinLinearMethod, MarlinLinearMethod))
|
||||
else:
|
||||
assert isinstance(lm_head_layer.linear_method, UnquantizedLinearMethod)
|
||||
assert isinstance(lm_head_layer.linear_method,
|
||||
UnquantizedEmbeddingMethod)
|
||||
|
||||
print(
|
||||
vllm_model.generate_greedy(prompts=["Hello my name is"],
|
||||
|
||||
@ -404,6 +404,38 @@ def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||
workspace, size_m, size_n, size_k)
|
||||
|
||||
|
||||
# gguf
|
||||
def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int, n: int):
|
||||
return torch.ops._C.ggml_dequantize(W, quant_type, m, n)
|
||||
|
||||
|
||||
def ggml_mul_mat_vec(
|
||||
W: torch.Tensor,
|
||||
X: torch.Tensor,
|
||||
quant_type: int,
|
||||
row: int,
|
||||
):
|
||||
return torch.ops._C.ggml_mul_mat_vec(W, X, quant_type, row)
|
||||
|
||||
|
||||
def ggml_mul_mat_vec_a8(
|
||||
W: torch.Tensor,
|
||||
X: torch.Tensor,
|
||||
quant_type: int,
|
||||
row: int,
|
||||
):
|
||||
return torch.ops._C.ggml_mul_mat_vec_a8(W, X, quant_type, row)
|
||||
|
||||
|
||||
def ggml_mul_mat_a8(
|
||||
W: torch.Tensor,
|
||||
X: torch.Tensor,
|
||||
quant_type: int,
|
||||
row: int,
|
||||
):
|
||||
return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row)
|
||||
|
||||
|
||||
# moe
|
||||
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
|
||||
block_size: int, sorted_token_ids: torch.Tensor,
|
||||
|
||||
@ -582,6 +582,7 @@ class LoadFormat(str, enum.Enum):
|
||||
DUMMY = "dummy"
|
||||
TENSORIZER = "tensorizer"
|
||||
SHARDED_STATE = "sharded_state"
|
||||
GGUF = "gguf"
|
||||
BITSANDBYTES = "bitsandbytes"
|
||||
|
||||
|
||||
|
||||
@ -672,6 +672,9 @@ class EngineArgs:
|
||||
return engine_args
|
||||
|
||||
def create_engine_config(self, ) -> EngineConfig:
|
||||
# gguf file needs a specific model loader and doesn't use hf_repo
|
||||
if self.model.endswith(".gguf"):
|
||||
self.quantization = self.load_format = "gguf"
|
||||
|
||||
# bitsandbytes quantization needs a specific model loader
|
||||
# so we make sure the quant method and the load format are consistent
|
||||
|
||||
@ -3,7 +3,7 @@ from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.nn.parameter import Parameter, UninitializedParameter
|
||||
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
@ -311,6 +311,17 @@ class ColumnParallelLinear(LinearBase):
|
||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
|
||||
# Special case for GGUF
|
||||
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
||||
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
|
||||
if is_gguf_weight_type:
|
||||
param.weight_type = loaded_weight.item()
|
||||
|
||||
# Materialize GGUF UninitializedParameter
|
||||
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
||||
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
|
||||
|
||||
param_data = param.data
|
||||
if output_dim is not None:
|
||||
shard_size = param_data.shape[output_dim]
|
||||
@ -398,6 +409,27 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: Optional[int] = None):
|
||||
|
||||
# Special case for GGUF
|
||||
# initialize GGUF param after we know the quantize type
|
||||
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
||||
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
|
||||
if is_gguf_weight_type:
|
||||
param.data[loaded_shard_id].copy_(loaded_weight)
|
||||
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
|
||||
return
|
||||
|
||||
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
||||
from gguf.constants import GGML_QUANT_SIZES
|
||||
|
||||
ori_shape = param.tensor_shape
|
||||
weight_types = self.qweight_type.shard_weight_type.values()
|
||||
row_size = []
|
||||
for weight_type in weight_types:
|
||||
block_size, type_size = GGML_QUANT_SIZES[weight_type]
|
||||
row_size.append(ori_shape[1] // block_size * type_size)
|
||||
q_shape = (ori_shape[0], max(row_size))
|
||||
param.materialize(q_shape, dtype=loaded_weight.dtype)
|
||||
|
||||
param_data = param.data
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
# Special case for AQLM codebooks.
|
||||
@ -460,6 +492,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
shard_offset = loaded_weight.shape[output_dim] * \
|
||||
loaded_shard_id
|
||||
|
||||
if is_gguf_weight:
|
||||
shard_size = loaded_weight.shape[output_dim]
|
||||
shard_offset = loaded_weight.shape[output_dim] * \
|
||||
loaded_shard_id
|
||||
param.shard_id.append(loaded_shard_id)
|
||||
param.shard_size[loaded_shard_id] = loaded_weight.shape
|
||||
|
||||
param_data = param_data.narrow(output_dim, shard_offset,
|
||||
shard_size)
|
||||
start_idx = tp_rank * shard_size
|
||||
@ -563,6 +602,29 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
param: Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: Optional[str] = None):
|
||||
|
||||
# Special case for GGUF
|
||||
# initialize GGUF param after we know the quantize type
|
||||
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
||||
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
|
||||
if is_gguf_weight_type and loaded_shard_id is not None:
|
||||
idx_map = {"q": 0, "k": 1, "v": 2}
|
||||
param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
|
||||
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
|
||||
return
|
||||
|
||||
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
||||
from gguf.constants import GGML_QUANT_SIZES
|
||||
|
||||
ori_shape = param.tensor_shape
|
||||
weight_types = self.qweight_type.shard_weight_type.values()
|
||||
row_size = []
|
||||
for weight_type in weight_types:
|
||||
block_size, type_size = GGML_QUANT_SIZES[weight_type]
|
||||
row_size.append(ori_shape[1] // block_size * type_size)
|
||||
q_shape = (ori_shape[0], max(row_size))
|
||||
param.materialize(q_shape, dtype=loaded_weight.dtype)
|
||||
|
||||
param_data = param.data
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
# Special case for AQLM codebooks.
|
||||
@ -650,6 +712,13 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
shard_size, shard_offset = adjust_bitsandbytes_shard(
|
||||
param, orig_qkv_offsets, loaded_shard_id)
|
||||
|
||||
if is_gguf_weight:
|
||||
param.shard_id.append(loaded_shard_id)
|
||||
param.shard_size[loaded_shard_id] = loaded_weight.shape
|
||||
input_dim = getattr(param, "input_dim", None)
|
||||
input_size = loaded_weight.shape[input_dim]
|
||||
param_data = param_data.narrow(input_dim, 0, input_size)
|
||||
|
||||
param_data = param_data.narrow(output_dim, shard_offset,
|
||||
shard_size)
|
||||
if loaded_shard_id == "q":
|
||||
@ -755,6 +824,17 @@ class RowParallelLinear(LinearBase):
|
||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
input_dim = getattr(param, "input_dim", None)
|
||||
|
||||
# Special case for GGUF
|
||||
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
||||
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
|
||||
if is_gguf_weight_type:
|
||||
param.weight_type = loaded_weight.item()
|
||||
|
||||
# Materialize GGUF UninitializedParameter
|
||||
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
||||
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
|
||||
|
||||
param_data = param.data
|
||||
if input_dim is not None:
|
||||
shard_size = param_data.shape[input_dim]
|
||||
|
||||
@ -13,6 +13,7 @@ from vllm.model_executor.layers.quantization.deepspeedfp import (
|
||||
DeepSpeedFPConfig)
|
||||
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
||||
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinConfig)
|
||||
@ -31,6 +32,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
||||
# The order of gptq methods is important for config.py iteration over
|
||||
# override_quantization_method(..)
|
||||
"marlin": MarlinConfig,
|
||||
"gguf": GGUFConfig,
|
||||
"gptq_marlin_24": GPTQMarlin24Config,
|
||||
"gptq_marlin": GPTQMarlinConfig,
|
||||
"awq_marlin": AWQMarlinConfig,
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -23,6 +24,14 @@ class QuantizeMethodBase(ABC):
|
||||
Expects create_weights to have been called before on the layer."""
|
||||
raise NotImplementedError
|
||||
|
||||
# Not required functions
|
||||
def embedding(self, layer: torch.nn.Module, *args,
|
||||
**kwargs) -> torch.Tensor:
|
||||
"""Gather embeddings in the layer based on indices in the input tensor.
|
||||
|
||||
Expects create_weights to have been called before on the layer."""
|
||||
raise NotImplementedError
|
||||
|
||||
def process_weights_after_loading(self, layer: nn.Module) -> None:
|
||||
"""Process the weight after loading.
|
||||
|
||||
@ -31,6 +40,21 @@ class QuantizeMethodBase(ABC):
|
||||
return
|
||||
|
||||
|
||||
def method_has_implemented_embedding(
|
||||
method_class: Type[QuantizeMethodBase]) -> bool:
|
||||
"""
|
||||
Not all quant methods have embedding implemented, so we need to check that
|
||||
it exists for our given method. We check this by making sure the function
|
||||
has been changed from the base implementation.
|
||||
"""
|
||||
base_embedding = inspect.getattr_static(QuantizeMethodBase, "embedding",
|
||||
None)
|
||||
class_embedding = inspect.getattr_static(method_class, "embedding", None)
|
||||
|
||||
return (class_embedding is not None
|
||||
and class_embedding is not base_embedding)
|
||||
|
||||
|
||||
class QuantizationConfig(ABC):
|
||||
"""Base class for quantization configs."""
|
||||
|
||||
|
||||
165
vllm/model_executor/layers/quantization/gguf.py
Normal file
165
vllm/model_executor/layers/quantization/gguf.py
Normal file
@ -0,0 +1,165 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import gguf
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter, UninitializedParameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
|
||||
class GGUFConfig(QuantizationConfig):
|
||||
"""Config class for GGUF."""
|
||||
|
||||
def __init__(self, ) -> None:
|
||||
pass
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return ("GGUFConfig()")
|
||||
|
||||
def get_name(self) -> str:
|
||||
return "gguf"
|
||||
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 60
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return [] # no extra configs.
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "GGUFConfig":
|
||||
if get_tensor_model_parallel_world_size() > 1:
|
||||
raise ValueError(
|
||||
"GGUF quantization hasn't supported tensor parallelism yet.")
|
||||
return cls()
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
return GGUFLinearMethod(self)
|
||||
elif isinstance(layer, VocabParallelEmbedding):
|
||||
return GGUFEmbeddingMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor,
|
||||
qweight_type: int) -> torch.Tensor:
|
||||
# use dequantize mulmat for IQmatrix, mmq for k-quants
|
||||
if qweight_type >= 16:
|
||||
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
|
||||
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
|
||||
weight = ops.ggml_dequantize(qweight, qweight_type, *shape)
|
||||
y = x @ weight.T
|
||||
else:
|
||||
y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
|
||||
return y
|
||||
|
||||
|
||||
class GGUFLinearMethod(LinearMethodBase):
|
||||
"""Linear method for GGUF.
|
||||
|
||||
Args:
|
||||
quant_config: The GGUF quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: GGUFConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
|
||||
tensor_shape = (output_size_per_partition, input_size_per_partition)
|
||||
qweight = UninitializedParameter(requires_grad=False)
|
||||
set_weight_attrs(
|
||||
qweight, {
|
||||
"input_dim": 1,
|
||||
"output_dim": 0,
|
||||
"tensor_shape": tensor_shape,
|
||||
"is_gguf_weight": True,
|
||||
"shard_size": {},
|
||||
"shard_id": [],
|
||||
})
|
||||
set_weight_attrs(qweight, extra_weight_attrs)
|
||||
layer.register_parameter("qweight", qweight)
|
||||
|
||||
qweight_type = Parameter(torch.empty(len(output_partition_sizes),
|
||||
dtype=torch.uint8),
|
||||
requires_grad=False)
|
||||
set_weight_attrs(
|
||||
qweight_type, {
|
||||
"is_gguf_weight_type": True,
|
||||
"weight_type": 0,
|
||||
"shard_weight_type": {},
|
||||
"ignore_warning": True
|
||||
})
|
||||
set_weight_attrs(qweight_type, extra_weight_attrs)
|
||||
layer.register_parameter("qweight_type", qweight_type)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
shard_size = getattr(layer.qweight, "shard_size", None)
|
||||
shard_id = getattr(layer.qweight, "shard_id", None)
|
||||
|
||||
if shard_id and shard_size:
|
||||
result = []
|
||||
offset = 0
|
||||
# dequantize shard weights respectively
|
||||
shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id
|
||||
for id in shard_id:
|
||||
shard_weight = layer.qweight[
|
||||
offset:offset +
|
||||
shard_size[id][0], :shard_size[id][1]].contiguous()
|
||||
qweight_type = layer.qweight_type.shard_weight_type[id]
|
||||
result.append(_fuse_mul_mat(x, shard_weight, qweight_type))
|
||||
offset += shard_size[id][0]
|
||||
out = torch.cat(result, axis=1)
|
||||
else:
|
||||
qweight = layer.qweight
|
||||
qweight_type = layer.qweight_type.weight_type
|
||||
out = _fuse_mul_mat(x, qweight, qweight_type)
|
||||
if bias is not None:
|
||||
out.add_(bias)
|
||||
return out
|
||||
|
||||
|
||||
class GGUFEmbeddingMethod(GGUFLinearMethod):
|
||||
"""Embedding method for GGUF.
|
||||
|
||||
Args:
|
||||
quant_config: The GGUF quantization config.
|
||||
"""
|
||||
|
||||
def embedding(self, layer: torch.nn.Module,
|
||||
x: torch.Tensor) -> torch.Tensor:
|
||||
qweight = layer.qweight
|
||||
qweight_type = layer.qweight_type.weight_type
|
||||
|
||||
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
|
||||
hidden_size = qweight.shape[1] // type_size * block_size
|
||||
if qweight_type < 2:
|
||||
return torch.embedding(qweight, x)
|
||||
x_flat = x.flatten()
|
||||
quant = torch.index_select(qweight, dim=0, index=x_flat)
|
||||
dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size,
|
||||
x_flat.shape[0])
|
||||
return dequant.view(*x.shape, hidden_size)
|
||||
@ -3,19 +3,46 @@ from typing import List, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.nn.parameter import Parameter, UninitializedParameter
|
||||
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
DEFAULT_VOCAB_PADDING_SIZE = 64
|
||||
|
||||
|
||||
class UnquantizedEmbeddingMethod(QuantizeMethodBase):
|
||||
"""Unquantized method for embeddings."""
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
"""Create weights for embedding layer."""
|
||||
weight = Parameter(torch.empty(sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, extra_weight_attrs)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
return F.linear(x, layer.weight, bias)
|
||||
|
||||
def embedding(self, layer: torch.nn.Module,
|
||||
input_: torch.Tensor) -> torch.Tensor:
|
||||
return F.embedding(input_, layer.weight)
|
||||
|
||||
|
||||
def pad_vocab_size(vocab_size: int,
|
||||
pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
|
||||
"""Pad the vocab size to the given value."""
|
||||
@ -199,7 +226,19 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
if quant_config is not None:
|
||||
linear_method = quant_config.get_quant_method(self, prefix=prefix)
|
||||
if linear_method is None:
|
||||
linear_method = UnquantizedLinearMethod()
|
||||
linear_method = UnquantizedEmbeddingMethod()
|
||||
|
||||
# If we are making an embedding layer, then our quantization linear
|
||||
# method must implement the embedding operation. If we are another
|
||||
# layer type like ParallelLMHead, this is not important.
|
||||
is_embedding_layer = type(self.__class__) is VocabParallelEmbedding
|
||||
linear_method_implements_embedding = method_has_implemented_embedding(
|
||||
type(linear_method))
|
||||
if is_embedding_layer and not linear_method_implements_embedding:
|
||||
raise NotImplementedError(
|
||||
f"The class {type(linear_method).__name__} must implement "
|
||||
"the 'embedding' method, see UnquantizedEmbeddingMethod.")
|
||||
|
||||
self.linear_method: QuantizeMethodBase = linear_method
|
||||
|
||||
if params_dtype is None:
|
||||
@ -306,6 +345,14 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
packed_dim = getattr(param, "packed_dim", None)
|
||||
|
||||
# If the parameter is a gguf weight, then load it directly.
|
||||
if getattr(param, "is_gguf_weight_type", None):
|
||||
param.data.copy_(loaded_weight)
|
||||
param.weight_type = loaded_weight.item()
|
||||
return
|
||||
elif isinstance(param, UninitializedParameter):
|
||||
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
|
||||
|
||||
# If parameter does not have output dim, then it should
|
||||
# be copied onto all gpus (e.g. g_idx for act_order gptq).
|
||||
if output_dim is None:
|
||||
@ -344,7 +391,8 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
else:
|
||||
masked_input = input_
|
||||
# Get the embeddings.
|
||||
output_parallel = F.embedding(masked_input.long(), self.weight)
|
||||
output_parallel = self.linear_method.embedding(self,
|
||||
masked_input.long())
|
||||
# Mask the output embedding.
|
||||
if self.tp_size > 1:
|
||||
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
|
||||
@ -389,6 +437,7 @@ class ParallelLMHead(VocabParallelEmbedding):
|
||||
super().__init__(num_embeddings, embedding_dim, params_dtype,
|
||||
org_num_embeddings, padding_size, quant_config,
|
||||
prefix)
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.num_embeddings_per_partition,
|
||||
|
||||
@ -10,11 +10,13 @@ from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Generator, List, Optional, Tuple, Type
|
||||
|
||||
import gguf
|
||||
import huggingface_hub
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import HfApi, hf_hub_download
|
||||
from torch import nn
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat,
|
||||
LoRAConfig, ModelConfig, MultiModalConfig,
|
||||
@ -31,8 +33,9 @@ from vllm.model_executor.model_loader.utils import (get_model_architecture,
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
download_safetensors_index_file_from_hf, download_weights_from_hf,
|
||||
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
|
||||
get_quant_config, initialize_dummy_weights, np_cache_weights_iterator,
|
||||
pt_weights_iterator, safetensors_weights_iterator)
|
||||
get_gguf_extra_tensor_names, get_quant_config, gguf_quant_weights_iterator,
|
||||
initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator,
|
||||
safetensors_weights_iterator)
|
||||
from vllm.model_executor.models.interfaces import (has_inner_state,
|
||||
supports_lora,
|
||||
supports_vision)
|
||||
@ -948,6 +951,90 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
return model.eval()
|
||||
|
||||
|
||||
class GGUFModelLoader(BaseModelLoader):
|
||||
"""
|
||||
Model loader that can load GGUF files. This is useful for loading models
|
||||
that are quantized with GGUF and saved in the GGUF format. This loader
|
||||
supports loading both full models and sharded models.
|
||||
"""
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
if load_config.model_loader_extra_config:
|
||||
raise ValueError(f"Model loader extra config is not supported for "
|
||||
f"load format {load_config.load_format}")
|
||||
|
||||
def _prepare_weights(self, model_name_or_path: str):
|
||||
if os.path.isfile(model_name_or_path):
|
||||
return model_name_or_path
|
||||
else:
|
||||
raise ValueError(f"{model_name_or_path} is not a file.")
|
||||
|
||||
def _get_gguf_weights_map(self, model_config: ModelConfig):
|
||||
"""
|
||||
GGUF uses this naming convention for their tensors from HF checkpoint:
|
||||
`blk.N.BB.weight` and `blk.N.BB.bias`
|
||||
where N signifies the block number of a layer, and BB signifies the
|
||||
attention/mlp layer components.
|
||||
See "Standardized tensor names" in
|
||||
https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details.
|
||||
"""
|
||||
config = model_config.hf_config
|
||||
model_type = config.model_type
|
||||
# hack: ggufs have a different name than transformers
|
||||
if model_type == "cohere":
|
||||
model_type = "command-r"
|
||||
arch = None
|
||||
for key, value in gguf.MODEL_ARCH_NAMES.items():
|
||||
if value == model_type:
|
||||
arch = key
|
||||
break
|
||||
if arch is None:
|
||||
raise RuntimeError(f"Unknown gguf model_type: {model_type}")
|
||||
num_layers = config.num_hidden_layers
|
||||
name_map = gguf.get_tensor_name_map(arch, num_layers)
|
||||
with torch.device("meta"):
|
||||
dummy_model = AutoModelForCausalLM.from_config(config)
|
||||
state_dict = dummy_model.state_dict()
|
||||
|
||||
gguf_to_hf_name_map = {}
|
||||
for hf_name in state_dict:
|
||||
name, suffix = hf_name.rsplit(".", 1)
|
||||
gguf_name = name_map.get_name(name)
|
||||
gguf_to_hf_name_map[f"{gguf_name}.{suffix}"] = hf_name
|
||||
return gguf_to_hf_name_map
|
||||
|
||||
def _get_weights_iterator(
|
||||
self, model_name_or_path: str, gguf_to_hf_name_map: Dict[str, str]
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
return gguf_quant_weights_iterator(model_name_or_path,
|
||||
gguf_to_hf_name_map)
|
||||
|
||||
def load_model(self, *, model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
multimodal_config: Optional[MultiModalConfig],
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
cache_config: CacheConfig) -> nn.Module:
|
||||
|
||||
local_model_path = self._prepare_weights(model_config.model)
|
||||
gguf_weights_map = self._get_gguf_weights_map(model_config)
|
||||
# we can only know if tie word embeddings after mapping weights
|
||||
if "lm_head.weight" in get_gguf_extra_tensor_names(
|
||||
local_model_path, gguf_weights_map):
|
||||
model_config.hf_config.update({"tie_word_embeddings": True})
|
||||
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
model = _initialize_model(model_config, self.load_config,
|
||||
lora_config, multimodal_config,
|
||||
cache_config)
|
||||
model.load_weights(
|
||||
self._get_weights_iterator(local_model_path, gguf_weights_map))
|
||||
return model
|
||||
|
||||
|
||||
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
||||
"""Get a model loader based on the load format."""
|
||||
|
||||
@ -966,4 +1053,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
||||
if load_config.load_format == LoadFormat.BITSANDBYTES:
|
||||
return BitsAndBytesModelLoader(load_config)
|
||||
|
||||
if load_config.load_format == LoadFormat.GGUF:
|
||||
return GGUFModelLoader(load_config)
|
||||
|
||||
return DefaultModelLoader(load_config)
|
||||
|
||||
@ -6,9 +6,10 @@ import json
|
||||
import os
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from typing import Any, Generator, Iterable, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import filelock
|
||||
import gguf
|
||||
import huggingface_hub.constants
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -121,6 +122,11 @@ def get_quant_config(model_config: ModelConfig,
|
||||
load_config: LoadConfig) -> QuantizationConfig:
|
||||
|
||||
quant_cls = get_quantization_config(model_config.quantization)
|
||||
|
||||
# GGUF doesn't have config file
|
||||
if model_config.quantization == "gguf":
|
||||
return quant_cls.from_config({})
|
||||
|
||||
# Read the quantization config from the HF model config, if available.
|
||||
hf_quant_config = getattr(model_config.hf_config, "quantization_config",
|
||||
None)
|
||||
@ -409,6 +415,45 @@ def pt_weights_iterator(
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def get_gguf_extra_tensor_names(
|
||||
gguf_file: str, gguf_to_hf_name_map: Dict[str, str]) -> List[str]:
|
||||
reader = gguf.GGUFReader(gguf_file)
|
||||
expected_gguf_keys = set(gguf_to_hf_name_map.keys())
|
||||
exact_gguf_keys = set([tensor.name for tensor in reader.tensors])
|
||||
extra_keys = expected_gguf_keys - exact_gguf_keys
|
||||
return [gguf_to_hf_name_map[key] for key in extra_keys]
|
||||
|
||||
|
||||
def gguf_quant_weights_iterator(
|
||||
gguf_file: str, gguf_to_hf_name_map: Dict[str, str]
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""
|
||||
Iterate over the quant weights in the model gguf files and convert
|
||||
them to torch tensors
|
||||
"""
|
||||
|
||||
reader = gguf.GGUFReader(gguf_file)
|
||||
|
||||
for tensor in reader.tensors:
|
||||
weight_type = tensor.tensor_type
|
||||
name = gguf_to_hf_name_map[tensor.name]
|
||||
|
||||
if weight_type.name != "F32":
|
||||
weight_type_name = name.replace("weight", "qweight_type")
|
||||
weight_type = torch.tensor(weight_type)
|
||||
yield weight_type_name, weight_type
|
||||
|
||||
for tensor in reader.tensors:
|
||||
weight = tensor.data
|
||||
weight_type = tensor.tensor_type
|
||||
name = gguf_to_hf_name_map[tensor.name]
|
||||
|
||||
if weight_type.name != "F32":
|
||||
name = name.replace("weight", "qweight")
|
||||
param = torch.tensor(weight)
|
||||
yield name, param
|
||||
|
||||
|
||||
def kv_cache_scales_loader(
|
||||
filename: str, tp_rank: int, tp_size: int, num_hidden_layers: int,
|
||||
model_type: Optional[str]) -> Iterable[Tuple[int, float]]:
|
||||
|
||||
@ -140,6 +140,7 @@ class LlamaAttention(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
|
||||
self.o_proj = RowParallelLinear(
|
||||
input_size=self.total_num_heads * self.head_dim,
|
||||
output_size=hidden_size,
|
||||
@ -148,12 +149,17 @@ class LlamaAttention(nn.Module):
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
is_neox_style = True
|
||||
if quant_config is not None and quant_config.get_name() == "gguf":
|
||||
is_neox_style = False
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
is_neox_style=is_neox_style,
|
||||
)
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
@ -279,6 +285,7 @@ class LlamaModel(nn.Module):
|
||||
self.vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
else:
|
||||
self.embed_tokens = PPMissingLayer()
|
||||
|
||||
@ -238,6 +238,7 @@ class Qwen2Model(nn.Module):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
import contextlib
|
||||
from typing import Dict, Optional, Type
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Type, Union
|
||||
|
||||
from transformers import GenerationConfig, PretrainedConfig
|
||||
from transformers.models.auto.modeling_auto import (
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
|
||||
|
||||
from vllm.envs import VLLM_USE_MODELSCOPE
|
||||
from vllm.logger import init_logger
|
||||
@ -36,18 +39,29 @@ for name, cls in _CONFIG_REGISTRY.items():
|
||||
AutoConfig.register(name, cls)
|
||||
|
||||
|
||||
def get_config(model: str,
|
||||
trust_remote_code: bool,
|
||||
revision: Optional[str] = None,
|
||||
code_revision: Optional[str] = None,
|
||||
rope_scaling: Optional[dict] = None,
|
||||
rope_theta: Optional[float] = None) -> PretrainedConfig:
|
||||
def get_config(
|
||||
model: Union[str, Path],
|
||||
trust_remote_code: bool,
|
||||
revision: Optional[str] = None,
|
||||
code_revision: Optional[str] = None,
|
||||
rope_scaling: Optional[dict] = None,
|
||||
rope_theta: Optional[float] = None,
|
||||
**kwargs,
|
||||
) -> PretrainedConfig:
|
||||
|
||||
# Separate model folder from file path for GGUF models
|
||||
is_gguf = Path(model).is_file() and Path(model).suffix == ".gguf"
|
||||
if is_gguf:
|
||||
kwargs["gguf_file"] = Path(model).name
|
||||
model = Path(model).parent
|
||||
|
||||
try:
|
||||
config = AutoConfig.from_pretrained(
|
||||
model,
|
||||
trust_remote_code=trust_remote_code,
|
||||
revision=revision,
|
||||
code_revision=code_revision)
|
||||
code_revision=code_revision,
|
||||
**kwargs)
|
||||
except ValueError as e:
|
||||
if (not trust_remote_code and
|
||||
"requires you to execute the configuration file" in str(e)):
|
||||
@ -64,12 +78,22 @@ def get_config(model: str,
|
||||
config = config_class.from_pretrained(model,
|
||||
revision=revision,
|
||||
code_revision=code_revision)
|
||||
|
||||
# Special architecture mapping check for GGUF models
|
||||
if is_gguf:
|
||||
if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||
raise RuntimeError(
|
||||
f"Can't get gguf config for {config.model_type}.")
|
||||
model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
|
||||
config.update({"architectures": [model_type]})
|
||||
|
||||
for key, value in [("rope_scaling", rope_scaling),
|
||||
("rope_theta", rope_theta)]:
|
||||
if value is not None:
|
||||
logger.info("Updating %s from %r to %r", key,
|
||||
getattr(config, key, None), value)
|
||||
config.update({key: value})
|
||||
|
||||
return config
|
||||
|
||||
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import huggingface_hub
|
||||
@ -55,7 +56,7 @@ def get_cached_tokenizer(
|
||||
|
||||
|
||||
def get_tokenizer(
|
||||
tokenizer_name: str,
|
||||
tokenizer_name: Union[str, Path],
|
||||
*args,
|
||||
tokenizer_mode: str = "auto",
|
||||
trust_remote_code: bool = False,
|
||||
@ -91,6 +92,13 @@ def get_tokenizer(
|
||||
if "truncation_side" not in kwargs:
|
||||
kwargs["truncation_side"] = "left"
|
||||
|
||||
# Separate model folder from file path for GGUF models
|
||||
is_gguf = Path(tokenizer_name).is_file() and Path(
|
||||
tokenizer_name).suffix == ".gguf"
|
||||
if is_gguf:
|
||||
kwargs["gguf_file"] = Path(tokenizer_name).name
|
||||
tokenizer_name = Path(tokenizer_name).parent
|
||||
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer_name,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user