[Perf] Vectorize static / dynamic INT8 quant kernels (#19233)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye 2025-06-12 09:51:41 -04:00 committed by GitHub
parent 1129e2b1ab
commit b6efafd9e4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 411 additions and 97 deletions

View File

@ -0,0 +1,200 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import copy
import itertools
import torch
from weight_shapes import WEIGHT_SHAPES
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
from vllm._custom_ops import scaled_int8_quant as vllm_scaled_int8_quant
from vllm.triton_utils import triton
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384],
x_log=False,
line_arg="provider",
line_vals=[
"torch-bf16",
# "int8-tensor-w-token-a",
"int8-tensor-w-tensor-a",
"int8-channel-w-token-a",
# "int8-channel-w-tensor-a",
# "int8-tensor-w-token-a-noquant",
"int8-tensor-w-tensor-a-noquant",
"int8-channel-w-token-a-noquant",
# "int8-channel-w-tensor-a-noquant",
],
line_names=[
"torch-bf16",
# "int8-tensor-w-token-a",
"int8-tensor-w-tensor-a",
"int8-channel-w-token-a",
# "int8-channel-w-tensor-a",
# "int8-tensor-w-token-a-noquant",
"int8-tensor-w-tensor-a-noquant",
"int8-channel-w-token-a-noquant",
# "int8-channel-w-tensor-a-noquant",
],
ylabel="TFLOP/s (larger is better)",
plot_name="BF16 vs INT8 GEMMs",
args={},
)
)
def benchmark(batch_size, provider, N, K):
M = batch_size
device = "cuda"
dtype = torch.bfloat16
a = torch.randn((M, K), device=device, dtype=dtype)
b = torch.randn((N, K), device=device, dtype=dtype)
quantiles = [0.5, 0.2, 0.8]
if "torch-bf16" in provider:
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: torch.nn.functional.linear(a, b), quantiles=quantiles
)
elif "int8" in provider:
# Weights are always quantized ahead of time
if "noquant" in provider:
# For "no quant", we don't measure the time for activations
if "tensor-w-token-a" in provider:
# Dynamic per-token quant for A, static per-tensor quant for B
scale_b = torch.tensor([1.0], device=device, dtype=torch.float32)
b_int8, scale_b_int8, _ = vllm_scaled_int8_quant(b, scale_b)
assert scale_b_int8.numel() == 1
a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a)
elif "tensor-w-tensor-a" in provider:
# Static per-tensor quantization with fixed scales for both A and B
scale_a = torch.tensor([1.0], device=device, dtype=torch.float32)
scale_b = torch.tensor([1.0], device=device, dtype=torch.float32)
b_int8, scale_b_int8, _ = vllm_scaled_int8_quant(b, scale_b)
assert scale_b_int8.numel() == 1
a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a, scale_a)
elif "channel-w-token-a" in provider:
# Dynamic per-channel quantization for weights, per-token quant for A
b_int8, scale_b_int8, _ = vllm_scaled_int8_quant(b)
assert scale_b_int8.numel() == N
a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a)
elif "channel-w-tensor-a" in provider:
# Dynamic per-channel quantization for weights, per-tensor quant for A
scale_a = torch.tensor([1.0], device=device, dtype=torch.float32)
b_int8, scale_b_int8, _ = vllm_scaled_int8_quant(b)
assert scale_b_int8.numel() == N
a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a, scale_a)
def run_quant():
return vllm_scaled_mm(a_int8, b_int8, scale_a_int8, scale_b_int8, dtype)
else:
# Quantize the activations during the GEMM call
if "tensor-w-token-a" in provider:
# Dynamic per-token quant for A, static per-tensor quant for B
scale_b = torch.tensor([1.0], device=device, dtype=torch.float32)
b_int8, scale_b_int8, _ = vllm_scaled_int8_quant(b, scale_b)
assert scale_b_int8.numel() == 1
def run_quant():
a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a)
return vllm_scaled_mm(
a_int8, b_int8, scale_a_int8, scale_b_int8, dtype
)
elif "tensor-w-tensor-a" in provider:
# Static per-tensor quantization with fixed scales for both A and B
scale_a = torch.tensor([1.0], device=device, dtype=torch.float32)
scale_b = torch.tensor([1.0], device=device, dtype=torch.float32)
b_int8, scale_b_int8, _ = vllm_scaled_int8_quant(b, scale_b)
assert scale_b_int8.numel() == 1
def run_quant():
a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a, scale_a)
return vllm_scaled_mm(
a_int8, b_int8, scale_a_int8, scale_b_int8, dtype
)
elif "channel-w-token-a" in provider:
# Dynamic per-channel quant for weights, per-token quant for A
b_int8, scale_b_int8, _ = vllm_scaled_int8_quant(b)
assert scale_b_int8.numel() == N
def run_quant():
a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a)
return vllm_scaled_mm(
a_int8, b_int8, scale_a_int8, scale_b_int8, dtype
)
elif "channel-w-tensor-a" in provider:
# Dynamic per-channel quant for weights, static per-tensor quant for A
scale_a = torch.tensor([1.0], device=device, dtype=torch.float32)
b_int8, scale_b_int8, _ = vllm_scaled_int8_quant(b)
assert scale_b_int8.numel() == N
def run_quant():
a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a, scale_a)
return vllm_scaled_mm(
a_int8, b_int8, scale_a_int8, scale_b_int8, dtype
)
b_int8 = b_int8.t()
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: run_quant(), quantiles=quantiles
)
# Calculate TFLOP/s, two flops per multiply-add
tflops = lambda ms: (2 * M * N * K) * 1e-12 / (ms * 1e-3)
return tflops(ms), tflops(max_ms), tflops(min_ms)
def prepare_shapes(args):
KN_model_names = []
models_tps = list(itertools.product(args.models, args.tp_sizes))
for model, tp_size in models_tps:
assert model in WEIGHT_SHAPES
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
KN.append(model)
KN_model_names.append(KN)
return KN_model_names
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--models",
nargs="+",
type=str,
default=["meta-llama/Llama-3.1-8B-Instruct"],
choices=[*WEIGHT_SHAPES.keys()],
help="List of models to benchmark",
)
parser.add_argument(
"--tp-sizes",
nargs="+",
type=int,
default=[1],
help="List of tensor parallel sizes",
)
args = parser.parse_args()
KN_model_names = prepare_shapes(args)
for K, N, model_name in KN_model_names:
print(f"{model_name}, N={N} K={K}, BF16 vs INT8 GEMMs TFLOP/s:")
benchmark.run(
print_data=True,
show_plots=True,
save_path=f"bench_int8_res_n{N}_k{K}",
N=N,
K=K,
)
print("Benchmark finished!")

View File

@ -1,15 +1,17 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include <cmath>
#include "../../dispatch_utils.h"
#include "../vectorization_utils.cuh"
#ifndef USE_ROCM
#include <cub/util_type.cuh>
#include <cub/cub.cuh>
#include <cub/util_type.cuh>
#else
#include <hipcub/util_type.hpp>
#include <hipcub/hipcub.hpp>
#include <hipcub/util_type.hpp>
#endif
static inline __device__ int8_t float_to_int8_rn(float x) {
@ -103,134 +105,170 @@ static inline __device__ int8_t int32_to_int8(int32_t x) {
namespace vllm {
template <typename scalar_t, typename scale_type>
template <typename scalar_t, typename scale_t>
__global__ void static_scaled_int8_quant_kernel(
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
scale_type const* scale_ptr, const int hidden_size) {
int const tid = threadIdx.x;
int64_t const token_idx = blockIdx.x;
scale_type const scale = *scale_ptr;
const scalar_t* __restrict__ input, int8_t* __restrict__ output,
const scale_t* scale_ptr, const int hidden_size) {
const int tid = threadIdx.x;
const int stride = blockDim.x;
const int64_t token_idx = blockIdx.x;
const float scale = *scale_ptr;
// Must be performed using 64-bit math to avoid integer overflow.
out += token_idx * hidden_size;
input += token_idx * hidden_size;
const scalar_t* row_in = input + token_idx * hidden_size;
int8_t* row_out = output + token_idx * hidden_size;
for (int i = tid; i < hidden_size; i += blockDim.x) {
out[i] = float_to_int8_rn(static_cast<float>(input[i]) / scale);
}
vectorize_with_alignment<16>(
row_in, row_out, hidden_size, tid, stride,
[=] __device__(int8_t& dst, const scalar_t& src) {
dst = float_to_int8_rn(static_cast<float>(src) / scale);
});
}
template <typename scalar_t, typename scale_type, typename azp_type>
template <typename scalar_t, typename scale_t, typename azp_t>
__global__ void static_scaled_int8_azp_quant_kernel(
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
scale_type const* scale_ptr, azp_type const* azp_ptr,
const int hidden_size) {
int const tid = threadIdx.x;
int64_t const token_idx = blockIdx.x;
scale_type const scale = *scale_ptr;
azp_type const azp = *azp_ptr;
const scalar_t* __restrict__ input, int8_t* __restrict__ output,
const scale_t* scale_ptr, const azp_t* azp_ptr, const int hidden_size) {
const int tid = threadIdx.x;
const int stride = blockDim.x;
const int64_t token_idx = blockIdx.x;
const float scale = *scale_ptr;
const azp_t azp = *azp_ptr;
const float inv_s = 1.0f / scale;
// Must be performed using 64-bit math to avoid integer overflow.
out += token_idx * hidden_size;
input += token_idx * hidden_size;
const scalar_t* row_in = input + token_idx * hidden_size;
int8_t* row_out = output + token_idx * hidden_size;
for (int i = tid; i < hidden_size; i += blockDim.x) {
auto const val = static_cast<float>(input[i]);
auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale) + azp);
out[i] = quant_val;
}
vectorize_with_alignment<16>(
row_in, row_out, hidden_size, tid, stride,
[=] __device__(int8_t& dst, const scalar_t& src) {
const auto v = static_cast<float>(src) * inv_s;
dst = int32_to_int8(float_to_int32_rn(v) + azp);
});
}
template <typename scalar_t, typename scale_type>
template <typename scalar_t, typename scale_t>
__global__ void dynamic_scaled_int8_quant_kernel(
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
scale_type* scale, const int hidden_size) {
int const tid = threadIdx.x;
int64_t const token_idx = blockIdx.x;
float absmax_val = 0.0f;
float const zero = 0.0f;
const scalar_t* __restrict__ input, int8_t* __restrict__ output,
scale_t* scale_out, const int hidden_size) {
const int tid = threadIdx.x;
const int stride = blockDim.x;
const int64_t token_idx = blockIdx.x;
// Must be performed using 64-bit math to avoid integer overflow.
out += token_idx * hidden_size;
input += token_idx * hidden_size;
const scalar_t* row_in = input + token_idx * hidden_size;
int8_t* row_out = output + token_idx * hidden_size;
for (int i = tid; i < hidden_size; i += blockDim.x) {
float val = static_cast<float>(input[i]);
val = val > zero ? val : -val;
absmax_val = val > absmax_val ? val : absmax_val;
// calculate for absmax
float thread_max = 0.f;
for (int i = tid; i < hidden_size; i += stride) {
const auto v = fabsf(static_cast<float>(row_in[i]));
thread_max = fmaxf(thread_max, v);
}
using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStorage;
float const block_absmax_val_maybe =
BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x);
__shared__ float block_absmax_val;
using BlockReduce = cub::BlockReduce<float, 256>;
__shared__ typename BlockReduce::TempStorage tmp;
float block_max = BlockReduce(tmp).Reduce(thread_max, cub::Max{}, blockDim.x);
__shared__ float absmax;
if (tid == 0) {
block_absmax_val = block_absmax_val_maybe;
scale[token_idx] = block_absmax_val / 127.0f;
absmax = block_max;
scale_out[blockIdx.x] = absmax / 127.f;
}
__syncthreads();
float const tmp_scale = 127.0f / block_absmax_val;
for (int i = tid; i < hidden_size; i += blockDim.x) {
out[i] = float_to_int8_rn(static_cast<float>(input[i]) * tmp_scale);
}
float inv_s = (absmax == 0.f) ? 0.f : 127.f / absmax;
// 2. quantize
vectorize_with_alignment<16>(
row_in, row_out, hidden_size, tid, stride,
[=] __device__(int8_t& dst, const scalar_t& src) {
dst = float_to_int8_rn(static_cast<float>(src) * inv_s);
});
}
template <typename scalar_t, typename scale_type, typename azp_type>
// MinMax structure to hold min and max values in one go
struct MinMax {
float min, max;
__host__ __device__ MinMax()
: min(std::numeric_limits<float>::max()),
max(std::numeric_limits<float>::lowest()) {}
__host__ __device__ explicit MinMax(float v) : min(v), max(v) {}
// add a value to the MinMax
__host__ __device__ MinMax& operator+=(float v) {
min = fminf(min, v);
max = fmaxf(max, v);
return *this;
}
// merge two MinMax objects
__host__ __device__ MinMax& operator&=(const MinMax& other) {
min = fminf(min, other.min);
max = fmaxf(max, other.max);
return *this;
}
};
__host__ __device__ inline MinMax operator+(MinMax a, float v) {
return a += v;
}
__host__ __device__ inline MinMax operator&(MinMax a, const MinMax& b) {
return a &= b;
}
template <typename scalar_t, typename scale_t, typename azp_t>
__global__ void dynamic_scaled_int8_azp_quant_kernel(
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
scale_type* scale, azp_type* azp, const int hidden_size) {
int64_t const token_idx = blockIdx.x;
const scalar_t* __restrict__ input, int8_t* __restrict__ output,
scale_t* scale_out, azp_t* azp_out, const int hidden_size) {
const int tid = threadIdx.x;
const int stride = blockDim.x;
const int64_t token_idx = blockIdx.x;
// Must be performed using 64-bit math to avoid integer overflow.
out += token_idx * hidden_size;
input += token_idx * hidden_size;
const scalar_t* row_in = input + token_idx * hidden_size;
int8_t* row_out = output + token_idx * hidden_size;
// Scan for the min and max value for this token
float max_val = std::numeric_limits<float>::min();
float min_val = std::numeric_limits<float>::max();
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
auto val = static_cast<float>(input[i]);
max_val = std::max(max_val, val);
min_val = std::min(min_val, val);
// 1. calculate min & max
MinMax thread_mm;
for (int i = tid; i < hidden_size; i += stride) {
thread_mm += static_cast<float>(row_in[i]);
}
// Reduce the max and min values across the block
using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStorage;
max_val = BlockReduce(reduceStorage).Reduce(max_val, cub::Max{}, blockDim.x);
__syncthreads(); // Make sure min doesn't mess with max shared memory
min_val = BlockReduce(reduceStorage).Reduce(min_val, cub::Min{}, blockDim.x);
using BlockReduce = cub::BlockReduce<MinMax, 256>;
__shared__ typename BlockReduce::TempStorage tmp;
__shared__ scale_type scale_sh;
__shared__ azp_type azp_sh;
MinMax mm = BlockReduce(tmp).Reduce(
thread_mm,
[] __device__(MinMax a, const MinMax& b) {
a &= b;
return a;
},
blockDim.x);
// Compute the scale and zero point and store them, only on the first thread
if (threadIdx.x == 0) {
float const scale_val = (max_val - min_val) / 255.0f;
// Use rounding to even (same as torch.round)
auto const azp_float = std::nearbyint(-128.0f - min_val / scale_val);
auto const azp_val = static_cast<azp_type>(azp_float);
// Store the scale and azp into shared and global
scale[token_idx] = scale_sh = scale_val;
azp[token_idx] = azp_sh = azp_val;
__shared__ float scale_sh;
__shared__ azp_t azp_sh;
if (tid == 0) {
float s = (mm.max - mm.min) / 255.f;
float zp = nearbyintf(-128.f - mm.min / s); // round-to-even
scale_sh = s;
azp_sh = azp_t(zp);
scale_out[blockIdx.x] = s;
azp_out[blockIdx.x] = azp_sh;
}
// Wait for the scale and azp to be computed
__syncthreads();
float const scale_val = scale_sh;
azp_type const azp_val = azp_sh;
const float inv_s = 1.f / scale_sh;
const azp_t azp = azp_sh;
// Quantize the values
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
auto const val = static_cast<float>(input[i]);
auto const quant_val =
int32_to_int8(float_to_int32_rn(val / scale_val) + azp_val);
out[i] = quant_val;
}
// 2. quantize
vectorize_with_alignment<16>(
row_in, row_out, hidden_size, tid, stride,
[=] __device__(int8_t& dst, const scalar_t& src) {
const auto v = static_cast<float>(src) * inv_s;
dst = int32_to_int8(float_to_int32_rn(v) + azp);
});
}
} // namespace vllm
@ -247,7 +285,7 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
int const hidden_size = input.size(-1);
int const num_tokens = input.numel() / hidden_size;
dim3 const grid(num_tokens);
dim3 const block(std::min(hidden_size, 1024));
dim3 const block(std::min(hidden_size, 256));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
@ -278,7 +316,7 @@ void dynamic_scaled_int8_quant(
int const hidden_size = input.size(-1);
int const num_tokens = input.numel() / hidden_size;
dim3 const grid(num_tokens);
dim3 const block(std::min(hidden_size, 1024));
dim3 const block(std::min(hidden_size, 256));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] {

View File

@ -0,0 +1,75 @@
#pragma once
#include "vectorization.cuh"
namespace vllm {
template <int VEC_SIZE, typename InT, typename OutT, typename ScaOp>
struct DefaultVecOp {
ScaOp scalar_op;
__device__ __forceinline__ void operator()(
vec_n_t<OutT, VEC_SIZE>& dst, const vec_n_t<InT, VEC_SIZE>& src) const {
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
scalar_op(dst.val[i], src.val[i]);
}
}
};
template <int VEC_SIZE, typename InT, typename OutT, typename VecOp,
typename ScaOp>
__device__ inline void vectorize_with_alignment(
const InT* in, OutT* out, int len, int tid, int stride,
VecOp&& vec_op, // vec_n_t<InT,16> -> vec_n_t<OutT,16>
ScaOp&& scalar_op) { // InT -> OutT
static_assert(VEC_SIZE > 0 && (VEC_SIZE & (VEC_SIZE - 1)) == 0,
"VEC_SIZE must be a positive power-of-two");
constexpr int WIDTH = VEC_SIZE * sizeof(InT); // eg: 64 B
uintptr_t addr = reinterpret_cast<uintptr_t>(in);
int misalignment_offset = addr & (WIDTH - 1); // addr % 64
int alignment_bytes = WIDTH - misalignment_offset; // 64 - (addr % 64)
int prefix_elems = alignment_bytes & (WIDTH - 1); // handle 64
prefix_elems /= sizeof(InT);
prefix_elems = min(prefix_elems, len); // 0 ≤ prefix < 16
// 1. prefill the when it is unsafe to vectorize
for (int i = tid; i < prefix_elems; i += stride) {
scalar_op(out[i], in[i]);
}
in += prefix_elems;
out += prefix_elems;
len -= prefix_elems;
int num_vec = len / VEC_SIZE;
using vin_t = vec_n_t<InT, VEC_SIZE>;
using vout_t = vec_n_t<OutT, VEC_SIZE>;
auto* v_in = reinterpret_cast<const vin_t*>(in);
auto* v_out = reinterpret_cast<vout_t*>(out);
// 2. vectorize the main part
for (int i = tid; i < num_vec; i += stride) {
vout_t tmp;
vec_op(tmp, v_in[i]);
v_out[i] = tmp;
}
// 3. handle the tail
int tail_start = num_vec * VEC_SIZE;
for (int i = tid + tail_start; i < len; i += stride) {
scalar_op(out[i], in[i]);
}
}
template <int VEC_SIZE, typename InT, typename OutT, typename ScaOp>
__device__ __forceinline__ void vectorize_with_alignment(const InT* in,
OutT* out, int len,
int tid, int stride,
ScaOp&& scalar_op) {
using Vec = DefaultVecOp<VEC_SIZE, InT, OutT, std::decay_t<ScaOp>>;
vectorize_with_alignment<VEC_SIZE>(in, out, len, tid, stride, Vec{scalar_op},
std::forward<ScaOp>(scalar_op));
}
} // namespace vllm

View File

@ -11,6 +11,7 @@ from vllm.platforms import current_platform
DTYPES = [torch.half, torch.bfloat16, torch.float]
HIDDEN_SIZES = [16, 67, 768, 5137, 8193] # Arbitrary values for testing
HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases
NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing
SEEDS = [0]
SCALE = [0.1, 2.1]