From 783921d889f942d6ec121e84d8eaa4704e5e0f27 Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Fri, 4 Jul 2025 03:06:24 -0400 Subject: [PATCH] [Perf] Optimize Vectorization Utils for Int 8 Quantization Kernels (#20331) Signed-off-by: yewentao256 --- .../compressed_tensors/int8_quant_kernels.cu | 16 +-- csrc/quantization/vectorization_utils.cuh | 97 +++++++++++++++++++ 2 files changed, 106 insertions(+), 7 deletions(-) diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index 87117a165fe92..5cd2ac179768b 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -162,10 +162,11 @@ __global__ void dynamic_scaled_int8_quant_kernel( // calculate for absmax float thread_max = 0.f; - for (int i = tid; i < hidden_size; i += stride) { - const auto v = fabsf(static_cast(row_in[i])); - thread_max = fmaxf(thread_max, v); - } + vectorize_read_with_alignment<16>( + row_in, hidden_size, tid, stride, [&] __device__(const scalar_t& src) { + const float v = fabsf(static_cast(src)); + thread_max = fmaxf(thread_max, v); + }); using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage tmp; float block_max = BlockReduce(tmp).Reduce(thread_max, cub::Max{}, blockDim.x); @@ -232,9 +233,10 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel( // 1. calculate min & max MinMax thread_mm; - for (int i = tid; i < hidden_size; i += stride) { - thread_mm += static_cast(row_in[i]); - } + vectorize_read_with_alignment<16>(row_in, hidden_size, tid, stride, + [&] __device__(const scalar_t& src) { + thread_mm += static_cast(src); + }); using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage tmp; diff --git a/csrc/quantization/vectorization_utils.cuh b/csrc/quantization/vectorization_utils.cuh index 8d3c1d6d3b9fb..8aa0147df6ba8 100644 --- a/csrc/quantization/vectorization_utils.cuh +++ b/csrc/quantization/vectorization_utils.cuh @@ -27,6 +27,26 @@ __device__ inline void vectorize_with_alignment( constexpr int WIDTH = VEC_SIZE * sizeof(InT); // eg: 64 B uintptr_t addr = reinterpret_cast(in); + // fast path when the whole region is already aligned + // Note: currently the output is guaranteed to be same as the input, so we + // don't check it here, comments here just for future reference. + bool can_vec = ((addr & (WIDTH - 1)) == 0) && ((len & (VEC_SIZE - 1)) == 0); + if (can_vec) { + int num_vec = len / VEC_SIZE; + + using vin_t = vec_n_t; + using vout_t = vec_n_t; + auto* v_in = reinterpret_cast(in); + auto* v_out = reinterpret_cast(out); + + for (int i = tid; i < num_vec; i += stride) { + vout_t tmp; + vec_op(tmp, v_in[i]); + v_out[i] = tmp; + } + return; + } + int misalignment_offset = addr & (WIDTH - 1); // addr % 64 int alignment_bytes = WIDTH - misalignment_offset; // 64 - (addr % 64) int prefix_elems = alignment_bytes & (WIDTH - 1); // handle 64 @@ -72,4 +92,81 @@ __device__ __forceinline__ void vectorize_with_alignment(const InT* in, std::forward(scalar_op)); } +template +struct DefaultReadVecOp { + ScaOp scalar_op; + + __device__ __forceinline__ void operator()( + const vec_n_t& src) const { +#pragma unroll + for (int i = 0; i < VEC_SIZE; ++i) { + scalar_op(src.val[i]); + } + } +}; + +// read-only version: iterate over the input with alignment guarantees +template +__device__ inline void vectorize_read_with_alignment(const InT* in, int len, + int tid, int stride, + VecOp&& vec_op, + ScaOp&& scalar_op) { + static_assert(VEC_SIZE > 0 && (VEC_SIZE & (VEC_SIZE - 1)) == 0, + "VEC_SIZE must be a positive power-of-two"); + constexpr int WIDTH = VEC_SIZE * sizeof(InT); + uintptr_t addr = reinterpret_cast(in); + + // fast path when the whole region is already aligned + bool can_vec = ((addr & (WIDTH - 1)) == 0) && ((len & (VEC_SIZE - 1)) == 0); + if (can_vec) { + int num_vec = len / VEC_SIZE; + + using vin_t = vec_n_t; + auto* v_in = reinterpret_cast(in); + + for (int i = tid; i < num_vec; i += stride) { + vec_op(v_in[i]); + } + return; + } + + int misalignment_offset = addr & (WIDTH - 1); + int alignment_bytes = WIDTH - misalignment_offset; + int prefix_elems = alignment_bytes & (WIDTH - 1); + prefix_elems /= sizeof(InT); + prefix_elems = min(prefix_elems, len); + + // 1. handle the possibly unaligned prefix with scalar access. + for (int i = tid; i < prefix_elems; i += stride) { + scalar_op(in[i]); + } + + in += prefix_elems; + len -= prefix_elems; + + int num_vec = len / VEC_SIZE; + using vin_t = vec_n_t; + auto* v_in = reinterpret_cast(in); + + // 2. vectorized traversal of the main aligned region. + for (int i = tid; i < num_vec; i += stride) { + vec_op(v_in[i]); + } + + // 3. handle remaining tail elements. + int tail_start = num_vec * VEC_SIZE; + for (int i = tid + tail_start; i < len; i += stride) { + scalar_op(in[i]); + } +} + +// overload that requires only a scalar_op +template +__device__ __forceinline__ void vectorize_read_with_alignment( + const InT* in, int len, int tid, int stride, ScaOp&& scalar_op) { + using Vec = DefaultReadVecOp>; + vectorize_read_with_alignment(in, len, tid, stride, Vec{scalar_op}, + std::forward(scalar_op)); +} + } // namespace vllm