/* * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include #include #include #include "dispatch_utils.h" #include "cuda_utils.h" #include "launch_bounds_utils.h" #include "nvfp4_utils.cuh" namespace vllm { template __host__ __device__ inline Int round_up(Int x, Int y) { static_assert(std::is_integral_v, "round_up argument must be integral type"); return (x + y - 1) / y * y; } // Use UE4M3 by default. template __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, uint32_t* out, uint32_t* SFout) { using PackedVec = PackedVec; static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, "Vec size is not matched."); int sf_m = round_up(numRows, 128); int sf_n_unpadded = numCols / CVT_FP4_SF_VEC_SIZE; int sf_n_int = round_up(sf_n_unpadded, 4) / 4; for (int row = numRows + blockIdx.x; row < sf_m; row += gridDim.x) { // Each thread writes 4 uint32_t elements. for (int col = sf_n_unpadded + threadIdx.x * 4; col < sf_n_int; col += blockDim.x * 4) { SFout[row * sf_n_int + col] = 0x00; } } // Get the global scaling factor, which will be applied to the SF. // Note SFScale is the same as next GEMM's alpha, which is // (448.f / (Alpha_A / 6.f)). float const global_scale = SFScale == nullptr ? 1.0f : SFScale[0]; // Input tensor row/col loops. for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD; colIdx += blockDim.x) { int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; PackedVec in_vec = reinterpret_cast(in)[inOffset]; // Get the output tensor offset. // Same as inOffset because 8 elements are packed into one uint32_t. int64_t outOffset = inOffset; auto& out_pos = out[outOffset]; auto sf_out = cvt_quant_to_fp4_get_sf_out_offset( rowIdx, colIdx, numCols, SFout); out_pos = cvt_warp_fp16_to_fp4(in_vec, global_scale, sf_out); } } } template void invokeFP4Quantization(int m, int n, T const* input, float const* SFScale, int64_t* output, int32_t* SFOuput, bool useUE8M0, int multiProcessorCount, cudaStream_t stream) { // Grid, Block size. // Each thread converts 8 values. dim3 block(std::min(int(n / ELTS_PER_THREAD), 512)); // Get number of blocks per SM int const numBlocksPerSM = vllm_runtime_blocks_per_sm(static_cast(block.x)); dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM)); // Launch the cvt kernel. if (useUE8M0) { cvt_fp16_to_fp4<<>>( m, n, input, SFScale, reinterpret_cast(output), reinterpret_cast(SFOuput)); } else { cvt_fp16_to_fp4<<>>( m, n, input, SFScale, reinterpret_cast(output), reinterpret_cast(SFOuput)); } } // Instantiate the function. template void invokeFP4Quantization(int m, int n, half const* input, float const* SFScale, int64_t* output, int32_t* SFOuput, bool useUE8M0, int multiProcessorCount, cudaStream_t stream); template void invokeFP4Quantization(int m, int n, __nv_bfloat16 const* input, float const* SFScale, int64_t* output, int32_t* SFOuput, bool useUE8M0, int multiProcessorCount, cudaStream_t stream); } // namespace vllm void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, torch::Tensor const& input, torch::Tensor const& output_sf, torch::Tensor const& input_sf) { int32_t m = input.size(0); int32_t n = input.size(1); TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16."); TORCH_CHECK(input.scalar_type() == at::ScalarType::Half || input.scalar_type() == at::ScalarType::BFloat16, "Unsupported input data type for quantize_to_fp4."); int multiProcessorCount = get_device_attribute(cudaDevAttrMultiProcessorCount, -1); auto input_sf_ptr = static_cast(input_sf.data_ptr()); auto sf_out = static_cast(output_sf.data_ptr()); auto output_ptr = static_cast(output.data_ptr()); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); // We don't support e8m0 scales at this moment. bool useUE8M0 = false; VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] { using cuda_type = vllm::CUDATypeConverter::Type; auto input_ptr = static_cast(input.data_ptr()); vllm::invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr, sf_out, useUE8M0, multiProcessorCount, stream); }); }