[Performance] Fused blockwise quant RMS norm (#27883)

Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Co-authored-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
ElizaWszola 2025-12-07 17:38:04 +01:00 committed by GitHub
parent 0044c4038c
commit af0444bf40
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 949 additions and 157 deletions

View File

@ -14,6 +14,9 @@ from tqdm import tqdm
import vllm._custom_ops as ops
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8,
)
@dataclass
@ -22,6 +25,7 @@ class bench_params_t:
hidden_size: int
add_residual: bool
dtype: torch.dtype
group_size: list[int]
def description(self):
return (
@ -29,6 +33,7 @@ class bench_params_t:
f"x D {self.hidden_size} "
f"x R {self.add_residual} "
f"x DT {self.dtype}"
f"x GS {self.group_size}"
)
@ -38,10 +43,11 @@ def get_bench_params() -> list[bench_params_t]:
HIDDEN_SIZES = list(range(1024, 8129, 1024))
ADD_RESIDUAL = [True, False]
DTYPES = [torch.bfloat16, torch.float]
GROUP_SIZES = [[1, 64], [1, 128]]
combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES)
combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES, GROUP_SIZES)
bench_params = list(
map(lambda x: bench_params_t(x[0], x[1], x[2], x[3]), combinations)
map(lambda x: bench_params_t(x[0], x[1], x[2], x[3], x[4]), combinations)
)
return bench_params
@ -52,6 +58,7 @@ def unfused_int8_impl(
x: torch.Tensor,
residual: torch.Tensor | None,
quant_dtype: torch.dtype,
group_size: list[int],
):
# Norm
torch_out = None
@ -69,6 +76,7 @@ def unfused_fp8_impl(
x: torch.Tensor,
residual: torch.Tensor | None,
quant_dtype: torch.dtype,
group_size: list[int],
):
# Norm
torch_out = None
@ -81,23 +89,63 @@ def unfused_fp8_impl(
torch_out, _ = ops.scaled_fp8_quant(torch_out)
def unfused_groupwise_fp8_impl(
rms_norm_layer: RMSNorm,
x: torch.Tensor,
residual: torch.Tensor | None,
quant_dtype: torch.dtype,
group_size: list[int],
):
# Norm
torch_out = None
if residual is None:
torch_out = rms_norm_layer.forward_cuda(x, residual)
else:
torch_out, _ = rms_norm_layer.forward_cuda(x, residual)
# Quant
torch_out, _ = per_token_group_quant_fp8(
torch_out, group_size=group_size[1], use_ue8m0=False
)
def fused_impl(
rms_norm_layer: RMSNorm, # this stores the weights
x: torch.Tensor,
residual: torch.Tensor | None,
quant_dtype: torch.dtype,
group_size: list[int],
):
out, _ = ops.rms_norm_dynamic_per_token_quant(
x, rms_norm_layer.weight, 1e-6, quant_dtype, residual=residual
)
def fused_groupwise_impl(
rms_norm_layer: RMSNorm, # this stores the weights
x: torch.Tensor,
residual: torch.Tensor | None,
quant_dtype: torch.dtype,
group_size: list[int],
):
out, _ = ops.rms_norm_per_block_quant(
x,
rms_norm_layer.weight,
1e-6,
quant_dtype,
group_size,
residual=residual,
is_scale_transposed=True,
)
# Bench functions
def bench_fn(
rms_norm_layer: RMSNorm,
x: torch.Tensor,
residual: torch.Tensor,
quant_dtype: torch.dtype,
group_size: list[int],
label: str,
sub_label: str,
fn: Callable,
@ -110,10 +158,11 @@ def bench_fn(
"x": x,
"residual": residual,
"quant_dtype": quant_dtype,
"group_size": group_size,
"fn": fn,
}
return TBenchmark.Timer(
stmt="fn(rms_norm_layer, x, residual, quant_dtype)",
stmt="fn(rms_norm_layer, x, residual, quant_dtype, group_size)",
globals=globals,
label=label,
sub_label=sub_label,
@ -147,6 +196,7 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu
x,
residual,
torch.int8,
params.group_size,
label,
sub_label,
unfused_int8_impl,
@ -161,6 +211,7 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu
x,
residual,
torch.float8_e4m3fn,
params.group_size,
label,
sub_label,
unfused_fp8_impl,
@ -175,6 +226,7 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu
x,
residual,
torch.int8,
params.group_size,
label,
sub_label,
fused_impl,
@ -189,6 +241,7 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu
x,
residual,
torch.float8_e4m3fn,
params.group_size,
label,
sub_label,
fused_impl,
@ -196,6 +249,36 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu
)
)
# unfused groupwise fp8 impl.
timers.append(
bench_fn(
layer,
x,
residual,
torch.float8_e4m3fn,
params.group_size,
label,
sub_label,
unfused_groupwise_fp8_impl,
"unfused_groupwise_fp8_impl",
)
)
# fused groupwise fp8 impl.
timers.append(
bench_fn(
layer,
x,
residual,
torch.float8_e4m3fn,
params.group_size,
label,
sub_label,
fused_groupwise_impl,
"fused_groupwise_fp8_impl",
)
)
print_timers(timers)
return timers

View File

@ -118,6 +118,24 @@
} \
}
#define VLLM_DISPATCH_BOOL(expr, const_expr, ...) \
if (expr) { \
constexpr bool const_expr = true; \
__VA_ARGS__(); \
} else { \
constexpr bool const_expr = false; \
__VA_ARGS__(); \
}
#define VLLM_DISPATCH_GROUP_SIZE(group_size, const_group_size, ...) \
if (group_size == 128) { \
constexpr int const_group_size = 128; \
__VA_ARGS__(); \
} else if (group_size == 64) { \
constexpr int const_group_size = 64; \
__VA_ARGS__(); \
}
#define VLLM_DISPATCH_RANK234(NUM_DIMS, ...) \
switch (NUM_DIMS) { \
case 2: { \

View File

@ -128,6 +128,13 @@ void rms_norm_dynamic_per_token_quant(torch::Tensor& out,
std::optional<torch::Tensor> scale_ub,
std::optional<torch::Tensor> residual);
void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor const& weight,
torch::Tensor& scales, double const epsilon,
std::optional<torch::Tensor> scale_ub,
std::optional<torch::Tensor> residual,
int64_t group_size, bool is_scale_transposed);
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
std::optional<torch::Tensor> key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox);

View File

@ -31,14 +31,15 @@ __device__ void rms_norm_dynamic_per_token_quant_vec(
// RMS Norm + Quant
if constexpr (std::is_same_v<scalar_out_t, int8_t>) {
token_scale = 1.0f / token_scale;
vllm::vectorized::norm_and_quant<scalar_t, scalar_out_t, true,
has_residual>(
out, input, weight, rms, 1.0f / token_scale, hidden_size, residual);
out, input, weight, rms, &token_scale, hidden_size, residual);
} else {
// FP8 - Do not invert token_scale for exact match with FBGemm
vllm::vectorized::norm_and_quant<scalar_t, scalar_out_t, false,
has_residual>(
out, input, weight, rms, token_scale, hidden_size, residual);
out, input, weight, rms, &token_scale, hidden_size, residual);
}
}
@ -75,14 +76,52 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel(
// RMS Norm + Quant
if constexpr (std::is_same_v<scalar_out_t, int8_t>) {
token_scale = 1.0f / token_scale;
vllm::norm_and_quant<scalar_t, scalar_out_t, true, has_residual>(
out, input, weight, rms, 1.0f / token_scale, hidden_size, residual);
out, input, weight, rms, &token_scale, hidden_size, residual);
} else {
// FP8 - Do not invert s_token_scale for exact match with FBGemm
vllm::norm_and_quant<scalar_t, scalar_out_t, false, has_residual>(
out, input, weight, rms, token_scale, hidden_size, residual);
out, input, weight, rms, &token_scale, hidden_size, residual);
}
}
// RMS norm + quant kernel
template <typename scalar_t, typename scalar_out_t, bool has_residual = false,
bool is_scale_transposed = false, int32_t group_size = 0>
__global__ void rms_norm_per_block_quant_kernel(
scalar_out_t* __restrict__ out, // [..., hidden_size]
float* __restrict__ scales, // [num_tokens, hidden_size / group_size]
// or
// [hidden_size / group_size, num_tokens]
scalar_t const* __restrict__ input, // [..., hidden_size]
scalar_t const* __restrict__ weight, // [hidden_size]
float const* scale_ub, float const var_epsilon, int32_t const hidden_size,
scalar_t* __restrict__ residual = nullptr) {
float rms;
// Compute RMS
// Always able to vectorize due to constraints on hidden_size
vllm::vectorized::compute_rms<scalar_t, has_residual>(
&rms, input, hidden_size, var_epsilon, residual);
// Compute Scale
// Always able to vectorize due to constraints on hidden_size and group_size
vllm::vectorized::compute_dynamic_per_token_scales<
scalar_t, scalar_out_t, has_residual, is_scale_transposed, group_size>(
nullptr, scales, input, weight, rms, scale_ub, hidden_size, residual);
// RMS Norm + Quant
// Always able to vectorize due to constraints on hidden_size
// For int8, don't invert token_scale here: do it inside the norm_and_quant
// kernel. We do it because particular elements of token_scale can be shared
// between multiple threads, so this way, we avoid extra synchronization
// overhead.
vllm::vectorized::norm_and_quant<
scalar_t, scalar_out_t, std::is_same_v<scalar_out_t, int8_t>,
has_residual, is_scale_transposed, group_size>(
out, input, weight, rms, scales, hidden_size, residual);
}
} // namespace vllm
// Residual add + RMS norm + dynamic per token
@ -103,30 +142,19 @@ void rms_norm_dynamic_per_token_quant_dispatch(
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (residual.has_value()) {
VLLM_DISPATCH_BOOL(residual.has_value(), has_residual, [&] {
VLLM_DISPATCH_QUANT_TYPES(
out.scalar_type(), "rms_norm_dynamic_per_token_quant_kernel", [&] {
vllm::rms_norm_dynamic_per_token_quant_kernel<scalar_in_t, scalar_t,
true>
has_residual>
<<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(), scales.data_ptr<float>(),
input.data_ptr<scalar_in_t>(), weight.data_ptr<scalar_in_t>(),
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
var_epsilon, hidden_size, residual->data_ptr<scalar_in_t>());
var_epsilon, hidden_size,
has_residual ? residual->data_ptr<scalar_in_t>() : nullptr);
});
} else {
VLLM_DISPATCH_QUANT_TYPES(
out.scalar_type(), "rms_norm_dynamic_per_token_quant_kernel", [&] {
vllm::rms_norm_dynamic_per_token_quant_kernel<scalar_in_t, scalar_t,
false>
<<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(), scales.data_ptr<float>(),
input.data_ptr<scalar_in_t>(), weight.data_ptr<scalar_in_t>(),
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
var_epsilon, hidden_size, nullptr);
});
}
});
}
void rms_norm_dynamic_per_token_quant(
@ -157,3 +185,79 @@ void rms_norm_dynamic_per_token_quant(
out, input, weight, scales, var_epsilon, scale_ub, residual);
});
}
// Residual add + RMS norm + dynamic per token
void rms_norm_per_block_quant_dispatch(
torch::Tensor& out, // [..., hidden_size]
torch::Tensor const& input, // [..., hidden_size]
torch::Tensor const& weight, // [hidden_size]
torch::Tensor& scales, // [num_tokens, hidden_size / group_size] or
// [hidden_size / group_size, num_tokens]
int32_t group_size,
double const var_epsilon, // Variance epsilon used in norm calculation
std::optional<at::Tensor> const& scale_ub,
std::optional<at::Tensor>& residual, bool is_scale_transposed) {
int32_t hidden_size = input.size(-1);
auto num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
const int max_block_size = (num_tokens <= 256) ? 512 : 256;
dim3 block(std::min(hidden_size, max_block_size));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "rms_norm_per_block_quant_fp_dispatch", [&] {
using scalar_in_t = scalar_t;
VLLM_DISPATCH_GROUP_SIZE(group_size, gs, [&] {
VLLM_DISPATCH_BOOL(residual.has_value(), has_residual, [&] {
VLLM_DISPATCH_BOOL(is_scale_transposed, transpose_scale, [&] {
VLLM_DISPATCH_QUANT_TYPES(
out.scalar_type(), "rms_norm_per_block_quant_kernel", [&] {
vllm::rms_norm_per_block_quant_kernel<scalar_in_t, scalar_t,
has_residual,
transpose_scale, gs>
<<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(), scales.data_ptr<float>(),
input.data_ptr<scalar_in_t>(),
weight.data_ptr<scalar_in_t>(),
scale_ub.has_value() ? scale_ub->data_ptr<float>()
: nullptr,
var_epsilon, hidden_size,
has_residual ? residual->data_ptr<scalar_in_t>()
: nullptr);
});
});
});
});
});
}
void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor const& weight,
torch::Tensor& scales, double const var_epsilon,
std::optional<torch::Tensor> scale_ub,
std::optional<torch::Tensor> residual,
int64_t group_size, bool is_scale_transposed) {
static c10::ScalarType kFp8Type = is_fp8_ocp()
? c10::ScalarType::Float8_e4m3fn
: c10::ScalarType::Float8_e4m3fnuz;
TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8);
TORCH_CHECK(out.is_contiguous() && input.is_contiguous());
if (scale_ub.has_value()) {
TORCH_CHECK(out.dtype() == kFp8Type);
}
TORCH_CHECK(weight.dtype() == input.dtype());
TORCH_CHECK(scales.dtype() == torch::kFloat32);
if (residual) {
TORCH_CHECK(residual->scalar_type() == input.scalar_type());
}
TORCH_CHECK(group_size == 128 || group_size == 64,
"Unsupported group size: ", group_size);
rms_norm_per_block_quant_dispatch(out, input, weight, scales, group_size,
var_epsilon, scale_ub, residual,
is_scale_transposed);
}

View File

@ -9,6 +9,7 @@
#include "quant_conversions.cuh"
#include "../../cub_helpers.h"
#include "../../cuda_compat.h"
namespace vllm {
@ -43,62 +44,150 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
*rms = s_rms;
}
template <typename scalar_t, typename scalar_out_t, bool has_residual = false>
__device__ float warpReduceMaxSpecialized(volatile float* val, int64_t tid,
int64_t thread_in_warp,
int64_t reduced_elems) {
static_assert(WARP_SIZE == 32 || WARP_SIZE == 64);
if constexpr (WARP_SIZE == 64) {
if (thread_in_warp + 64 < reduced_elems)
val[tid] = fmaxf(val[tid], val[tid + 64]);
}
if (thread_in_warp + 32 < reduced_elems)
val[tid] = fmaxf(val[tid], val[tid + 32]);
if (thread_in_warp + 16 < reduced_elems)
val[tid] = fmaxf(val[tid], val[tid + 16]);
if (thread_in_warp + 8 < reduced_elems)
val[tid] = fmaxf(val[tid], val[tid + 8]);
if (thread_in_warp + 4 < reduced_elems)
val[tid] = fmaxf(val[tid], val[tid + 4]);
if (thread_in_warp + 2 < reduced_elems)
val[tid] = fmaxf(val[tid], val[tid + 2]);
if (thread_in_warp + 1 < reduced_elems)
val[tid] = fmaxf(val[tid], val[tid + 1]);
return val[tid];
}
template <typename scalar_t, typename scalar_out_t, bool has_residual = false,
bool is_scale_transposed = false>
__device__ void compute_dynamic_per_token_scales(
float* __restrict__ token_scale, float* __restrict__ all_token_scales,
scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight,
float const rms, float const* __restrict__ scale_ub,
int32_t const hidden_size,
scalar_t const* __restrict__ residual = nullptr) {
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
;
constexpr scalar_out_t qmax{quant_type_max_v<scalar_out_t>};
int32_t const hidden_size, scalar_t const* __restrict__ residual = nullptr,
int32_t const group_size = 0) {
float block_absmax_val_maybe = 0.0f;
for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) {
float x = static_cast<float>(input[token_offset + i]);
if constexpr (has_residual) {
x += static_cast<float>(residual[token_offset + i]);
}
x = static_cast<float>(static_cast<scalar_t>(x * rms) * weight[i]);
block_absmax_val_maybe = fmaxf(block_absmax_val_maybe, fabsf(x));
}
using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
block_absmax_val_maybe =
BlockReduce(reduceStore)
.Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x);
__shared__ float s_token_scale;
if (threadIdx.x == 0) {
float scale = 0.0f;
if (scale_ub) {
scale = min(block_absmax_val_maybe, *scale_ub);
} else {
scale = block_absmax_val_maybe;
}
// token scale computation
scale = max(scale / qmax, min_scaling_factor<scalar_out_t>::val());
s_token_scale = scale; // Shared memory store
all_token_scales[blockIdx.x] = scale; // Global output store
}
constexpr scalar_out_t qmax{quant_type_max_v<scalar_out_t>};
__syncthreads();
if (group_size > 0) {
__shared__ float s_max_vals[1024];
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
int64_t num_groups = hidden_size / group_size;
int64_t const threads_per_group = blockDim.x / num_groups;
int64_t const thread_in_group = threadIdx.x % threads_per_group;
int64_t const group_offset = threadIdx.x / threads_per_group * group_size;
int64_t const thread_offset = group_offset + thread_in_group;
int64_t const thread_end =
min(group_offset + group_size, static_cast<int64_t>(hidden_size));
for (auto i = thread_offset; i < thread_end; i += threads_per_group) {
float x = static_cast<float>(input[token_offset + i]);
if constexpr (has_residual) {
x += static_cast<float>(residual[token_offset + i]);
}
x = static_cast<float>(static_cast<scalar_t>(x * rms) * weight[i]);
block_absmax_val_maybe = fmaxf(block_absmax_val_maybe, fabsf(x));
}
s_max_vals[threadIdx.x] = block_absmax_val_maybe;
__syncthreads();
*token_scale = s_token_scale;
int64_t const warp_size = WARP_SIZE;
int64_t const num_warps = blockDim.x / warp_size;
int64_t const warp_id = threadIdx.x / warp_size;
int64_t const thread_in_warp = threadIdx.x % warp_size;
int64_t const groups_per_warp = (num_groups + num_warps - 1) / num_warps;
for (auto i = 0; i < groups_per_warp; ++i) {
int64_t const group_id = i * num_warps + warp_id;
if (group_id < num_groups) {
int64_t warp_start = group_id * threads_per_group;
int64_t const start = warp_start + thread_in_warp;
int64_t const warp_end = min(warp_start + threads_per_group,
static_cast<int64_t>(hidden_size));
for (auto j = start; j + warp_size < warp_end; j += warp_size) {
s_max_vals[start] =
fmaxf(s_max_vals[start], s_max_vals[j + warp_size]);
}
warpReduceMaxSpecialized(s_max_vals, start, thread_in_warp,
min(warp_end - warp_start, warp_size));
}
}
__syncthreads();
if (thread_in_group == 0 && thread_offset < thread_end) {
block_absmax_val_maybe = s_max_vals[threadIdx.x];
float scale = 0.0f;
if (scale_ub) {
scale = min(block_absmax_val_maybe, *scale_ub);
} else {
scale = block_absmax_val_maybe;
}
// token scale computation
scale = max(scale / qmax, min_scaling_factor<scalar_out_t>::val());
// Global output store
if constexpr (is_scale_transposed) {
all_token_scales[(threadIdx.x / threads_per_group) * gridDim.x +
blockIdx.x] = scale;
} else {
all_token_scales[blockIdx.x * num_groups +
threadIdx.x / threads_per_group] = scale;
}
}
__syncthreads();
} else {
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) {
float x = static_cast<float>(input[token_offset + i]);
if constexpr (has_residual) {
x += static_cast<float>(residual[token_offset + i]);
}
x = static_cast<float>(static_cast<scalar_t>(x * rms) * weight[i]);
block_absmax_val_maybe = fmaxf(block_absmax_val_maybe, fabsf(x));
}
using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
block_absmax_val_maybe =
BlockReduce(reduceStore)
.Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x);
__shared__ float s_token_scale;
if (threadIdx.x == 0) {
float scale = 0.0f;
if (scale_ub) {
scale = min(block_absmax_val_maybe, *scale_ub);
} else {
scale = block_absmax_val_maybe;
}
// token scale computation
scale = max(scale / qmax, min_scaling_factor<scalar_out_t>::val());
s_token_scale = scale; // Shared memory store
all_token_scales[blockIdx.x] = scale; // Global output store
}
__syncthreads();
*token_scale = s_token_scale;
}
}
template <typename scalar_t, typename scalar_out_t, bool is_scale_inverted,
bool has_residual = false>
bool has_residual = false, bool is_scale_transposed = false>
__device__ void norm_and_quant(scalar_out_t* __restrict__ output,
scalar_t const* __restrict__ input,
scalar_t const* __restrict__ weight,
float const rms, float const scale,
float const rms, float* const scale,
int32_t const hidden_size,
scalar_t* __restrict__ residual = nullptr) {
scalar_t* __restrict__ residual = nullptr,
int32_t const group_size = 0) {
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
;
for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) {
float x = static_cast<float>(input[token_offset + i]);
@ -109,8 +198,21 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output,
// Norm
x = static_cast<float>(static_cast<scalar_t>(x * rms) * weight[i]);
// Quant
// If groupwise is_scale_inverted is true, so we invert the scale here.
int64_t scale_idx = 0;
if (group_size > 0) {
if constexpr (is_scale_transposed) {
scale_idx = (i / group_size) * gridDim.x + blockIdx.x;
} else {
scale_idx = blockIdx.x * (hidden_size / group_size) + i / group_size;
}
}
auto scale_val =
(group_size > 0
? (is_scale_inverted ? 1.0f / scale[scale_idx] : scale[scale_idx])
: *scale);
output[token_offset + i] =
ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(x, scale);
ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(x, scale_val);
}
}
@ -178,95 +280,191 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
// Vectorized version of vllm::compute_dynamic_per_token_scales
// hidden_size must be a multiple of 4
template <typename scalar_t, typename scalar_out_t, bool has_residual = false>
template <typename scalar_t, typename scalar_out_t, bool has_residual = false,
bool is_scale_transposed = false, int32_t group_size = 0>
__device__ void compute_dynamic_per_token_scales(
float* __restrict__ token_scale, float* __restrict__ all_token_scales,
scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight,
float const rms, float const* __restrict__ scale_ub,
int32_t const hidden_size,
scalar_t const* __restrict__ residual = nullptr) {
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
;
// Vectorized input/weight/residual to better utilize memory bandwidth.
vec4_t<scalar_t> const* vec_input =
reinterpret_cast<vec4_t<scalar_t> const*>(&input[token_offset]);
vec4_t<scalar_t> const* vec_weight =
reinterpret_cast<vec4_t<scalar_t> const*>(weight);
vec4_t<scalar_t> const* vec_residual = nullptr;
if constexpr (has_residual) {
vec_residual =
reinterpret_cast<vec4_t<scalar_t> const*>(&residual[token_offset]);
}
constexpr scalar_out_t qmax{quant_type_max_v<scalar_out_t>};
const int VEC_SIZE = 4;
int32_t const num_vec_elems = hidden_size >> 2;
float block_absmax_val_maybe = 0.0f;
#pragma unroll 4
for (auto i = threadIdx.x; i < num_vec_elems; i += blockDim.x) {
vec4_t<scalar_t> in = vec_input[i];
vec4_t<scalar_t> const w = vec_weight[i];
// Vectorized input/weight/residual to better utilize memory bandwidth.
vec4_t<scalar_t> const* vec_input = nullptr;
vec4_t<scalar_t> const* vec_weight = nullptr;
vec4_t<scalar_t> const* vec_residual = nullptr;
vec4_t<float> x;
#pragma unroll
for (int j = 0; j < VEC_SIZE; ++j) {
x.val[j] = static_cast<float>(in.val[j]);
}
if constexpr (group_size > 0) {
__shared__ float s_max_vals[1024];
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
int64_t const num_groups = hidden_size / group_size;
int64_t const threads_per_group = blockDim.x / num_groups;
int64_t const thread_in_group = threadIdx.x % threads_per_group;
int64_t const group_offset =
threadIdx.x / threads_per_group * (group_size >> 2);
int64_t const thread_offset = group_offset + thread_in_group;
int64_t const thread_end = min(group_offset + (group_size >> 2),
static_cast<int64_t>(hidden_size >> 2));
vec_input = reinterpret_cast<vec4_t<scalar_t> const*>(&input[token_offset]);
vec_weight = reinterpret_cast<vec4_t<scalar_t> const*>(weight);
if constexpr (has_residual) {
vec4_t<scalar_t> r = vec_residual[i];
vec_residual =
reinterpret_cast<vec4_t<scalar_t> const*>(&residual[token_offset]);
}
int32_t const num_vec_elems = thread_end;
#pragma unroll 4
for (auto i = thread_offset; i < num_vec_elems; i += threads_per_group) {
vec4_t<scalar_t> in = vec_input[i];
vec4_t<scalar_t> const w = vec_weight[i];
vec4_t<float> x;
#pragma unroll
for (int j = 0; j < VEC_SIZE; ++j) {
x.val[j] += static_cast<float>(r.val[j]);
x.val[j] = static_cast<float>(in.val[j]);
}
if constexpr (has_residual) {
vec4_t<scalar_t> r = vec_residual[i];
#pragma unroll
for (int j = 0; j < VEC_SIZE; ++j) {
x.val[j] += static_cast<float>(r.val[j]);
}
}
#pragma unroll
for (int j = 0; j < VEC_SIZE; ++j) {
block_absmax_val_maybe =
fmaxf(block_absmax_val_maybe,
fabs(static_cast<scalar_t>(x.val[j] * rms) * w.val[j]));
}
}
s_max_vals[threadIdx.x] = block_absmax_val_maybe;
__syncthreads();
int64_t const warp_size = WARP_SIZE;
int64_t const num_warps = blockDim.x / warp_size;
int64_t const warp_id = threadIdx.x / warp_size;
int64_t const thread_in_warp = threadIdx.x % warp_size;
int64_t const groups_per_warp = (num_groups + num_warps - 1) / num_warps;
for (auto i = 0; i < groups_per_warp; ++i) {
int64_t const group_id = i * num_warps + warp_id;
if (group_id < num_groups) {
int64_t warp_start = group_id * threads_per_group;
int64_t const start = warp_start + thread_in_warp;
int64_t const warp_end = min(warp_start + threads_per_group,
static_cast<int64_t>(hidden_size));
for (auto j = start; j + warp_size < warp_end; j += warp_size) {
s_max_vals[start] =
fmaxf(s_max_vals[start], s_max_vals[j + warp_size]);
}
warpReduceMaxSpecialized(s_max_vals, start, thread_in_warp,
min(warp_end - warp_start, warp_size));
}
}
__syncthreads();
if (thread_in_group == 0 && thread_offset < thread_end) {
block_absmax_val_maybe = s_max_vals[threadIdx.x];
float scale = 0.0f;
if (scale_ub) {
scale = min(block_absmax_val_maybe, *scale_ub);
} else {
scale = block_absmax_val_maybe;
}
// token scale computation
scale = max(scale / qmax, min_scaling_factor<scalar_out_t>::val());
// Global output store
if constexpr (is_scale_transposed) {
all_token_scales[(threadIdx.x / threads_per_group) * gridDim.x +
blockIdx.x] = scale;
} else {
all_token_scales[blockIdx.x * num_groups +
threadIdx.x / threads_per_group] = scale;
}
}
__syncthreads();
} else {
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
vec_input = reinterpret_cast<vec4_t<scalar_t> const*>(&input[token_offset]);
vec_weight = reinterpret_cast<vec4_t<scalar_t> const*>(weight);
if constexpr (has_residual) {
vec_residual =
reinterpret_cast<vec4_t<scalar_t> const*>(&residual[token_offset]);
}
int32_t const num_vec_elems = (hidden_size >> 2);
#pragma unroll 4
for (auto i = threadIdx.x; i < num_vec_elems; i += blockDim.x) {
vec4_t<scalar_t> in = vec_input[i];
vec4_t<scalar_t> const w = vec_weight[i];
vec4_t<float> x;
#pragma unroll
for (int j = 0; j < VEC_SIZE; ++j) {
block_absmax_val_maybe =
fmaxf(block_absmax_val_maybe,
fabs(static_cast<scalar_t>(x.val[j] * rms) * w.val[j]));
for (int j = 0; j < VEC_SIZE; ++j) {
x.val[j] = static_cast<float>(in.val[j]);
}
if constexpr (has_residual) {
vec4_t<scalar_t> r = vec_residual[i];
#pragma unroll
for (int j = 0; j < VEC_SIZE; ++j) {
x.val[j] += static_cast<float>(r.val[j]);
}
}
#pragma unroll
for (int j = 0; j < VEC_SIZE; ++j) {
block_absmax_val_maybe =
fmaxf(block_absmax_val_maybe,
fabs(static_cast<scalar_t>(x.val[j] * rms) * w.val[j]));
}
}
}
using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
block_absmax_val_maybe =
BlockReduce(reduceStore)
.Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x);
using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
block_absmax_val_maybe =
BlockReduce(reduceStore)
.Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x);
__shared__ float s_token_scale;
if (threadIdx.x == 0) {
float scale = 0.0f;
if (scale_ub) {
scale = min(block_absmax_val_maybe, *scale_ub);
} else {
scale = block_absmax_val_maybe;
__shared__ float s_token_scale;
if (threadIdx.x == 0) {
float scale = 0.0f;
if (scale_ub) {
scale = min(block_absmax_val_maybe, *scale_ub);
} else {
scale = block_absmax_val_maybe;
}
// token scale computation
scale = max(scale / qmax, min_scaling_factor<scalar_out_t>::val());
s_token_scale = scale; // shared memory store
all_token_scales[blockIdx.x] = scale; // global output store
}
// token scale computation
scale = max(scale / qmax, min_scaling_factor<scalar_out_t>::val());
s_token_scale = scale; // shared memory store
all_token_scales[blockIdx.x] = scale; // global output store
}
__syncthreads();
__syncthreads();
*token_scale = s_token_scale;
*token_scale = s_token_scale;
}
}
// hidden_size must be a multiple of 4
template <typename scalar_t, typename scalar_out_t, bool is_scale_inverted,
bool has_residual = false>
bool has_residual = false, bool is_scale_transposed = false,
int32_t group_size = 0>
__device__ void norm_and_quant(scalar_out_t* __restrict__ output,
scalar_t const* __restrict__ input,
scalar_t const* __restrict__ weight,
float const rms, float const scale,
float const rms, float* const scale,
int32_t const hidden_size,
scalar_t* __restrict__ residual = nullptr) {
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
;
// Vectorized input/output/weight/residual to better utilize memory bandwidth.
vec4_t<scalar_t> const* vec_input =
@ -311,10 +509,26 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output,
}
q8x4_t<scalar_out_t> out;
float scale_val;
if constexpr (group_size > 0) {
int64_t const num_groups = hidden_size / group_size;
int64_t scale_idx = 0;
if constexpr (is_scale_transposed) {
scale_idx = (i * VEC_SIZE / group_size) * gridDim.x + blockIdx.x;
} else {
scale_idx = blockIdx.x * num_groups + i * VEC_SIZE / group_size;
}
scale_val =
is_scale_inverted ? 1.0f / scale[scale_idx] : scale[scale_idx];
} else {
scale_val = *scale;
}
#pragma unroll
for (int j = 0; j < VEC_SIZE; ++j) {
out.val[j] = ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(
static_cast<scalar_t>(x.val[j] * rms) * w.val[j], scale);
static_cast<scalar_t>(x.val[j] * rms) * w.val[j], scale_val);
}
vec_output[i] = out;
}

View File

@ -215,6 +215,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("rms_norm_dynamic_per_token_quant", torch::kCUDA,
&rms_norm_dynamic_per_token_quant);
// Fused Layernorm + Block quant kernels
ops.def(
"rms_norm_per_block_quant(Tensor! result, Tensor input, "
"Tensor weight, Tensor! scale, float epsilon, "
"Tensor? scale_ub, Tensor!? residual, int group_size, "
"bool is_scale_transposed) -> ()");
ops.impl("rms_norm_per_block_quant", torch::kCUDA, &rms_norm_per_block_quant);
// Rotary embedding
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops.def(

View File

@ -18,6 +18,9 @@ from vllm.config import (
VllmConfig,
)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
QuantKey,
@ -25,10 +28,12 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
cutlass_block_fp8_supported,
cutlass_fp8_supported,
maybe_create_device_identity,
)
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import is_deep_gemm_supported
from ..utils import override_cutlass_fp8_supported
from .backend import TestBackend
@ -44,7 +49,7 @@ class TestModel(torch.nn.Module):
self,
hidden_size: int,
eps: float,
static: bool,
group_shape: GroupShape,
cuda_force_torch: bool,
*args,
**kwargs,
@ -52,8 +57,17 @@ class TestModel(torch.nn.Module):
super().__init__(*args, **kwargs)
self.cuda_force_torch = cuda_force_torch
self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
if group_shape.is_per_group():
self.wscale = [
torch.rand(
(hidden_size // group_shape[1], hidden_size // group_shape[1]),
dtype=torch.float32,
)
for _ in range(3)
]
else:
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
static = group_shape == GroupShape.PER_TENSOR
quant_scale = ScaleDesc(torch.float32, static, group_shape)
self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
if static:
@ -61,18 +75,29 @@ class TestModel(torch.nn.Module):
else:
self.scale = [None for _ in range(3)]
self.w = [
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
for _ in range(3)
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE) for _ in range(3)
]
if not group_shape.is_per_group():
self.w = [self.w[0].t() for _ in range(3)]
with override_cutlass_fp8_supported(not cuda_force_torch):
self.fp8_linear = Fp8LinearOp(
act_quant_static=static,
if group_shape.is_per_group():
self.fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(group_shape[1], group_shape[1]),
act_quant_group_shape=group_shape,
cutlass_block_fp8_supported=cutlass_block_fp8_supported(),
use_aiter_and_is_supported=False,
)
self.enable_quant_fp8_custom_op = self.fp8_linear.input_quant_op.enabled()
else:
with override_cutlass_fp8_supported(not cuda_force_torch):
self.fp8_linear = Fp8LinearOp(
act_quant_static=static,
act_quant_group_shape=group_shape,
)
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
self.enable_rms_norm_custom_op = self.norm[0].enabled()
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
self.group_shape = group_shape
def forward(self, x):
# avoid having graph input be an arg to a pattern directly
@ -119,11 +144,19 @@ class TestModel(torch.nn.Module):
)
GROUP_SHAPES = [
GroupShape.PER_TOKEN,
GroupShape.PER_TENSOR,
GroupShape(1, 128),
GroupShape(1, 64),
]
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [64])
@pytest.mark.parametrize("hidden_size", [256])
@pytest.mark.parametrize("num_tokens", [257])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize("static", [True, False])
@pytest.mark.parametrize("group_shape", GROUP_SHAPES)
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
@pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False])
# cuda_force_torch used to test torch code path on platforms that
@ -139,7 +172,7 @@ def test_fusion_rmsnorm_quant(
hidden_size,
num_tokens,
eps,
static,
group_shape,
enable_rms_norm_custom_op,
enable_quant_fp8_custom_op,
cuda_force_torch,
@ -149,6 +182,15 @@ def test_fusion_rmsnorm_quant(
torch.manual_seed(1)
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths
if not enable_quant_fp8_custom_op and group_shape.is_per_group():
pytest.skip("Unsupported unwrapped quant fp8 op for blockwise quantization")
# Skip test for 64-bit group shape when running with cutlass or deepgemm
if group_shape == GroupShape(1, 64) and (
cutlass_block_fp8_supported() or is_deep_gemm_supported()
):
pytest.skip("Unsupported group shape 64 for CUTLASS/DeepGemm")
custom_ops = []
if enable_rms_norm_custom_op:
custom_ops.append("+rms_norm")
@ -172,8 +214,7 @@ def test_fusion_rmsnorm_quant(
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
backend2 = TestBackend(noop_pass, cleanup_pass)
model = TestModel(hidden_size, eps, static, cuda_force_torch)
model = TestModel(hidden_size, eps, group_shape, cuda_force_torch)
# First dimension dynamic
x = torch.rand(num_tokens, hidden_size)
torch._dynamo.mark_dynamic(x, 0)

View File

@ -8,6 +8,12 @@ import torch
import vllm._custom_ops as ops
from tests.kernels.utils import opcheck
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8,
)
from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_group_quant_int8,
)
DTYPES = [torch.bfloat16, torch.float]
QUANT_DTYPES = [torch.int8, torch.float8_e4m3fn]
@ -21,6 +27,7 @@ NUM_TOKENS_HIDDEN_SIZES = [
ADD_RESIDUAL = [False, True]
SCALE_UBS = [True, False]
GROUP_SIZES = [None, [1, 64], [1, 128]]
SEEDS = [0]
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
@ -45,12 +52,13 @@ def ref_rms_norm(
return out, residual
def ref_dynamic_per_token_quant(
def ref_dynamic_per_token_or_block_quant(
rms_norm_layer: RMSNorm,
x: torch.Tensor,
quant_dtype: torch.dtype,
residual: torch.Tensor | None,
scale_ub: torch.Tensor | None,
group_size: list[int] | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
if scale_ub is not None:
assert quant_dtype == torch.float8_e4m3fn
@ -59,13 +67,24 @@ def ref_dynamic_per_token_quant(
torch_out, residual = ref_rms_norm(rms_norm_layer, x, residual)
# Quant
if quant_dtype == torch.float8_e4m3fn:
torch_out, scales = ops.scaled_fp8_quant(
torch_out, scale_ub=scale_ub, use_per_token_if_dynamic=True
)
if group_size is not None:
if quant_dtype == torch.float8_e4m3fn:
torch_out, scales = per_token_group_quant_fp8(
torch_out, group_size=group_size[1], use_ue8m0=False
)
else:
assert quant_dtype == torch.int8
torch_out, scales = per_token_group_quant_int8(
torch_out, group_size=group_size[1]
)
else:
assert quant_dtype == torch.int8
torch_out, scales, _ = ops.scaled_int8_quant(torch_out)
if quant_dtype == torch.float8_e4m3fn:
torch_out, scales = ops.scaled_fp8_quant(
torch_out, scale_ub=scale_ub, use_per_token_if_dynamic=True
)
else:
assert quant_dtype == torch.int8
torch_out, scales, _ = ops.scaled_int8_quant(torch_out)
return torch_out, scales, residual
@ -76,24 +95,32 @@ def ref_impl(
quant_dtype: torch.dtype,
residual: torch.Tensor | None,
scale_ub: torch.Tensor | None,
group_size: list[int] | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
return ref_dynamic_per_token_quant(
rms_norm_layer, x, quant_dtype, residual, scale_ub
return ref_dynamic_per_token_or_block_quant(
rms_norm_layer, x, quant_dtype, residual, scale_ub, group_size
)
def ops_dynamic_per_token_quant(
def ops_dynamic_per_token_or_block_quant(
weight: torch.Tensor,
x: torch.Tensor,
quant_dtype: torch.dtype,
residual: torch.Tensor | None,
scale_ub: torch.Tensor | None,
group_size: list[int] | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
if residual is not None:
residual = residual.clone()
out, scales = ops.rms_norm_dynamic_per_token_quant(
x, weight, EPS, quant_dtype, scale_ub, residual
)
if group_size is not None:
out, scales = ops.rms_norm_per_block_quant(
x, weight, EPS, quant_dtype, group_size, scale_ub, residual, True
)
scales = scales.contiguous()
else:
out, scales = ops.rms_norm_dynamic_per_token_quant(
x, weight, EPS, quant_dtype, scale_ub, residual
)
return out, scales, residual
@ -103,8 +130,11 @@ def ops_impl(
quant_dtype: torch.dtype,
residual: torch.Tensor | None,
scale_ub: torch.Tensor | None,
group_size: list[int] | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
return ops_dynamic_per_token_quant(weight, x, quant_dtype, residual, scale_ub)
return ops_dynamic_per_token_or_block_quant(
weight, x, quant_dtype, residual, scale_ub, group_size
)
@pytest.mark.parametrize("num_tokens, hidden_size", NUM_TOKENS_HIDDEN_SIZES)
@ -112,6 +142,7 @@ def ops_impl(
@pytest.mark.parametrize("has_scale_ub", SCALE_UBS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("quant_dtype", QUANT_DTYPES)
@pytest.mark.parametrize("group_size", GROUP_SIZES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
@ -122,6 +153,7 @@ def test_rms_norm(
has_scale_ub: bool,
dtype: torch.dtype,
quant_dtype: torch.dtype,
group_size: list[int] | None,
seed: int,
device: str,
) -> None:
@ -130,6 +162,14 @@ def test_rms_norm(
torch.cuda.manual_seed(seed)
torch.set_default_device(device)
if group_size is not None and hidden_size % group_size[1] != 0:
# skip
return
if group_size is not None and has_scale_ub:
# blockwise baseline doesn't support scale_ub
return
if has_scale_ub and quant_dtype != torch.float8_e4m3fn:
# skip
return
@ -150,10 +190,10 @@ def test_rms_norm(
scale_ub = None
ref_out, ref_scales, ref_residual = ref_impl(
layer, x, quant_dtype, residual, scale_ub
layer, x, quant_dtype, residual, scale_ub, group_size
)
ops_out, ops_scales, ops_residual = ops_impl(
layer.weight, x, quant_dtype, residual, scale_ub
layer.weight, x, quant_dtype, residual, scale_ub, group_size
)
assert ref_out.dtype == quant_dtype
@ -166,11 +206,15 @@ def test_rms_norm(
assert torch.allclose(ref_scales, ops_scales)
a = ref_out.to(dtype=torch.float32)
b = ops_out.to(dtype=torch.float32)
ok = torch.allclose(a, b)
ok = torch.allclose(a, b, atol=1e-6)
if not ok:
# fallback: compare dequantized values with relaxed tolerance
a_deq = a * ref_scales.view(-1, 1)
b_deq = b * ops_scales.view(-1, 1)
if group_size is None:
a_deq = a * ref_scales.view(-1, 1)
b_deq = b * ops_scales.view(-1, 1)
else:
a_deq = a * ref_scales.repeat_interleave(group_size[1], dim=1)
b_deq = b * ops_scales.repeat_interleave(group_size[1], dim=1)
# NOTE: It is possible that some future test cases trigger this
# max diff due to precision issues. If such an error is
# encountered, it's recommended to inspect the differences between

View File

@ -436,6 +436,46 @@ def rms_norm_dynamic_per_token_quant(
return output, scales
# fused quant layer norm ops blocked
def rms_norm_per_block_quant(
input: torch.Tensor,
weight: torch.Tensor,
epsilon: float,
quant_dtype: torch.dtype,
group_size: list[int],
scale_ub: torch.Tensor | None = None,
residual: torch.Tensor | None = None,
is_scale_transposed: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
assert len(group_size) == 2
output = torch.empty_like(input, dtype=quant_dtype)
if is_scale_transposed:
scales = torch.empty(
(input.shape[-1] // group_size[1], input.numel() // input.shape[-1]),
device=input.device,
dtype=torch.float32,
).transpose(0, 1)
else:
scales = torch.empty(
(input.numel() // input.shape[-1], input.shape[-1] // group_size[1]),
device=input.device,
dtype=torch.float32,
)
torch.ops._C.rms_norm_per_block_quant(
output,
input,
weight,
scales,
epsilon,
scale_ub,
residual,
group_size[1],
is_scale_transposed,
)
return output, scales
# quantization ops
# awq
def awq_dequantize(

View File

@ -15,13 +15,22 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
QuantKey,
ScaleDesc,
kFp8Dynamic64Sym,
kFp8Dynamic128Sym,
kFp8DynamicTensorSym,
kFp8DynamicTokenSym,
kFp8StaticTensorSym,
kNvfp4Quant,
kStaticTensorScale,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_block_fp8_supported,
)
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import (
is_deep_gemm_e8m0_used,
should_use_deepgemm_for_fp8_linear_for_nk,
)
from .inductor_pass import enable_fake_mode
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
@ -58,6 +67,9 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
}
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default
if current_platform.is_cuda():
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
class FusedRMSQuantKey(NamedTuple):
@ -90,6 +102,18 @@ FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
FusedRMSQuantKey(
kFp8DynamicTokenSym, True
): torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
FusedRMSQuantKey(
kFp8Dynamic128Sym, False
): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501
FusedRMSQuantKey(
kFp8Dynamic128Sym, True
): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501
FusedRMSQuantKey(
kFp8Dynamic64Sym, False
): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501
FusedRMSQuantKey(
kFp8Dynamic64Sym, True
): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501
}
@ -100,6 +124,15 @@ class RMSNormQuantPattern:
config = get_current_vllm_config()
self.model_dtype = config.model_config.dtype if config.model_config else None
# groupwise FP8 linear uses col major scales if deepgemm and cutlass
using_deepgemm = should_use_deepgemm_for_fp8_linear_for_nk(
self.model_dtype,
config.model_config.hf_config.intermediate_size,
config.model_config.hf_config.hidden_size,
)
use_col_major_scales = using_deepgemm or cutlass_block_fp8_supported()
use_e8m0 = is_deep_gemm_e8m0_used() if using_deepgemm else False
assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
self.FUSED_OP = FUSED_OPS[key]
@ -108,7 +141,9 @@ class RMSNormQuantPattern:
if not key.fused_add
else MatcherFusedAddRMSNorm(epsilon)
)
self.quant_matcher = MatcherQuantFP8(key.quant)
self.quant_matcher = MatcherQuantFP8(
key.quant, use_col_major_scales=use_col_major_scales, use_e8m0=use_e8m0
)
class RMSNormStaticQuantPattern(RMSNormQuantPattern):
@ -218,6 +253,120 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
)
class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
def __init__(
self,
epsilon: float,
quant_dtype: torch.dtype,
group_shape: GroupShape,
symmetric=True,
):
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=True,
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
)
self.group_shape = group_shape
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass):
def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor):
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
result, scale = self.quant_matcher(result_rms)
return result, residual, scale
def replacement(
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
):
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
result = torch.empty_like(input, dtype=self.quant_dtype)
scale = self.quant_matcher.make_scale(
input, transposed=self.quant_matcher.use_col_major_scales
)
at = auto_functionalized(
self.FUSED_OP,
result=result,
input=input,
weight=weight,
scale=scale,
epsilon=self.epsilon,
scale_ub=None,
residual=residual,
group_size=self.group_shape[1],
is_scale_transposed=self.quant_matcher.use_col_major_scales,
)
# result, residual, scale
return at[1], at[3], at[2]
pm.register_replacement(
pattern,
replacement,
self.rmsnorm_matcher.inputs(),
pm.fwd_only,
pm_pass,
)
class RMSNormGroupQuantPattern(RMSNormQuantPattern):
def __init__(
self,
epsilon: float,
quant_dtype: torch.dtype,
group_shape: GroupShape,
symmetric=True,
):
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=False,
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
)
self.group_shape = group_shape
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass):
def pattern(input: torch.Tensor, weight: torch.Tensor):
result_rms = self.rmsnorm_matcher(input, weight)
result, scale = self.quant_matcher(result_rms)
return result, scale
def replacement(input: torch.Tensor, weight: torch.Tensor):
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
result = torch.empty_like(input, dtype=self.quant_dtype)
scale = self.quant_matcher.make_scale(
input, transposed=self.quant_matcher.use_col_major_scales
)
at = auto_functionalized(
self.FUSED_OP,
result=result,
input=input,
weight=weight,
scale=scale,
epsilon=self.epsilon,
scale_ub=None,
residual=None,
group_size=self.group_shape[1],
is_scale_transposed=self.quant_matcher.use_col_major_scales,
)
# result, scale
return at[1], at[2]
pm.register_replacement(
pattern,
replacement,
self.rmsnorm_matcher.inputs(),
pm.fwd_only,
pm_pass,
)
class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
def __init__(
self,
@ -340,6 +489,25 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
# Make sure fused add patterns are before simple rms norm,
# as the latter is a subset of the former in torch ops
for epsilon in [1e-5, 1e-6]:
# Fuse fused_add_rms_norm + fp8 group quant
FusedAddRMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128)
).register(self.patterns)
# Fuse rms_norm + fp8 group quant
RMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128)
).register(self.patterns)
FusedAddRMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64)
).register(self.patterns)
# Fuse rms_norm + fp8 group quant
RMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64)
).register(self.patterns)
# Fuse fused_add_rms_norm + static fp8 quant
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns
@ -366,9 +534,11 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
def uuid(self) -> Any:
return self.hash_source(
self,
RMSNormGroupQuantPattern,
RMSNormQuantPattern,
RMSNormStaticQuantPattern,
RMSNormDynamicQuantPattern,
FusedAddRMSNormStaticQuantPattern,
FusedAddRMSNormDynamicQuantPattern,
FusedAddRMSNormGroupQuantPattern,
)

View File

@ -13,6 +13,8 @@ from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
_normalize_quant_group_shape,
kFp8Dynamic64Sym,
kFp8Dynamic128Sym,
kFp8DynamicTensorSym,
kFp8DynamicTokenSym,
kFp8StaticTensorSym,
@ -35,6 +37,10 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
if current_platform.is_cuda():
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
SILU_MUL_OP = torch.ops._C.silu_and_mul.default
@ -224,12 +230,20 @@ class MatcherFusedAddRMSNorm(MatcherCustomOp):
class MatcherQuantFP8(MatcherCustomOp):
def __init__(self, quant_key: QuantKey, enabled: bool | None = None):
def __init__(
self,
quant_key: QuantKey,
enabled: bool | None = None,
use_col_major_scales: bool = False,
use_e8m0: bool = False,
):
if enabled is None:
enabled = QuantFP8.enabled()
super().__init__(enabled)
self.quant_key = quant_key
self.use_col_major_scales = use_col_major_scales
self.use_e8m0 = use_e8m0
assert quant_key in QUANT_OPS, f"unsupported quantization scheme {quant_key}"
self.QUANT_OP = QUANT_OPS[quant_key]
@ -248,6 +262,27 @@ class MatcherQuantFP8(MatcherCustomOp):
input.shape, device=input.device, dtype=self.quant_key.dtype
)
if self.quant_key.scale.group_shape.is_per_group():
assert scale is None
scale = self.make_scale(input, transposed=self.use_col_major_scales)
finfo = torch.finfo(self.quant_key.dtype)
fp8_min = finfo.min
fp8_max = finfo.max
_, result, scale = auto_functionalized(
self.QUANT_OP,
input=input,
output_q=result,
output_s=scale,
group_size=self.quant_key.scale.group_shape[1],
eps=1e-10,
fp8_min=fp8_min,
fp8_max=fp8_max,
scale_ue8m0=self.use_e8m0,
)
return result, scale
if self.quant_key.scale.static:
assert scale is not None
_, result = auto_functionalized(
@ -269,7 +304,7 @@ class MatcherQuantFP8(MatcherCustomOp):
) -> tuple[torch.Tensor, torch.Tensor]:
return self.quant_fp8(input, scale)
def make_scale(self, input: torch.Tensor):
def make_scale(self, input: torch.Tensor, transposed: bool = False):
normalized_group_shape = _normalize_quant_group_shape(
input, self.quant_key.scale.group_shape
)
@ -277,6 +312,11 @@ class MatcherQuantFP8(MatcherCustomOp):
input.shape[0] // normalized_group_shape[0],
input.shape[1] // normalized_group_shape[1],
)
if transposed:
scale_shape = tuple(reversed(scale_shape))
return torch.empty(
scale_shape, device=input.device, dtype=torch.float32
).permute(-1, -2)
return torch.empty(scale_shape, device=input.device, dtype=torch.float32)

View File

@ -733,7 +733,7 @@ def per_token_group_quant_fp8(
assert out_q is None or out_q.shape == x.shape
x_q = out_q
if x_q is None:
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
x_q = torch.empty(x.shape, device=x.device, dtype=dtype)
# Allocate the scale tensor in either row- or column-major format.
if column_major_scales:

View File

@ -115,6 +115,12 @@ kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, kDynamicTokenScale, symmetric=True)
kNvfp4GroupScale = ScaleDesc(FP8_DTYPE, False, GroupShape(1, 16))
kNvfp4Quant = QuantKey(FP4_DTYPE, scale=kNvfp4GroupScale, scale2=kStaticTensorScale)
kDynamic128Scale = ScaleDesc(torch.float32, False, GroupShape(1, 128))
kFp8Dynamic128Sym = QuantKey(FP8_DTYPE, kDynamic128Scale, symmetric=True)
kDynamic64Scale = ScaleDesc(torch.float32, False, GroupShape(1, 64))
kFp8Dynamic64Sym = QuantKey(FP8_DTYPE, kDynamic64Scale, symmetric=True)
# Normalize the group_shape to the full extent for any dims that are -1
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape):

View File

@ -381,6 +381,22 @@ def should_use_deepgemm_for_fp8_linear(
)
def should_use_deepgemm_for_fp8_linear_for_nk(
output_dtype: torch.dtype,
shape0: int,
shape1: int,
supports_deep_gemm: bool | None = None,
):
if supports_deep_gemm is None:
supports_deep_gemm = is_deep_gemm_supported()
return (
supports_deep_gemm
and output_dtype == torch.bfloat16
and shape0 % 128 == 0
and shape1 % 128 == 0
)
__all__ = [
"calc_diff",
"fp8_gemm_nt",
@ -394,6 +410,7 @@ __all__ = [
"is_deep_gemm_supported",
"get_num_sms",
"should_use_deepgemm_for_fp8_linear",
"should_use_deepgemm_for_fp8_linear_for_nk",
"get_col_major_tma_aligned_tensor",
"get_mk_alignment_for_contiguous_layout",
]