diff --git a/benchmarks/kernels/bench_int8_gemm.py b/benchmarks/kernels/bench_int8_gemm.py new file mode 100644 index 000000000000..e6adcaa00ded --- /dev/null +++ b/benchmarks/kernels/bench_int8_gemm.py @@ -0,0 +1,200 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse +import copy +import itertools + +import torch +from weight_shapes import WEIGHT_SHAPES + +from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm +from vllm._custom_ops import scaled_int8_quant as vllm_scaled_int8_quant +from vllm.triton_utils import triton + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384], + x_log=False, + line_arg="provider", + line_vals=[ + "torch-bf16", + # "int8-tensor-w-token-a", + "int8-tensor-w-tensor-a", + "int8-channel-w-token-a", + # "int8-channel-w-tensor-a", + # "int8-tensor-w-token-a-noquant", + "int8-tensor-w-tensor-a-noquant", + "int8-channel-w-token-a-noquant", + # "int8-channel-w-tensor-a-noquant", + ], + line_names=[ + "torch-bf16", + # "int8-tensor-w-token-a", + "int8-tensor-w-tensor-a", + "int8-channel-w-token-a", + # "int8-channel-w-tensor-a", + # "int8-tensor-w-token-a-noquant", + "int8-tensor-w-tensor-a-noquant", + "int8-channel-w-token-a-noquant", + # "int8-channel-w-tensor-a-noquant", + ], + ylabel="TFLOP/s (larger is better)", + plot_name="BF16 vs INT8 GEMMs", + args={}, + ) +) +def benchmark(batch_size, provider, N, K): + M = batch_size + device = "cuda" + dtype = torch.bfloat16 + a = torch.randn((M, K), device=device, dtype=dtype) + b = torch.randn((N, K), device=device, dtype=dtype) + + quantiles = [0.5, 0.2, 0.8] + + if "torch-bf16" in provider: + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: torch.nn.functional.linear(a, b), quantiles=quantiles + ) + + elif "int8" in provider: + # Weights are always quantized ahead of time + if "noquant" in provider: + # For "no quant", we don't measure the time for activations + if "tensor-w-token-a" in provider: + # Dynamic per-token quant for A, static per-tensor quant for B + scale_b = torch.tensor([1.0], device=device, dtype=torch.float32) + b_int8, scale_b_int8, _ = vllm_scaled_int8_quant(b, scale_b) + assert scale_b_int8.numel() == 1 + a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a) + + elif "tensor-w-tensor-a" in provider: + # Static per-tensor quantization with fixed scales for both A and B + scale_a = torch.tensor([1.0], device=device, dtype=torch.float32) + scale_b = torch.tensor([1.0], device=device, dtype=torch.float32) + b_int8, scale_b_int8, _ = vllm_scaled_int8_quant(b, scale_b) + assert scale_b_int8.numel() == 1 + a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a, scale_a) + + elif "channel-w-token-a" in provider: + # Dynamic per-channel quantization for weights, per-token quant for A + b_int8, scale_b_int8, _ = vllm_scaled_int8_quant(b) + assert scale_b_int8.numel() == N + a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a) + + elif "channel-w-tensor-a" in provider: + # Dynamic per-channel quantization for weights, per-tensor quant for A + scale_a = torch.tensor([1.0], device=device, dtype=torch.float32) + b_int8, scale_b_int8, _ = vllm_scaled_int8_quant(b) + assert scale_b_int8.numel() == N + a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a, scale_a) + + def run_quant(): + return vllm_scaled_mm(a_int8, b_int8, scale_a_int8, scale_b_int8, dtype) + + else: + # Quantize the activations during the GEMM call + if "tensor-w-token-a" in provider: + # Dynamic per-token quant for A, static per-tensor quant for B + scale_b = torch.tensor([1.0], device=device, dtype=torch.float32) + b_int8, scale_b_int8, _ = vllm_scaled_int8_quant(b, scale_b) + assert scale_b_int8.numel() == 1 + + def run_quant(): + a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a) + return vllm_scaled_mm( + a_int8, b_int8, scale_a_int8, scale_b_int8, dtype + ) + + elif "tensor-w-tensor-a" in provider: + # Static per-tensor quantization with fixed scales for both A and B + scale_a = torch.tensor([1.0], device=device, dtype=torch.float32) + scale_b = torch.tensor([1.0], device=device, dtype=torch.float32) + b_int8, scale_b_int8, _ = vllm_scaled_int8_quant(b, scale_b) + assert scale_b_int8.numel() == 1 + + def run_quant(): + a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a, scale_a) + return vllm_scaled_mm( + a_int8, b_int8, scale_a_int8, scale_b_int8, dtype + ) + + elif "channel-w-token-a" in provider: + # Dynamic per-channel quant for weights, per-token quant for A + b_int8, scale_b_int8, _ = vllm_scaled_int8_quant(b) + assert scale_b_int8.numel() == N + + def run_quant(): + a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a) + return vllm_scaled_mm( + a_int8, b_int8, scale_a_int8, scale_b_int8, dtype + ) + + elif "channel-w-tensor-a" in provider: + # Dynamic per-channel quant for weights, static per-tensor quant for A + scale_a = torch.tensor([1.0], device=device, dtype=torch.float32) + b_int8, scale_b_int8, _ = vllm_scaled_int8_quant(b) + assert scale_b_int8.numel() == N + + def run_quant(): + a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a, scale_a) + return vllm_scaled_mm( + a_int8, b_int8, scale_a_int8, scale_b_int8, dtype + ) + + b_int8 = b_int8.t() + + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: run_quant(), quantiles=quantiles + ) + + # Calculate TFLOP/s, two flops per multiply-add + tflops = lambda ms: (2 * M * N * K) * 1e-12 / (ms * 1e-3) + return tflops(ms), tflops(max_ms), tflops(min_ms) + + +def prepare_shapes(args): + KN_model_names = [] + models_tps = list(itertools.product(args.models, args.tp_sizes)) + for model, tp_size in models_tps: + assert model in WEIGHT_SHAPES + for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]): + KN[tp_split_dim] = KN[tp_split_dim] // tp_size + KN.append(model) + KN_model_names.append(KN) + return KN_model_names + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--models", + nargs="+", + type=str, + default=["meta-llama/Llama-3.1-8B-Instruct"], + choices=[*WEIGHT_SHAPES.keys()], + help="List of models to benchmark", + ) + parser.add_argument( + "--tp-sizes", + nargs="+", + type=int, + default=[1], + help="List of tensor parallel sizes", + ) + args = parser.parse_args() + + KN_model_names = prepare_shapes(args) + for K, N, model_name in KN_model_names: + print(f"{model_name}, N={N} K={K}, BF16 vs INT8 GEMMs TFLOP/s:") + benchmark.run( + print_data=True, + show_plots=True, + save_path=f"bench_int8_res_n{N}_k{K}", + N=N, + K=K, + ) + + print("Benchmark finished!") diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index bf46cce60a23..87117a165fe9 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -1,15 +1,17 @@ #include #include + #include #include "../../dispatch_utils.h" +#include "../vectorization_utils.cuh" #ifndef USE_ROCM - #include #include + #include #else - #include #include + #include #endif static inline __device__ int8_t float_to_int8_rn(float x) { @@ -103,134 +105,170 @@ static inline __device__ int8_t int32_to_int8(int32_t x) { namespace vllm { -template +template __global__ void static_scaled_int8_quant_kernel( - scalar_t const* __restrict__ input, int8_t* __restrict__ out, - scale_type const* scale_ptr, const int hidden_size) { - int const tid = threadIdx.x; - int64_t const token_idx = blockIdx.x; - scale_type const scale = *scale_ptr; + const scalar_t* __restrict__ input, int8_t* __restrict__ output, + const scale_t* scale_ptr, const int hidden_size) { + const int tid = threadIdx.x; + const int stride = blockDim.x; + const int64_t token_idx = blockIdx.x; + const float scale = *scale_ptr; // Must be performed using 64-bit math to avoid integer overflow. - out += token_idx * hidden_size; - input += token_idx * hidden_size; + const scalar_t* row_in = input + token_idx * hidden_size; + int8_t* row_out = output + token_idx * hidden_size; - for (int i = tid; i < hidden_size; i += blockDim.x) { - out[i] = float_to_int8_rn(static_cast(input[i]) / scale); - } + vectorize_with_alignment<16>( + row_in, row_out, hidden_size, tid, stride, + [=] __device__(int8_t& dst, const scalar_t& src) { + dst = float_to_int8_rn(static_cast(src) / scale); + }); } -template +template __global__ void static_scaled_int8_azp_quant_kernel( - scalar_t const* __restrict__ input, int8_t* __restrict__ out, - scale_type const* scale_ptr, azp_type const* azp_ptr, - const int hidden_size) { - int const tid = threadIdx.x; - int64_t const token_idx = blockIdx.x; - scale_type const scale = *scale_ptr; - azp_type const azp = *azp_ptr; + const scalar_t* __restrict__ input, int8_t* __restrict__ output, + const scale_t* scale_ptr, const azp_t* azp_ptr, const int hidden_size) { + const int tid = threadIdx.x; + const int stride = blockDim.x; + const int64_t token_idx = blockIdx.x; + const float scale = *scale_ptr; + const azp_t azp = *azp_ptr; + const float inv_s = 1.0f / scale; // Must be performed using 64-bit math to avoid integer overflow. - out += token_idx * hidden_size; - input += token_idx * hidden_size; + const scalar_t* row_in = input + token_idx * hidden_size; + int8_t* row_out = output + token_idx * hidden_size; - for (int i = tid; i < hidden_size; i += blockDim.x) { - auto const val = static_cast(input[i]); - auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale) + azp); - out[i] = quant_val; - } + vectorize_with_alignment<16>( + row_in, row_out, hidden_size, tid, stride, + [=] __device__(int8_t& dst, const scalar_t& src) { + const auto v = static_cast(src) * inv_s; + dst = int32_to_int8(float_to_int32_rn(v) + azp); + }); } -template +template __global__ void dynamic_scaled_int8_quant_kernel( - scalar_t const* __restrict__ input, int8_t* __restrict__ out, - scale_type* scale, const int hidden_size) { - int const tid = threadIdx.x; - int64_t const token_idx = blockIdx.x; - float absmax_val = 0.0f; - float const zero = 0.0f; + const scalar_t* __restrict__ input, int8_t* __restrict__ output, + scale_t* scale_out, const int hidden_size) { + const int tid = threadIdx.x; + const int stride = blockDim.x; + const int64_t token_idx = blockIdx.x; // Must be performed using 64-bit math to avoid integer overflow. - out += token_idx * hidden_size; - input += token_idx * hidden_size; + const scalar_t* row_in = input + token_idx * hidden_size; + int8_t* row_out = output + token_idx * hidden_size; - for (int i = tid; i < hidden_size; i += blockDim.x) { - float val = static_cast(input[i]); - val = val > zero ? val : -val; - absmax_val = val > absmax_val ? val : absmax_val; + // calculate for absmax + float thread_max = 0.f; + for (int i = tid; i < hidden_size; i += stride) { + const auto v = fabsf(static_cast(row_in[i])); + thread_max = fmaxf(thread_max, v); } - - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage reduceStorage; - float const block_absmax_val_maybe = - BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x); - __shared__ float block_absmax_val; + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmp; + float block_max = BlockReduce(tmp).Reduce(thread_max, cub::Max{}, blockDim.x); + __shared__ float absmax; if (tid == 0) { - block_absmax_val = block_absmax_val_maybe; - scale[token_idx] = block_absmax_val / 127.0f; + absmax = block_max; + scale_out[blockIdx.x] = absmax / 127.f; } __syncthreads(); - float const tmp_scale = 127.0f / block_absmax_val; - for (int i = tid; i < hidden_size; i += blockDim.x) { - out[i] = float_to_int8_rn(static_cast(input[i]) * tmp_scale); - } + float inv_s = (absmax == 0.f) ? 0.f : 127.f / absmax; + + // 2. quantize + vectorize_with_alignment<16>( + row_in, row_out, hidden_size, tid, stride, + [=] __device__(int8_t& dst, const scalar_t& src) { + dst = float_to_int8_rn(static_cast(src) * inv_s); + }); } -template +// MinMax structure to hold min and max values in one go +struct MinMax { + float min, max; + + __host__ __device__ MinMax() + : min(std::numeric_limits::max()), + max(std::numeric_limits::lowest()) {} + + __host__ __device__ explicit MinMax(float v) : min(v), max(v) {} + + // add a value to the MinMax + __host__ __device__ MinMax& operator+=(float v) { + min = fminf(min, v); + max = fmaxf(max, v); + return *this; + } + + // merge two MinMax objects + __host__ __device__ MinMax& operator&=(const MinMax& other) { + min = fminf(min, other.min); + max = fmaxf(max, other.max); + return *this; + } +}; + +__host__ __device__ inline MinMax operator+(MinMax a, float v) { + return a += v; +} +__host__ __device__ inline MinMax operator&(MinMax a, const MinMax& b) { + return a &= b; +} + +template __global__ void dynamic_scaled_int8_azp_quant_kernel( - scalar_t const* __restrict__ input, int8_t* __restrict__ out, - scale_type* scale, azp_type* azp, const int hidden_size) { - int64_t const token_idx = blockIdx.x; + const scalar_t* __restrict__ input, int8_t* __restrict__ output, + scale_t* scale_out, azp_t* azp_out, const int hidden_size) { + const int tid = threadIdx.x; + const int stride = blockDim.x; + const int64_t token_idx = blockIdx.x; // Must be performed using 64-bit math to avoid integer overflow. - out += token_idx * hidden_size; - input += token_idx * hidden_size; + const scalar_t* row_in = input + token_idx * hidden_size; + int8_t* row_out = output + token_idx * hidden_size; - // Scan for the min and max value for this token - float max_val = std::numeric_limits::min(); - float min_val = std::numeric_limits::max(); - for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { - auto val = static_cast(input[i]); - max_val = std::max(max_val, val); - min_val = std::min(min_val, val); + // 1. calculate min & max + MinMax thread_mm; + for (int i = tid; i < hidden_size; i += stride) { + thread_mm += static_cast(row_in[i]); } - // Reduce the max and min values across the block - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage reduceStorage; - max_val = BlockReduce(reduceStorage).Reduce(max_val, cub::Max{}, blockDim.x); - __syncthreads(); // Make sure min doesn't mess with max shared memory - min_val = BlockReduce(reduceStorage).Reduce(min_val, cub::Min{}, blockDim.x); + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmp; - __shared__ scale_type scale_sh; - __shared__ azp_type azp_sh; + MinMax mm = BlockReduce(tmp).Reduce( + thread_mm, + [] __device__(MinMax a, const MinMax& b) { + a &= b; + return a; + }, + blockDim.x); - // Compute the scale and zero point and store them, only on the first thread - if (threadIdx.x == 0) { - float const scale_val = (max_val - min_val) / 255.0f; - // Use rounding to even (same as torch.round) - auto const azp_float = std::nearbyint(-128.0f - min_val / scale_val); - auto const azp_val = static_cast(azp_float); - - // Store the scale and azp into shared and global - scale[token_idx] = scale_sh = scale_val; - azp[token_idx] = azp_sh = azp_val; + __shared__ float scale_sh; + __shared__ azp_t azp_sh; + if (tid == 0) { + float s = (mm.max - mm.min) / 255.f; + float zp = nearbyintf(-128.f - mm.min / s); // round-to-even + scale_sh = s; + azp_sh = azp_t(zp); + scale_out[blockIdx.x] = s; + azp_out[blockIdx.x] = azp_sh; } - - // Wait for the scale and azp to be computed __syncthreads(); - float const scale_val = scale_sh; - azp_type const azp_val = azp_sh; + const float inv_s = 1.f / scale_sh; + const azp_t azp = azp_sh; - // Quantize the values - for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { - auto const val = static_cast(input[i]); - auto const quant_val = - int32_to_int8(float_to_int32_rn(val / scale_val) + azp_val); - out[i] = quant_val; - } + // 2. quantize + vectorize_with_alignment<16>( + row_in, row_out, hidden_size, tid, stride, + [=] __device__(int8_t& dst, const scalar_t& src) { + const auto v = static_cast(src) * inv_s; + dst = int32_to_int8(float_to_int32_rn(v) + azp); + }); } } // namespace vllm @@ -247,7 +285,7 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] int const hidden_size = input.size(-1); int const num_tokens = input.numel() / hidden_size; dim3 const grid(num_tokens); - dim3 const block(std::min(hidden_size, 1024)); + dim3 const block(std::min(hidden_size, 256)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "static_scaled_int8_quant_kernel", [&] { @@ -278,7 +316,7 @@ void dynamic_scaled_int8_quant( int const hidden_size = input.size(-1); int const num_tokens = input.numel() / hidden_size; dim3 const grid(num_tokens); - dim3 const block(std::min(hidden_size, 1024)); + dim3 const block(std::min(hidden_size, 256)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] { diff --git a/csrc/quantization/vectorization_utils.cuh b/csrc/quantization/vectorization_utils.cuh new file mode 100644 index 000000000000..8d3c1d6d3b9f --- /dev/null +++ b/csrc/quantization/vectorization_utils.cuh @@ -0,0 +1,75 @@ +#pragma once +#include "vectorization.cuh" + +namespace vllm { + +template +struct DefaultVecOp { + ScaOp scalar_op; + + __device__ __forceinline__ void operator()( + vec_n_t& dst, const vec_n_t& src) const { +#pragma unroll + for (int i = 0; i < VEC_SIZE; ++i) { + scalar_op(dst.val[i], src.val[i]); + } + } +}; + +template +__device__ inline void vectorize_with_alignment( + const InT* in, OutT* out, int len, int tid, int stride, + VecOp&& vec_op, // vec_n_t -> vec_n_t + ScaOp&& scalar_op) { // InT -> OutT + static_assert(VEC_SIZE > 0 && (VEC_SIZE & (VEC_SIZE - 1)) == 0, + "VEC_SIZE must be a positive power-of-two"); + constexpr int WIDTH = VEC_SIZE * sizeof(InT); // eg: 64 B + uintptr_t addr = reinterpret_cast(in); + + int misalignment_offset = addr & (WIDTH - 1); // addr % 64 + int alignment_bytes = WIDTH - misalignment_offset; // 64 - (addr % 64) + int prefix_elems = alignment_bytes & (WIDTH - 1); // handle 64 + prefix_elems /= sizeof(InT); + prefix_elems = min(prefix_elems, len); // 0 ≤ prefix < 16 + + // 1. prefill the when it is unsafe to vectorize + for (int i = tid; i < prefix_elems; i += stride) { + scalar_op(out[i], in[i]); + } + + in += prefix_elems; + out += prefix_elems; + len -= prefix_elems; + + int num_vec = len / VEC_SIZE; + using vin_t = vec_n_t; + using vout_t = vec_n_t; + auto* v_in = reinterpret_cast(in); + auto* v_out = reinterpret_cast(out); + + // 2. vectorize the main part + for (int i = tid; i < num_vec; i += stride) { + vout_t tmp; + vec_op(tmp, v_in[i]); + v_out[i] = tmp; + } + + // 3. handle the tail + int tail_start = num_vec * VEC_SIZE; + for (int i = tid + tail_start; i < len; i += stride) { + scalar_op(out[i], in[i]); + } +} + +template +__device__ __forceinline__ void vectorize_with_alignment(const InT* in, + OutT* out, int len, + int tid, int stride, + ScaOp&& scalar_op) { + using Vec = DefaultVecOp>; + vectorize_with_alignment(in, out, len, tid, stride, Vec{scalar_op}, + std::forward(scalar_op)); +} + +} // namespace vllm diff --git a/tests/kernels/quantization/test_int8_quant.py b/tests/kernels/quantization/test_int8_quant.py index 63ccf4a91736..5a37b976db9e 100644 --- a/tests/kernels/quantization/test_int8_quant.py +++ b/tests/kernels/quantization/test_int8_quant.py @@ -11,6 +11,7 @@ from vllm.platforms import current_platform DTYPES = [torch.half, torch.bfloat16, torch.float] HIDDEN_SIZES = [16, 67, 768, 5137, 8193] # Arbitrary values for testing +HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing SEEDS = [0] SCALE = [0.1, 2.1]