mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-11 05:42:15 +08:00
refactor quant folder
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
7e3a8dc906
commit
f07e10e9bc
@ -243,7 +243,7 @@ set(VLLM_EXT_SRC
|
||||
"csrc/sampler.cu"
|
||||
"csrc/cuda_view.cu"
|
||||
"csrc/quantization/gptq/q_gemm.cu"
|
||||
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
|
||||
"csrc/quantization/8bit/int8/scaled_quant.cu"
|
||||
"csrc/quantization/fp8/common.cu"
|
||||
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
|
||||
"csrc/quantization/gguf/gguf_kernel.cu"
|
||||
@ -297,7 +297,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
|
||||
"csrc/cutlass_extensions/common.cpp"
|
||||
"csrc/attention/mla/cutlass_mla_entry.cu"
|
||||
"csrc/quantization/fp8/per_token_group_quant.cu")
|
||||
"csrc/quantization/8bit/fp8/per_token_group_quant.cu"
|
||||
"csrc/quantization/8bit/int8/per_token_group_quant.cu")
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${VLLM_EXT_SRC}"
|
||||
|
||||
@ -8,9 +8,9 @@
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "../vectorization.cuh"
|
||||
#include "../vectorization_utils.cuh"
|
||||
#include "../../dispatch_utils.h"
|
||||
#include "../../vectorization.cuh"
|
||||
#include "../../vectorization_utils.cuh"
|
||||
#include "../../../dispatch_utils.h"
|
||||
|
||||
__device__ __forceinline__ float GroupReduceMax(float val, const int tid) {
|
||||
unsigned mask = 0xffff;
|
||||
@ -212,4 +212,4 @@ void per_token_group_quant_fp8(const torch::Tensor& input,
|
||||
double fp8_max, bool scale_ue8m0) {
|
||||
per_token_group_quant_8bit(input, output_q, output_s, group_size, eps,
|
||||
fp8_min, fp8_max, scale_ue8m0);
|
||||
}
|
||||
}
|
||||
12
csrc/quantization/8bit/int8/per_token_group_quant.cu
Normal file
12
csrc/quantization/8bit/int8/per_token_group_quant.cu
Normal file
@ -0,0 +1,12 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "../per_token_group_quant_8bit.h"
|
||||
|
||||
void per_token_group_quant_int8(const torch::Tensor& input,
|
||||
torch::Tensor& output_q,
|
||||
torch::Tensor& output_s, int64_t group_size,
|
||||
double eps, double int8_min, double int8_max) {
|
||||
per_token_group_quant_8bit(input, output_q, output_s, group_size, eps,
|
||||
int8_min, int8_max);
|
||||
}
|
||||
@ -1,14 +1,10 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include "../per_token_group_quant_8bit.h"
|
||||
#endif
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "../../dispatch_utils.h"
|
||||
#include "../vectorization_utils.cuh"
|
||||
#include "../../../dispatch_utils.h"
|
||||
#include "../../vectorization_utils.cuh"
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cub/cub.cuh>
|
||||
@ -25,19 +21,9 @@ static inline __device__ int8_t float_to_int8_rn(float x) {
|
||||
static constexpr auto i8_max =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::max());
|
||||
|
||||
// To match the rounding mode of CUDA, we use nearbyint.
|
||||
// It uses the current rounding mode, which is always FE_TONEAREST on HIP.
|
||||
// If that changes in the future, we may need to set the rounding mode
|
||||
// explicitly, either at runtime or compile time.
|
||||
float dst = std::nearbyint(x);
|
||||
|
||||
// saturate
|
||||
|
||||
// See https://github.com/pytorch/pytorch/issues/127666
|
||||
// See https://github.com/llvm/llvm-project/issues/95183
|
||||
// hip-clang std::clamp __glibcxx_assert_fail host function when building on
|
||||
// Arch/gcc14. The following replaces std::clamp usage with similar logic
|
||||
// dst = std::clamp(dst, i8_min, i8_max);
|
||||
// Replace std::clamp due to hip-clang issues
|
||||
dst = (dst < i8_min) ? i8_min : (dst > i8_max) ? i8_max : dst;
|
||||
return static_cast<int8_t>(dst);
|
||||
#else
|
||||
@ -50,26 +36,16 @@ static inline __device__ int8_t float_to_int8_rn(float x) {
|
||||
|
||||
static inline __device__ int32_t float_to_int32_rn(float x) {
|
||||
#ifdef USE_ROCM
|
||||
// int32_max is not exactly representable as float.
|
||||
// Therefore, we need to be careful and manually return int32_max on overflow.
|
||||
// For symmetry, we also do the same for int32_min, even though it is exactly
|
||||
// representable as float and the conversion should be exact.
|
||||
static constexpr auto i32_min = std::numeric_limits<int32_t>::min();
|
||||
static constexpr auto i32_min_f = static_cast<float>(i32_min);
|
||||
static constexpr auto i32_max = std::numeric_limits<int32_t>::max();
|
||||
static constexpr auto i32_max_f = static_cast<float>(i32_max);
|
||||
|
||||
// To match the rounding mode of CUDA, we use nearbyint.
|
||||
// It uses the current rounding mode, which is always FE_TONEAREST on HIP.
|
||||
// If that changes in the future, we may need to set the rounding mode
|
||||
// explicitly, either at runtime or compile time.
|
||||
float dst = std::nearbyint(x);
|
||||
|
||||
// saturate on the higher end.
|
||||
if (dst >= i32_max_f) {
|
||||
return i32_max;
|
||||
}
|
||||
// saturate on the lower end.
|
||||
if (dst <= i32_min_f) {
|
||||
return i32_min;
|
||||
}
|
||||
@ -90,13 +66,7 @@ static inline __device__ int8_t int32_to_int8(int32_t x) {
|
||||
static constexpr auto i8_max =
|
||||
static_cast<int32_t>(std::numeric_limits<int8_t>::max());
|
||||
|
||||
// saturate
|
||||
|
||||
// See https://github.com/pytorch/pytorch/issues/127666
|
||||
// See https://github.com/llvm/llvm-project/issues/95183
|
||||
// hip-clang std::clamp __glibcxx_assert_fail host function when building on
|
||||
// Arch/gcc14. The following replaces std::clamp usage with similar logic
|
||||
// int32_t dst = std::clamp(x, i8_min, i8_max);
|
||||
// Replace std::clamp due to hip-clang issues
|
||||
int32_t dst = (x < i8_min) ? i8_min : (x > i8_max) ? i8_max : x;
|
||||
return static_cast<int8_t>(dst);
|
||||
#else
|
||||
@ -118,7 +88,6 @@ __global__ void static_scaled_int8_quant_kernel(
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const float scale = *scale_ptr;
|
||||
|
||||
// Must be performed using 64-bit math to avoid integer overflow.
|
||||
const scalar_t* row_in = input + token_idx * hidden_size;
|
||||
int8_t* row_out = output + token_idx * hidden_size;
|
||||
|
||||
@ -140,7 +109,6 @@ __global__ void static_scaled_int8_azp_quant_kernel(
|
||||
const azp_t azp = *azp_ptr;
|
||||
const float inv_s = 1.0f / scale;
|
||||
|
||||
// Must be performed using 64-bit math to avoid integer overflow.
|
||||
const scalar_t* row_in = input + token_idx * hidden_size;
|
||||
int8_t* row_out = output + token_idx * hidden_size;
|
||||
|
||||
@ -160,11 +128,9 @@ __global__ void dynamic_scaled_int8_quant_kernel(
|
||||
const int stride = blockDim.x;
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
|
||||
// Must be performed using 64-bit math to avoid integer overflow.
|
||||
const scalar_t* row_in = input + token_idx * hidden_size;
|
||||
int8_t* row_out = output + token_idx * hidden_size;
|
||||
|
||||
// calculate for absmax
|
||||
float thread_max = 0.f;
|
||||
vectorize_read_with_alignment<16>(
|
||||
row_in, hidden_size, tid, stride, [&] __device__(const scalar_t& src) {
|
||||
@ -183,7 +149,6 @@ __global__ void dynamic_scaled_int8_quant_kernel(
|
||||
|
||||
float inv_s = (absmax == 0.f) ? 0.f : 127.f / absmax;
|
||||
|
||||
// 2. quantize
|
||||
vectorize_with_alignment<16>(
|
||||
row_in, row_out, hidden_size, tid, stride,
|
||||
[=] __device__(int8_t& dst, const scalar_t& src) {
|
||||
@ -201,14 +166,12 @@ struct MinMax {
|
||||
|
||||
__host__ __device__ explicit MinMax(float v) : min(v), max(v) {}
|
||||
|
||||
// add a value to the MinMax
|
||||
__host__ __device__ MinMax& operator+=(float v) {
|
||||
min = fminf(min, v);
|
||||
max = fmaxf(max, v);
|
||||
return *this;
|
||||
}
|
||||
|
||||
// merge two MinMax objects
|
||||
__host__ __device__ MinMax& operator&=(const MinMax& other) {
|
||||
min = fminf(min, other.min);
|
||||
max = fmaxf(max, other.max);
|
||||
@ -231,11 +194,9 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(
|
||||
const int stride = blockDim.x;
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
|
||||
// Must be performed using 64-bit math to avoid integer overflow.
|
||||
const scalar_t* row_in = input + token_idx * hidden_size;
|
||||
int8_t* row_out = output + token_idx * hidden_size;
|
||||
|
||||
// 1. calculate min & max
|
||||
MinMax thread_mm;
|
||||
vectorize_read_with_alignment<16>(row_in, hidden_size, tid, stride,
|
||||
[&] __device__(const scalar_t& src) {
|
||||
@ -257,7 +218,7 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(
|
||||
__shared__ azp_t azp_sh;
|
||||
if (tid == 0) {
|
||||
float s = (mm.max - mm.min) / 255.f;
|
||||
float zp = nearbyintf(-128.f - mm.min / s); // round-to-even
|
||||
float zp = nearbyintf(-128.f - mm.min / s);
|
||||
scale_sh = s;
|
||||
azp_sh = azp_t(zp);
|
||||
scale_out[blockIdx.x] = s;
|
||||
@ -268,7 +229,6 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(
|
||||
const float inv_s = 1.f / scale_sh;
|
||||
const azp_t azp = azp_sh;
|
||||
|
||||
// 2. quantize
|
||||
vectorize_with_alignment<16>(
|
||||
row_in, row_out, hidden_size, tid, stride,
|
||||
[=] __device__(int8_t& dst, const scalar_t& src) {
|
||||
@ -339,14 +299,4 @@ void dynamic_scaled_int8_quant(
|
||||
hidden_size);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#ifndef USE_ROCM
|
||||
void per_token_group_quant_int8(const torch::Tensor& input,
|
||||
torch::Tensor& output_q,
|
||||
torch::Tensor& output_s, int64_t group_size,
|
||||
double eps, double int8_min, double int8_max) {
|
||||
per_token_group_quant_8bit(input, output_q, output_s, group_size, eps,
|
||||
int8_min, int8_max);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@ -1,7 +1,6 @@
|
||||
#pragma once
|
||||
#include <torch/all.h>
|
||||
|
||||
// TODO(wentao): refactor the folder to 8bit, then includes fp8 and int8 folders
|
||||
// 8-bit per-token-group quantization helper used by both FP8 and INT8
|
||||
void per_token_group_quant_8bit(const torch::Tensor& input,
|
||||
torch::Tensor& output_q,
|
||||
Loading…
x
Reference in New Issue
Block a user