diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 24cc57e9dfb9..454aaca0a112 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -668,6 +668,7 @@ steps: # Quantization - pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8' - pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py + - pytest -v -s tests/kernels/quantization/test_silu_nvfp4_quant_fusion.py - pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py - pytest -v -s tests/kernels/quantization/test_flashinfer_scaled_mm.py - pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py @@ -677,6 +678,7 @@ steps: - pytest -v -s tests/compile/test_fusion_all_reduce.py - pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern - pytest -v -s tests/kernels/moe/test_flashinfer.py + - pytest -v -s tests/compile/test_silu_mul_quant_fusion.py ##### 1 GPU test ##### ##### multi gpus test ##### diff --git a/CMakeLists.txt b/CMakeLists.txt index b0eb0f32e03a..e92e08f0d0ec 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -541,6 +541,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS) set(SRCS "csrc/quantization/fp4/nvfp4_quant_kernels.cu" + "csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu" "csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" @@ -559,6 +560,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS) set(SRCS "csrc/quantization/fp4/nvfp4_quant_kernels.cu" + "csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu" "csrc/quantization/fp4/nvfp4_experts_quant.cu" "csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu" "csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu") diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index f7b75c48373f..2728aa81f0c9 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -19,6 +19,13 @@ #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) +#define VLLM_DISPATCH_CASE_HALF_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + +#define VLLM_DISPATCH_HALF_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_HALF_TYPES(__VA_ARGS__)) + // ROCm devices might use either fn or fnuz, so set up dispatch table for both. // A host-based check at runtime will create a preferred FP8 type for ROCm // such that the correct kernel is dispatched. @@ -45,6 +52,15 @@ #define VLLM_DISPATCH_FP8_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__)) +#define AT_DISPATCH_BYTE_CASE(enum_type, ...) \ + AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, byte_t, __VA_ARGS__) + +#define VLLM_DISPATCH_CASE_BYTE_TYPES(...) \ + AT_DISPATCH_BYTE_CASE(at::ScalarType::Byte, __VA_ARGS__) + +#define VLLM_DISPATCH_BYTE_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_BYTE_TYPES(__VA_ARGS__)) + #define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__)) diff --git a/csrc/ops.h b/csrc/ops.h index 86fe848e2fd5..78a487201bdd 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -130,6 +130,14 @@ void silu_and_mul(torch::Tensor& out, torch::Tensor& input); void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& scale); +#ifndef USE_ROCM + +void silu_and_mul_nvfp4_quant(torch::Tensor& out, + torch::Tensor& output_block_scale, + torch::Tensor& input, + torch::Tensor& input_global_scale); +#endif + void mul_and_silu(torch::Tensor& out, torch::Tensor& input); void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); diff --git a/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu b/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu new file mode 100644 index 000000000000..9bbeb0334fb9 --- /dev/null +++ b/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu @@ -0,0 +1,368 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include + +#include +#include + +#include +#include "dispatch_utils.h" + +#include "cuda_utils.h" + +namespace vllm { + +// Get type2 from type or vice versa (applied to half and bfloat16) +template +struct TypeConverter { + using Type = half2; +}; // keep for generality + +template <> +struct TypeConverter { + using Type = c10::Half; +}; + +template <> +struct TypeConverter { + using Type = half2; +}; + +template <> +struct TypeConverter<__nv_bfloat162> { + using Type = c10::BFloat16; +}; + +template <> +struct TypeConverter { + using Type = __nv_bfloat162; +}; + +#define ELTS_PER_THREAD 8 + +constexpr int CVT_FP4_ELTS_PER_THREAD = 8; +constexpr int CVT_FP4_SF_VEC_SIZE = 16; + +// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). +inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]), + "f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7])); + return val; +#else + return 0; +#endif +} + +// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). +inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), + "f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y)); + return val; +#else + return 0; +#endif +} + +// Fast reciprocal. +inline __device__ float reciprocal_approximate_ftz(float a) { + float b; + asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); + return b; +} + +template +__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx, + int numCols, + SFType* SFout) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 || + CVT_FP4_NUM_THREADS_PER_SF == 2); + + // One pair of threads write one SF to global memory. + // TODO: stage through smem for packed STG.32 + // is it better than STG.8 from 4 threads ? + if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) { + // SF vector index (16 elements share one SF in the K dimension). + int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF; + int32_t mIdx = rowIdx; + + // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] + // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx] + + int32_t mTileIdx = mIdx / (32 * 4); + // SF vector size 16. + int factor = CVT_FP4_SF_VEC_SIZE * 4; + int32_t numKTiles = (numCols + factor - 1) / factor; + int64_t mTileStride = numKTiles * 32 * 4 * 4; + + int32_t kTileIdx = (kIdx / 4); + int64_t kTileStride = 32 * 4 * 4; + + // M tile layout [32, 4] is column-major. + int32_t outerMIdx = (mIdx % 32); + int64_t outerMStride = 4 * 4; + + int32_t innerMIdx = (mIdx % (32 * 4)) / 32; + int64_t innerMStride = 4; + + int32_t innerKIdx = (kIdx % 4); + int64_t innerKStride = 1; + + // Compute the global offset. + int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride + + outerMIdx * outerMStride + innerMIdx * innerMStride + + innerKIdx * innerKStride; + + return reinterpret_cast(SFout) + SFOffset; + } +#endif + return nullptr; +} + +// Define a 16 bytes packed data type. +template +struct PackedVec { + typename TypeConverter::Type elts[4]; +}; + +template <> +struct PackedVec<__nv_fp8_e4m3> { + __nv_fp8x2_e4m3 elts[8]; +}; + +template +__inline__ __device__ PackedVec compute_silu(PackedVec& vec, + PackedVec& vec2) { + PackedVec result; +#pragma unroll + for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) { + if constexpr (std::is_same_v) { + half2 val(0.5f, 0.5f); + half2 t0 = __hmul2(vec.elts[i], val); + half2 t1 = __hfma2(h2tanh(t0), val, val); + half2 t2 = __hmul2(vec.elts[i], t1); + result.elts[i] = __hmul2(t2, vec2.elts[i]); + } else { + __nv_bfloat162 val(0.5f, 0.5f); + __nv_bfloat162 t0 = __hmul2(vec.elts[i], val); + __nv_bfloat162 t1 = __hfma2(h2tanh(t0), val, val); + __nv_bfloat162 t2 = __hmul2(vec.elts[i], t1); + result.elts[i] = __hmul2(t2, vec2.elts[i]); + } + } + return result; +} + +// Quantizes the provided PackedVec into the uint32_t output +template +__device__ uint32_t silu_and_cvt_warp_fp16_to_fp4(PackedVec& vec, + PackedVec& vec2, + float SFScaleVal, + uint8_t* SFout) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + PackedVec out_silu = compute_silu(vec, vec2); + // Get absolute maximum values among the local 8 values. + auto localMax = __habs2(out_silu.elts[0]); + + // Local maximum value. + #pragma unroll + for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + localMax = __hmax2(localMax, __habs2(out_silu.elts[i])); + } + + // Get the absolute maximum among all 16 values (two threads). + localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); + // Get the final absolute maximum values. + float vecMax = float(__hmax(localMax.x, localMax.y)); + + // Get the SF (max value of the vector / max value of e2m1). + // maximum value of e2m1 = 6.0. + // TODO: use half as compute data type. + float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); + // 8 bits representation of the SF. + uint8_t fp8SFVal; + // Write the SF to global memory (STG.8). + if constexpr (UE8M0_SF) { + // Extract the 8 exponent bits from float32. + // float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits. + uint32_t tmp = reinterpret_cast(SFValue) >> 23; + fp8SFVal = tmp & 0xff; + // Convert back to fp32. + reinterpret_cast(SFValue) = tmp << 23; + } else { + // Here SFValue is always positive, so E4M3 is the same as UE4M3. + __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); + reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp; + // Convert back to fp32. + SFValue = float(tmp); + } + // Get the output scale. + // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * + // reciprocal(SFScaleVal)) + float outputScale = + SFValue != 0 ? reciprocal_approximate_ftz( + SFValue * reciprocal_approximate_ftz(SFScaleVal)) + : 0.0f; + + if (SFout) { + // Write the SF to global memory (STG.8). + *SFout = fp8SFVal; + } + + // Convert the input to float. + float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; + + #pragma unroll + for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + if constexpr (std::is_same_v) { + fp2Vals[i] = __half22float2(out_silu.elts[i]); + } else { + fp2Vals[i] = __bfloat1622float2(out_silu.elts[i]); + } + fp2Vals[i].x *= outputScale; + fp2Vals[i].y *= outputScale; + } + + // Convert to e2m1 values. + uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals); + + // Write the e2m1 values to global memory. + return e2m1Vec; +#else + return 0; +#endif +} + +// Use UE4M3 by default. +template +__global__ void +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__launch_bounds__(1024, 4) silu_and_cvt_fp16_to_fp4( +#else +silu_and_cvt_fp16_to_fp4( +#endif + int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, + uint32_t* out, uint32_t* SFout) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + using PackedVec = PackedVec; + static constexpr int CVT_FP4_NUM_THREADS_PER_SF = + (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); + static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, + "Vec size is not matched."); + + // Get the global scaling factor, which will be applied to the SF. + // Note SFScale is the same as next GEMM's alpha, which is + // (448.f / (Alpha_A / 6.f)). + float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0]; + + // Input tensor row/col loops. + for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { + for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD; + colIdx += blockDim.x) { + int64_t inOffset = + rowIdx * (numCols * 2 / CVT_FP4_ELTS_PER_THREAD) + colIdx; + int64_t inOffset2 = rowIdx * (numCols * 2 / CVT_FP4_ELTS_PER_THREAD) + + numCols / CVT_FP4_ELTS_PER_THREAD + colIdx; + PackedVec in_vec = reinterpret_cast(in)[inOffset]; + PackedVec in_vec2 = reinterpret_cast(in)[inOffset2]; + + // Get the output tensor offset. + // Same as inOffset because 8 elements are packed into one uint32_t. + int64_t outOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; + ; + auto& out_pos = out[outOffset]; + + auto sf_out = + cvt_quant_to_fp4_get_sf_out_offset( + rowIdx, colIdx, numCols, SFout); + + out_pos = silu_and_cvt_warp_fp16_to_fp4( + in_vec, in_vec2, SFScaleVal, sf_out); + } + } +#endif +} + +} // namespace vllm + +void silu_and_mul_nvfp4_quant(torch::Tensor& output, // [..., d] + torch::Tensor& output_sf, + torch::Tensor& input, // [..., 2 * d] + torch::Tensor& input_sf) { + TORCH_CHECK(input.dtype() == torch::kFloat16 || + input.dtype() == torch::kBFloat16); + int32_t m = input.size(0); + int32_t n = input.size(1) / 2; + TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16."); + int multiProcessorCount = + get_device_attribute(cudaDevAttrMultiProcessorCount, -1); + auto input_sf_ptr = static_cast(input_sf.data_ptr()); + auto sf_out = static_cast(output_sf.data_ptr()); + auto output_ptr = static_cast(output.data_ptr()); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); + dim3 block(std::min(int(n / ELTS_PER_THREAD), 1024)); + int const numBlocksPerSM = 2048 / block.x; + dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM)); + VLLM_DISPATCH_HALF_TYPES( + input.scalar_type(), "act_and_mul_quant_kernel", [&] { + auto input_ptr = reinterpret_cast(input.data_ptr()); + VLLM_DISPATCH_BYTE_TYPES( + output.scalar_type(), "fused_act_and_mul_quant_kernel_nvfp4_type", + [&] { + vllm::silu_and_cvt_fp16_to_fp4 + <<>>( + m, n, input_ptr, input_sf_ptr, + reinterpret_cast(output_ptr), + reinterpret_cast(sf_out)); + }); + }); +} diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 608b72440307..b769c09adc0f 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -115,6 +115,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()"); ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant); +#ifndef USE_ROCM + ops.def( + "silu_and_mul_nvfp4_quant(Tensor! result, Tensor! result_block_scale, " + "Tensor input, Tensor input_global_scale) -> ()"); + ops.impl("silu_and_mul_nvfp4_quant", torch::kCUDA, &silu_and_mul_nvfp4_quant); +#endif + ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()"); ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu); diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index 0e1059e65447..fcc2589e4211 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -4,32 +4,41 @@ import pytest import torch import vllm.envs as envs -from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass -from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe +from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant +# yapf conflicts with isort for this block +# yapf: disable +from vllm.compilation.activation_quant_fusion import ( + FUSED_OPS, SILU_MUL_OP, ActivationQuantFusionPass) +# yapf: enable +from vllm.compilation.fusion import QUANT_OPS from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.config import CompilationConfig, PassConfig, VllmConfig from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape) + GroupShape, kFp8StaticTensorSym, kNvfp4Quant) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp) from vllm.platforms import current_platform from .backend import TestBackend +FP8_DTYPE = current_platform.fp8_dtype() +FP4_DTYPE = torch.uint8 -class TestModel(torch.nn.Module): - def __init__(self, hidden_size: int, force_fp8_e4m3fnuz: bool, *args, - **kwargs): - super().__init__(*args, **kwargs) +def is_nvfp4_supported(): + return current_platform.has_device_capability(100) + + +class TestSiluMulFp8QuantModel(torch.nn.Module): + + def __init__(self, hidden_size: int, force_fp8_e4m3fnuz: bool, **kwargs): + super().__init__() self.silu_and_mul = SiluAndMul() self.wscale = torch.rand(1, dtype=torch.float32) self.scale = torch.rand(1, dtype=torch.float32) - self.w = (torch.rand( - hidden_size, - hidden_size).to(dtype=current_platform.fp8_dtype()).t()) + self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() self.fp8_linear = Fp8LinearOp( force_fp8_e4m3fnuz=force_fp8_e4m3fnuz, @@ -45,14 +54,56 @@ class TestModel(torch.nn.Module): input_scale=self.wscale) return x2 + def ops_in_model_before(self): + return [SILU_MUL_OP, QUANT_OPS[kFp8StaticTensorSym]] -@pytest.mark.parametrize("num_tokens", [256]) -@pytest.mark.parametrize("hidden_size", [64]) + def ops_in_model_after(self): + return [FUSED_OPS[kFp8StaticTensorSym]] + + +class TestSiluMulNvfp4QuantModel(torch.nn.Module): + + def __init__(self, hidden_size: int, **kwargs): + super().__init__() + self.silu_and_mul = SiluAndMul() + self.w = torch.randint(256, (hidden_size, hidden_size // 2), + dtype=FP4_DTYPE) + self.wscale = torch.randn(hidden_size, + hidden_size // 16).to(dtype=FP8_DTYPE) + self.wscale2 = torch.rand(1, dtype=torch.float32) + self.scale = torch.rand(1, dtype=torch.float32) + + def forward(self, x): + y = self.silu_and_mul(x) + y_quant, y_block_scale = scaled_fp4_quant(y, 1 / self.scale) + out = cutlass_scaled_fp4_mm(a=y_quant, + b=self.w, + block_scale_a=y_block_scale, + block_scale_b=self.wscale, + alpha=self.scale * self.wscale2, + out_dtype=y.dtype) + return out + + def ops_in_model_before(self): + return [SILU_MUL_OP, QUANT_OPS[kNvfp4Quant]] + + def ops_in_model_after(self): + return [FUSED_OPS[kNvfp4Quant]] + + +@pytest.mark.parametrize("num_tokens", [64]) +@pytest.mark.parametrize("hidden_size", [128]) +@pytest.mark.parametrize( + "model_class", [TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel] + if is_nvfp4_supported() else [TestSiluMulFp8QuantModel]) @pytest.mark.parametrize("force_fp8_e4m3fnuz", [True, False]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], reason="Only test on CUDA and ROCm") -def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, +def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, model_class, force_fp8_e4m3fnuz): + if model_class == TestSiluMulNvfp4QuantModel and force_fp8_e4m3fnuz: + pytest.skip("Duplicate tests for NVFP4") + torch.set_default_device("cuda") torch.set_default_dtype(torch.float16) @@ -63,7 +114,8 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, fusion_pass = ActivationQuantFusionPass(config) backend = TestBackend(NoOpEliminationPass(config), fusion_pass) - model = TestModel(hidden_size, force_fp8_e4m3fnuz) + model = model_class(hidden_size=hidden_size, + force_fp8_e4m3fnuz=force_fp8_e4m3fnuz) # First dimension dynamic x = torch.rand(num_tokens, hidden_size * 2) @@ -80,17 +132,8 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, atol=1e-3, rtol=1e-3) - # Check substitution worked - pre_nodes = backend.graph_pre_pass.nodes - post_nodes = backend.graph_post_pass.nodes + # In pre-nodes, quant op should be present and fused kernels should not + backend.check_before_ops(model.ops_in_model_before()) - silu_and_mul_quant = torch.ops._C.silu_and_mul_quant.default - fp8_quant = torch.ops._C.static_scaled_fp8_quant.default - - # In pre-nodes, fp8 quant should be present and fused kernels should not - assert find_auto_fn_maybe(pre_nodes, silu_and_mul_quant) is None - find_auto_fn(pre_nodes, fp8_quant) - - # In post-nodes, fused kernels should be present and fp8 quant should not - find_auto_fn(post_nodes, silu_and_mul_quant) - assert find_auto_fn_maybe(post_nodes, fp8_quant) is None + # In post-nodes, fused kernels should be present and quant op should not + backend.check_after_ops(model.ops_in_model_after()) diff --git a/tests/kernels/quantization/test_silu_nvfp4_quant_fusion.py b/tests/kernels/quantization/test_silu_nvfp4_quant_fusion.py new file mode 100644 index 000000000000..969f14cc3fe6 --- /dev/null +++ b/tests/kernels/quantization/test_silu_nvfp4_quant_fusion.py @@ -0,0 +1,126 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch + +from tests.kernels.utils import opcheck +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types + +if not current_platform.has_device_capability(100): + pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.", + allow_module_level=True) + +DTYPES = [torch.float16, torch.bfloat16] +SHAPES = [(128, 64), (128, 128), (256, 64), (256, 128)] +SEEDS = [42] +CUDA_DEVICES = ['cuda:0'] + +FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max + +BLOCK_SIZE = 16 + + +def ref_impl(silu_and_mul: SiluAndMul, x: torch.Tensor, + global_scale: torch.Tensor, + ref_output_scale: torch.Tensor) -> torch.Tensor: + silu_and_mul_out = silu_and_mul.forward_native(x) + assert not current_platform.is_rocm() + assert silu_and_mul_out.ndim >= 1, ( + f'input.ndim needs to be >= 1, but got {silu_and_mul_out.ndim}.') + other_dims = 1 if silu_and_mul_out.ndim == 1 else -1 + silu_and_mul_out = silu_and_mul_out.reshape(other_dims, + silu_and_mul_out.shape[-1]) + m, n = silu_and_mul_out.shape + device = silu_and_mul_out.device + + # Two fp4 values will be packed into an uint8. + out = torch.empty((m, n // 2), device=device, dtype=torch.uint8) + + output_scale = ref_output_scale + + torch.ops._C.scaled_fp4_quant(out, silu_and_mul_out, output_scale, + global_scale) + + return out, output_scale + + +def ops_impl(x: torch.Tensor, global_scale: torch.Tensor, + ref_output_scale: torch.Tensor) -> torch.Tensor: + out_shape = (x.shape[0], x.shape[1] // 4) + output_scale = ref_output_scale + out = torch.empty(out_shape, dtype=torch.uint8, device=x.device) + torch.ops._C.silu_and_mul_nvfp4_quant(out, output_scale, x, global_scale) + return out, output_scale + + +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("shape", SHAPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_quantize_to_fp4( + dtype: torch.dtype, + shape: tuple[int, int], + seed: int, + device: str, +) -> None: + current_platform.seed_everything(seed) + torch.set_default_device(device) + + m, n = shape + + x = torch.randn((m, n), dtype=dtype) + tensor_amax = torch.abs(x).max().to(torch.float32) + global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax + + block_size = 16 + + assert n % block_size == 0, ( + f'last dim has to be multiple of 16, but got {n}.') + assert x.dtype in (torch.float16, torch.bfloat16), ( + f'input.dtype needs to be fp16 or bf16 but got {x.dtype}.') + + round_up = lambda x, y: (x + y - 1) // y * y + rounded_m = round_up(x.shape[0], 128) + scale_n = x.shape[1] // (2 * block_size) + rounded_n = round_up(scale_n, 4) + output_scale = torch.empty((rounded_m, rounded_n // 4), + device=x.device, + dtype=torch.int32) + + layer = SiluAndMul() + + ref_out, ref_out_scale = ref_impl(layer, x, global_scale, output_scale) + + fusion_out, fusion_out_scale = ops_impl(x, global_scale, output_scale) + + assert ref_out.dtype == torch.uint8 + assert fusion_out.dtype == torch.uint8 + assert ref_out.shape == fusion_out.shape + + assert ref_out_scale.dtype == torch.int32 + assert fusion_out_scale.dtype == torch.int32 + assert ref_out_scale.shape == fusion_out_scale.shape + + # Allow up to 2% of mismatched values since BF16 has accuracy issues. + mis_threshold = 0.02 + atol = 0.4 + rtol = 0.4 + ref_logits = ref_out[-1] + fusion_logits = fusion_out[-1] + + mis_count = torch.sum( + torch.abs(fusion_logits - ref_logits) > (atol + + rtol * torch.abs(ref_logits))) + mis_ratio = mis_count / fusion_logits.numel() + + assert mis_ratio < mis_threshold, \ + f"Mismatch ratio {mis_ratio} exceeds threshold {mis_threshold}" + + torch.testing.assert_close(ref_out_scale, fusion_out_scale) + + opcheck(torch.ops._C.silu_and_mul_nvfp4_quant, + (fusion_out, fusion_out_scale, x, global_scale)) diff --git a/vllm/compilation/activation_quant_fusion.py b/vllm/compilation/activation_quant_fusion.py index 826014f770df..40e124a03eb0 100644 --- a/vllm/compilation/activation_quant_fusion.py +++ b/vllm/compilation/activation_quant_fusion.py @@ -1,55 +1,154 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod + import torch from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._inductor.pattern_matcher import (PatternMatcherPass, fwd_only, register_replacement) +from torch._ops import OpOverload from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale) from vllm.platforms import current_platform +from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32 from .inductor_pass import enable_fake_mode from .vllm_inductor_pass import VllmInductorPass logger = init_logger(__name__) +FP8_DTYPE = current_platform.fp8_dtype() +FP4_DTYPE = torch.uint8 -def silu_mul_pattern_static(result: torch.Tensor, - result_silu_mul: torch.Tensor, input: torch.Tensor, - scale: torch.Tensor): - at1 = auto_functionalized(torch.ops._C.silu_and_mul.default, - result=result_silu_mul, - input=input) - at2 = auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default, - result=result, - input=at1[1], - scale=scale) - return at2[1] +SILU_MUL_OP = torch.ops._C.silu_and_mul.default + +FUSED_OPS: dict[QuantKey, OpOverload] = { + kFp8StaticTensorSym: torch.ops._C.silu_and_mul_quant.default, # noqa: E501 +} +if current_platform.is_cuda() and hasattr(torch.ops._C, + "silu_and_mul_nvfp4_quant"): + FUSED_OPS[ + kNvfp4Quant] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501 -def silu_mul_replacement_static(result: torch.Tensor, - result_silu_mul: torch.Tensor, - input: torch.Tensor, scale: torch.Tensor): - at = auto_functionalized(torch.ops._C.silu_and_mul_quant.default, - result=result, - input=input, - scale=scale) - return at[1] +class ActivationQuantPattern(ABC): + """ + The base class for Activation+Quant fusions. + Should not be used directly. + """ + + def __init__( + self, + quant_key: QuantKey, + ): + self.quant_key = quant_key + self.quant_dtype = quant_key.dtype + + assert self.quant_key in QUANT_OPS, \ + f"unsupported quantization scheme {self.quant_key}" + self.QUANT_OP = QUANT_OPS[self.quant_key] + + assert self.quant_key in FUSED_OPS, \ + f"unsupported fusion scheme {self.quant_key}" + self.FUSED_OP = FUSED_OPS[self.quant_key] + + def empty_quant(self, *args, **kwargs): + kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs} + return torch.empty(*args, **kwargs) + + @abstractmethod + def register(self, pm_pass: PatternMatcherPass): + raise NotImplementedError -def empty_bf16(*args, **kwargs): - return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda") +class SiluMulFp8StaticQuantPattern(ActivationQuantPattern): + """ + Fusion for SiluMul+Fp8StaticQuant Pattern + """ + + def __init__(self, symmetric: bool = True): + quant_key = QuantKey(dtype=FP8_DTYPE, + scale=kStaticTensorScale, + symmetric=symmetric) + super().__init__(quant_key) + + def register(self, pm_pass: PatternMatcherPass): + + def pattern(result: torch.Tensor, result_silu_mul: torch.Tensor, + input: torch.Tensor, scale: torch.Tensor): + at1 = auto_functionalized(SILU_MUL_OP, + result=result_silu_mul, + input=input) + at2 = auto_functionalized(self.QUANT_OP, + result=result, + input=at1[1], + scale=scale) + return at2[1] + + def replacement(result: torch.Tensor, result_silu_mul: torch.Tensor, + input: torch.Tensor, scale: torch.Tensor): + at = auto_functionalized(self.FUSED_OP, + result=result, + input=input, + scale=scale) + return at[1] + + inputs = [ + self.empty_quant(5, 4), # result + empty_bf16(5, 4), # result_silu_mul + empty_bf16(5, 4), # input + empty_fp32(1, 1) # scale + ] + + register_replacement(pattern, replacement, inputs, fwd_only, pm_pass) -def empty_fp8(*args, **kwargs): - fp8 = current_platform.fp8_dtype() - return torch.empty(*args, **kwargs, dtype=fp8, device="cuda") +class SiluMulNvfp4QuantPattern(ActivationQuantPattern): + """ + Fusion for SiluMul+Nvfp4Quant Pattern + """ + def __init__(self): + super().__init__(kNvfp4Quant) -def empty_fp32(*args, **kwargs): - return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda") + def register(self, pm_pass: PatternMatcherPass): + + def pattern(result: torch.Tensor, output_scale: torch.Tensor, + result_silu_mul: torch.Tensor, input: torch.Tensor, + scale: torch.Tensor): + at1 = auto_functionalized(SILU_MUL_OP, + result=result_silu_mul, + input=input) + at2 = auto_functionalized(self.QUANT_OP, + output=result, + input=at1[1], + output_scale=output_scale, + input_scale=scale) + return at2[1], at2[2] + + def replacement(result: torch.Tensor, output_scale: torch.Tensor, + result_silu_mul: torch.Tensor, input: torch.Tensor, + scale: torch.Tensor): + at = auto_functionalized(self.FUSED_OP, + result=result, + result_block_scale=output_scale, + input=input, + input_global_scale=scale) + return at[1], at[2] + + inputs = [ + self.empty_quant(5, 32), # result + empty_i32(128, 4), # output_scale + empty_bf16(5, 64), # result_silu_mul + empty_bf16(5, 64), # input + empty_fp32(1, 1) # scale + ] + + register_replacement(pattern, replacement, inputs, fwd_only, pm_pass) class ActivationQuantFusionPass(VllmInductorPass): @@ -69,15 +168,11 @@ class ActivationQuantFusionPass(VllmInductorPass): self.patterns: PatternMatcherPass = PatternMatcherPass( pass_name="activation_quant_fusion_pass") - inputs = [ - empty_fp8(5, 4), # Quant output - empty_bf16(5, 4), # Silu_and_mul output - empty_bf16(5, 4), # Input - empty_fp32(1, 1) # Scale - ] - register_replacement(silu_mul_pattern_static, - silu_mul_replacement_static, inputs, fwd_only, - self.patterns) + pattern_silu_mul_fp8 = SiluMulFp8StaticQuantPattern() + pattern_silu_mul_fp8.register(self.patterns) + + pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern() + pattern_silu_mul_nvfp4.register(self.patterns) def __call__(self, graph: torch.fx.Graph): self.begin() @@ -89,3 +184,8 @@ class ActivationQuantFusionPass(VllmInductorPass): self.dump_graph(graph, "after_act_quant_fusion") self.end_and_log() + + def uuid(self): + return VllmInductorPass.hash_source(self, ActivationQuantPattern, + SiluMulFp8StaticQuantPattern, + SiluMulNvfp4QuantPattern) diff --git a/vllm/compilation/fix_functionalization.py b/vllm/compilation/fix_functionalization.py index 60ae14331879..a36dd8b845f1 100644 --- a/vllm/compilation/fix_functionalization.py +++ b/vllm/compilation/fix_functionalization.py @@ -97,6 +97,13 @@ class FixFunctionalizationPass(VllmInductorPass): node, mutated_args, args=('result', 'input', 'scale')) + elif at_target == torch.ops._C.silu_and_mul_nvfp4_quant.default: + mutated_args = {1: 'result', 2: 'result_block_scale'} + self.defunctionalize(graph, + node, + mutated_args, + args=('result', 'result_block_scale', + 'input', 'input_global_scale')) else: continue # skip the count diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 9d4e453ffc54..1fbb2e3bb6f2 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -885,6 +885,10 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): layer.alpha = Parameter(layer.input_scale * layer.weight_scale_2, requires_grad=False) + # Calculate `1 / input_scale` so that we don't need to do so at runtime + layer.input_scale_inv = Parameter( + (1 / layer.input_scale).to(torch.float32), requires_grad=False) + # Swizzle the weight blockscale. # contracting dimension is input dimension # block_size = 16; @@ -941,8 +945,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): output_shape = [x.shape[0], layer.weight.shape[0]] # quantize BF16 or FP16 to (FP4 and interleaved block scale) - s_quant = 1 / layer.input_scale - x_fp4, x_blockscale = scaled_fp4_quant(x, s_quant) + x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv) # validate dtypes of quantized input, input block scale, # weight and weight_blockscale