From 06d490282f2bab6922137eb5230be9df5ebbe9c4 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Sun, 21 Dec 2025 12:41:57 -0500 Subject: [PATCH] [NVFP4][Perf] Tune NVFP4 input quant kernel for small batch size (#30897) Signed-off-by: mgoin --- benchmarks/kernels/bench_nvfp4_quant.py | 177 ++++++++++++++++++ .../activation_nvfp4_quant_fusion_kernels.cu | 5 +- csrc/quantization/fp4/nvfp4_experts_quant.cu | 31 ++- csrc/quantization/fp4/nvfp4_quant_kernels.cu | 62 ++---- csrc/quantization/fp4/nvfp4_utils.cuh | 65 +++---- 5 files changed, 243 insertions(+), 97 deletions(-) create mode 100644 benchmarks/kernels/bench_nvfp4_quant.py diff --git a/benchmarks/kernels/bench_nvfp4_quant.py b/benchmarks/kernels/bench_nvfp4_quant.py new file mode 100644 index 0000000000000..7517376535925 --- /dev/null +++ b/benchmarks/kernels/bench_nvfp4_quant.py @@ -0,0 +1,177 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse +import copy +import itertools + +import torch +from weight_shapes import WEIGHT_SHAPES + +from vllm import _custom_ops as ops +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types +from vllm.triton_utils import triton +from vllm.utils.flashinfer import flashinfer_fp4_quantize + +if not current_platform.has_device_capability(100): + raise RuntimeError("NVFP4 requires compute capability of 10.0 (Blackwell)") + +FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max + +PROVIDER_CFGS = { + "vllm": dict(backend="vllm", enabled=True), + "flashinfer": dict(backend="flashinfer", enabled=True), +} + +_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]] + + +def compute_global_scale(tensor: torch.Tensor) -> torch.Tensor: + """Compute global scale for FP4 quantization.""" + amax = torch.abs(tensor).max().to(torch.float32) + return FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / amax + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096], + x_log=False, + line_arg="provider", + line_vals=_enabled, + line_names=_enabled, + ylabel="us (lower is better)", + plot_name="NVFP4 Input Quantization Latency (us)", + args={}, + ) +) +def benchmark(batch_size, provider, N, K): + M = batch_size + device = "cuda" + dtype = torch.bfloat16 + + # Create input tensor + a = torch.randn((M, K), device=device, dtype=dtype) + + # Compute global scale for activation + a_global_scale = compute_global_scale(a) + + quantiles = [0.5, 0.2, 0.8] + + cfg = PROVIDER_CFGS[provider] + + if cfg["backend"] == "vllm": + # vLLM's FP4 quantization + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: ops.scaled_fp4_quant(a, a_global_scale), + quantiles=quantiles, + ) + elif cfg["backend"] == "flashinfer": + # FlashInfer's FP4 quantization + # Use is_sf_swizzled_layout=True to match vLLM's output format + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: flashinfer_fp4_quantize( + a, a_global_scale, is_sf_swizzled_layout=True + ), + quantiles=quantiles, + ) + + # Convert ms to us for better readability at small batch sizes + to_us = lambda t_ms: t_ms * 1000 + return to_us(ms), to_us(max_ms), to_us(min_ms) + + +def prepare_shapes(args): + out = [] + for model, tp_size in itertools.product(args.models, args.tp_sizes): + for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]): + KN[tp_dim] //= tp_size + KN.append(model) + out.append(KN) + return out + + +def _test_accuracy_once(M: int, K: int, dtype: torch.dtype, device: str): + """Test accuracy between vLLM and FlashInfer FP4 quantization.""" + # Create input tensor + a = torch.randn((M, K), device=device, dtype=dtype) + + # Compute global scale + a_global_scale = compute_global_scale(a) + + # vLLM quantization + vllm_fp4, vllm_scale = ops.scaled_fp4_quant(a, a_global_scale) + + # FlashInfer quantization (with swizzled layout to match vLLM's output) + flashinfer_fp4, flashinfer_scale = flashinfer_fp4_quantize( + a, a_global_scale, is_sf_swizzled_layout=True + ) + flashinfer_scale = flashinfer_scale.view(torch.float8_e4m3fn) + + # Compare outputs + torch.testing.assert_close( + vllm_fp4, + flashinfer_fp4, + ) + print(f"M={M}, K={K}, dtype={dtype}: PASSED") + + +def test_accuracy(): + """Run accuracy tests across various shapes.""" + print("\n" + "=" * 60) + print("Running accuracy tests: vLLM vs FlashInfer") + print("=" * 60) + + device = "cuda" + dtype = torch.bfloat16 + + # Test various batch sizes and hidden dimensions + Ms = [1, 1024] + Ks = [4096] + + for M in Ms: + for K in Ks: + _test_accuracy_once(M, K, dtype, device) + + print("\nAll accuracy tests passed!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Benchmark NVFP4 quantization: vLLM vs FlashInfer" + ) + parser.add_argument( + "--models", + nargs="+", + type=str, + default=["meta-llama/Llama-3.1-8B-Instruct"], + choices=list(WEIGHT_SHAPES.keys()), + ) + parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1]) + parser.add_argument( + "--save-path", + type=str, + default=None, + help="Path to save benchmark results", + ) + parser.add_argument( + "--accuracy", + action="store_true", + help="Run accuracy tests", + ) + args = parser.parse_args() + + if args.accuracy: + test_accuracy() + + for K, N, model in prepare_shapes(args): + print(f"\n{model}, N={N} K={K}") + benchmark.run( + print_data=True, + save_path=args.save_path, + N=N, + K=K, + ) + + print("\nBenchmark finished!") diff --git a/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu b/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu index 7539f836ecf37..e0438556dfe5c 100644 --- a/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu +++ b/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu @@ -74,6 +74,9 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024)) static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, "Vec size is not matched."); + // Precompute SF layout parameter (constant for entire kernel). + int32_t const numKTiles = (numCols + 63) / 64; + // Get the global scaling factor, which will be applied to the SF. // Note SFScale is the same as next GEMM's alpha, which is // (448.f / (Alpha_A / 6.f)). @@ -101,7 +104,7 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024)) auto sf_out = cvt_quant_to_fp4_get_sf_out_offset( - rowIdx, colIdx, numCols, SFout); + rowIdx, colIdx, numKTiles, SFout); out_pos = cvt_warp_fp16_to_fp4(out_silu_mul, SFScaleVal, sf_out); diff --git a/csrc/quantization/fp4/nvfp4_experts_quant.cu b/csrc/quantization/fp4/nvfp4_experts_quant.cu index 82c53c2375a31..20191a9bc6160 100644 --- a/csrc/quantization/fp4/nvfp4_experts_quant.cu +++ b/csrc/quantization/fp4/nvfp4_experts_quant.cu @@ -25,6 +25,7 @@ #include #include "dispatch_utils.h" +#include "cuda_utils.h" #include "nvfp4_utils.cuh" #include "launch_bounds_utils.h" @@ -44,6 +45,9 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, "Vec size is not matched."); + // Precompute SF layout parameter (constant for entire kernel). + int32_t const numKTiles = (numCols + 63) / 64; + int tid = blockIdx.x * blockDim.x + threadIdx.x; int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD; @@ -112,17 +116,13 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) // (448.f / (Alpha_A / 6.f)). float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx]; - int factor = CVT_FP4_SF_VEC_SIZE * 4; - // The actual output_scales dim is computed from the padded numCols. - int32_t numCols_padded = (numCols + factor - 1) / factor * factor; - int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4; uint32_t* SFout_in_expert = - SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout; + SFout + output_scale_offset_by_experts[expert_idx] * numKTiles; auto sf_out = cvt_quant_to_fp4_get_sf_out_offset( - rowIdx_in_expert, colIdx, numCols, SFout_in_expert); + rowIdx_in_expert, colIdx, numKTiles, SFout_in_expert); out_pos = cvt_warp_fp16_to_fp4(in_vec, SFScaleVal, sf_out); } @@ -140,6 +140,10 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024)) (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, "Vec size is not matched."); + + // Precompute SF layout parameter (constant for entire kernel). + int32_t const numKTiles = (numCols + 63) / 64; + extern __shared__ uint32_t shared_input_offsets[]; // Load input offsets into shared memory. @@ -202,16 +206,13 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024)) float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx]; - int factor = CVT_FP4_SF_VEC_SIZE * 4; - int32_t numCols_padded = (numCols + factor - 1) / factor * factor; - int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4; uint32_t* SFout_in_expert = - SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout; + SFout + output_scale_offset_by_experts[expert_idx] * numKTiles; auto sf_out = cvt_quant_to_fp4_get_sf_out_offset( - rowIdx_in_expert, colIdx, numCols, SFout_in_expert); + rowIdx_in_expert, colIdx, numKTiles, SFout_in_expert); out_pos = cvt_warp_fp16_to_fp4(in_vec, SFScaleVal, sf_out); } @@ -222,12 +223,8 @@ void quant_impl(void* output, void* output_scale, void* input, void* input_global_scale, void* input_offset_by_experts, void* output_scale_offset_by_experts, int m_topk, int k, int n_experts, cudaStream_t stream) { - // TODO: this multiProcessorCount should be cached. - int device; - cudaGetDevice(&device); - int multiProcessorCount; - cudaDeviceGetAttribute(&multiProcessorCount, cudaDevAttrMultiProcessorCount, - device); + int multiProcessorCount = + get_device_attribute(cudaDevAttrMultiProcessorCount, -1); // Grid, Block size. // Each thread converts 8 values. diff --git a/csrc/quantization/fp4/nvfp4_quant_kernels.cu b/csrc/quantization/fp4/nvfp4_quant_kernels.cu index 6d69852bb4e4f..6acadb4cefd2c 100644 --- a/csrc/quantization/fp4/nvfp4_quant_kernels.cu +++ b/csrc/quantization/fp4/nvfp4_quant_kernels.cu @@ -38,6 +38,12 @@ __host__ __device__ inline Int round_up(Int x, Int y) { return (x + y - 1) / y * y; } +// Compute effective rows for grid configuration with swizzled SF layouts. +inline int computeEffectiveRows(int m) { + constexpr int ROW_TILE = 128; + return round_up(m, ROW_TILE); +} + // Use UE4M3 by default. template __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) @@ -49,6 +55,9 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, "Vec size is not matched."); + // Precompute SF layout parameter (constant for entire kernel). + int32_t const numKTiles = (numCols + 63) / 64; + int sf_m = round_up(numRows, 128); int sf_n_unpadded = numCols / CVT_FP4_SF_VEC_SIZE; int sf_n_int = round_up(sf_n_unpadded, 4) / 4; @@ -79,7 +88,7 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) auto sf_out = cvt_quant_to_fp4_get_sf_out_offset( - rowIdx, colIdx, numCols, SFout); + rowIdx, colIdx, numKTiles, SFout); out_pos = cvt_warp_fp16_to_fp4(in_vec, global_scale, sf_out); @@ -87,43 +96,6 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) } } -template -void invokeFP4Quantization(int m, int n, T const* input, float const* SFScale, - int64_t* output, int32_t* SFOuput, bool useUE8M0, - int multiProcessorCount, cudaStream_t stream) { - // Grid, Block size. - // Each thread converts 8 values. - dim3 block(std::min(int(n / ELTS_PER_THREAD), 512)); - // Get number of blocks per SM - int const numBlocksPerSM = - vllm_runtime_blocks_per_sm(static_cast(block.x)); - dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM)); - - // Launch the cvt kernel. - if (useUE8M0) { - cvt_fp16_to_fp4<<>>( - m, n, input, SFScale, reinterpret_cast(output), - reinterpret_cast(SFOuput)); - } else { - cvt_fp16_to_fp4<<>>( - m, n, input, SFScale, reinterpret_cast(output), - reinterpret_cast(SFOuput)); - } -} - -// Instantiate the function. -template void invokeFP4Quantization(int m, int n, half const* input, - float const* SFScale, int64_t* output, - int32_t* SFOuput, bool useUE8M0, - int multiProcessorCount, - cudaStream_t stream); - -template void invokeFP4Quantization(int m, int n, __nv_bfloat16 const* input, - float const* SFScale, int64_t* output, - int32_t* SFOuput, bool useUE8M0, - int multiProcessorCount, - cudaStream_t stream); - } // namespace vllm void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, @@ -147,13 +119,19 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); - // We don't support e8m0 scales at this moment. - bool useUE8M0 = false; + // Grid, Block size. Each thread converts 8 values. + dim3 block(std::min(int(n / ELTS_PER_THREAD), 512)); + int const numBlocksPerSM = + vllm_runtime_blocks_per_sm(static_cast(block.x)); + int effectiveRows = vllm::computeEffectiveRows(m); + dim3 grid(std::min(effectiveRows, multiProcessorCount * numBlocksPerSM)); VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] { using cuda_type = vllm::CUDATypeConverter::Type; auto input_ptr = static_cast(input.data_ptr()); - vllm::invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr, - sf_out, useUE8M0, multiProcessorCount, stream); + // NOTE: We don't support e8m0 scales at this moment. + vllm::cvt_fp16_to_fp4<<>>( + m, n, input_ptr, input_sf_ptr, reinterpret_cast(output_ptr), + reinterpret_cast(sf_out)); }); } diff --git a/csrc/quantization/fp4/nvfp4_utils.cuh b/csrc/quantization/fp4/nvfp4_utils.cuh index 48e4959de9793..4c91af85e1514 100644 --- a/csrc/quantization/fp4/nvfp4_utils.cuh +++ b/csrc/quantization/fp4/nvfp4_utils.cuh @@ -128,51 +128,42 @@ inline __device__ float reciprocal_approximate_ftz(float a) { return b; } +// Compute SF output offset for swizzled tensor core layout. +// SF layout: [numMTiles, numKTiles, 32, 4, 4] +// Caller must precompute: numKTiles = (numCols + 63) / 64 template -__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx, - int numCols, - SFType* SFout) { +__device__ __forceinline__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset( + int rowIdx, int colIdx, int32_t numKTiles, SFType* SFout) { static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 || CVT_FP4_NUM_THREADS_PER_SF == 2); // One pair of threads write one SF to global memory. // TODO: stage through smem for packed STG.32 // is it better than STG.8 from 4 threads ? - if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) { - // SF vector index (16 elements share one SF in the K dimension). - int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF; - int32_t mIdx = rowIdx; - - // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] - // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx] - - int32_t mTileIdx = mIdx / (32 * 4); - // SF vector size 16. - int factor = CVT_FP4_SF_VEC_SIZE * 4; - int32_t numKTiles = (numCols + factor - 1) / factor; - int64_t mTileStride = numKTiles * 32 * 4 * 4; - - int32_t kTileIdx = (kIdx / 4); - int64_t kTileStride = 32 * 4 * 4; - - // M tile layout [32, 4] is column-major. - int32_t outerMIdx = (mIdx % 32); - int64_t outerMStride = 4 * 4; - - int32_t innerMIdx = (mIdx % (32 * 4)) / 32; - int64_t innerMStride = 4; - - int32_t innerKIdx = (kIdx % 4); - int64_t innerKStride = 1; - - // Compute the global offset. - int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride + - outerMIdx * outerMStride + innerMIdx * innerMStride + - innerKIdx * innerKStride; - - return reinterpret_cast(SFout) + SFOffset; + if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF != 0) { + return nullptr; } - return nullptr; + + // SF vector index (16 elements share one SF in the K dimension). + int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF; + int32_t mIdx = rowIdx; + + // Decompose indices using bitwise ops (all divisors are powers of 2). + // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] + int32_t mTileIdx = mIdx >> 7; // mIdx / 128 + int32_t outerMIdx = mIdx & 31; // mIdx % 32 + int32_t innerMIdx = (mIdx >> 5) & 3; // (mIdx / 32) % 4 + int32_t kTileIdx = kIdx >> 2; // kIdx / 4 + int32_t innerKIdx = kIdx & 3; // kIdx % 4 + + // Compute global SF offset: mTileIdx * (numKTiles * 512) + kTileIdx * 512 + + // outerMIdx * 16 + innerMIdx * 4 + innerKIdx + // Use bitwise OR for non-overlapping lower bits. + int64_t SFOffset = (static_cast(mTileIdx) * numKTiles + kTileIdx) + << 9 | + (outerMIdx << 4) | (innerMIdx << 2) | innerKIdx; + + return reinterpret_cast(SFout) + SFOffset; } // Quantizes the provided PackedVec into the uint32_t output