[Perf] Cuda Kernel for Int8 Per Token Group Quant (#21476)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye 2025-07-25 20:07:07 -04:00 committed by GitHub
parent 41d3082c41
commit 75d29cf4e1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 47 additions and 3 deletions

View File

@ -292,6 +292,11 @@ void per_token_group_quant_fp8(const torch::Tensor& input,
torch::Tensor& output_q, torch::Tensor& output_s,
int64_t group_size, double eps, double fp8_min,
double fp8_max, bool scale_ue8m0);
void per_token_group_quant_int8(const torch::Tensor& input,
torch::Tensor& output_q,
torch::Tensor& output_s, int64_t group_size,
double eps, double int8_min, double int8_max);
#endif
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,

View File

@ -1,6 +1,8 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include "../per_token_group_quant_8bit.h"
#include <cmath>
#include "../../dispatch_utils.h"
@ -336,3 +338,11 @@ void dynamic_scaled_int8_quant(
}
});
}
void per_token_group_quant_int8(const torch::Tensor& input,
torch::Tensor& output_q,
torch::Tensor& output_s, int64_t group_size,
double eps, double int8_min, double int8_max) {
per_token_group_quant_8bit(input, output_q, output_s, group_size, eps,
int8_min, int8_max);
}

View File

@ -1,6 +1,8 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/util/Float8_e4m3fn.h>
#include "../per_token_group_quant_8bit.h"
#include <cmath>
#include <cuda_fp16.h>
@ -120,7 +122,7 @@ void per_token_group_quant_8bit(const torch::Tensor& input,
torch::Tensor& output_q,
torch::Tensor& output_s, int64_t group_size,
double eps, double min_8bit, double max_8bit,
bool scale_ue8m0 = false) {
bool scale_ue8m0) {
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(output_q.is_contiguous());
@ -198,6 +200,8 @@ void per_token_group_quant_8bit(const torch::Tensor& input,
input.scalar_type(), "per_token_group_quant_8bit", ([&] {
if (dst_type == at::ScalarType::Float8_e4m3fn) {
LAUNCH_KERNEL(scalar_t, c10::Float8_e4m3fn);
} else if (dst_type == at::ScalarType::Char) {
LAUNCH_KERNEL(scalar_t, int8_t);
}
}));

View File

@ -0,0 +1,10 @@
#pragma once
#include <torch/all.h>
// TODO(wentao): refactor the folder to 8bit, then includes fp8 and int8 folders
// 8-bit per-token-group quantization helper used by both FP8 and INT8
void per_token_group_quant_8bit(const torch::Tensor& input,
torch::Tensor& output_q,
torch::Tensor& output_s, int64_t group_size,
double eps, double min_8bit, double max_8bit,
bool scale_ue8m0 = false);

View File

@ -624,6 +624,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("per_token_group_fp8_quant", torch::kCUDA,
&per_token_group_quant_fp8);
// Compute per-token-group INT8 quantized tensor and scaling factor.
ops.def(
"per_token_group_quant_int8(Tensor input, Tensor! output_q, Tensor! "
"output_s, int group_size, float eps, float int8_min, float int8_max) -> "
"()");
ops.impl("per_token_group_quant_int8", torch::kCUDA,
&per_token_group_quant_int8);
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
ops.def(
"rearrange_kn_weight_as_n32k16_order(Tensor b_qweight, Tensor b_scales, "

View File

@ -238,13 +238,20 @@ def per_token_group_quant_int8(
int8_min = iinfo.min
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // group_size
N = group_size
x_s = torch.empty(
x.shape[:-1] + (x.shape[-1] // group_size, ),
device=x.device,
dtype=torch.float32,
)
# prefer CUDA kernel if available
if current_platform.is_cuda():
torch.ops._C.per_token_group_quant_int8(x, x_q, x_s, group_size, eps,
float(int8_min),
float(int8_max))
return x_q, x_s
M = x.numel() // group_size
N = group_size
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps