diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index be7044c41a732..55349e0ac9321 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -666,7 +666,7 @@ steps: # Quantization - pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8' - pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py - # - pytest -v -s tests/kernels/quantization/test_silu_nvfp4_quant_fusion.py + - pytest -v -s tests/kernels/quantization/test_silu_nvfp4_quant_fusion.py - pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py - pytest -v -s tests/kernels/quantization/test_flashinfer_scaled_mm.py - pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py @@ -676,7 +676,7 @@ steps: - pytest -v -s tests/compile/test_fusion_all_reduce.py - pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern - pytest -v -s tests/kernels/moe/test_flashinfer.py - # - pytest -v -s tests/compile/test_silu_mul_quant_fusion.py + - pytest -v -s tests/compile/test_silu_mul_quant_fusion.py ##### 1 GPU test ##### ##### multi gpus test ##### diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index 2728aa81f0c9f..995374a50b037 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -52,15 +52,6 @@ #define VLLM_DISPATCH_FP8_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__)) -#define AT_DISPATCH_BYTE_CASE(enum_type, ...) \ - AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, byte_t, __VA_ARGS__) - -#define VLLM_DISPATCH_CASE_BYTE_TYPES(...) \ - AT_DISPATCH_BYTE_CASE(at::ScalarType::Byte, __VA_ARGS__) - -#define VLLM_DISPATCH_BYTE_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_BYTE_TYPES(__VA_ARGS__)) - #define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__)) diff --git a/csrc/ops.h b/csrc/ops.h index 7a176a5c00322..a288112e21000 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -130,8 +130,7 @@ void silu_and_mul(torch::Tensor& out, torch::Tensor& input); void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& scale); -#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ - (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) +#ifndef USE_ROCM void silu_and_mul_nvfp4_quant(torch::Tensor& out, torch::Tensor& output_block_scale, torch::Tensor& input, diff --git a/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu b/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu index 9bbeb0334fb9a..b4eb141cb4883 100644 --- a/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu +++ b/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu @@ -26,164 +26,17 @@ #include "dispatch_utils.h" #include "cuda_utils.h" +#include "nvfp4_utils.cuh" namespace vllm { -// Get type2 from type or vice versa (applied to half and bfloat16) -template -struct TypeConverter { - using Type = half2; -}; // keep for generality - -template <> -struct TypeConverter { - using Type = c10::Half; -}; - -template <> -struct TypeConverter { - using Type = half2; -}; - -template <> -struct TypeConverter<__nv_bfloat162> { - using Type = c10::BFloat16; -}; - -template <> -struct TypeConverter { - using Type = __nv_bfloat162; -}; - -#define ELTS_PER_THREAD 8 - -constexpr int CVT_FP4_ELTS_PER_THREAD = 8; -constexpr int CVT_FP4_SF_VEC_SIZE = 16; - -// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). -inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - uint32_t val; - asm volatile( - "{\n" - ".reg .b8 byte0;\n" - ".reg .b8 byte1;\n" - ".reg .b8 byte2;\n" - ".reg .b8 byte3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" - "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" - "}" - : "=r"(val) - : "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]), - "f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7])); - return val; -#else - return 0; -#endif -} - -// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). -inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - uint32_t val; - asm volatile( - "{\n" - ".reg .b8 byte0;\n" - ".reg .b8 byte1;\n" - ".reg .b8 byte2;\n" - ".reg .b8 byte3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" - "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" - "}" - : "=r"(val) - : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), - "f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y)); - return val; -#else - return 0; -#endif -} - -// Fast reciprocal. -inline __device__ float reciprocal_approximate_ftz(float a) { - float b; - asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); - return b; -} - -template -__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx, - int numCols, - SFType* SFout) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 || - CVT_FP4_NUM_THREADS_PER_SF == 2); - - // One pair of threads write one SF to global memory. - // TODO: stage through smem for packed STG.32 - // is it better than STG.8 from 4 threads ? - if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) { - // SF vector index (16 elements share one SF in the K dimension). - int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF; - int32_t mIdx = rowIdx; - - // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] - // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx] - - int32_t mTileIdx = mIdx / (32 * 4); - // SF vector size 16. - int factor = CVT_FP4_SF_VEC_SIZE * 4; - int32_t numKTiles = (numCols + factor - 1) / factor; - int64_t mTileStride = numKTiles * 32 * 4 * 4; - - int32_t kTileIdx = (kIdx / 4); - int64_t kTileStride = 32 * 4 * 4; - - // M tile layout [32, 4] is column-major. - int32_t outerMIdx = (mIdx % 32); - int64_t outerMStride = 4 * 4; - - int32_t innerMIdx = (mIdx % (32 * 4)) / 32; - int64_t innerMStride = 4; - - int32_t innerKIdx = (kIdx % 4); - int64_t innerKStride = 1; - - // Compute the global offset. - int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride + - outerMIdx * outerMStride + innerMIdx * innerMStride + - innerKIdx * innerKStride; - - return reinterpret_cast(SFout) + SFOffset; - } -#endif - return nullptr; -} - -// Define a 16 bytes packed data type. -template -struct PackedVec { - typename TypeConverter::Type elts[4]; -}; - -template <> -struct PackedVec<__nv_fp8_e4m3> { - __nv_fp8x2_e4m3 elts[8]; -}; - template __inline__ __device__ PackedVec compute_silu(PackedVec& vec, PackedVec& vec2) { PackedVec result; #pragma unroll for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) { - if constexpr (std::is_same_v) { + if constexpr (std::is_same_v) { half2 val(0.5f, 0.5f); half2 t0 = __hmul2(vec.elts[i], val); half2 t1 = __hfma2(h2tanh(t0), val, val); @@ -206,13 +59,12 @@ __device__ uint32_t silu_and_cvt_warp_fp16_to_fp4(PackedVec& vec, PackedVec& vec2, float SFScaleVal, uint8_t* SFout) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) PackedVec out_silu = compute_silu(vec, vec2); // Get absolute maximum values among the local 8 values. auto localMax = __habs2(out_silu.elts[0]); - // Local maximum value. - #pragma unroll +// Local maximum value. +#pragma unroll for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { localMax = __hmax2(localMax, __habs2(out_silu.elts[i])); } @@ -259,9 +111,9 @@ __device__ uint32_t silu_and_cvt_warp_fp16_to_fp4(PackedVec& vec, // Convert the input to float. float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; - #pragma unroll +#pragma unroll for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { - if constexpr (std::is_same_v) { + if constexpr (std::is_same_v) { fp2Vals[i] = __half22float2(out_silu.elts[i]); } else { fp2Vals[i] = __bfloat1622float2(out_silu.elts[i]); @@ -275,22 +127,14 @@ __device__ uint32_t silu_and_cvt_warp_fp16_to_fp4(PackedVec& vec, // Write the e2m1 values to global memory. return e2m1Vec; -#else - return 0; -#endif } // Use UE4M3 by default. template -__global__ void -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -__launch_bounds__(1024, 4) silu_and_cvt_fp16_to_fp4( -#else -silu_and_cvt_fp16_to_fp4( -#endif - int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, - uint32_t* out, uint32_t* SFout) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__global__ void __launch_bounds__(1024, 4) + silu_and_cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in, + float const* SFScale, uint32_t* out, + uint32_t* SFout) { using PackedVec = PackedVec; static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); @@ -328,22 +172,25 @@ silu_and_cvt_fp16_to_fp4( in_vec, in_vec2, SFScaleVal, sf_out); } } -#endif } } // namespace vllm -void silu_and_mul_nvfp4_quant(torch::Tensor& output, // [..., d] - torch::Tensor& output_sf, - torch::Tensor& input, // [..., 2 * d] - torch::Tensor& input_sf) { - TORCH_CHECK(input.dtype() == torch::kFloat16 || - input.dtype() == torch::kBFloat16); +void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d] + torch::Tensor& output_sf, + torch::Tensor& input, // [..., 2 * d] + torch::Tensor& input_sf) { int32_t m = input.size(0); int32_t n = input.size(1) / 2; + TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16."); + TORCH_CHECK(input.scalar_type() == at::ScalarType::Half || + input.scalar_type() == at::ScalarType::BFloat16, + "Unsupported input data type for quantize_to_fp4."); + int multiProcessorCount = get_device_attribute(cudaDevAttrMultiProcessorCount, -1); + auto input_sf_ptr = static_cast(input_sf.data_ptr()); auto sf_out = static_cast(output_sf.data_ptr()); auto output_ptr = static_cast(output.data_ptr()); @@ -352,17 +199,14 @@ void silu_and_mul_nvfp4_quant(torch::Tensor& output, // [..., d] dim3 block(std::min(int(n / ELTS_PER_THREAD), 1024)); int const numBlocksPerSM = 2048 / block.x; dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM)); + VLLM_DISPATCH_HALF_TYPES( - input.scalar_type(), "act_and_mul_quant_kernel", [&] { - auto input_ptr = reinterpret_cast(input.data_ptr()); - VLLM_DISPATCH_BYTE_TYPES( - output.scalar_type(), "fused_act_and_mul_quant_kernel_nvfp4_type", - [&] { - vllm::silu_and_cvt_fp16_to_fp4 - <<>>( - m, n, input_ptr, input_sf_ptr, - reinterpret_cast(output_ptr), - reinterpret_cast(sf_out)); - }); + input.scalar_type(), "silu_and_mul_nvfp4_quant_kernel", [&] { + using cuda_type = vllm::CUDATypeConverter::Type; + auto input_ptr = static_cast(input.data_ptr()); + vllm::silu_and_cvt_fp16_to_fp4<<>>( + m, n, input_ptr, input_sf_ptr, + reinterpret_cast(output_ptr), + reinterpret_cast(sf_out)); }); } diff --git a/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu b/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu index 03db5cc196d59..2c8df6144bf4d 100644 --- a/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu +++ b/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu @@ -1,3 +1,19 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + #include #include diff --git a/csrc/quantization/fp4/nvfp4_experts_quant.cu b/csrc/quantization/fp4/nvfp4_experts_quant.cu index 190d66f318a83..ce3ba2c19b9eb 100644 --- a/csrc/quantization/fp4/nvfp4_experts_quant.cu +++ b/csrc/quantization/fp4/nvfp4_experts_quant.cu @@ -1,247 +1,42 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + #include +#include +#include + #include #include -#include #include +#include "dispatch_utils.h" -template -struct TypeConverter { - using Type = half2; -}; // keep for generality +#include "nvfp4_utils.cuh" -template <> -struct TypeConverter { - using Type = half; -}; - -template <> -struct TypeConverter { - using Type = half2; -}; - -template <> -struct TypeConverter<__nv_bfloat162> { - using Type = __nv_bfloat16; -}; - -template <> -struct TypeConverter<__nv_bfloat16> { - using Type = __nv_bfloat162; -}; - -#define ELTS_PER_THREAD 8 - -constexpr int CVT_FP4_ELTS_PER_THREAD = 8; -constexpr int CVT_FP4_SF_VEC_SIZE = 16; - -// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). -inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - uint32_t val; - asm volatile( - "{\n" - ".reg .b8 byte0;\n" - ".reg .b8 byte1;\n" - ".reg .b8 byte2;\n" - ".reg .b8 byte3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" - "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" - "}" - : "=r"(val) - : "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]), - "f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7])); - return val; -#else - return 0; -#endif -} - -// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). -inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - uint32_t val; - asm volatile( - "{\n" - ".reg .b8 byte0;\n" - ".reg .b8 byte1;\n" - ".reg .b8 byte2;\n" - ".reg .b8 byte3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" - "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" - "}" - : "=r"(val) - : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), - "f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y)); - return val; -#else - return 0; -#endif -} - -// Fast reciprocal. -inline __device__ float reciprocal_approximate_ftz(float a) { - float b; - asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); - return b; -} - -template -__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx, - int numCols, - SFType* SFout) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 || - CVT_FP4_NUM_THREADS_PER_SF == 2); - - // One pair of threads write one SF to global memory. - // TODO: stage through smem for packed STG.32 - // is it better than STG.8 from 4 threads ? - if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) { - // SF vector index (16 elements share one SF in the K dimension). - int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF; - int32_t mIdx = rowIdx; - - // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] - // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx] - - int32_t mTileIdx = mIdx / (32 * 4); - // SF vector size 16. - int factor = CVT_FP4_SF_VEC_SIZE * 4; - int32_t numKTiles = (numCols + factor - 1) / factor; - int64_t mTileStride = numKTiles * 32 * 4 * 4; - - int32_t kTileIdx = (kIdx / 4); - int64_t kTileStride = 32 * 4 * 4; - - // M tile layout [32, 4] is column-major. - int32_t outerMIdx = (mIdx % 32); - int64_t outerMStride = 4 * 4; - - int32_t innerMIdx = (mIdx % (32 * 4)) / 32; - int64_t innerMStride = 4; - - int32_t innerKIdx = (kIdx % 4); - int64_t innerKStride = 1; - - // Compute the global offset. - int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride + - outerMIdx * outerMStride + innerMIdx * innerMStride + - innerKIdx * innerKStride; - - return reinterpret_cast(SFout) + SFOffset; - } -#endif - return nullptr; -} - -// Define a 16 bytes packed data type. -template -struct PackedVec { - typename TypeConverter::Type elts[4]; -}; - -template <> -struct PackedVec<__nv_fp8_e4m3> { - __nv_fp8x2_e4m3 elts[8]; -}; - -// Quantizes the provided PackedVec into the uint32_t output -template -__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec& vec, float SFScaleVal, - uint8_t* SFout) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - // Get absolute maximum values among the local 8 values. - auto localMax = __habs2(vec.elts[0]); - - // Local maximum value. - #pragma unroll - for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { - localMax = __hmax2(localMax, __habs2(vec.elts[i])); - } - - // Get the absolute maximum among all 16 values (two threads). - localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); - // Get the final absolute maximum values. - float vecMax = float(__hmax(localMax.x, localMax.y)); - - // Get the SF (max value of the vector / max value of e2m1). - // maximum value of e2m1 = 6.0. - // TODO: use half as compute data type. - float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); - // 8 bits representation of the SF. - uint8_t fp8SFVal; - // Write the SF to global memory (STG.8). - if constexpr (UE8M0_SF) { - // Extract the 8 exponent bits from float32. - // float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits. - uint32_t tmp = reinterpret_cast(SFValue) >> 23; - fp8SFVal = tmp & 0xff; - // Convert back to fp32. - reinterpret_cast(SFValue) = tmp << 23; - } else { - // Here SFValue is always positive, so E4M3 is the same as UE4M3. - __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); - reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp; - // Convert back to fp32. - SFValue = float(tmp); - } - // Get the output scale. - // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * - // reciprocal(SFScaleVal)) - float outputScale = - SFValue != 0 ? reciprocal_approximate_ftz( - SFValue * reciprocal_approximate_ftz(SFScaleVal)) - : 0.0f; - - if (SFout) { - // Write the SF to global memory (STG.8). - *SFout = fp8SFVal; - } - - // Convert the input to float. - float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; - - #pragma unroll - for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { - if constexpr (std::is_same_v) { - fp2Vals[i] = __half22float2(vec.elts[i]); - } else { - fp2Vals[i] = __bfloat1622float2(vec.elts[i]); - } - fp2Vals[i].x *= outputScale; - fp2Vals[i].y *= outputScale; - } - - // Convert to e2m1 values. - uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals); - - // Write the e2m1 values to global memory. - return e2m1Vec; -#else - return 0; -#endif -} +namespace vllm { // Use UE4M3 by default. template -__global__ void -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -__launch_bounds__(512, 4) cvt_fp16_to_fp4( -#else -cvt_fp16_to_fp4( -#endif - int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, - uint32_t* out, uint32_t* SFout, uint32_t* input_offset_by_experts, - uint32_t* output_scale_offset_by_experts, int n_experts, bool low_latency) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__global__ void __launch_bounds__(512, 4) + cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in, + float const* SFScale, uint32_t* out, uint32_t* SFout, + uint32_t* input_offset_by_experts, + uint32_t* output_scale_offset_by_experts, int n_experts, + bool low_latency) { using PackedVec = PackedVec; static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); @@ -299,8 +94,8 @@ cvt_fp16_to_fp4( &input_offset_by_experts[chunk_start + 12])); local_offsets[16] = __ldca(&input_offset_by_experts[chunk_start + 16]); - // Check against the 16 loaded offsets - #pragma unroll +// Check against the 16 loaded offsets +#pragma unroll for (int i = 0; i < 16; i++) { if (rowIdx >= local_offsets[i] && rowIdx < local_offsets[i + 1]) { rowIdx_in_expert = rowIdx - local_offsets[i]; @@ -330,21 +125,15 @@ cvt_fp16_to_fp4( out_pos = cvt_warp_fp16_to_fp4(in_vec, SFScaleVal, sf_out); } -#endif } // Kernel for LARGE_M_TOPK = true (large m_topk optimized version) template -__global__ void -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -__launch_bounds__(1024, 4) cvt_fp16_to_fp4( -#else -cvt_fp16_to_fp4( -#endif - int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, - uint32_t* out, uint32_t* SFout, uint32_t* input_offset_by_experts, - uint32_t* output_scale_offset_by_experts, int n_experts) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__global__ void __launch_bounds__(1024, 4) + cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in, + float const* SFScale, uint32_t* out, uint32_t* SFout, + uint32_t* input_offset_by_experts, + uint32_t* output_scale_offset_by_experts, int n_experts) { using PackedVec = PackedVec; static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); @@ -425,7 +214,6 @@ cvt_fp16_to_fp4( out_pos = cvt_warp_fp16_to_fp4(in_vec, SFScaleVal, sf_out); } -#endif } template @@ -501,6 +289,8 @@ void quant_impl(void* output, void* output_scale, void* input, } } +} // namespace vllm + /*Quantization entry for fp4 experts quantization*/ #define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor") #define CHECK_CONTIGUOUS(x, m) \ @@ -560,23 +350,17 @@ void scaled_fp4_experts_quant_sm100a( // 4 means 4 fp8 values are packed into one int32 TORCH_CHECK(output_scale.size(1) * 4 == padded_k); - auto in_dtype = input.dtype(); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(input.get_device()); - if (in_dtype == at::ScalarType::Half) { - quant_impl(output.data_ptr(), output_scale.data_ptr(), - input.data_ptr(), input_global_scale.data_ptr(), - input_offset_by_experts.data_ptr(), - output_scale_offset_by_experts.data_ptr(), m_topk, k, - n_experts, stream); - } else if (in_dtype == at::ScalarType::BFloat16) { - quant_impl<__nv_bfloat16>(output.data_ptr(), output_scale.data_ptr(), - input.data_ptr(), input_global_scale.data_ptr(), - input_offset_by_experts.data_ptr(), - output_scale_offset_by_experts.data_ptr(), m_topk, - k, n_experts, stream); - } else { - TORCH_CHECK(false, "Expected input data type to be half or bfloat16"); - } + + VLLM_DISPATCH_HALF_TYPES( + input.scalar_type(), "nvfp4_experts_quant_kernel", [&] { + using cuda_type = vllm::CUDATypeConverter::Type; + vllm::quant_impl( + output.data_ptr(), output_scale.data_ptr(), input.data_ptr(), + input_global_scale.data_ptr(), input_offset_by_experts.data_ptr(), + output_scale_offset_by_experts.data_ptr(), m_topk, k, n_experts, + stream); + }); } diff --git a/csrc/quantization/fp4/nvfp4_quant_entry.cu b/csrc/quantization/fp4/nvfp4_quant_entry.cu index 1b61bd4519fc3..c2b39e5438805 100644 --- a/csrc/quantization/fp4/nvfp4_quant_entry.cu +++ b/csrc/quantization/fp4/nvfp4_quant_entry.cu @@ -32,6 +32,14 @@ void scaled_fp4_experts_quant_sm100a( torch::Tensor const& output_scale_offset_by_experts); #endif +#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ + (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) +void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, + torch::Tensor& output_sf, + torch::Tensor& input, + torch::Tensor& input_sf); +#endif + void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf, torch::Tensor const& input_sf) { #if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ @@ -54,3 +62,13 @@ void scaled_fp4_experts_quant( TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 experts quantization kernel"); } + +void silu_and_mul_nvfp4_quant(torch::Tensor& output, torch::Tensor& output_sf, + torch::Tensor& input, torch::Tensor& input_sf) { +#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ + (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) + return silu_and_mul_nvfp4_quant_sm1xxa(output, output_sf, input, input_sf); +#endif + TORCH_CHECK_NOT_IMPLEMENTED( + false, "No compiled silu_and_mul nvfp4 quantization kernel"); +} diff --git a/csrc/quantization/fp4/nvfp4_quant_kernels.cu b/csrc/quantization/fp4/nvfp4_quant_kernels.cu index 4e080de151648..0c1b9ef0664d7 100644 --- a/csrc/quantization/fp4/nvfp4_quant_kernels.cu +++ b/csrc/quantization/fp4/nvfp4_quant_kernels.cu @@ -23,245 +23,18 @@ #include #include +#include "dispatch_utils.h" #include "cuda_utils.h" +#include "nvfp4_utils.cuh" -// Get type2 from type or vice versa (applied to half and bfloat16) -template -struct TypeConverter { - using Type = half2; -}; // keep for generality - -template <> -struct TypeConverter { - using Type = half; -}; - -template <> -struct TypeConverter { - using Type = half2; -}; - -template <> -struct TypeConverter<__nv_bfloat162> { - using Type = __nv_bfloat16; -}; - -template <> -struct TypeConverter<__nv_bfloat16> { - using Type = __nv_bfloat162; -}; - -#define ELTS_PER_THREAD 8 - -constexpr int CVT_FP4_ELTS_PER_THREAD = 8; -constexpr int CVT_FP4_SF_VEC_SIZE = 16; - -// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). -inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - uint32_t val; - asm volatile( - "{\n" - ".reg .b8 byte0;\n" - ".reg .b8 byte1;\n" - ".reg .b8 byte2;\n" - ".reg .b8 byte3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" - "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" - "}" - : "=r"(val) - : "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]), - "f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7])); - return val; -#else - return 0; -#endif -} - -// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). -inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - uint32_t val; - asm volatile( - "{\n" - ".reg .b8 byte0;\n" - ".reg .b8 byte1;\n" - ".reg .b8 byte2;\n" - ".reg .b8 byte3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" - "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" - "}" - : "=r"(val) - : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), - "f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y)); - return val; -#else - return 0; -#endif -} - -// Fast reciprocal. -inline __device__ float reciprocal_approximate_ftz(float a) { - float b; - asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); - return b; -} - -template -__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx, - int numCols, - SFType* SFout) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 || - CVT_FP4_NUM_THREADS_PER_SF == 2); - - // One pair of threads write one SF to global memory. - // TODO: stage through smem for packed STG.32 - // is it better than STG.8 from 4 threads ? - if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) { - // SF vector index (16 elements share one SF in the K dimension). - int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF; - int32_t mIdx = rowIdx; - - // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] - // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx] - - int32_t mTileIdx = mIdx / (32 * 4); - // SF vector size 16. - int factor = CVT_FP4_SF_VEC_SIZE * 4; - int32_t numKTiles = (numCols + factor - 1) / factor; - int64_t mTileStride = numKTiles * 32 * 4 * 4; - - int32_t kTileIdx = (kIdx / 4); - int64_t kTileStride = 32 * 4 * 4; - - // M tile layout [32, 4] is column-major. - int32_t outerMIdx = (mIdx % 32); - int64_t outerMStride = 4 * 4; - - int32_t innerMIdx = (mIdx % (32 * 4)) / 32; - int64_t innerMStride = 4; - - int32_t innerKIdx = (kIdx % 4); - int64_t innerKStride = 1; - - // Compute the global offset. - int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride + - outerMIdx * outerMStride + innerMIdx * innerMStride + - innerKIdx * innerKStride; - - return reinterpret_cast(SFout) + SFOffset; - } -#endif - return nullptr; -} - -// Define a 16 bytes packed data type. -template -struct PackedVec { - typename TypeConverter::Type elts[4]; -}; - -template <> -struct PackedVec<__nv_fp8_e4m3> { - __nv_fp8x2_e4m3 elts[8]; -}; - -// Quantizes the provided PackedVec into the uint32_t output -template -__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec& vec, float SFScaleVal, - uint8_t* SFout) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - // Get absolute maximum values among the local 8 values. - auto localMax = __habs2(vec.elts[0]); - - // Local maximum value. - #pragma unroll - for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { - localMax = __hmax2(localMax, __habs2(vec.elts[i])); - } - - // Get the absolute maximum among all 16 values (two threads). - localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); - // Get the final absolute maximum values. - float vecMax = float(__hmax(localMax.x, localMax.y)); - - // Get the SF (max value of the vector / max value of e2m1). - // maximum value of e2m1 = 6.0. - // TODO: use half as compute data type. - float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); - // 8 bits representation of the SF. - uint8_t fp8SFVal; - // Write the SF to global memory (STG.8). - if constexpr (UE8M0_SF) { - // Extract the 8 exponent bits from float32. - // float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits. - uint32_t tmp = reinterpret_cast(SFValue) >> 23; - fp8SFVal = tmp & 0xff; - // Convert back to fp32. - reinterpret_cast(SFValue) = tmp << 23; - } else { - // Here SFValue is always positive, so E4M3 is the same as UE4M3. - __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); - reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp; - // Convert back to fp32. - SFValue = float(tmp); - } - // Get the output scale. - // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * - // reciprocal(SFScaleVal)) - float outputScale = - SFValue != 0 ? reciprocal_approximate_ftz( - SFValue * reciprocal_approximate_ftz(SFScaleVal)) - : 0.0f; - - if (SFout) { - // Write the SF to global memory (STG.8). - *SFout = fp8SFVal; - } - - // Convert the input to float. - float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; - - #pragma unroll - for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { - if constexpr (std::is_same_v) { - fp2Vals[i] = __half22float2(vec.elts[i]); - } else { - fp2Vals[i] = __bfloat1622float2(vec.elts[i]); - } - fp2Vals[i].x *= outputScale; - fp2Vals[i].y *= outputScale; - } - - // Convert to e2m1 values. - uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals); - - // Write the e2m1 values to global memory. - return e2m1Vec; -#else - return 0; -#endif -} +namespace vllm { // Use UE4M3 by default. template -__global__ void -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -__launch_bounds__(512, 4) cvt_fp16_to_fp4( -#else -cvt_fp16_to_fp4( -#endif - int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, - uint32_t* out, uint32_t* SFout) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__global__ void __launch_bounds__(512, 4) + cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in, + float const* SFScale, uint32_t* out, uint32_t* SFout) { using PackedVec = PackedVec; static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); @@ -293,7 +66,6 @@ cvt_fp16_to_fp4( cvt_warp_fp16_to_fp4(in_vec, SFScaleVal, sf_out); } } -#endif } template @@ -332,6 +104,8 @@ template void invokeFP4Quantization(int m, int n, __nv_bfloat16 const* input, int multiProcessorCount, cudaStream_t stream); +} // namespace vllm + void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, torch::Tensor const& input, torch::Tensor const& output_sf, @@ -340,6 +114,9 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, int32_t n = input.size(1); TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16."); + TORCH_CHECK(input.scalar_type() == at::ScalarType::Half || + input.scalar_type() == at::ScalarType::BFloat16, + "Unsupported input data type for quantize_to_fp4."); int multiProcessorCount = get_device_attribute(cudaDevAttrMultiProcessorCount, -1); @@ -353,24 +130,10 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, // We don't support e8m0 scales at this moment. bool useUE8M0 = false; - switch (input.scalar_type()) { - case torch::kHalf: { - auto input_ptr = reinterpret_cast(input.data_ptr()); - invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr, sf_out, - useUE8M0, multiProcessorCount, stream); - break; - } - case torch::kBFloat16: { - auto input_ptr = reinterpret_cast<__nv_bfloat16 const*>(input.data_ptr()); - invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr, sf_out, - useUE8M0, multiProcessorCount, stream); - break; - } - default: { - std::cerr << "Observing: " << input.scalar_type() - << " for the input datatype which is invalid"; - throw std::runtime_error( - "Unsupported input data type for quantize_to_fp4."); - } - } + VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] { + using cuda_type = vllm::CUDATypeConverter::Type; + auto input_ptr = static_cast(input.data_ptr()); + vllm::invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr, + sf_out, useUE8M0, multiProcessorCount, stream); + }); } diff --git a/csrc/quantization/fp4/nvfp4_utils.cuh b/csrc/quantization/fp4/nvfp4_utils.cuh new file mode 100644 index 0000000000000..48e4959de9793 --- /dev/null +++ b/csrc/quantization/fp4/nvfp4_utils.cuh @@ -0,0 +1,251 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#define ELTS_PER_THREAD 8 + +constexpr int CVT_FP4_ELTS_PER_THREAD = 8; +constexpr int CVT_FP4_SF_VEC_SIZE = 16; + +namespace vllm { + +// Convert PyTorch cpp type to CUDA type +template +struct CUDATypeConverter { + using Type = T; +}; + +template <> +struct CUDATypeConverter { + using Type = half; +}; + +template <> +struct CUDATypeConverter { + using Type = __nv_bfloat16; +}; + +// Get type2 from type or vice versa (applied to half and bfloat16) +template +struct TypeConverter { + using Type = half2; +}; // keep for generality + +template <> +struct TypeConverter { + using Type = half; +}; + +template <> +struct TypeConverter { + using Type = half2; +}; + +template <> +struct TypeConverter<__nv_bfloat162> { + using Type = __nv_bfloat16; +}; + +template <> +struct TypeConverter<__nv_bfloat16> { + using Type = __nv_bfloat162; +}; + +// Define a 16 bytes packed data type. +template +struct PackedVec { + typename TypeConverter::Type elts[4]; +}; + +template <> +struct PackedVec<__nv_fp8_e4m3> { + __nv_fp8x2_e4m3 elts[8]; +}; + +// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). +inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) { + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]), + "f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7])); + return val; +} + +// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). +inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) { + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), + "f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y)); + return val; +} + +// Fast reciprocal. +inline __device__ float reciprocal_approximate_ftz(float a) { + float b; + asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); + return b; +} + +template +__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx, + int numCols, + SFType* SFout) { + static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 || + CVT_FP4_NUM_THREADS_PER_SF == 2); + + // One pair of threads write one SF to global memory. + // TODO: stage through smem for packed STG.32 + // is it better than STG.8 from 4 threads ? + if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) { + // SF vector index (16 elements share one SF in the K dimension). + int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF; + int32_t mIdx = rowIdx; + + // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] + // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx] + + int32_t mTileIdx = mIdx / (32 * 4); + // SF vector size 16. + int factor = CVT_FP4_SF_VEC_SIZE * 4; + int32_t numKTiles = (numCols + factor - 1) / factor; + int64_t mTileStride = numKTiles * 32 * 4 * 4; + + int32_t kTileIdx = (kIdx / 4); + int64_t kTileStride = 32 * 4 * 4; + + // M tile layout [32, 4] is column-major. + int32_t outerMIdx = (mIdx % 32); + int64_t outerMStride = 4 * 4; + + int32_t innerMIdx = (mIdx % (32 * 4)) / 32; + int64_t innerMStride = 4; + + int32_t innerKIdx = (kIdx % 4); + int64_t innerKStride = 1; + + // Compute the global offset. + int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride + + outerMIdx * outerMStride + innerMIdx * innerMStride + + innerKIdx * innerKStride; + + return reinterpret_cast(SFout) + SFOffset; + } + return nullptr; +} + +// Quantizes the provided PackedVec into the uint32_t output +template +__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec& vec, float SFScaleVal, + uint8_t* SFout) { + // Get absolute maximum values among the local 8 values. + auto localMax = __habs2(vec.elts[0]); + +// Local maximum value. +#pragma unroll + for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + localMax = __hmax2(localMax, __habs2(vec.elts[i])); + } + + // Get the absolute maximum among all 16 values (two threads). + localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); + // Get the final absolute maximum values. + float vecMax = float(__hmax(localMax.x, localMax.y)); + + // Get the SF (max value of the vector / max value of e2m1). + // maximum value of e2m1 = 6.0. + // TODO: use half as compute data type. + float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); + // 8 bits representation of the SF. + uint8_t fp8SFVal; + // Write the SF to global memory (STG.8). + if constexpr (UE8M0_SF) { + // Extract the 8 exponent bits from float32. + // float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits. + uint32_t tmp = reinterpret_cast(SFValue) >> 23; + fp8SFVal = tmp & 0xff; + // Convert back to fp32. + reinterpret_cast(SFValue) = tmp << 23; + } else { + // Here SFValue is always positive, so E4M3 is the same as UE4M3. + __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); + reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp; + // Convert back to fp32. + SFValue = float(tmp); + } + // Get the output scale. + // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * + // reciprocal(SFScaleVal)) + float outputScale = + SFValue != 0 ? reciprocal_approximate_ftz( + SFValue * reciprocal_approximate_ftz(SFScaleVal)) + : 0.0f; + + if (SFout) { + // Write the SF to global memory (STG.8). + *SFout = fp8SFVal; + } + + // Convert the input to float. + float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; + +#pragma unroll + for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + if constexpr (std::is_same_v) { + fp2Vals[i] = __half22float2(vec.elts[i]); + } else { + fp2Vals[i] = __bfloat1622float2(vec.elts[i]); + } + fp2Vals[i].x *= outputScale; + fp2Vals[i].y *= outputScale; + } + + // Convert to e2m1 values. + uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals); + + // Write the e2m1 values to global memory. + return e2m1Vec; +} + +} // namespace vllm diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 56626a02c0277..b769c09adc0f0 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -115,8 +115,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()"); ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant); -#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ - (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) +#ifndef USE_ROCM ops.def( "silu_and_mul_nvfp4_quant(Tensor! result, Tensor! result_block_scale, " "Tensor input, Tensor input_global_scale) -> ()"); diff --git a/tests/kernels/quantization/test_silu_nvfp4_quant_fusion.py b/tests/kernels/quantization/test_silu_nvfp4_quant_fusion.py index 4325162ae94a9..969f14cc3fe62 100644 --- a/tests/kernels/quantization/test_silu_nvfp4_quant_fusion.py +++ b/tests/kernels/quantization/test_silu_nvfp4_quant_fusion.py @@ -8,8 +8,7 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.platforms import current_platform from vllm.scalar_type import scalar_types -if not (current_platform.has_device_capability(100) - and hasattr(torch.ops._C, "silu_and_mul_nvfp4_quant")): +if not current_platform.has_device_capability(100): pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.", allow_module_level=True)