From af0444bf40b7db2f3fb9fe1508d25ceba24cac87 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Sun, 7 Dec 2025 17:38:04 +0100 Subject: [PATCH] [Performance] Fused blockwise quant RMS norm (#27883) Signed-off-by: ElizaWszola Signed-off-by: yewentao256 Co-authored-by: yewentao256 --- .../fused_kernels/layernorm_rms_benchmarks.py | 89 +++- csrc/dispatch_utils.h | 18 + csrc/ops.h | 7 + ...fused_layernorm_dynamic_per_token_quant.cu | 144 ++++++- .../fused_kernels/layernorm_utils.cuh | 408 +++++++++++++----- csrc/torch_bindings.cpp | 8 + tests/compile/test_fusion.py | 69 ++- .../core/test_fused_quant_layernorm.py | 82 +++- vllm/_custom_ops.py | 40 ++ vllm/compilation/fusion.py | 172 +++++++- vllm/compilation/matcher_utils.py | 44 +- .../layers/quantization/utils/fp8_utils.py | 2 +- .../layers/quantization/utils/quant_utils.py | 6 + vllm/utils/deep_gemm.py | 17 + 14 files changed, 949 insertions(+), 157 deletions(-) diff --git a/benchmarks/fused_kernels/layernorm_rms_benchmarks.py b/benchmarks/fused_kernels/layernorm_rms_benchmarks.py index d809bf1db8cbc..fb3329975cee3 100644 --- a/benchmarks/fused_kernels/layernorm_rms_benchmarks.py +++ b/benchmarks/fused_kernels/layernorm_rms_benchmarks.py @@ -14,6 +14,9 @@ from tqdm import tqdm import vllm._custom_ops as ops from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8, +) @dataclass @@ -22,6 +25,7 @@ class bench_params_t: hidden_size: int add_residual: bool dtype: torch.dtype + group_size: list[int] def description(self): return ( @@ -29,6 +33,7 @@ class bench_params_t: f"x D {self.hidden_size} " f"x R {self.add_residual} " f"x DT {self.dtype}" + f"x GS {self.group_size}" ) @@ -38,10 +43,11 @@ def get_bench_params() -> list[bench_params_t]: HIDDEN_SIZES = list(range(1024, 8129, 1024)) ADD_RESIDUAL = [True, False] DTYPES = [torch.bfloat16, torch.float] + GROUP_SIZES = [[1, 64], [1, 128]] - combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES) + combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES, GROUP_SIZES) bench_params = list( - map(lambda x: bench_params_t(x[0], x[1], x[2], x[3]), combinations) + map(lambda x: bench_params_t(x[0], x[1], x[2], x[3], x[4]), combinations) ) return bench_params @@ -52,6 +58,7 @@ def unfused_int8_impl( x: torch.Tensor, residual: torch.Tensor | None, quant_dtype: torch.dtype, + group_size: list[int], ): # Norm torch_out = None @@ -69,6 +76,7 @@ def unfused_fp8_impl( x: torch.Tensor, residual: torch.Tensor | None, quant_dtype: torch.dtype, + group_size: list[int], ): # Norm torch_out = None @@ -81,23 +89,63 @@ def unfused_fp8_impl( torch_out, _ = ops.scaled_fp8_quant(torch_out) +def unfused_groupwise_fp8_impl( + rms_norm_layer: RMSNorm, + x: torch.Tensor, + residual: torch.Tensor | None, + quant_dtype: torch.dtype, + group_size: list[int], +): + # Norm + torch_out = None + if residual is None: + torch_out = rms_norm_layer.forward_cuda(x, residual) + else: + torch_out, _ = rms_norm_layer.forward_cuda(x, residual) + + # Quant + torch_out, _ = per_token_group_quant_fp8( + torch_out, group_size=group_size[1], use_ue8m0=False + ) + + def fused_impl( rms_norm_layer: RMSNorm, # this stores the weights x: torch.Tensor, residual: torch.Tensor | None, quant_dtype: torch.dtype, + group_size: list[int], ): out, _ = ops.rms_norm_dynamic_per_token_quant( x, rms_norm_layer.weight, 1e-6, quant_dtype, residual=residual ) +def fused_groupwise_impl( + rms_norm_layer: RMSNorm, # this stores the weights + x: torch.Tensor, + residual: torch.Tensor | None, + quant_dtype: torch.dtype, + group_size: list[int], +): + out, _ = ops.rms_norm_per_block_quant( + x, + rms_norm_layer.weight, + 1e-6, + quant_dtype, + group_size, + residual=residual, + is_scale_transposed=True, + ) + + # Bench functions def bench_fn( rms_norm_layer: RMSNorm, x: torch.Tensor, residual: torch.Tensor, quant_dtype: torch.dtype, + group_size: list[int], label: str, sub_label: str, fn: Callable, @@ -110,10 +158,11 @@ def bench_fn( "x": x, "residual": residual, "quant_dtype": quant_dtype, + "group_size": group_size, "fn": fn, } return TBenchmark.Timer( - stmt="fn(rms_norm_layer, x, residual, quant_dtype)", + stmt="fn(rms_norm_layer, x, residual, quant_dtype, group_size)", globals=globals, label=label, sub_label=sub_label, @@ -147,6 +196,7 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu x, residual, torch.int8, + params.group_size, label, sub_label, unfused_int8_impl, @@ -161,6 +211,7 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu x, residual, torch.float8_e4m3fn, + params.group_size, label, sub_label, unfused_fp8_impl, @@ -175,6 +226,7 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu x, residual, torch.int8, + params.group_size, label, sub_label, fused_impl, @@ -189,6 +241,7 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu x, residual, torch.float8_e4m3fn, + params.group_size, label, sub_label, fused_impl, @@ -196,6 +249,36 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu ) ) + # unfused groupwise fp8 impl. + timers.append( + bench_fn( + layer, + x, + residual, + torch.float8_e4m3fn, + params.group_size, + label, + sub_label, + unfused_groupwise_fp8_impl, + "unfused_groupwise_fp8_impl", + ) + ) + + # fused groupwise fp8 impl. + timers.append( + bench_fn( + layer, + x, + residual, + torch.float8_e4m3fn, + params.group_size, + label, + sub_label, + fused_groupwise_impl, + "fused_groupwise_fp8_impl", + ) + ) + print_timers(timers) return timers diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index e1d131e4a7851..de0c505b7a62f 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -118,6 +118,24 @@ } \ } +#define VLLM_DISPATCH_BOOL(expr, const_expr, ...) \ + if (expr) { \ + constexpr bool const_expr = true; \ + __VA_ARGS__(); \ + } else { \ + constexpr bool const_expr = false; \ + __VA_ARGS__(); \ + } + +#define VLLM_DISPATCH_GROUP_SIZE(group_size, const_group_size, ...) \ + if (group_size == 128) { \ + constexpr int const_group_size = 128; \ + __VA_ARGS__(); \ + } else if (group_size == 64) { \ + constexpr int const_group_size = 64; \ + __VA_ARGS__(); \ + } + #define VLLM_DISPATCH_RANK234(NUM_DIMS, ...) \ switch (NUM_DIMS) { \ case 2: { \ diff --git a/csrc/ops.h b/csrc/ops.h index d302f04913266..9617d6358e18a 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -128,6 +128,13 @@ void rms_norm_dynamic_per_token_quant(torch::Tensor& out, std::optional scale_ub, std::optional residual); +void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input, + torch::Tensor const& weight, + torch::Tensor& scales, double const epsilon, + std::optional scale_ub, + std::optional residual, + int64_t group_size, bool is_scale_transposed); + void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, std::optional key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox); diff --git a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu index 92d6c2f402a24..2080ef3cd39b5 100644 --- a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu +++ b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu @@ -31,14 +31,15 @@ __device__ void rms_norm_dynamic_per_token_quant_vec( // RMS Norm + Quant if constexpr (std::is_same_v) { + token_scale = 1.0f / token_scale; vllm::vectorized::norm_and_quant( - out, input, weight, rms, 1.0f / token_scale, hidden_size, residual); + out, input, weight, rms, &token_scale, hidden_size, residual); } else { // FP8 - Do not invert token_scale for exact match with FBGemm vllm::vectorized::norm_and_quant( - out, input, weight, rms, token_scale, hidden_size, residual); + out, input, weight, rms, &token_scale, hidden_size, residual); } } @@ -75,14 +76,52 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel( // RMS Norm + Quant if constexpr (std::is_same_v) { + token_scale = 1.0f / token_scale; vllm::norm_and_quant( - out, input, weight, rms, 1.0f / token_scale, hidden_size, residual); + out, input, weight, rms, &token_scale, hidden_size, residual); } else { // FP8 - Do not invert s_token_scale for exact match with FBGemm vllm::norm_and_quant( - out, input, weight, rms, token_scale, hidden_size, residual); + out, input, weight, rms, &token_scale, hidden_size, residual); } } + +// RMS norm + quant kernel +template +__global__ void rms_norm_per_block_quant_kernel( + scalar_out_t* __restrict__ out, // [..., hidden_size] + float* __restrict__ scales, // [num_tokens, hidden_size / group_size] + // or + // [hidden_size / group_size, num_tokens] + scalar_t const* __restrict__ input, // [..., hidden_size] + scalar_t const* __restrict__ weight, // [hidden_size] + float const* scale_ub, float const var_epsilon, int32_t const hidden_size, + scalar_t* __restrict__ residual = nullptr) { + float rms; + // Compute RMS + // Always able to vectorize due to constraints on hidden_size + vllm::vectorized::compute_rms( + &rms, input, hidden_size, var_epsilon, residual); + + // Compute Scale + // Always able to vectorize due to constraints on hidden_size and group_size + vllm::vectorized::compute_dynamic_per_token_scales< + scalar_t, scalar_out_t, has_residual, is_scale_transposed, group_size>( + nullptr, scales, input, weight, rms, scale_ub, hidden_size, residual); + + // RMS Norm + Quant + // Always able to vectorize due to constraints on hidden_size + // For int8, don't invert token_scale here: do it inside the norm_and_quant + // kernel. We do it because particular elements of token_scale can be shared + // between multiple threads, so this way, we avoid extra synchronization + // overhead. + vllm::vectorized::norm_and_quant< + scalar_t, scalar_out_t, std::is_same_v, + has_residual, is_scale_transposed, group_size>( + out, input, weight, rms, scales, hidden_size, residual); +} + } // namespace vllm // Residual add + RMS norm + dynamic per token @@ -103,30 +142,19 @@ void rms_norm_dynamic_per_token_quant_dispatch( const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if (residual.has_value()) { + VLLM_DISPATCH_BOOL(residual.has_value(), has_residual, [&] { VLLM_DISPATCH_QUANT_TYPES( out.scalar_type(), "rms_norm_dynamic_per_token_quant_kernel", [&] { vllm::rms_norm_dynamic_per_token_quant_kernel + has_residual> <<>>( out.data_ptr(), scales.data_ptr(), input.data_ptr(), weight.data_ptr(), scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, - var_epsilon, hidden_size, residual->data_ptr()); + var_epsilon, hidden_size, + has_residual ? residual->data_ptr() : nullptr); }); - - } else { - VLLM_DISPATCH_QUANT_TYPES( - out.scalar_type(), "rms_norm_dynamic_per_token_quant_kernel", [&] { - vllm::rms_norm_dynamic_per_token_quant_kernel - <<>>( - out.data_ptr(), scales.data_ptr(), - input.data_ptr(), weight.data_ptr(), - scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, - var_epsilon, hidden_size, nullptr); - }); - } + }); } void rms_norm_dynamic_per_token_quant( @@ -157,3 +185,79 @@ void rms_norm_dynamic_per_token_quant( out, input, weight, scales, var_epsilon, scale_ub, residual); }); } + +// Residual add + RMS norm + dynamic per token +void rms_norm_per_block_quant_dispatch( + torch::Tensor& out, // [..., hidden_size] + torch::Tensor const& input, // [..., hidden_size] + torch::Tensor const& weight, // [hidden_size] + torch::Tensor& scales, // [num_tokens, hidden_size / group_size] or + // [hidden_size / group_size, num_tokens] + int32_t group_size, + double const var_epsilon, // Variance epsilon used in norm calculation + std::optional const& scale_ub, + std::optional& residual, bool is_scale_transposed) { + int32_t hidden_size = input.size(-1); + auto num_tokens = input.numel() / hidden_size; + + dim3 grid(num_tokens); + const int max_block_size = (num_tokens <= 256) ? 512 : 256; + dim3 block(std::min(hidden_size, max_block_size)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "rms_norm_per_block_quant_fp_dispatch", [&] { + using scalar_in_t = scalar_t; + VLLM_DISPATCH_GROUP_SIZE(group_size, gs, [&] { + VLLM_DISPATCH_BOOL(residual.has_value(), has_residual, [&] { + VLLM_DISPATCH_BOOL(is_scale_transposed, transpose_scale, [&] { + VLLM_DISPATCH_QUANT_TYPES( + out.scalar_type(), "rms_norm_per_block_quant_kernel", [&] { + vllm::rms_norm_per_block_quant_kernel + <<>>( + out.data_ptr(), scales.data_ptr(), + input.data_ptr(), + weight.data_ptr(), + scale_ub.has_value() ? scale_ub->data_ptr() + : nullptr, + var_epsilon, hidden_size, + has_residual ? residual->data_ptr() + : nullptr); + }); + }); + }); + }); + }); +} + +void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input, + torch::Tensor const& weight, + torch::Tensor& scales, double const var_epsilon, + std::optional scale_ub, + std::optional residual, + int64_t group_size, bool is_scale_transposed) { + static c10::ScalarType kFp8Type = is_fp8_ocp() + ? c10::ScalarType::Float8_e4m3fn + : c10::ScalarType::Float8_e4m3fnuz; + TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8); + TORCH_CHECK(out.is_contiguous() && input.is_contiguous()); + + if (scale_ub.has_value()) { + TORCH_CHECK(out.dtype() == kFp8Type); + } + TORCH_CHECK(weight.dtype() == input.dtype()); + TORCH_CHECK(scales.dtype() == torch::kFloat32); + if (residual) { + TORCH_CHECK(residual->scalar_type() == input.scalar_type()); + } + + TORCH_CHECK(group_size == 128 || group_size == 64, + "Unsupported group size: ", group_size); + + rms_norm_per_block_quant_dispatch(out, input, weight, scales, group_size, + var_epsilon, scale_ub, residual, + is_scale_transposed); +} \ No newline at end of file diff --git a/csrc/quantization/fused_kernels/layernorm_utils.cuh b/csrc/quantization/fused_kernels/layernorm_utils.cuh index 2d2fd771205c7..cb7adc3125734 100644 --- a/csrc/quantization/fused_kernels/layernorm_utils.cuh +++ b/csrc/quantization/fused_kernels/layernorm_utils.cuh @@ -9,6 +9,7 @@ #include "quant_conversions.cuh" #include "../../cub_helpers.h" +#include "../../cuda_compat.h" namespace vllm { @@ -43,62 +44,150 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, *rms = s_rms; } -template +__device__ float warpReduceMaxSpecialized(volatile float* val, int64_t tid, + int64_t thread_in_warp, + int64_t reduced_elems) { + static_assert(WARP_SIZE == 32 || WARP_SIZE == 64); + if constexpr (WARP_SIZE == 64) { + if (thread_in_warp + 64 < reduced_elems) + val[tid] = fmaxf(val[tid], val[tid + 64]); + } + if (thread_in_warp + 32 < reduced_elems) + val[tid] = fmaxf(val[tid], val[tid + 32]); + if (thread_in_warp + 16 < reduced_elems) + val[tid] = fmaxf(val[tid], val[tid + 16]); + if (thread_in_warp + 8 < reduced_elems) + val[tid] = fmaxf(val[tid], val[tid + 8]); + if (thread_in_warp + 4 < reduced_elems) + val[tid] = fmaxf(val[tid], val[tid + 4]); + if (thread_in_warp + 2 < reduced_elems) + val[tid] = fmaxf(val[tid], val[tid + 2]); + if (thread_in_warp + 1 < reduced_elems) + val[tid] = fmaxf(val[tid], val[tid + 1]); + return val[tid]; +} + +template __device__ void compute_dynamic_per_token_scales( float* __restrict__ token_scale, float* __restrict__ all_token_scales, scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight, float const rms, float const* __restrict__ scale_ub, - int32_t const hidden_size, - scalar_t const* __restrict__ residual = nullptr) { - int64_t const token_offset = blockIdx.x * static_cast(hidden_size); - ; - constexpr scalar_out_t qmax{quant_type_max_v}; - + int32_t const hidden_size, scalar_t const* __restrict__ residual = nullptr, + int32_t const group_size = 0) { float block_absmax_val_maybe = 0.0f; - for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) { - float x = static_cast(input[token_offset + i]); - if constexpr (has_residual) { - x += static_cast(residual[token_offset + i]); - } - - x = static_cast(static_cast(x * rms) * weight[i]); - block_absmax_val_maybe = fmaxf(block_absmax_val_maybe, fabsf(x)); - } - - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage reduceStore; - block_absmax_val_maybe = - BlockReduce(reduceStore) - .Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x); - - __shared__ float s_token_scale; - if (threadIdx.x == 0) { - float scale = 0.0f; - if (scale_ub) { - scale = min(block_absmax_val_maybe, *scale_ub); - } else { - scale = block_absmax_val_maybe; - } - // token scale computation - scale = max(scale / qmax, min_scaling_factor::val()); - s_token_scale = scale; // Shared memory store - all_token_scales[blockIdx.x] = scale; // Global output store - } + constexpr scalar_out_t qmax{quant_type_max_v}; __syncthreads(); + if (group_size > 0) { + __shared__ float s_max_vals[1024]; + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + int64_t num_groups = hidden_size / group_size; + int64_t const threads_per_group = blockDim.x / num_groups; + int64_t const thread_in_group = threadIdx.x % threads_per_group; + int64_t const group_offset = threadIdx.x / threads_per_group * group_size; + int64_t const thread_offset = group_offset + thread_in_group; + int64_t const thread_end = + min(group_offset + group_size, static_cast(hidden_size)); + for (auto i = thread_offset; i < thread_end; i += threads_per_group) { + float x = static_cast(input[token_offset + i]); + if constexpr (has_residual) { + x += static_cast(residual[token_offset + i]); + } + x = static_cast(static_cast(x * rms) * weight[i]); + block_absmax_val_maybe = fmaxf(block_absmax_val_maybe, fabsf(x)); + } + s_max_vals[threadIdx.x] = block_absmax_val_maybe; + __syncthreads(); - *token_scale = s_token_scale; + int64_t const warp_size = WARP_SIZE; + int64_t const num_warps = blockDim.x / warp_size; + int64_t const warp_id = threadIdx.x / warp_size; + int64_t const thread_in_warp = threadIdx.x % warp_size; + int64_t const groups_per_warp = (num_groups + num_warps - 1) / num_warps; + for (auto i = 0; i < groups_per_warp; ++i) { + int64_t const group_id = i * num_warps + warp_id; + if (group_id < num_groups) { + int64_t warp_start = group_id * threads_per_group; + int64_t const start = warp_start + thread_in_warp; + int64_t const warp_end = min(warp_start + threads_per_group, + static_cast(hidden_size)); + for (auto j = start; j + warp_size < warp_end; j += warp_size) { + s_max_vals[start] = + fmaxf(s_max_vals[start], s_max_vals[j + warp_size]); + } + warpReduceMaxSpecialized(s_max_vals, start, thread_in_warp, + min(warp_end - warp_start, warp_size)); + } + } + __syncthreads(); + + if (thread_in_group == 0 && thread_offset < thread_end) { + block_absmax_val_maybe = s_max_vals[threadIdx.x]; + float scale = 0.0f; + if (scale_ub) { + scale = min(block_absmax_val_maybe, *scale_ub); + } else { + scale = block_absmax_val_maybe; + } + // token scale computation + scale = max(scale / qmax, min_scaling_factor::val()); + // Global output store + if constexpr (is_scale_transposed) { + all_token_scales[(threadIdx.x / threads_per_group) * gridDim.x + + blockIdx.x] = scale; + } else { + all_token_scales[blockIdx.x * num_groups + + threadIdx.x / threads_per_group] = scale; + } + } + __syncthreads(); + } else { + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + + for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float x = static_cast(input[token_offset + i]); + if constexpr (has_residual) { + x += static_cast(residual[token_offset + i]); + } + + x = static_cast(static_cast(x * rms) * weight[i]); + block_absmax_val_maybe = fmaxf(block_absmax_val_maybe, fabsf(x)); + } + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + block_absmax_val_maybe = + BlockReduce(reduceStore) + .Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x); + + __shared__ float s_token_scale; + if (threadIdx.x == 0) { + float scale = 0.0f; + if (scale_ub) { + scale = min(block_absmax_val_maybe, *scale_ub); + } else { + scale = block_absmax_val_maybe; + } + // token scale computation + scale = max(scale / qmax, min_scaling_factor::val()); + s_token_scale = scale; // Shared memory store + all_token_scales[blockIdx.x] = scale; // Global output store + } + __syncthreads(); + + *token_scale = s_token_scale; + } } template + bool has_residual = false, bool is_scale_transposed = false> __device__ void norm_and_quant(scalar_out_t* __restrict__ output, scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight, - float const rms, float const scale, + float const rms, float* const scale, int32_t const hidden_size, - scalar_t* __restrict__ residual = nullptr) { + scalar_t* __restrict__ residual = nullptr, + int32_t const group_size = 0) { int64_t const token_offset = blockIdx.x * static_cast(hidden_size); - ; for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) { float x = static_cast(input[token_offset + i]); @@ -109,8 +198,21 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output, // Norm x = static_cast(static_cast(x * rms) * weight[i]); // Quant + // If groupwise is_scale_inverted is true, so we invert the scale here. + int64_t scale_idx = 0; + if (group_size > 0) { + if constexpr (is_scale_transposed) { + scale_idx = (i / group_size) * gridDim.x + blockIdx.x; + } else { + scale_idx = blockIdx.x * (hidden_size / group_size) + i / group_size; + } + } + auto scale_val = + (group_size > 0 + ? (is_scale_inverted ? 1.0f / scale[scale_idx] : scale[scale_idx]) + : *scale); output[token_offset + i] = - ScaledQuant::quant_fn(x, scale); + ScaledQuant::quant_fn(x, scale_val); } } @@ -178,95 +280,191 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, // Vectorized version of vllm::compute_dynamic_per_token_scales // hidden_size must be a multiple of 4 -template +template __device__ void compute_dynamic_per_token_scales( float* __restrict__ token_scale, float* __restrict__ all_token_scales, scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight, float const rms, float const* __restrict__ scale_ub, int32_t const hidden_size, scalar_t const* __restrict__ residual = nullptr) { - int64_t const token_offset = blockIdx.x * static_cast(hidden_size); - ; - - // Vectorized input/weight/residual to better utilize memory bandwidth. - vec4_t const* vec_input = - reinterpret_cast const*>(&input[token_offset]); - vec4_t const* vec_weight = - reinterpret_cast const*>(weight); - vec4_t const* vec_residual = nullptr; - if constexpr (has_residual) { - vec_residual = - reinterpret_cast const*>(&residual[token_offset]); - } - constexpr scalar_out_t qmax{quant_type_max_v}; const int VEC_SIZE = 4; - int32_t const num_vec_elems = hidden_size >> 2; float block_absmax_val_maybe = 0.0f; -#pragma unroll 4 - for (auto i = threadIdx.x; i < num_vec_elems; i += blockDim.x) { - vec4_t in = vec_input[i]; - vec4_t const w = vec_weight[i]; + // Vectorized input/weight/residual to better utilize memory bandwidth. + vec4_t const* vec_input = nullptr; + vec4_t const* vec_weight = nullptr; + vec4_t const* vec_residual = nullptr; - vec4_t x; -#pragma unroll - for (int j = 0; j < VEC_SIZE; ++j) { - x.val[j] = static_cast(in.val[j]); - } + if constexpr (group_size > 0) { + __shared__ float s_max_vals[1024]; + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + int64_t const num_groups = hidden_size / group_size; + int64_t const threads_per_group = blockDim.x / num_groups; + int64_t const thread_in_group = threadIdx.x % threads_per_group; + int64_t const group_offset = + threadIdx.x / threads_per_group * (group_size >> 2); + int64_t const thread_offset = group_offset + thread_in_group; + int64_t const thread_end = min(group_offset + (group_size >> 2), + static_cast(hidden_size >> 2)); + vec_input = reinterpret_cast const*>(&input[token_offset]); + vec_weight = reinterpret_cast const*>(weight); if constexpr (has_residual) { - vec4_t r = vec_residual[i]; + vec_residual = + reinterpret_cast const*>(&residual[token_offset]); + } + int32_t const num_vec_elems = thread_end; + +#pragma unroll 4 + for (auto i = thread_offset; i < num_vec_elems; i += threads_per_group) { + vec4_t in = vec_input[i]; + vec4_t const w = vec_weight[i]; + + vec4_t x; #pragma unroll for (int j = 0; j < VEC_SIZE; ++j) { - x.val[j] += static_cast(r.val[j]); + x.val[j] = static_cast(in.val[j]); + } + + if constexpr (has_residual) { + vec4_t r = vec_residual[i]; +#pragma unroll + for (int j = 0; j < VEC_SIZE; ++j) { + x.val[j] += static_cast(r.val[j]); + } + } + +#pragma unroll + for (int j = 0; j < VEC_SIZE; ++j) { + block_absmax_val_maybe = + fmaxf(block_absmax_val_maybe, + fabs(static_cast(x.val[j] * rms) * w.val[j])); } } + s_max_vals[threadIdx.x] = block_absmax_val_maybe; + __syncthreads(); + + int64_t const warp_size = WARP_SIZE; + int64_t const num_warps = blockDim.x / warp_size; + int64_t const warp_id = threadIdx.x / warp_size; + int64_t const thread_in_warp = threadIdx.x % warp_size; + int64_t const groups_per_warp = (num_groups + num_warps - 1) / num_warps; + for (auto i = 0; i < groups_per_warp; ++i) { + int64_t const group_id = i * num_warps + warp_id; + if (group_id < num_groups) { + int64_t warp_start = group_id * threads_per_group; + int64_t const start = warp_start + thread_in_warp; + int64_t const warp_end = min(warp_start + threads_per_group, + static_cast(hidden_size)); + for (auto j = start; j + warp_size < warp_end; j += warp_size) { + s_max_vals[start] = + fmaxf(s_max_vals[start], s_max_vals[j + warp_size]); + } + warpReduceMaxSpecialized(s_max_vals, start, thread_in_warp, + min(warp_end - warp_start, warp_size)); + } + } + __syncthreads(); + + if (thread_in_group == 0 && thread_offset < thread_end) { + block_absmax_val_maybe = s_max_vals[threadIdx.x]; + float scale = 0.0f; + if (scale_ub) { + scale = min(block_absmax_val_maybe, *scale_ub); + } else { + scale = block_absmax_val_maybe; + } + // token scale computation + scale = max(scale / qmax, min_scaling_factor::val()); + // Global output store + if constexpr (is_scale_transposed) { + all_token_scales[(threadIdx.x / threads_per_group) * gridDim.x + + blockIdx.x] = scale; + } else { + all_token_scales[blockIdx.x * num_groups + + threadIdx.x / threads_per_group] = scale; + } + } + __syncthreads(); + + } else { + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + vec_input = reinterpret_cast const*>(&input[token_offset]); + vec_weight = reinterpret_cast const*>(weight); + if constexpr (has_residual) { + vec_residual = + reinterpret_cast const*>(&residual[token_offset]); + } + + int32_t const num_vec_elems = (hidden_size >> 2); + +#pragma unroll 4 + for (auto i = threadIdx.x; i < num_vec_elems; i += blockDim.x) { + vec4_t in = vec_input[i]; + vec4_t const w = vec_weight[i]; + + vec4_t x; #pragma unroll - for (int j = 0; j < VEC_SIZE; ++j) { - block_absmax_val_maybe = - fmaxf(block_absmax_val_maybe, - fabs(static_cast(x.val[j] * rms) * w.val[j])); + for (int j = 0; j < VEC_SIZE; ++j) { + x.val[j] = static_cast(in.val[j]); + } + + if constexpr (has_residual) { + vec4_t r = vec_residual[i]; +#pragma unroll + for (int j = 0; j < VEC_SIZE; ++j) { + x.val[j] += static_cast(r.val[j]); + } + } + +#pragma unroll + for (int j = 0; j < VEC_SIZE; ++j) { + block_absmax_val_maybe = + fmaxf(block_absmax_val_maybe, + fabs(static_cast(x.val[j] * rms) * w.val[j])); + } } - } - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage reduceStore; - block_absmax_val_maybe = - BlockReduce(reduceStore) - .Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x); + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + block_absmax_val_maybe = + BlockReduce(reduceStore) + .Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x); - __shared__ float s_token_scale; - if (threadIdx.x == 0) { - float scale = 0.0f; - if (scale_ub) { - scale = min(block_absmax_val_maybe, *scale_ub); - } else { - scale = block_absmax_val_maybe; + __shared__ float s_token_scale; + if (threadIdx.x == 0) { + float scale = 0.0f; + if (scale_ub) { + scale = min(block_absmax_val_maybe, *scale_ub); + } else { + scale = block_absmax_val_maybe; + } + // token scale computation + scale = max(scale / qmax, min_scaling_factor::val()); + s_token_scale = scale; // shared memory store + all_token_scales[blockIdx.x] = scale; // global output store } - // token scale computation - scale = max(scale / qmax, min_scaling_factor::val()); - s_token_scale = scale; // shared memory store - all_token_scales[blockIdx.x] = scale; // global output store - } - __syncthreads(); + __syncthreads(); - *token_scale = s_token_scale; + *token_scale = s_token_scale; + } } // hidden_size must be a multiple of 4 template + bool has_residual = false, bool is_scale_transposed = false, + int32_t group_size = 0> __device__ void norm_and_quant(scalar_out_t* __restrict__ output, scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight, - float const rms, float const scale, + float const rms, float* const scale, int32_t const hidden_size, scalar_t* __restrict__ residual = nullptr) { int64_t const token_offset = blockIdx.x * static_cast(hidden_size); - ; // Vectorized input/output/weight/residual to better utilize memory bandwidth. vec4_t const* vec_input = @@ -311,10 +509,26 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output, } q8x4_t out; + + float scale_val; + + if constexpr (group_size > 0) { + int64_t const num_groups = hidden_size / group_size; + int64_t scale_idx = 0; + if constexpr (is_scale_transposed) { + scale_idx = (i * VEC_SIZE / group_size) * gridDim.x + blockIdx.x; + } else { + scale_idx = blockIdx.x * num_groups + i * VEC_SIZE / group_size; + } + scale_val = + is_scale_inverted ? 1.0f / scale[scale_idx] : scale[scale_idx]; + } else { + scale_val = *scale; + } #pragma unroll for (int j = 0; j < VEC_SIZE; ++j) { out.val[j] = ScaledQuant::quant_fn( - static_cast(x.val[j] * rms) * w.val[j], scale); + static_cast(x.val[j] * rms) * w.val[j], scale_val); } vec_output[i] = out; } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 23ac1d9abeea9..db37a9b9b88e3 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -215,6 +215,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("rms_norm_dynamic_per_token_quant", torch::kCUDA, &rms_norm_dynamic_per_token_quant); + // Fused Layernorm + Block quant kernels + ops.def( + "rms_norm_per_block_quant(Tensor! result, Tensor input, " + "Tensor weight, Tensor! scale, float epsilon, " + "Tensor? scale_ub, Tensor!? residual, int group_size, " + "bool is_scale_transposed) -> ()"); + ops.impl("rms_norm_per_block_quant", torch::kCUDA, &rms_norm_per_block_quant); + // Rotary embedding // Apply GPT-NeoX or GPT-J style rotary embedding to query and key. ops.def( diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index d0ba8385f4a01..2ad34a79859a3 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -18,6 +18,9 @@ from vllm.config import ( VllmConfig, ) from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + W8A8BlockFp8LinearOp, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, QuantKey, @@ -25,10 +28,12 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp, + cutlass_block_fp8_supported, cutlass_fp8_supported, maybe_create_device_identity, ) from vllm.platforms import current_platform +from vllm.utils.deep_gemm import is_deep_gemm_supported from ..utils import override_cutlass_fp8_supported from .backend import TestBackend @@ -44,7 +49,7 @@ class TestModel(torch.nn.Module): self, hidden_size: int, eps: float, - static: bool, + group_shape: GroupShape, cuda_force_torch: bool, *args, **kwargs, @@ -52,8 +57,17 @@ class TestModel(torch.nn.Module): super().__init__(*args, **kwargs) self.cuda_force_torch = cuda_force_torch self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)] - self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] - group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN + if group_shape.is_per_group(): + self.wscale = [ + torch.rand( + (hidden_size // group_shape[1], hidden_size // group_shape[1]), + dtype=torch.float32, + ) + for _ in range(3) + ] + else: + self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] + static = group_shape == GroupShape.PER_TENSOR quant_scale = ScaleDesc(torch.float32, static, group_shape) self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True) if static: @@ -61,18 +75,29 @@ class TestModel(torch.nn.Module): else: self.scale = [None for _ in range(3)] self.w = [ - torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() - for _ in range(3) + torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE) for _ in range(3) ] + if not group_shape.is_per_group(): + self.w = [self.w[0].t() for _ in range(3)] - with override_cutlass_fp8_supported(not cuda_force_torch): - self.fp8_linear = Fp8LinearOp( - act_quant_static=static, + if group_shape.is_per_group(): + self.fp8_linear = W8A8BlockFp8LinearOp( + weight_group_shape=GroupShape(group_shape[1], group_shape[1]), act_quant_group_shape=group_shape, + cutlass_block_fp8_supported=cutlass_block_fp8_supported(), + use_aiter_and_is_supported=False, ) + self.enable_quant_fp8_custom_op = self.fp8_linear.input_quant_op.enabled() + else: + with override_cutlass_fp8_supported(not cuda_force_torch): + self.fp8_linear = Fp8LinearOp( + act_quant_static=static, + act_quant_group_shape=group_shape, + ) + self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled() self.enable_rms_norm_custom_op = self.norm[0].enabled() - self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled() + self.group_shape = group_shape def forward(self, x): # avoid having graph input be an arg to a pattern directly @@ -119,11 +144,19 @@ class TestModel(torch.nn.Module): ) +GROUP_SHAPES = [ + GroupShape.PER_TOKEN, + GroupShape.PER_TENSOR, + GroupShape(1, 128), + GroupShape(1, 64), +] + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("hidden_size", [64]) +@pytest.mark.parametrize("hidden_size", [256]) @pytest.mark.parametrize("num_tokens", [257]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) -@pytest.mark.parametrize("static", [True, False]) +@pytest.mark.parametrize("group_shape", GROUP_SHAPES) @pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False]) @pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False]) # cuda_force_torch used to test torch code path on platforms that @@ -139,7 +172,7 @@ def test_fusion_rmsnorm_quant( hidden_size, num_tokens, eps, - static, + group_shape, enable_rms_norm_custom_op, enable_quant_fp8_custom_op, cuda_force_torch, @@ -149,6 +182,15 @@ def test_fusion_rmsnorm_quant( torch.manual_seed(1) maybe_create_device_identity() # needed for certain non-cutlass fp8 paths + if not enable_quant_fp8_custom_op and group_shape.is_per_group(): + pytest.skip("Unsupported unwrapped quant fp8 op for blockwise quantization") + + # Skip test for 64-bit group shape when running with cutlass or deepgemm + if group_shape == GroupShape(1, 64) and ( + cutlass_block_fp8_supported() or is_deep_gemm_supported() + ): + pytest.skip("Unsupported group shape 64 for CUTLASS/DeepGemm") + custom_ops = [] if enable_rms_norm_custom_op: custom_ops.append("+rms_norm") @@ -172,8 +214,7 @@ def test_fusion_rmsnorm_quant( backend = TestBackend(noop_pass, fusion_pass, cleanup_pass) backend2 = TestBackend(noop_pass, cleanup_pass) - model = TestModel(hidden_size, eps, static, cuda_force_torch) - + model = TestModel(hidden_size, eps, group_shape, cuda_force_torch) # First dimension dynamic x = torch.rand(num_tokens, hidden_size) torch._dynamo.mark_dynamic(x, 0) diff --git a/tests/kernels/core/test_fused_quant_layernorm.py b/tests/kernels/core/test_fused_quant_layernorm.py index b5fc653ca7353..094073f5d3f92 100644 --- a/tests/kernels/core/test_fused_quant_layernorm.py +++ b/tests/kernels/core/test_fused_quant_layernorm.py @@ -8,6 +8,12 @@ import torch import vllm._custom_ops as ops from tests.kernels.utils import opcheck from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8, +) +from vllm.model_executor.layers.quantization.utils.int8_utils import ( + per_token_group_quant_int8, +) DTYPES = [torch.bfloat16, torch.float] QUANT_DTYPES = [torch.int8, torch.float8_e4m3fn] @@ -21,6 +27,7 @@ NUM_TOKENS_HIDDEN_SIZES = [ ADD_RESIDUAL = [False, True] SCALE_UBS = [True, False] +GROUP_SIZES = [None, [1, 64], [1, 128]] SEEDS = [0] CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] @@ -45,12 +52,13 @@ def ref_rms_norm( return out, residual -def ref_dynamic_per_token_quant( +def ref_dynamic_per_token_or_block_quant( rms_norm_layer: RMSNorm, x: torch.Tensor, quant_dtype: torch.dtype, residual: torch.Tensor | None, scale_ub: torch.Tensor | None, + group_size: list[int] | None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: if scale_ub is not None: assert quant_dtype == torch.float8_e4m3fn @@ -59,13 +67,24 @@ def ref_dynamic_per_token_quant( torch_out, residual = ref_rms_norm(rms_norm_layer, x, residual) # Quant - if quant_dtype == torch.float8_e4m3fn: - torch_out, scales = ops.scaled_fp8_quant( - torch_out, scale_ub=scale_ub, use_per_token_if_dynamic=True - ) + if group_size is not None: + if quant_dtype == torch.float8_e4m3fn: + torch_out, scales = per_token_group_quant_fp8( + torch_out, group_size=group_size[1], use_ue8m0=False + ) + else: + assert quant_dtype == torch.int8 + torch_out, scales = per_token_group_quant_int8( + torch_out, group_size=group_size[1] + ) else: - assert quant_dtype == torch.int8 - torch_out, scales, _ = ops.scaled_int8_quant(torch_out) + if quant_dtype == torch.float8_e4m3fn: + torch_out, scales = ops.scaled_fp8_quant( + torch_out, scale_ub=scale_ub, use_per_token_if_dynamic=True + ) + else: + assert quant_dtype == torch.int8 + torch_out, scales, _ = ops.scaled_int8_quant(torch_out) return torch_out, scales, residual @@ -76,24 +95,32 @@ def ref_impl( quant_dtype: torch.dtype, residual: torch.Tensor | None, scale_ub: torch.Tensor | None, + group_size: list[int] | None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: - return ref_dynamic_per_token_quant( - rms_norm_layer, x, quant_dtype, residual, scale_ub + return ref_dynamic_per_token_or_block_quant( + rms_norm_layer, x, quant_dtype, residual, scale_ub, group_size ) -def ops_dynamic_per_token_quant( +def ops_dynamic_per_token_or_block_quant( weight: torch.Tensor, x: torch.Tensor, quant_dtype: torch.dtype, residual: torch.Tensor | None, scale_ub: torch.Tensor | None, + group_size: list[int] | None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: if residual is not None: residual = residual.clone() - out, scales = ops.rms_norm_dynamic_per_token_quant( - x, weight, EPS, quant_dtype, scale_ub, residual - ) + if group_size is not None: + out, scales = ops.rms_norm_per_block_quant( + x, weight, EPS, quant_dtype, group_size, scale_ub, residual, True + ) + scales = scales.contiguous() + else: + out, scales = ops.rms_norm_dynamic_per_token_quant( + x, weight, EPS, quant_dtype, scale_ub, residual + ) return out, scales, residual @@ -103,8 +130,11 @@ def ops_impl( quant_dtype: torch.dtype, residual: torch.Tensor | None, scale_ub: torch.Tensor | None, + group_size: list[int] | None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: - return ops_dynamic_per_token_quant(weight, x, quant_dtype, residual, scale_ub) + return ops_dynamic_per_token_or_block_quant( + weight, x, quant_dtype, residual, scale_ub, group_size + ) @pytest.mark.parametrize("num_tokens, hidden_size", NUM_TOKENS_HIDDEN_SIZES) @@ -112,6 +142,7 @@ def ops_impl( @pytest.mark.parametrize("has_scale_ub", SCALE_UBS) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("quant_dtype", QUANT_DTYPES) +@pytest.mark.parametrize("group_size", GROUP_SIZES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() @@ -122,6 +153,7 @@ def test_rms_norm( has_scale_ub: bool, dtype: torch.dtype, quant_dtype: torch.dtype, + group_size: list[int] | None, seed: int, device: str, ) -> None: @@ -130,6 +162,14 @@ def test_rms_norm( torch.cuda.manual_seed(seed) torch.set_default_device(device) + if group_size is not None and hidden_size % group_size[1] != 0: + # skip + return + + if group_size is not None and has_scale_ub: + # blockwise baseline doesn't support scale_ub + return + if has_scale_ub and quant_dtype != torch.float8_e4m3fn: # skip return @@ -150,10 +190,10 @@ def test_rms_norm( scale_ub = None ref_out, ref_scales, ref_residual = ref_impl( - layer, x, quant_dtype, residual, scale_ub + layer, x, quant_dtype, residual, scale_ub, group_size ) ops_out, ops_scales, ops_residual = ops_impl( - layer.weight, x, quant_dtype, residual, scale_ub + layer.weight, x, quant_dtype, residual, scale_ub, group_size ) assert ref_out.dtype == quant_dtype @@ -166,11 +206,15 @@ def test_rms_norm( assert torch.allclose(ref_scales, ops_scales) a = ref_out.to(dtype=torch.float32) b = ops_out.to(dtype=torch.float32) - ok = torch.allclose(a, b) + ok = torch.allclose(a, b, atol=1e-6) if not ok: # fallback: compare dequantized values with relaxed tolerance - a_deq = a * ref_scales.view(-1, 1) - b_deq = b * ops_scales.view(-1, 1) + if group_size is None: + a_deq = a * ref_scales.view(-1, 1) + b_deq = b * ops_scales.view(-1, 1) + else: + a_deq = a * ref_scales.repeat_interleave(group_size[1], dim=1) + b_deq = b * ops_scales.repeat_interleave(group_size[1], dim=1) # NOTE: It is possible that some future test cases trigger this # max diff due to precision issues. If such an error is # encountered, it's recommended to inspect the differences between diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 94e27545239f4..77d5453291e3c 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -436,6 +436,46 @@ def rms_norm_dynamic_per_token_quant( return output, scales +# fused quant layer norm ops blocked +def rms_norm_per_block_quant( + input: torch.Tensor, + weight: torch.Tensor, + epsilon: float, + quant_dtype: torch.dtype, + group_size: list[int], + scale_ub: torch.Tensor | None = None, + residual: torch.Tensor | None = None, + is_scale_transposed: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + assert len(group_size) == 2 + output = torch.empty_like(input, dtype=quant_dtype) + if is_scale_transposed: + scales = torch.empty( + (input.shape[-1] // group_size[1], input.numel() // input.shape[-1]), + device=input.device, + dtype=torch.float32, + ).transpose(0, 1) + else: + scales = torch.empty( + (input.numel() // input.shape[-1], input.shape[-1] // group_size[1]), + device=input.device, + dtype=torch.float32, + ) + + torch.ops._C.rms_norm_per_block_quant( + output, + input, + weight, + scales, + epsilon, + scale_ub, + residual, + group_size[1], + is_scale_transposed, + ) + return output, scales + + # quantization ops # awq def awq_dequantize( diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 1d6e297b495eb..de083a2e5e3c7 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -15,13 +15,22 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, QuantKey, ScaleDesc, + kFp8Dynamic64Sym, + kFp8Dynamic128Sym, kFp8DynamicTensorSym, kFp8DynamicTokenSym, kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale, ) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + cutlass_block_fp8_supported, +) from vllm.platforms import current_platform +from vllm.utils.deep_gemm import ( + is_deep_gemm_e8m0_used, + should_use_deepgemm_for_fp8_linear_for_nk, +) from .inductor_pass import enable_fake_mode from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm @@ -58,6 +67,9 @@ QUANT_OPS: dict[QuantKey, OpOverload] = { } if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default +if current_platform.is_cuda(): + QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501 + QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501 class FusedRMSQuantKey(NamedTuple): @@ -90,6 +102,18 @@ FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = { FusedRMSQuantKey( kFp8DynamicTokenSym, True ): torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501 + FusedRMSQuantKey( + kFp8Dynamic128Sym, False + ): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501 + FusedRMSQuantKey( + kFp8Dynamic128Sym, True + ): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501 + FusedRMSQuantKey( + kFp8Dynamic64Sym, False + ): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501 + FusedRMSQuantKey( + kFp8Dynamic64Sym, True + ): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501 } @@ -100,6 +124,15 @@ class RMSNormQuantPattern: config = get_current_vllm_config() self.model_dtype = config.model_config.dtype if config.model_config else None + # groupwise FP8 linear uses col major scales if deepgemm and cutlass + using_deepgemm = should_use_deepgemm_for_fp8_linear_for_nk( + self.model_dtype, + config.model_config.hf_config.intermediate_size, + config.model_config.hf_config.hidden_size, + ) + use_col_major_scales = using_deepgemm or cutlass_block_fp8_supported() + use_e8m0 = is_deep_gemm_e8m0_used() if using_deepgemm else False + assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}" self.FUSED_OP = FUSED_OPS[key] @@ -108,7 +141,9 @@ class RMSNormQuantPattern: if not key.fused_add else MatcherFusedAddRMSNorm(epsilon) ) - self.quant_matcher = MatcherQuantFP8(key.quant) + self.quant_matcher = MatcherQuantFP8( + key.quant, use_col_major_scales=use_col_major_scales, use_e8m0=use_e8m0 + ) class RMSNormStaticQuantPattern(RMSNormQuantPattern): @@ -218,6 +253,120 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern): ) +class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern): + def __init__( + self, + epsilon: float, + quant_dtype: torch.dtype, + group_shape: GroupShape, + symmetric=True, + ): + scale = ScaleDesc(torch.float32, False, group_shape) + key = FusedRMSQuantKey( + fused_add=True, + quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric), + ) + self.group_shape = group_shape + super().__init__(epsilon, key) + + def register(self, pm_pass: PatternMatcherPass): + def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor): + result_rms, residual = self.rmsnorm_matcher(input, weight, residual) + result, scale = self.quant_matcher(result_rms) + return result, residual, scale + + def replacement( + input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor + ): + # In case we're matching native rms-norm, conversions might be + # optimized out. We convert here just to be safe. + input = input.to(dtype=self.model_dtype) + + result = torch.empty_like(input, dtype=self.quant_dtype) + scale = self.quant_matcher.make_scale( + input, transposed=self.quant_matcher.use_col_major_scales + ) + at = auto_functionalized( + self.FUSED_OP, + result=result, + input=input, + weight=weight, + scale=scale, + epsilon=self.epsilon, + scale_ub=None, + residual=residual, + group_size=self.group_shape[1], + is_scale_transposed=self.quant_matcher.use_col_major_scales, + ) + + # result, residual, scale + return at[1], at[3], at[2] + + pm.register_replacement( + pattern, + replacement, + self.rmsnorm_matcher.inputs(), + pm.fwd_only, + pm_pass, + ) + + +class RMSNormGroupQuantPattern(RMSNormQuantPattern): + def __init__( + self, + epsilon: float, + quant_dtype: torch.dtype, + group_shape: GroupShape, + symmetric=True, + ): + scale = ScaleDesc(torch.float32, False, group_shape) + key = FusedRMSQuantKey( + fused_add=False, + quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric), + ) + self.group_shape = group_shape + super().__init__(epsilon, key) + + def register(self, pm_pass: PatternMatcherPass): + def pattern(input: torch.Tensor, weight: torch.Tensor): + result_rms = self.rmsnorm_matcher(input, weight) + result, scale = self.quant_matcher(result_rms) + return result, scale + + def replacement(input: torch.Tensor, weight: torch.Tensor): + # In case we're matching native rms-norm, conversions might be + # optimized out. We convert here just to be safe. + input = input.to(dtype=self.model_dtype) + + result = torch.empty_like(input, dtype=self.quant_dtype) + scale = self.quant_matcher.make_scale( + input, transposed=self.quant_matcher.use_col_major_scales + ) + at = auto_functionalized( + self.FUSED_OP, + result=result, + input=input, + weight=weight, + scale=scale, + epsilon=self.epsilon, + scale_ub=None, + residual=None, + group_size=self.group_shape[1], + is_scale_transposed=self.quant_matcher.use_col_major_scales, + ) + + # result, scale + return at[1], at[2] + + pm.register_replacement( + pattern, + replacement, + self.rmsnorm_matcher.inputs(), + pm.fwd_only, + pm_pass, + ) + + class RMSNormDynamicQuantPattern(RMSNormQuantPattern): def __init__( self, @@ -340,6 +489,25 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass): # Make sure fused add patterns are before simple rms norm, # as the latter is a subset of the former in torch ops for epsilon in [1e-5, 1e-6]: + # Fuse fused_add_rms_norm + fp8 group quant + FusedAddRMSNormGroupQuantPattern( + epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128) + ).register(self.patterns) + + # Fuse rms_norm + fp8 group quant + RMSNormGroupQuantPattern( + epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128) + ).register(self.patterns) + + FusedAddRMSNormGroupQuantPattern( + epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64) + ).register(self.patterns) + + # Fuse rms_norm + fp8 group quant + RMSNormGroupQuantPattern( + epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64) + ).register(self.patterns) + # Fuse fused_add_rms_norm + static fp8 quant FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register( self.patterns @@ -366,9 +534,11 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass): def uuid(self) -> Any: return self.hash_source( self, + RMSNormGroupQuantPattern, RMSNormQuantPattern, RMSNormStaticQuantPattern, RMSNormDynamicQuantPattern, FusedAddRMSNormStaticQuantPattern, FusedAddRMSNormDynamicQuantPattern, + FusedAddRMSNormGroupQuantPattern, ) diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index e4cd063d2aee1..0c0bece9b3fda 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -13,6 +13,8 @@ from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, _normalize_quant_group_shape, + kFp8Dynamic64Sym, + kFp8Dynamic128Sym, kFp8DynamicTensorSym, kFp8DynamicTokenSym, kFp8StaticTensorSym, @@ -35,6 +37,10 @@ QUANT_OPS: dict[QuantKey, OpOverload] = { if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501 +if current_platform.is_cuda(): + QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501 + QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501 + SILU_MUL_OP = torch.ops._C.silu_and_mul.default @@ -224,12 +230,20 @@ class MatcherFusedAddRMSNorm(MatcherCustomOp): class MatcherQuantFP8(MatcherCustomOp): - def __init__(self, quant_key: QuantKey, enabled: bool | None = None): + def __init__( + self, + quant_key: QuantKey, + enabled: bool | None = None, + use_col_major_scales: bool = False, + use_e8m0: bool = False, + ): if enabled is None: enabled = QuantFP8.enabled() super().__init__(enabled) self.quant_key = quant_key + self.use_col_major_scales = use_col_major_scales + self.use_e8m0 = use_e8m0 assert quant_key in QUANT_OPS, f"unsupported quantization scheme {quant_key}" self.QUANT_OP = QUANT_OPS[quant_key] @@ -248,6 +262,27 @@ class MatcherQuantFP8(MatcherCustomOp): input.shape, device=input.device, dtype=self.quant_key.dtype ) + if self.quant_key.scale.group_shape.is_per_group(): + assert scale is None + scale = self.make_scale(input, transposed=self.use_col_major_scales) + + finfo = torch.finfo(self.quant_key.dtype) + fp8_min = finfo.min + fp8_max = finfo.max + + _, result, scale = auto_functionalized( + self.QUANT_OP, + input=input, + output_q=result, + output_s=scale, + group_size=self.quant_key.scale.group_shape[1], + eps=1e-10, + fp8_min=fp8_min, + fp8_max=fp8_max, + scale_ue8m0=self.use_e8m0, + ) + return result, scale + if self.quant_key.scale.static: assert scale is not None _, result = auto_functionalized( @@ -269,7 +304,7 @@ class MatcherQuantFP8(MatcherCustomOp): ) -> tuple[torch.Tensor, torch.Tensor]: return self.quant_fp8(input, scale) - def make_scale(self, input: torch.Tensor): + def make_scale(self, input: torch.Tensor, transposed: bool = False): normalized_group_shape = _normalize_quant_group_shape( input, self.quant_key.scale.group_shape ) @@ -277,6 +312,11 @@ class MatcherQuantFP8(MatcherCustomOp): input.shape[0] // normalized_group_shape[0], input.shape[1] // normalized_group_shape[1], ) + if transposed: + scale_shape = tuple(reversed(scale_shape)) + return torch.empty( + scale_shape, device=input.device, dtype=torch.float32 + ).permute(-1, -2) return torch.empty(scale_shape, device=input.device, dtype=torch.float32) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 7e1bda8639ac7..ad92f4ec63c34 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -733,7 +733,7 @@ def per_token_group_quant_fp8( assert out_q is None or out_q.shape == x.shape x_q = out_q if x_q is None: - x_q = torch.empty_like(x, device=x.device, dtype=dtype) + x_q = torch.empty(x.shape, device=x.device, dtype=dtype) # Allocate the scale tensor in either row- or column-major format. if column_major_scales: diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index d056d3404385a..92ee8c498e01f 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -115,6 +115,12 @@ kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, kDynamicTokenScale, symmetric=True) kNvfp4GroupScale = ScaleDesc(FP8_DTYPE, False, GroupShape(1, 16)) kNvfp4Quant = QuantKey(FP4_DTYPE, scale=kNvfp4GroupScale, scale2=kStaticTensorScale) +kDynamic128Scale = ScaleDesc(torch.float32, False, GroupShape(1, 128)) +kFp8Dynamic128Sym = QuantKey(FP8_DTYPE, kDynamic128Scale, symmetric=True) + +kDynamic64Scale = ScaleDesc(torch.float32, False, GroupShape(1, 64)) +kFp8Dynamic64Sym = QuantKey(FP8_DTYPE, kDynamic64Scale, symmetric=True) + # Normalize the group_shape to the full extent for any dims that are -1 def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape): diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index b25c1e3e1ece3..8545108a02666 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -381,6 +381,22 @@ def should_use_deepgemm_for_fp8_linear( ) +def should_use_deepgemm_for_fp8_linear_for_nk( + output_dtype: torch.dtype, + shape0: int, + shape1: int, + supports_deep_gemm: bool | None = None, +): + if supports_deep_gemm is None: + supports_deep_gemm = is_deep_gemm_supported() + return ( + supports_deep_gemm + and output_dtype == torch.bfloat16 + and shape0 % 128 == 0 + and shape1 % 128 == 0 + ) + + __all__ = [ "calc_diff", "fp8_gemm_nt", @@ -394,6 +410,7 @@ __all__ = [ "is_deep_gemm_supported", "get_num_sms", "should_use_deepgemm_for_fp8_linear", + "should_use_deepgemm_for_fp8_linear_for_nk", "get_col_major_tma_aligned_tensor", "get_mk_alignment_for_contiguous_layout", ]