mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-07 23:01:20 +08:00
[Perf] Cuda Kernel for Int8 Per Token Group Quant (#21476)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
41d3082c41
commit
75d29cf4e1
@ -292,6 +292,11 @@ void per_token_group_quant_fp8(const torch::Tensor& input,
|
|||||||
torch::Tensor& output_q, torch::Tensor& output_s,
|
torch::Tensor& output_q, torch::Tensor& output_s,
|
||||||
int64_t group_size, double eps, double fp8_min,
|
int64_t group_size, double eps, double fp8_min,
|
||||||
double fp8_max, bool scale_ue8m0);
|
double fp8_max, bool scale_ue8m0);
|
||||||
|
|
||||||
|
void per_token_group_quant_int8(const torch::Tensor& input,
|
||||||
|
torch::Tensor& output_q,
|
||||||
|
torch::Tensor& output_s, int64_t group_size,
|
||||||
|
double eps, double int8_min, double int8_max);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include <torch/all.h>
|
#include <torch/all.h>
|
||||||
|
|
||||||
|
#include "../per_token_group_quant_8bit.h"
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
|
||||||
#include "../../dispatch_utils.h"
|
#include "../../dispatch_utils.h"
|
||||||
@ -336,3 +338,11 @@ void dynamic_scaled_int8_quant(
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void per_token_group_quant_int8(const torch::Tensor& input,
|
||||||
|
torch::Tensor& output_q,
|
||||||
|
torch::Tensor& output_s, int64_t group_size,
|
||||||
|
double eps, double int8_min, double int8_max) {
|
||||||
|
per_token_group_quant_8bit(input, output_q, output_s, group_size, eps,
|
||||||
|
int8_min, int8_max);
|
||||||
|
}
|
||||||
@ -1,6 +1,8 @@
|
|||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include <c10/util/Float8_e4m3fn.h>
|
#include <c10/util/Float8_e4m3fn.h>
|
||||||
|
|
||||||
|
#include "../per_token_group_quant_8bit.h"
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
|
||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
@ -120,7 +122,7 @@ void per_token_group_quant_8bit(const torch::Tensor& input,
|
|||||||
torch::Tensor& output_q,
|
torch::Tensor& output_q,
|
||||||
torch::Tensor& output_s, int64_t group_size,
|
torch::Tensor& output_s, int64_t group_size,
|
||||||
double eps, double min_8bit, double max_8bit,
|
double eps, double min_8bit, double max_8bit,
|
||||||
bool scale_ue8m0 = false) {
|
bool scale_ue8m0) {
|
||||||
TORCH_CHECK(input.is_contiguous());
|
TORCH_CHECK(input.is_contiguous());
|
||||||
TORCH_CHECK(output_q.is_contiguous());
|
TORCH_CHECK(output_q.is_contiguous());
|
||||||
|
|
||||||
@ -198,6 +200,8 @@ void per_token_group_quant_8bit(const torch::Tensor& input,
|
|||||||
input.scalar_type(), "per_token_group_quant_8bit", ([&] {
|
input.scalar_type(), "per_token_group_quant_8bit", ([&] {
|
||||||
if (dst_type == at::ScalarType::Float8_e4m3fn) {
|
if (dst_type == at::ScalarType::Float8_e4m3fn) {
|
||||||
LAUNCH_KERNEL(scalar_t, c10::Float8_e4m3fn);
|
LAUNCH_KERNEL(scalar_t, c10::Float8_e4m3fn);
|
||||||
|
} else if (dst_type == at::ScalarType::Char) {
|
||||||
|
LAUNCH_KERNEL(scalar_t, int8_t);
|
||||||
}
|
}
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
|||||||
10
csrc/quantization/per_token_group_quant_8bit.h
Normal file
10
csrc/quantization/per_token_group_quant_8bit.h
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
#pragma once
|
||||||
|
#include <torch/all.h>
|
||||||
|
|
||||||
|
// TODO(wentao): refactor the folder to 8bit, then includes fp8 and int8 folders
|
||||||
|
// 8-bit per-token-group quantization helper used by both FP8 and INT8
|
||||||
|
void per_token_group_quant_8bit(const torch::Tensor& input,
|
||||||
|
torch::Tensor& output_q,
|
||||||
|
torch::Tensor& output_s, int64_t group_size,
|
||||||
|
double eps, double min_8bit, double max_8bit,
|
||||||
|
bool scale_ue8m0 = false);
|
||||||
@ -624,6 +624,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
ops.impl("per_token_group_fp8_quant", torch::kCUDA,
|
ops.impl("per_token_group_fp8_quant", torch::kCUDA,
|
||||||
&per_token_group_quant_fp8);
|
&per_token_group_quant_fp8);
|
||||||
|
|
||||||
|
// Compute per-token-group INT8 quantized tensor and scaling factor.
|
||||||
|
ops.def(
|
||||||
|
"per_token_group_quant_int8(Tensor input, Tensor! output_q, Tensor! "
|
||||||
|
"output_s, int group_size, float eps, float int8_min, float int8_max) -> "
|
||||||
|
"()");
|
||||||
|
ops.impl("per_token_group_quant_int8", torch::kCUDA,
|
||||||
|
&per_token_group_quant_int8);
|
||||||
|
|
||||||
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
|
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
|
||||||
ops.def(
|
ops.def(
|
||||||
"rearrange_kn_weight_as_n32k16_order(Tensor b_qweight, Tensor b_scales, "
|
"rearrange_kn_weight_as_n32k16_order(Tensor b_qweight, Tensor b_scales, "
|
||||||
|
|||||||
@ -238,13 +238,20 @@ def per_token_group_quant_int8(
|
|||||||
int8_min = iinfo.min
|
int8_min = iinfo.min
|
||||||
|
|
||||||
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
||||||
M = x.numel() // group_size
|
|
||||||
N = group_size
|
|
||||||
x_s = torch.empty(
|
x_s = torch.empty(
|
||||||
x.shape[:-1] + (x.shape[-1] // group_size, ),
|
x.shape[:-1] + (x.shape[-1] // group_size, ),
|
||||||
device=x.device,
|
device=x.device,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
)
|
)
|
||||||
|
# prefer CUDA kernel if available
|
||||||
|
if current_platform.is_cuda():
|
||||||
|
torch.ops._C.per_token_group_quant_int8(x, x_q, x_s, group_size, eps,
|
||||||
|
float(int8_min),
|
||||||
|
float(int8_max))
|
||||||
|
return x_q, x_s
|
||||||
|
|
||||||
|
M = x.numel() // group_size
|
||||||
|
N = group_size
|
||||||
|
|
||||||
BLOCK = triton.next_power_of_2(N)
|
BLOCK = triton.next_power_of_2(N)
|
||||||
# heuristics for number of warps
|
# heuristics for number of warps
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user