[Refactor] Reduce duplicate code in per_token_group_quant cuda kernels (#30496)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye 2025-12-12 16:45:23 -05:00 committed by GitHub
parent 13618626df
commit 1e6b115300
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -22,6 +22,62 @@ __device__ __forceinline__ float GroupReduceMax(float val) {
return val;
}
template <typename T, bool SCALE_UE8M0>
__device__ __forceinline__ float ComputeGroupScale(
const T* __restrict__ group_input, T* __restrict__ smem_group,
const int group_size, const int lane_id, const int threads_per_group,
const float eps, const float max_8bit) {
float local_absmax = eps;
constexpr int vec_size = 16 / sizeof(T);
// copy global -> shared & compute absmax
auto scalar_op_cache = [&] __device__(T & dst, const T& src) {
float abs_v = fabsf(static_cast<float>(src));
local_absmax = fmaxf(local_absmax, abs_v);
dst = src;
};
vllm::vectorize_with_alignment<vec_size>(
group_input, // in
smem_group, // out (shared)
group_size, // elements per group
lane_id, // thread id
threads_per_group, // stride in group
scalar_op_cache); // scalar handler
local_absmax = GroupReduceMax(local_absmax);
float y_s = local_absmax / max_8bit;
if constexpr (SCALE_UE8M0) {
y_s = exp2f(ceilf(log2f(fmaxf(fabsf(y_s), 1e-10f))));
}
return y_s;
}
template <typename T, typename DST_DTYPE>
__device__ __forceinline__ void QuantizeGroup(
const T* __restrict__ smem_group, DST_DTYPE* __restrict__ group_output,
const int group_size, const int lane_id, const int threads_per_group,
const float y_s, const float min_8bit, const float max_8bit) {
constexpr int vec_size = 16 / sizeof(T);
// quantize shared -> global 8-bit
auto scalar_op_quant = [&] __device__(DST_DTYPE & dst, const T& src) {
float q = fminf(fmaxf(static_cast<float>(src) / y_s, min_8bit), max_8bit);
dst = DST_DTYPE(q);
};
vllm::vectorize_with_alignment<vec_size>(
smem_group, // in (shared)
group_output, // out (global quant tensor)
group_size, // elements
lane_id, // tid
threads_per_group, // stride
scalar_op_quant); // scalar handler
}
template <typename T, typename DST_DTYPE, bool IS_COLUMN_MAJOR = false,
bool SCALE_UE8M0 = false, typename scale_packed_t = float>
__global__ void per_token_group_quant_8bit_kernel(
@ -38,8 +94,6 @@ __global__ void per_token_group_quant_8bit_kernel(
const int64_t global_group_id = block_group_id + local_group_id;
const int64_t block_group_offset = global_group_id * group_size;
float local_absmax = eps;
using scale_element_t = float;
static_assert(sizeof(scale_packed_t) % sizeof(scale_element_t) == 0);
@ -68,30 +122,9 @@ __global__ void per_token_group_quant_8bit_kernel(
T* smem = reinterpret_cast<T*>(smem_raw);
T* smem_group = smem + local_group_id * group_size;
constexpr int vec_size = 16 / sizeof(T);
using vec_t = vllm::vec_n_t<T, vec_size>;
// copy global -> shared & compute absmax
auto scalar_op_cache = [&] __device__(T & dst, const T& src) {
float abs_v = fabsf(static_cast<float>(src));
local_absmax = fmaxf(local_absmax, abs_v);
dst = src;
};
vllm::vectorize_with_alignment<vec_size>(
group_input, // in
smem_group, // out (shared)
group_size, // elements per group
lane_id, // thread id
threads_per_group, // stride in group
scalar_op_cache); // scalar handler
local_absmax = GroupReduceMax(local_absmax);
float y_s = local_absmax / max_8bit;
if constexpr (SCALE_UE8M0) {
y_s = exp2f(ceilf(log2f(fmaxf(fabsf(y_s), 1e-10f))));
}
const float y_s = ComputeGroupScale<T, SCALE_UE8M0>(
group_input, smem_group, group_size, lane_id, threads_per_group, eps,
max_8bit);
scale_element_t y_s_quant = y_s;
@ -101,19 +134,24 @@ __global__ void per_token_group_quant_8bit_kernel(
__syncthreads();
// quantize shared -> global 8-bit
auto scalar_op_quant = [&] __device__(DST_DTYPE & dst, const T& src) {
float q = fminf(fmaxf(static_cast<float>(src) / y_s, min_8bit), max_8bit);
dst = DST_DTYPE(q);
};
QuantizeGroup<T, DST_DTYPE>(smem_group, group_output, group_size, lane_id,
threads_per_group, y_s, min_8bit, max_8bit);
}
vllm::vectorize_with_alignment<vec_size>(
smem_group, // in (shared)
group_output, // out (global quant tensor)
group_size, // elements
lane_id, // tid
threads_per_group, // stride
scalar_op_quant); // scalar handler
inline int GetGroupsPerBlock(int64_t num_groups) {
if (num_groups % 16 == 0) {
return 16;
}
if (num_groups % 8 == 0) {
return 8;
}
if (num_groups % 4 == 0) {
return 4;
}
if (num_groups % 2 == 0) {
return 2;
}
return 1;
}
void per_token_group_quant_8bit(const torch::Tensor& input,
@ -133,17 +171,7 @@ void per_token_group_quant_8bit(const torch::Tensor& input,
constexpr int THREADS_PER_GROUP = 16;
int groups_per_block = 1;
if (num_groups % 16 == 0) {
groups_per_block = 16;
} else if (num_groups % 8 == 0) {
groups_per_block = 8;
} else if (num_groups % 4 == 0) {
groups_per_block = 4;
} else if (num_groups % 2 == 0) {
groups_per_block = 2;
}
const int groups_per_block = GetGroupsPerBlock(num_groups);
auto dst_type = output_q.scalar_type();
const int num_blocks = num_groups / groups_per_block;
@ -225,8 +253,6 @@ __global__ void per_token_group_quant_8bit_packed_kernel(
const int64_t block_group_offset = global_group_id * group_size;
float local_absmax = eps;
const T* group_input = input + block_group_offset;
DST_DTYPE* group_output =
static_cast<DST_DTYPE*>(output_q) + block_group_offset;
@ -235,29 +261,9 @@ __global__ void per_token_group_quant_8bit_packed_kernel(
extern __shared__ __align__(16) char smem_raw[];
T* smem = reinterpret_cast<T*>(smem_raw);
T* smem_group = smem + local_group_id * group_size;
constexpr int vec_size = 16 / sizeof(T);
using vec_t = vllm::vec_n_t<T, vec_size>;
// copy global -> shared & compute absmax
auto scalar_op_cache = [&] __device__(T & dst, const T& src) {
float abs_v = fabsf(static_cast<float>(src));
local_absmax = fmaxf(local_absmax, abs_v);
dst = src;
};
vllm::vectorize_with_alignment<vec_size>(
group_input, // in
smem_group, // out (shared)
group_size, // elements per group
lane_id, // thread id
threads_per_group, // stride in group
scalar_op_cache); // scalar handler
local_absmax = GroupReduceMax(local_absmax);
float y_s = local_absmax / max_8bit;
y_s = exp2f(ceilf(log2f(fmaxf(fabsf(y_s), 1e-10f))));
const float y_s =
ComputeGroupScale<T, true>(group_input, smem_group, group_size, lane_id,
threads_per_group, eps, max_8bit);
// pack 4 scales into a uint32
if (lane_id == 0) {
@ -284,19 +290,8 @@ __global__ void per_token_group_quant_8bit_packed_kernel(
__syncthreads();
// quantize shared -> global 8-bit
auto scalar_op_quant = [&] __device__(DST_DTYPE & dst, const T& src) {
float q = fminf(fmaxf(static_cast<float>(src) / y_s, min_8bit), max_8bit);
dst = DST_DTYPE(q);
};
vllm::vectorize_with_alignment<vec_size>(
smem_group, // in (shared)
group_output, // out (global quant tensor)
group_size, // elements
lane_id, // tid
threads_per_group, // stride
scalar_op_quant); // scalar handler
QuantizeGroup<T, DST_DTYPE>(smem_group, group_output, group_size, lane_id,
threads_per_group, y_s, min_8bit, max_8bit);
}
void per_token_group_quant_8bit_packed(const torch::Tensor& input,
@ -337,17 +332,7 @@ void per_token_group_quant_8bit_packed(const torch::Tensor& input,
constexpr int THREADS_PER_GROUP = 16;
int groups_per_block = 1;
if (num_groups % 16 == 0) {
groups_per_block = 16;
} else if (num_groups % 8 == 0) {
groups_per_block = 8;
} else if (num_groups % 4 == 0) {
groups_per_block = 4;
} else if (num_groups % 2 == 0) {
groups_per_block = 2;
}
const int groups_per_block = GetGroupsPerBlock(num_groups);
auto dst_type = output_q.scalar_type();
const int num_blocks = num_groups / groups_per_block;