/* * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include #include #include #include "dispatch_utils.h" #include "cuda_utils.h" #include "nvfp4_utils.cuh" namespace vllm { template __inline__ __device__ PackedVec compute_silu(PackedVec& vec, PackedVec& vec2) { PackedVec result; #pragma unroll for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) { if constexpr (std::is_same_v) { half2 val(0.5f, 0.5f); half2 t0 = __hmul2(vec.elts[i], val); half2 t1 = __hfma2(h2tanh(t0), val, val); half2 t2 = __hmul2(vec.elts[i], t1); result.elts[i] = __hmul2(t2, vec2.elts[i]); } else { __nv_bfloat162 val(0.5f, 0.5f); __nv_bfloat162 t0 = __hmul2(vec.elts[i], val); __nv_bfloat162 t1 = __hfma2(h2tanh(t0), val, val); __nv_bfloat162 t2 = __hmul2(vec.elts[i], t1); result.elts[i] = __hmul2(t2, vec2.elts[i]); } } return result; } // Quantizes the provided PackedVec into the uint32_t output template __device__ uint32_t silu_and_cvt_warp_fp16_to_fp4(PackedVec& vec, PackedVec& vec2, float SFScaleVal, uint8_t* SFout) { PackedVec out_silu = compute_silu(vec, vec2); // Get absolute maximum values among the local 8 values. auto localMax = __habs2(out_silu.elts[0]); // Local maximum value. #pragma unroll for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { localMax = __hmax2(localMax, __habs2(out_silu.elts[i])); } // Get the absolute maximum among all 16 values (two threads). localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); // Get the final absolute maximum values. float vecMax = float(__hmax(localMax.x, localMax.y)); // Get the SF (max value of the vector / max value of e2m1). // maximum value of e2m1 = 6.0. // TODO: use half as compute data type. float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); // 8 bits representation of the SF. uint8_t fp8SFVal; // Write the SF to global memory (STG.8). if constexpr (UE8M0_SF) { // Extract the 8 exponent bits from float32. // float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits. uint32_t tmp = reinterpret_cast(SFValue) >> 23; fp8SFVal = tmp & 0xff; // Convert back to fp32. reinterpret_cast(SFValue) = tmp << 23; } else { // Here SFValue is always positive, so E4M3 is the same as UE4M3. __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp; // Convert back to fp32. SFValue = float(tmp); } // Get the output scale. // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * // reciprocal(SFScaleVal)) float outputScale = SFValue != 0 ? reciprocal_approximate_ftz( SFValue * reciprocal_approximate_ftz(SFScaleVal)) : 0.0f; if (SFout) { // Write the SF to global memory (STG.8). *SFout = fp8SFVal; } // Convert the input to float. float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; #pragma unroll for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { if constexpr (std::is_same_v) { fp2Vals[i] = __half22float2(out_silu.elts[i]); } else { fp2Vals[i] = __bfloat1622float2(out_silu.elts[i]); } fp2Vals[i].x *= outputScale; fp2Vals[i].y *= outputScale; } // Convert to e2m1 values. uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals); // Write the e2m1 values to global memory. return e2m1Vec; } // Use UE4M3 by default. template __global__ void __launch_bounds__(1024, 4) silu_and_cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, uint32_t* out, uint32_t* SFout) { using PackedVec = PackedVec; static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, "Vec size is not matched."); // Get the global scaling factor, which will be applied to the SF. // Note SFScale is the same as next GEMM's alpha, which is // (448.f / (Alpha_A / 6.f)). float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0]; // Input tensor row/col loops. for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD; colIdx += blockDim.x) { int64_t inOffset = rowIdx * (numCols * 2 / CVT_FP4_ELTS_PER_THREAD) + colIdx; int64_t inOffset2 = rowIdx * (numCols * 2 / CVT_FP4_ELTS_PER_THREAD) + numCols / CVT_FP4_ELTS_PER_THREAD + colIdx; PackedVec in_vec = reinterpret_cast(in)[inOffset]; PackedVec in_vec2 = reinterpret_cast(in)[inOffset2]; // Get the output tensor offset. // Same as inOffset because 8 elements are packed into one uint32_t. int64_t outOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; ; auto& out_pos = out[outOffset]; auto sf_out = cvt_quant_to_fp4_get_sf_out_offset( rowIdx, colIdx, numCols, SFout); out_pos = silu_and_cvt_warp_fp16_to_fp4( in_vec, in_vec2, SFScaleVal, sf_out); } } } } // namespace vllm void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d] torch::Tensor& output_sf, torch::Tensor& input, // [..., 2 * d] torch::Tensor& input_sf) { int32_t m = input.size(0); int32_t n = input.size(1) / 2; TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16."); TORCH_CHECK(input.scalar_type() == at::ScalarType::Half || input.scalar_type() == at::ScalarType::BFloat16, "Unsupported input data type for quantize_to_fp4."); int multiProcessorCount = get_device_attribute(cudaDevAttrMultiProcessorCount, -1); auto input_sf_ptr = static_cast(input_sf.data_ptr()); auto sf_out = static_cast(output_sf.data_ptr()); auto output_ptr = static_cast(output.data_ptr()); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); dim3 block(std::min(int(n / ELTS_PER_THREAD), 1024)); int const numBlocksPerSM = 2048 / block.x; dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM)); VLLM_DISPATCH_HALF_TYPES( input.scalar_type(), "silu_and_mul_nvfp4_quant_kernel", [&] { using cuda_type = vllm::CUDATypeConverter::Type; auto input_ptr = static_cast(input.data_ptr()); vllm::silu_and_cvt_fp16_to_fp4<<>>( m, n, input_ptr, input_sf_ptr, reinterpret_cast(output_ptr), reinterpret_cast(sf_out)); }); }