[Perf] Tune scaled_fp8_quant by increasing vectorization (#18844)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-06-03 16:48:25 -04:00 committed by GitHub
parent bdf13965ab
commit e31446b6c8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 118 additions and 113 deletions

View File

@ -39,8 +39,8 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
fp8_type* __restrict__ token_output = &out[offset];
// For vectorization, token_input and token_output pointers need to be
// aligned at 8-byte and 4-byte addresses respectively.
bool const can_vectorize = hidden_size % 4 == 0;
// aligned at 32-byte and 16-byte addresses respectively.
bool const can_vectorize = hidden_size % 16 == 0;
float absmax_val = 0.0f;
if (can_vectorize) {
@ -48,24 +48,24 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
} else {
for (int i = tid; i < hidden_size; i += blockDim.x) {
float const x = static_cast<float>(token_input[i]);
absmax_val = max(absmax_val, fabs(x));
absmax_val = fmaxf(absmax_val, fabsf(x));
}
}
using BlockReduce = cub::BlockReduce<float, 1024>;
using BlockReduce = cub::BlockReduce<float, 256>;
__shared__ typename BlockReduce::TempStorage reduceStorage;
float const block_absmax_val_maybe =
BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x);
__shared__ float token_scale;
if (tid == 0) {
if (scale_ub) {
token_scale = min(block_absmax_val_maybe, *scale_ub);
token_scale = fminf(block_absmax_val_maybe, *scale_ub);
} else {
token_scale = block_absmax_val_maybe;
}
// token scale computation
token_scale = max(token_scale / quant_type_max_v<fp8_type>,
min_scaling_factor<fp8_type>::val());
token_scale = fmaxf(token_scale / quant_type_max_v<fp8_type>,
min_scaling_factor<fp8_type>::val());
scale[token_idx] = token_scale;
}
__syncthreads();
@ -88,10 +88,11 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d]
torch::Tensor const& scale) // [1]
{
int64_t num_tokens = input.numel() / input.size(-1);
int64_t num_elems = input.numel();
dim3 grid(num_tokens);
dim3 block(1024);
int const block_size = 256;
int const num_tokens = input.numel() / input.size(-1);
int const num_elems = input.numel();
dim3 const grid(num_tokens);
dim3 const block(block_size);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
@ -110,10 +111,11 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d]
torch::Tensor& scale) // [1]
{
int64_t num_tokens = input.numel() / input.size(-1);
int64_t num_elems = input.numel();
dim3 grid(num_tokens);
dim3 block(1024);
int const block_size = 256;
int const num_tokens = input.numel() / input.size(-1);
int const num_elems = input.numel();
dim3 const grid(num_tokens);
dim3 const block(block_size);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
@ -141,8 +143,9 @@ void dynamic_per_token_scaled_fp8_quant(
int const hidden_size = input.size(-1);
int const num_tokens = input.numel() / hidden_size;
int const block_size = 256;
dim3 const grid(num_tokens);
dim3 const block(std::min(hidden_size, 1024));
dim3 const block(std::min(hidden_size, block_size));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

View File

@ -46,7 +46,7 @@ __device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val,
}
float r =
fmax(-quant_type_max_v<fp8_type>, fmin(x, quant_type_max_v<fp8_type>));
fmaxf(-quant_type_max_v<fp8_type>, fminf(x, quant_type_max_v<fp8_type>));
#ifndef USE_ROCM
return static_cast<fp8_type>(r);
#else
@ -65,7 +65,7 @@ template <typename scalar_t, typename fp8_type>
__global__ void segmented_max_reduction(float* __restrict__ scale,
const scalar_t* __restrict__ input,
int64_t num_elems) {
__shared__ float cache[1024];
__shared__ float cache[256];
int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
// First store maximum for all values processes by
@ -73,7 +73,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
scalar_t tmp = 0.0;
while (i < num_elems) {
float x = static_cast<float>(input[i]);
tmp = max(tmp, fabs(x));
tmp = fmaxf(tmp, fabsf(x));
i += blockDim.x * gridDim.x;
}
cache[threadIdx.x] = tmp;
@ -100,25 +100,27 @@ template <typename scalar_t>
__device__ float thread_max_vec(scalar_t const* __restrict__ input,
int64_t const num_elems, int const tid,
int const step) {
constexpr size_t VEC_SIZE = 16;
using scalarxN_t = vec_n_t<scalar_t, VEC_SIZE>;
// Vectorized input/output to better utilize memory bandwidth.
vec4_t<scalar_t> const* vectorized_in =
reinterpret_cast<vec4_t<scalar_t> const*>(input);
auto const* vectorized_in = reinterpret_cast<scalarxN_t const*>(input);
int64_t const num_vec_elems = num_elems >> 2;
// num_elems / VEC_SIZE (which is 16)
int64_t const num_vec_elems = num_elems >> 4;
float absmax_val = 0.0f;
#pragma unroll 4
#pragma unroll
for (int64_t i = tid; i < num_vec_elems; i += step) {
vec4_t<scalar_t> in_vec = vectorized_in[i];
absmax_val = max(absmax_val, fabs(in_vec.x));
absmax_val = max(absmax_val, fabs(in_vec.y));
absmax_val = max(absmax_val, fabs(in_vec.z));
absmax_val = max(absmax_val, fabs(in_vec.w));
scalarxN_t in_vec = vectorized_in[i];
#pragma unroll
for (int j = 0; j < VEC_SIZE; ++j) {
absmax_val = fmaxf(absmax_val, fabsf(in_vec.val[j]));
}
}
// Handle the remaining elements if num_elems is not divisible by 4
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
absmax_val = max(absmax_val, fabs(input[i]));
// Handle the remaining elements if num_elems is not divisible by VEC_SIZE
for (int64_t i = num_vec_elems * VEC_SIZE + tid; i < num_elems; i += step) {
absmax_val = fmaxf(absmax_val, fabsf(input[i]));
}
return absmax_val;
@ -130,31 +132,31 @@ __device__ void scaled_fp8_conversion_vec(fp8_type* __restrict__ out,
float const scale,
int64_t const num_elems,
int const tid, int const step) {
using float8x4_t = q8x4_t<fp8_type>;
constexpr size_t VEC_SIZE = 16;
using scalarxN_t = vec_n_t<scalar_t, VEC_SIZE>;
using float8xN_t = q8_n_t<fp8_type, VEC_SIZE>;
// Vectorized input/output to better utilize memory bandwidth.
auto const* vectorized_in = reinterpret_cast<vec4_t<scalar_t> const*>(input);
auto* vectorized_out = reinterpret_cast<float8x4_t*>(out);
auto const* vectorized_in = reinterpret_cast<scalarxN_t const*>(input);
auto* vectorized_out = reinterpret_cast<float8xN_t*>(out);
int64_t const num_vec_elems = num_elems >> 2;
// num_elems / VEC_SIZE (which is 16)
int64_t const num_vec_elems = num_elems >> 4;
#pragma unroll 4
#pragma unroll
for (int64_t i = tid; i < num_vec_elems; i += step) {
vec4_t<scalar_t> in_vec = vectorized_in[i];
float8x4_t out_vec;
scalarxN_t in_vec = vectorized_in[i];
float8xN_t out_vec;
out_vec.x = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
static_cast<float>(in_vec.x), scale);
out_vec.y = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
static_cast<float>(in_vec.y), scale);
out_vec.z = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
static_cast<float>(in_vec.z), scale);
out_vec.w = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
static_cast<float>(in_vec.w), scale);
#pragma unroll
for (int j = 0; j < VEC_SIZE; ++j) {
out_vec.val[j] = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
static_cast<float>(in_vec.val[j]), scale);
}
vectorized_out[i] = out_vec;
}
// Handle the remaining elements if num_elems is not divisible by 4
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
// Handle the remaining elements if num_elems is not divisible by VEC_SIZE
for (int64_t i = num_vec_elems * VEC_SIZE + tid; i < num_elems; i += step) {
out[i] = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
static_cast<float>(input[i]), scale);
}

View File

@ -140,6 +140,7 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
// sum of squares
float ss = 0.0f;
const int VEC_SIZE = 4;
int32_t const num_vec_elems = hidden_size >> 2;
#pragma unroll 4
@ -147,22 +148,23 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
vec4_t<scalar_t> in = vec_input[i];
vec4_t<float> x;
x.x = static_cast<float>(in.x);
x.y = static_cast<float>(in.y);
x.z = static_cast<float>(in.z);
x.w = static_cast<float>(in.w);
if constexpr (has_residual) {
vec4_t<scalar_t> r = vec_residual[i];
x.x += static_cast<float>(r.x);
x.y += static_cast<float>(r.y);
x.z += static_cast<float>(r.z);
x.w += static_cast<float>(r.w);
#pragma unroll
for (int j = 0; j < VEC_SIZE; ++j) {
x.val[j] = static_cast<float>(in.val[j]);
}
ss += x.x * x.x;
ss += x.y * x.y;
ss += x.z * x.z;
ss += x.w * x.w;
if constexpr (has_residual) {
vec4_t<scalar_t> r = vec_residual[i];
#pragma unroll
for (int j = 0; j < VEC_SIZE; ++j) {
x.val[j] += static_cast<float>(r.val[j]);
}
}
#pragma unroll
for (int j = 0; j < VEC_SIZE; ++j) {
ss += x.val[j] * x.val[j];
}
}
using BlockReduce = cub::BlockReduce<float, 1024>;
@ -203,6 +205,7 @@ __device__ void compute_dynamic_per_token_scales(
constexpr scalar_out_t qmax{quant_type_max_v<scalar_out_t>};
const int VEC_SIZE = 4;
int32_t const num_vec_elems = hidden_size >> 2;
float block_absmax_val_maybe = 0.0f;
@ -212,26 +215,25 @@ __device__ void compute_dynamic_per_token_scales(
vec4_t<scalar_t> const w = vec_weight[i];
vec4_t<float> x;
x.x = static_cast<float>(in.x);
x.y = static_cast<float>(in.y);
x.z = static_cast<float>(in.z);
x.w = static_cast<float>(in.w);
if constexpr (has_residual) {
vec4_t<scalar_t> r = vec_residual[i];
x.x += static_cast<float>(r.x);
x.y += static_cast<float>(r.y);
x.z += static_cast<float>(r.z);
x.w += static_cast<float>(r.w);
#pragma unroll
for (int j = 0; j < VEC_SIZE; ++j) {
x.val[j] = static_cast<float>(in.val[j]);
}
block_absmax_val_maybe = fmaxf(
block_absmax_val_maybe, fabs(static_cast<scalar_t>(x.x * rms) * w.x));
block_absmax_val_maybe = fmaxf(
block_absmax_val_maybe, fabs(static_cast<scalar_t>(x.y * rms) * w.y));
block_absmax_val_maybe = fmaxf(
block_absmax_val_maybe, fabs(static_cast<scalar_t>(x.z * rms) * w.z));
block_absmax_val_maybe = fmaxf(
block_absmax_val_maybe, fabs(static_cast<scalar_t>(x.w * rms) * w.w));
if constexpr (has_residual) {
vec4_t<scalar_t> r = vec_residual[i];
#pragma unroll
for (int j = 0; j < VEC_SIZE; ++j) {
x.val[j] += static_cast<float>(r.val[j]);
}
}
#pragma unroll
for (int j = 0; j < VEC_SIZE; ++j) {
block_absmax_val_maybe =
fmaxf(block_absmax_val_maybe,
fabs(static_cast<scalar_t>(x.val[j] * rms) * w.val[j]));
}
}
using BlockReduce = cub::BlockReduce<float, 1024>;
@ -282,6 +284,7 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output,
vec_residual = reinterpret_cast<vec4_t<scalar_t>*>(&residual[token_offset]);
}
const int VEC_SIZE = 4;
int32_t const num_vec_elems = hidden_size >> 2;
// TODO(luka/varun) extract into type-agnostic vectorized quant function to
@ -292,33 +295,31 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output,
vec4_t<scalar_t> const w = vec_weight[i];
vec4_t<float> x;
x.x = static_cast<float>(in.x);
x.y = static_cast<float>(in.y);
x.z = static_cast<float>(in.z);
x.w = static_cast<float>(in.w);
#pragma unroll
for (int j = 0; j < VEC_SIZE; ++j) {
x.val[j] = static_cast<float>(in.val[j]);
}
if constexpr (has_residual) {
vec4_t<scalar_t> r = vec_residual[i];
x.x += static_cast<float>(r.x);
x.y += static_cast<float>(r.y);
x.z += static_cast<float>(r.z);
x.w += static_cast<float>(r.w);
// Update residual
r.x = static_cast<scalar_t>(x.x);
r.y = static_cast<scalar_t>(x.y);
r.z = static_cast<scalar_t>(x.z);
r.w = static_cast<scalar_t>(x.w);
#pragma unroll
for (int j = 0; j < VEC_SIZE; ++j) {
x.val[j] += static_cast<float>(r.val[j]);
}
// Update residual
#pragma unroll
for (int j = 0; j < VEC_SIZE; ++j) {
r.val[j] = static_cast<scalar_t>(x.val[j]);
}
vec_residual[i] = r;
}
q8x4_t<scalar_out_t> out;
out.x = ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(
static_cast<scalar_t>(x.x * rms) * w.x, scale);
out.y = ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(
static_cast<scalar_t>(x.y * rms) * w.y, scale);
out.z = ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(
static_cast<scalar_t>(x.z * rms) * w.z, scale);
out.w = ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(
static_cast<scalar_t>(x.w * rms) * w.w, scale);
#pragma unroll
for (int j = 0; j < VEC_SIZE; ++j) {
out.val[j] = ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(
static_cast<scalar_t>(x.val[j] * rms) * w.val[j], scale);
}
vec_output[i] = out;
}
}

View File

@ -10,23 +10,22 @@
namespace vllm {
// Vectorization containers
template <typename scalar_t>
struct __align__(8) vec4_t {
scalar_t x;
scalar_t y;
scalar_t z;
scalar_t w;
template <typename scalar_t, size_t vec_size>
struct __align__(vec_size * sizeof(scalar_t)) vec_n_t {
scalar_t val[vec_size];
};
template <typename quant_type_t>
struct __align__(4) q8x4_t {
template <typename quant_type_t, size_t vec_size>
struct __align__(vec_size * sizeof(quant_type_t)) q8_n_t {
static_assert(std::is_same_v<quant_type_t, int8_t> ||
std::is_same_v<quant_type_t, c10::Float8_e4m3fn> ||
std::is_same_v<quant_type_t, c10::Float8_e4m3fnuz>);
quant_type_t x;
quant_type_t y;
quant_type_t z;
quant_type_t w;
quant_type_t val[vec_size];
};
template <typename scalar_t>
using vec4_t = vec_n_t<scalar_t, 4>;
template <typename quant_type_t>
using q8x4_t = q8_n_t<quant_type_t, 4>;
} // namespace vllm