From 90969fb39a58593515f6a087d9200bc72333ab9a Mon Sep 17 00:00:00 2001 From: LukasBluebaum <38468743+LukasBluebaum@users.noreply.github.com> Date: Wed, 2 Apr 2025 10:58:48 +0200 Subject: [PATCH] [Kernel] Add more dtype support for GGUF dequantization (#15879) Signed-off-by: lukas.bluebaum --- csrc/ops.h | 3 +- csrc/quantization/gguf/dequantize.cuh | 65 ++++++++++--------- csrc/quantization/gguf/ggml-common.h | 17 ++++- csrc/quantization/gguf/gguf_kernel.cu | 15 +++-- csrc/torch_bindings.cpp | 4 +- tests/kernels/test_ggml.py | 3 +- tests/kernels/test_gguf.py | 4 +- vllm/_custom_ops.py | 15 +++-- .../layers/quantization/gguf.py | 4 +- 9 files changed, 80 insertions(+), 50 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index a0985d3242662..152c94e860032 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -145,7 +145,8 @@ torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm); #endif torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m, - int64_t n); + int64_t n, + std::optional const& dtype); torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X, int64_t type, int64_t row); diff --git a/csrc/quantization/gguf/dequantize.cuh b/csrc/quantization/gguf/dequantize.cuh index 41fc032ff1a56..9d355003ef91d 100644 --- a/csrc/quantization/gguf/dequantize.cuh +++ b/csrc/quantization/gguf/dequantize.cuh @@ -94,8 +94,8 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __ dfloat2 v; dequantize_kernel(vx, ib, iqs, v); - y[iybs + iqs + 0] = v.x; - y[iybs + iqs + y_offset] = v.y; + y[iybs + iqs + 0] = convert_from_half(v.x); + y[iybs + iqs + y_offset] = convert_from_half(v.y); } template @@ -114,10 +114,10 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t half dall = __low2half(x[i].dm); half dmin = __high2half(x[i].dm); - y[l+ 0] = __hsub(__hmul(dall, __int2half_rn((x[i].scales[is+0] & 0xF) * ((q >> 0) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+0] >> 4))); - y[l+32] = __hsub(__hmul(dall, __int2half_rn((x[i].scales[is+2] & 0xF) * ((q >> 2) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+2] >> 4))); - y[l+64] = __hsub(__hmul(dall, __int2half_rn((x[i].scales[is+4] & 0xF) * ((q >> 4) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+4] >> 4))); - y[l+96] = __hsub(__hmul(dall, __int2half_rn((x[i].scales[is+6] & 0xF) * ((q >> 6) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+6] >> 4))); + y[l+ 0] = convert_from_half(__hsub(__hmul(dall, __int2half_rn((x[i].scales[is+0] & 0xF) * ((q >> 0) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+0] >> 4)))); + y[l+32] = convert_from_half(__hsub(__hmul(dall, __int2half_rn((x[i].scales[is+2] & 0xF) * ((q >> 2) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+2] >> 4)))); + y[l+64] = convert_from_half(__hsub(__hmul(dall, __int2half_rn((x[i].scales[is+4] & 0xF) * ((q >> 4) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+4] >> 4)))); + y[l+96] = convert_from_half(__hsub(__hmul(dall, __int2half_rn((x[i].scales[is+6] & 0xF) * ((q >> 6) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+6] >> 4)))); } template @@ -148,7 +148,9 @@ static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t const uint8_t * q = x[i].qs + 32*n; const uint8_t * hm = x[i].hmask; - for (int l = l0; l < l0+4; ++l) y[l] = __hmul(dl, __int2half_rn((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4))); + for (int l = l0; l < l0+4; ++l) { + y[l] = convert_from_half(__hmul(dl, __int2half_rn((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4)))); + } } static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) { @@ -188,8 +190,8 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t const half d2 = __hmul(dall, __int2half_rn(sc)); const half m2 = __hmul(dmin, __int2half_rn(m)); for (int l = 0; l < n; ++l) { - y[l + 0] = __hsub(__hmul(d1, __int2half_rn(q[l] & 0xF)), m1); - y[l +32] = __hsub(__hmul(d2, __int2half_rn(q[l] >> 4)), m2); + y[l + 0] = convert_from_half(__hsub(__hmul(d1, __int2half_rn(q[l] & 0xF)), m1)); + y[l +32] = convert_from_half(__hsub(__hmul(d2, __int2half_rn(q[l] >> 4)), m2)); } } @@ -220,11 +222,11 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t const half d2 = __hmul(dall, __int2half_rn(sc)); const half m2 = __hmul(dmin, __int2half_rn(m)); uint8_t hm = 1 << (2*il); - y[ 0] = __hsub(__hmul(d1, __int2half_rn((ql[0] & 0xF) + (qh[0] & hm ? 16 : 0))), m1); - y[ 1] = __hsub(__hmul(d1, __int2half_rn((ql[1] & 0xF) + (qh[1] & hm ? 16 : 0))), m1); + y[ 0] = convert_from_half(__hsub(__hmul(d1, __int2half_rn((ql[0] & 0xF) + (qh[0] & hm ? 16 : 0))), m1)); + y[ 1] = convert_from_half(__hsub(__hmul(d1, __int2half_rn((ql[1] & 0xF) + (qh[1] & hm ? 16 : 0))), m1)); hm <<= 1; - y[32] = __hsub(__hmul(d2, __int2half_rn((ql[0] >> 4) + (qh[0] & hm ? 16 : 0))), m2); - y[33] = __hsub(__hmul(d2, __int2half_rn((ql[1] >> 4) + (qh[1] & hm ? 16 : 0))), m2); + y[32] = convert_from_half(__hsub(__hmul(d2, __int2half_rn((ql[0] >> 4) + (qh[0] & hm ? 16 : 0))), m2)); + y[33] = convert_from_half(__hsub(__hmul(d2, __int2half_rn((ql[1] >> 4) + (qh[1] & hm ? 16 : 0))), m2)); } template @@ -247,10 +249,10 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t const uint8_t qh = x[i].qh[32*ip + il]; const int8_t * sc = x[i].scales + is; - y[ 0] = __hmul(d, __int2half_rn(sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32))); - y[32] = __hmul(d, __int2half_rn(sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32))); - y[64] = __hmul(d, __int2half_rn(sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32))); - y[96] = __hmul(d, __int2half_rn(sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32))); + y[ 0] = convert_from_half(__hmul(d, __int2half_rn(sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32)))); + y[32] = convert_from_half(__hmul(d, __int2half_rn(sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32)))); + y[64] = convert_from_half(__hmul(d, __int2half_rn(sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32)))); + y[96] = convert_from_half(__hmul(d, __int2half_rn(sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32)))); } template @@ -269,7 +271,7 @@ static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, ds const uint32_t aux32 = q2[2] | (q2[3] << 16); const float d = __half2float(x[i].d) * (0.5f + (aux32 >> 28)) * 0.25f; const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127]; - for (int j = 0; j < 8; ++j) y[j] = __float2half(d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f)); + for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); } template @@ -286,7 +288,7 @@ static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511)); const float d = __half2float(x[i].d) * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f; const uint8_t signs = ksigns_iq2xs[q2[il] >> 9]; - for (int j = 0; j < 8; ++j) y[j] = __float2half(d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f)); + for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); } @@ -303,7 +305,7 @@ static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_ const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300))); const float d = __half2float(x[i].d) * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f; const uint8_t signs = x[i].qs[QK_K/8+4*ib+il]; - for (int j = 0; j < 8; ++j) y[j] = __float2half(d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f)); + for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); } template @@ -324,8 +326,8 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds const float d = __half2float(x[i].d) * (0.5f + (aux32 >> 28)) * 0.5f; const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127]; for (int j = 0; j < 4; ++j) { - y[j+0] = __float2half(d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f)); - y[j+4] = __float2half(d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f)); + y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); + y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); } } @@ -345,8 +347,8 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_ const float d = __half2float(x[i].d) * (0.5f + ((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf)) * 0.5f; const uint8_t signs = x[i].signs[4*ib + il]; for (int j = 0; j < 4; ++j) { - y[j+0] = __float2half(d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f)); - y[j+4] = __float2half(d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f)); + y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); + y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); } } @@ -367,7 +369,7 @@ static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_ grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f; grid32[0] &= 0x0f0f0f0f; for (int j = 0; j < 8; ++j) { - y[j] = __float2half(d * (q[j] + delta)); + y[j] = d * (q[j] + delta); } } @@ -392,7 +394,7 @@ static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_ grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f; grid32[0] &= 0x0f0f0f0f; for (int j = 0; j < 8; ++j) { - y[j] = __float2half(d * (q[j] + delta)); + y[j] = d * (q[j] + delta); } } @@ -409,8 +411,8 @@ static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst const uint8_t * q4 = x[ib].qs + 4*il; const float d = __half2float(x[ib].d); for (int j = 0; j < 4; ++j) { - y[j+ 0] = __float2half(d * kvalues_iq4nl[q4[j] & 0xf]); - y[j+16] = __float2half(d * kvalues_iq4nl[q4[j] >> 4]); + y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf]; + y[j+16] = d * kvalues_iq4nl[q4[j] >> 4]; } } @@ -427,8 +429,8 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst const uint8_t * q4 = x[i].qs + 16*ib + 4*il; const float d = __half2float(x[i].d) * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32); for (int j = 0; j < 4; ++j) { - y[j+ 0] = __float2half(d * kvalues_iq4nl[q4[j] & 0xf]); - y[j+16] = __float2half(d * kvalues_iq4nl[q4[j] >> 4]); + y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf]; + y[j+16] = d * kvalues_iq4nl[q4[j] >> 4]; } } @@ -522,7 +524,8 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k, dequantize_block_iq4_xs<<>>(vx, y); } -static to_fp16_cuda_t ggml_get_to_fp16_cuda(int64_t type) { +template +static to_cuda_ggml_t ggml_get_to_cuda(int64_t type) { switch (type) { case 2: return dequantize_block_cuda; diff --git a/csrc/quantization/gguf/ggml-common.h b/csrc/quantization/gguf/ggml-common.h index d42205a6571db..99a7ea0fb277e 100644 --- a/csrc/quantization/gguf/ggml-common.h +++ b/csrc/quantization/gguf/ggml-common.h @@ -1063,7 +1063,8 @@ static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, - typedef half dfloat; // dequantize float typedef half2 dfloat2; typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v); -typedef void (*to_fp16_cuda_t)(const void * __restrict__ x, dfloat * __restrict__ y, int k, cudaStream_t stream); +template +using to_cuda_ggml_t = void (*)(const void * __restrict__ x, dst_t * __restrict__ y, int k, cudaStream_t stream); typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs); typedef void (*allocate_tiles_cuda_t)(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc); typedef void (*load_tiles_cuda_t)( @@ -1075,6 +1076,20 @@ typedef float (*vec_dot_q_mul_mat_cuda_t)( // Utility function +template +static __device__ __forceinline__ dst_t convert_from_half(half val) { + return val; +} + +template<> +__device__ __forceinline__ c10::BFloat16 convert_from_half(half val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + return __float2bfloat16(__half2float(val)); +#else + return __half2float(val); +#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +} + #if defined(USE_ROCM) #ifndef __has_builtin diff --git a/csrc/quantization/gguf/gguf_kernel.cu b/csrc/quantization/gguf/gguf_kernel.cu index dbbb97e6fb3a9..56b78f1834d15 100644 --- a/csrc/quantization/gguf/gguf_kernel.cu +++ b/csrc/quantization/gguf/gguf_kernel.cu @@ -71,14 +71,19 @@ static void quantize_row_q8_1_cuda(const scalar_t* x, void* vy, const int kx, } torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight - int64_t type, int64_t m, int64_t n) { + int64_t type, int64_t m, int64_t n, + std::optional const& dtype) { const at::cuda::OptionalCUDAGuard device_guard(device_of(W)); - auto options = - torch::TensorOptions().dtype(torch::kFloat16).device(W.device()); + auto dtype_ = dtype.value_or(torch::kFloat16); + auto options = torch::TensorOptions().dtype(dtype_).device(W.device()); at::Tensor DW = torch::empty({m, n}, options); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(type); - to_fp16_cuda((void*)W.data_ptr(), (half*)DW.data_ptr(), m * n, stream); + + VLLM_DISPATCH_FLOATING_TYPES(DW.scalar_type(), "ggml_dequantize", [&] { + auto to_cuda = ggml_get_to_cuda(type); + to_cuda((void*)W.data_ptr(), (scalar_t*)DW.data_ptr(), m * n, stream); + }); + return DW; } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index feb3882c4d54e..d3b80572b6ead 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -295,7 +295,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { #endif // Dequantization for GGML. - ops.def("ggml_dequantize(Tensor W, int type, SymInt m, SymInt n) -> Tensor"); + ops.def( + "ggml_dequantize(Tensor W, int type, SymInt m, SymInt n, ScalarType? " + "dtype) -> Tensor"); ops.impl("ggml_dequantize", torch::kCUDA, &ggml_dequantize); // mmvq kernel for GGML. diff --git a/tests/kernels/test_ggml.py b/tests/kernels/test_ggml.py index 23fa1fdfda179..cc157da518cbf 100644 --- a/tests/kernels/test_ggml.py +++ b/tests/kernels/test_ggml.py @@ -15,7 +15,8 @@ def test_ggml_opcheck(quant_type): qweight = torch.randint(0, 100, shape, device='cuda', dtype=torch.uint8) m = qweight.shape[0] n = qweight.shape[1] // type_size * block_size - opcheck(torch.ops._C.ggml_dequantize, (qweight, quant_type, m, n)) + opcheck(torch.ops._C.ggml_dequantize, + (qweight, quant_type, m, n, torch.float16)) x = torch.rand((m, 512), device='cuda', dtype=torch.float16) opcheck(torch.ops._C.ggml_mul_mat_a8, diff --git a/tests/kernels/test_gguf.py b/tests/kernels/test_gguf.py index ede941844dc0e..4c0fae9d9fd75 100644 --- a/tests/kernels/test_gguf.py +++ b/tests/kernels/test_gguf.py @@ -65,7 +65,7 @@ QUANT_TYPES = [ @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) -@pytest.mark.parametrize("dtype", [torch.half]) +@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("quant_type", QUANT_TYPES) @torch.inference_mode() def test_dequantize(hidden_size: int, dtype: torch.dtype, @@ -78,7 +78,7 @@ def test_dequantize(hidden_size: int, dtype: torch.dtype, ref_output = torch.tensor(dequantize(tensor.data, quant_type), device="cuda").to(dtype) output = ops.ggml_dequantize(torch.tensor(tensor.data, device="cuda"), - quant_type, *list(shape)).to(dtype) + quant_type, *list(shape), dtype) torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=4e-2) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 039397f5a5ef5..fe41a2d963b2e 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -436,9 +436,12 @@ if hasattr(torch.ops._C, "allspark_w8a16_gemm"): if hasattr(torch.ops._C, "ggml_dequantize"): @register_fake("_C::ggml_dequantize") - def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, - m: torch.SymInt, - n: torch.SymInt) -> torch.Tensor: + def _ggml_dequantize_fake( + W: torch.Tensor, + quant_type: int, + m: torch.SymInt, + n: torch.SymInt, + dtype: Optional[torch.dtype] = None) -> torch.Tensor: return torch.empty((m, n), dtype=torch.float16, device=W.device) @register_fake("_C::ggml_mul_mat_vec_a8") @@ -1097,9 +1100,9 @@ def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, # gguf -def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int, - n: int) -> torch.Tensor: - return torch.ops._C.ggml_dequantize(W, quant_type, m, n) +def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int, n: int, + dtype: Optional[torch.dtype]) -> torch.Tensor: + return torch.ops._C.ggml_dequantize(W, quant_type, m, n, dtype) def ggml_mul_mat_vec_a8( diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index c8ab12d9a0aa2..9861e0a85b3f1 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -117,7 +117,7 @@ def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor, elif qweight_type in DEQUANT_TYPES: block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type] shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size) - weight = ops.ggml_dequantize(qweight, qweight_type, *shape) + weight = ops.ggml_dequantize(qweight, qweight_type, *shape, x.dtype) y = x @ weight.T else: # Raise an error if the quantization type is not supported. @@ -377,7 +377,7 @@ class GGUFEmbeddingMethod(GGUFLinearMethod): x_flat = x.flatten() quant = torch.index_select(qweight, dim=0, index=x_flat) dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size, - x_flat.shape[0]).to(self.params_dtype) + x_flat.shape[0], self.params_dtype) return dequant.view(*x.shape, hidden_size)