diff --git a/csrc/ops.h b/csrc/ops.h index 97a247d9d628c..207291eceb169 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -292,6 +292,11 @@ void per_token_group_quant_fp8(const torch::Tensor& input, torch::Tensor& output_q, torch::Tensor& output_s, int64_t group_size, double eps, double fp8_min, double fp8_max, bool scale_ue8m0); + +void per_token_group_quant_int8(const torch::Tensor& input, + torch::Tensor& output_q, + torch::Tensor& output_s, int64_t group_size, + double eps, double int8_min, double int8_max); #endif void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index 5cd2ac179768b..6a81f159f46ae 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -1,6 +1,8 @@ #include #include +#include "../per_token_group_quant_8bit.h" + #include #include "../../dispatch_utils.h" @@ -336,3 +338,11 @@ void dynamic_scaled_int8_quant( } }); } + +void per_token_group_quant_int8(const torch::Tensor& input, + torch::Tensor& output_q, + torch::Tensor& output_s, int64_t group_size, + double eps, double int8_min, double int8_max) { + per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, + int8_min, int8_max); +} \ No newline at end of file diff --git a/csrc/quantization/fp8/per_token_group_quant.cu b/csrc/quantization/fp8/per_token_group_quant.cu index afc41faeca902..2609054f2072b 100644 --- a/csrc/quantization/fp8/per_token_group_quant.cu +++ b/csrc/quantization/fp8/per_token_group_quant.cu @@ -1,6 +1,8 @@ #include #include +#include "../per_token_group_quant_8bit.h" + #include #include @@ -120,7 +122,7 @@ void per_token_group_quant_8bit(const torch::Tensor& input, torch::Tensor& output_q, torch::Tensor& output_s, int64_t group_size, double eps, double min_8bit, double max_8bit, - bool scale_ue8m0 = false) { + bool scale_ue8m0) { TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(output_q.is_contiguous()); @@ -198,6 +200,8 @@ void per_token_group_quant_8bit(const torch::Tensor& input, input.scalar_type(), "per_token_group_quant_8bit", ([&] { if (dst_type == at::ScalarType::Float8_e4m3fn) { LAUNCH_KERNEL(scalar_t, c10::Float8_e4m3fn); + } else if (dst_type == at::ScalarType::Char) { + LAUNCH_KERNEL(scalar_t, int8_t); } })); diff --git a/csrc/quantization/per_token_group_quant_8bit.h b/csrc/quantization/per_token_group_quant_8bit.h new file mode 100644 index 0000000000000..537b61bc4303f --- /dev/null +++ b/csrc/quantization/per_token_group_quant_8bit.h @@ -0,0 +1,10 @@ +#pragma once +#include + +// TODO(wentao): refactor the folder to 8bit, then includes fp8 and int8 folders +// 8-bit per-token-group quantization helper used by both FP8 and INT8 +void per_token_group_quant_8bit(const torch::Tensor& input, + torch::Tensor& output_q, + torch::Tensor& output_s, int64_t group_size, + double eps, double min_8bit, double max_8bit, + bool scale_ue8m0 = false); \ No newline at end of file diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 95f8541bc9e2d..85b6abef00b03 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -624,6 +624,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("per_token_group_fp8_quant", torch::kCUDA, &per_token_group_quant_fp8); + // Compute per-token-group INT8 quantized tensor and scaling factor. + ops.def( + "per_token_group_quant_int8(Tensor input, Tensor! output_q, Tensor! " + "output_s, int group_size, float eps, float int8_min, float int8_max) -> " + "()"); + ops.impl("per_token_group_quant_int8", torch::kCUDA, + &per_token_group_quant_int8); + // reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel ops.def( "rearrange_kn_weight_as_n32k16_order(Tensor b_qweight, Tensor b_scales, " diff --git a/vllm/model_executor/layers/quantization/utils/int8_utils.py b/vllm/model_executor/layers/quantization/utils/int8_utils.py index 1fdf7d174e25e..6840cabbf1ae3 100644 --- a/vllm/model_executor/layers/quantization/utils/int8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/int8_utils.py @@ -238,13 +238,20 @@ def per_token_group_quant_int8( int8_min = iinfo.min x_q = torch.empty_like(x, device=x.device, dtype=dtype) - M = x.numel() // group_size - N = group_size x_s = torch.empty( x.shape[:-1] + (x.shape[-1] // group_size, ), device=x.device, dtype=torch.float32, ) + # prefer CUDA kernel if available + if current_platform.is_cuda(): + torch.ops._C.per_token_group_quant_int8(x, x_q, x_s, group_size, eps, + float(int8_min), + float(int8_max)) + return x_q, x_s + + M = x.numel() // group_size + N = group_size BLOCK = triton.next_power_of_2(N) # heuristics for number of warps