From b94f80ffb8f279b97931bdc2d39bd9328503d7d1 Mon Sep 17 00:00:00 2001 From: danielafrimi <45691845+danielafrimi@users.noreply.github.com> Date: Tue, 23 Dec 2025 18:45:18 +0200 Subject: [PATCH] [FIX] FP4 quantization kernel padding initialization bug (#31097) Signed-off-by: <> Co-authored-by: root Co-authored-by: root --- csrc/quantization/fp4/nvfp4_quant_kernels.cu | 46 ++++++++++++-------- 1 file changed, 28 insertions(+), 18 deletions(-) diff --git a/csrc/quantization/fp4/nvfp4_quant_kernels.cu b/csrc/quantization/fp4/nvfp4_quant_kernels.cu index 6acadb4cefd2c..8e38deeb6607f 100644 --- a/csrc/quantization/fp4/nvfp4_quant_kernels.cu +++ b/csrc/quantization/fp4/nvfp4_quant_kernels.cu @@ -35,7 +35,7 @@ template __host__ __device__ inline Int round_up(Int x, Int y) { static_assert(std::is_integral_v, "round_up argument must be integral type"); - return (x + y - 1) / y * y; + return ((x + y - 1) / y) * y; } // Compute effective rows for grid configuration with swizzled SF layouts. @@ -61,37 +61,47 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) int sf_m = round_up(numRows, 128); int sf_n_unpadded = numCols / CVT_FP4_SF_VEC_SIZE; int sf_n_int = round_up(sf_n_unpadded, 4) / 4; - for (int row = numRows + blockIdx.x; row < sf_m; row += gridDim.x) { - // Each thread writes 4 uint32_t elements. - for (int col = sf_n_unpadded + threadIdx.x * 4; col < sf_n_int; - col += blockDim.x * 4) { - SFout[row * sf_n_int + col] = 0x00; - } - } + int num_padded_cols = sf_n_int * 4 * CVT_FP4_SF_VEC_SIZE; // Get the global scaling factor, which will be applied to the SF. // Note SFScale is the same as next GEMM's alpha, which is // (448.f / (Alpha_A / 6.f)). float const global_scale = SFScale == nullptr ? 1.0f : SFScale[0]; - // Input tensor row/col loops. - for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { - for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD; + // Iterate over all rows and cols including padded ones - + // ensures we visit every single scale factor address to initialize it. + for (int rowIdx = blockIdx.x; rowIdx < sf_m; rowIdx += gridDim.x) { + for (int colIdx = threadIdx.x; + colIdx < num_padded_cols / CVT_FP4_ELTS_PER_THREAD; colIdx += blockDim.x) { + int elem_idx = colIdx * CVT_FP4_ELTS_PER_THREAD; + + PackedVec in_vec; int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; - PackedVec in_vec = reinterpret_cast(in)[inOffset]; - // Get the output tensor offset. - // Same as inOffset because 8 elements are packed into one uint32_t. - int64_t outOffset = inOffset; - auto& out_pos = out[outOffset]; + + // If we are outside valid rows OR outside valid columns -> Use Zeros + if (rowIdx >= numRows || elem_idx >= numCols) { + memset(&in_vec, 0, sizeof(PackedVec)); + + } else { + // Valid Region: Load actual data + in_vec = reinterpret_cast(in)[inOffset]; + } auto sf_out = cvt_quant_to_fp4_get_sf_out_offset( rowIdx, colIdx, numKTiles, SFout); - out_pos = + auto out_val = cvt_warp_fp16_to_fp4(in_vec, global_scale, sf_out); + + // We do NOT write output for padding because the 'out' tensor is not + // padded. + if (rowIdx < numRows && elem_idx < numCols) { + // Same as inOffset because 8 elements are packed into one uint32_t. + out[inOffset] = out_val; + } } } } @@ -134,4 +144,4 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, m, n, input_ptr, input_sf_ptr, reinterpret_cast(output_ptr), reinterpret_cast(sf_out)); }); -} +} \ No newline at end of file