[Bugfix] Fix accuracy issue for silu_mul + nvfp4 quant fusion kernel (#24833)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
elvischenv 2025-09-18 07:37:23 +08:00 committed by GitHub
parent 2a4d6412e6
commit e6585ddb45
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 110 additions and 226 deletions

View File

@ -796,7 +796,7 @@ steps:
# Quantization
- pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8'
- pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py
- pytest -v -s tests/kernels/quantization/test_silu_nvfp4_quant_fusion.py
- pytest -v -s tests/kernels/quantization/test_silu_mul_nvfp4_quant.py
- pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py
- pytest -v -s tests/kernels/quantization/test_flashinfer_scaled_mm.py
- pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py

View File

@ -30,109 +30,41 @@
namespace vllm {
// silu in float32
__device__ __forceinline__ float silu(float x) {
return __fdividef(x, (1.f + __expf(-x)));
}
__device__ __forceinline__ float2 silu2(float2 x) {
return make_float2(silu(x.x), silu(x.y));
}
template <class Type>
__inline__ __device__ PackedVec<Type> compute_silu(PackedVec<Type>& vec,
PackedVec<Type>& vec2) {
__inline__ __device__ PackedVec<Type> compute_silu_mul(PackedVec<Type>& vec,
PackedVec<Type>& vec2) {
PackedVec<Type> result;
using packed_type = typename TypeConverter<Type>::Type;
#pragma unroll
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) {
// silu_mul in float32
if constexpr (std::is_same_v<Type, half>) {
half2 val(0.5f, 0.5f);
half2 t0 = __hmul2(vec.elts[i], val);
half2 t1 = __hfma2(h2tanh(t0), val, val);
half2 t2 = __hmul2(vec.elts[i], t1);
result.elts[i] = __hmul2(t2, vec2.elts[i]);
float2 silu_vec = silu2(__half22float2(vec.elts[i]));
result.elts[i] =
__float22half2_rn(__fmul2_rn(silu_vec, __half22float2(vec2.elts[i])));
} else {
__nv_bfloat162 val(0.5f, 0.5f);
__nv_bfloat162 t0 = __hmul2(vec.elts[i], val);
__nv_bfloat162 t1 = __hfma2(h2tanh(t0), val, val);
__nv_bfloat162 t2 = __hmul2(vec.elts[i], t1);
result.elts[i] = __hmul2(t2, vec2.elts[i]);
float2 silu_vec = silu2(__bfloat1622float2(vec.elts[i]));
result.elts[i] = __float22bfloat162_rn(
__fmul2_rn(silu_vec, __bfloat1622float2(vec2.elts[i])));
}
}
return result;
}
// Quantizes the provided PackedVec into the uint32_t output
template <class Type, bool UE8M0_SF = false>
__device__ uint32_t silu_and_cvt_warp_fp16_to_fp4(PackedVec<Type>& vec,
PackedVec<Type>& vec2,
float SFScaleVal,
uint8_t* SFout) {
PackedVec<Type> out_silu = compute_silu(vec, vec2);
// Get absolute maximum values among the local 8 values.
auto localMax = __habs2(out_silu.elts[0]);
// Local maximum value.
#pragma unroll
for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
localMax = __hmax2(localMax, __habs2(out_silu.elts[i]));
}
// Get the absolute maximum among all 16 values (two threads).
localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax);
// Get the final absolute maximum values.
float vecMax = float(__hmax(localMax.x, localMax.y));
// Get the SF (max value of the vector / max value of e2m1).
// maximum value of e2m1 = 6.0.
// TODO: use half as compute data type.
float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f));
// 8 bits representation of the SF.
uint8_t fp8SFVal;
// Write the SF to global memory (STG.8).
if constexpr (UE8M0_SF) {
// Extract the 8 exponent bits from float32.
// float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits.
uint32_t tmp = reinterpret_cast<uint32_t&>(SFValue) >> 23;
fp8SFVal = tmp & 0xff;
// Convert back to fp32.
reinterpret_cast<uint32_t&>(SFValue) = tmp << 23;
} else {
// Here SFValue is always positive, so E4M3 is the same as UE4M3.
__nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue);
reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp;
// Convert back to fp32.
SFValue = float(tmp);
}
// Get the output scale.
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) *
// reciprocal(SFScaleVal))
float outputScale =
SFValue != 0 ? reciprocal_approximate_ftz(
SFValue * reciprocal_approximate_ftz(SFScaleVal))
: 0.0f;
if (SFout) {
// Write the SF to global memory (STG.8).
*SFout = fp8SFVal;
}
// Convert the input to float.
float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2];
#pragma unroll
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
if constexpr (std::is_same_v<Type, half>) {
fp2Vals[i] = __half22float2(out_silu.elts[i]);
} else {
fp2Vals[i] = __bfloat1622float2(out_silu.elts[i]);
}
fp2Vals[i].x *= outputScale;
fp2Vals[i].y *= outputScale;
}
// Convert to e2m1 values.
uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals);
// Write the e2m1 values to global memory.
return e2m1Vec;
}
// Use UE4M3 by default.
template <class Type, bool UE8M0_SF = false>
__global__ void __launch_bounds__(1024, 4)
silu_and_cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
silu_mul_cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
float const* SFScale, uint32_t* out,
uint32_t* SFout) {
using PackedVec = PackedVec<Type>;
@ -160,16 +92,18 @@ __global__ void __launch_bounds__(1024, 4)
// Get the output tensor offset.
// Same as inOffset because 8 elements are packed into one uint32_t.
int64_t outOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx;
;
auto& out_pos = out[outOffset];
// Compute silu and mul
PackedVec out_silu_mul = compute_silu_mul(in_vec, in_vec2);
auto sf_out =
cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
CVT_FP4_NUM_THREADS_PER_SF>(
rowIdx, colIdx, numCols, SFout);
out_pos = silu_and_cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(
in_vec, in_vec2, SFScaleVal, sf_out);
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(out_silu_mul, SFScaleVal,
sf_out);
}
}
}
@ -204,7 +138,7 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d]
input.scalar_type(), "silu_and_mul_nvfp4_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
vllm::silu_and_cvt_fp16_to_fp4<cuda_type><<<grid, block, 0, stream>>>(
vllm::silu_mul_cvt_fp16_to_fp4<cuda_type><<<grid, block, 0, stream>>>(
m, n, input_ptr, input_sf_ptr,
reinterpret_cast<uint32_t*>(output_ptr),
reinterpret_cast<uint32_t*>(sf_out));

View File

@ -98,8 +98,9 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
return [FUSED_OPS[kNvfp4Quant]]
@pytest.mark.parametrize("num_tokens", [64])
@pytest.mark.parametrize("hidden_size", [128])
@pytest.mark.parametrize("num_tokens", [32, 64])
@pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize(
"model_class",
cast(list[type], [TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel]
@ -110,13 +111,13 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
[True, False] if cutlass_fp8_supported() else [True])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
reason="Only test on CUDA and ROCm")
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, model_class,
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, dtype, model_class,
cuda_force_torch):
if model_class == TestSiluMulNvfp4QuantModel and cuda_force_torch:
pytest.skip("Duplicate tests for NVFP4")
torch.set_default_device("cuda")
torch.set_default_dtype(torch.float16)
torch.set_default_dtype(dtype)
x = torch.rand(num_tokens, hidden_size * 2)
@ -145,8 +146,8 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, model_class,
elif model_class == TestSiluMulNvfp4QuantModel:
atol, rtol = 1e-1, 1e-1
torch.testing.assert_close(result[0].to(dtype=torch.float16),
result2[0].to(dtype=torch.float16),
torch.testing.assert_close(result[0].to(dtype=dtype),
result2[0].to(dtype=dtype),
atol=atol,
rtol=rtol)

View File

@ -0,0 +1,75 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype)
from vllm._custom_ops import scaled_fp4_quant
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.platforms import current_platform
if not current_platform.has_device_capability(100):
pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.",
allow_module_level=True)
FP4_DTYPE = torch.uint8
FP8_DTYPE = current_platform.fp8_dtype()
DTYPES = [torch.float16, torch.bfloat16]
SHAPES = [(128, 256), (128, 128), (256, 256), (256, 128)]
BLOCK_SIZE = 16
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("shape", SHAPES)
@torch.inference_mode()
def test_silu_mul_nvfp4_quant(
dtype: torch.dtype,
shape: tuple[int, int],
) -> None:
current_platform.seed_everything(42)
device = 'cuda:0'
torch.set_default_device(device)
x = torch.randn(shape, dtype=dtype)
# ref op
ref_output = SiluAndMul().forward_native(x)
ref_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.abs(ref_output).max().to(torch.float32))
ref_output_quant, ref_block_scale = scaled_fp4_quant(
ref_output, ref_global_scale)
# fused op
fused_output_quant = torch.empty_like(ref_output_quant)
fused_block_scale = torch.empty_like(ref_block_scale)
torch.ops._C.silu_and_mul_nvfp4_quant(fused_output_quant,
fused_block_scale, x,
ref_global_scale)
# check dtype
assert ref_output_quant.dtype == FP4_DTYPE
assert fused_output_quant.dtype == FP4_DTYPE
assert ref_output_quant.shape == fused_output_quant.shape
assert ref_block_scale.dtype == FP8_DTYPE
assert fused_block_scale.dtype == FP8_DTYPE
assert ref_block_scale.shape == fused_block_scale.shape
# check dequantized output
ref_output_dequant = dequantize_nvfp4_to_dtype(ref_output_quant,
ref_block_scale,
ref_global_scale, dtype,
device)
fused_output_dequant = dequantize_nvfp4_to_dtype(fused_output_quant,
fused_block_scale,
ref_global_scale, dtype,
device)
atol, rtol = 3e-1, 3e-1
torch.testing.assert_close(ref_output_dequant,
fused_output_dequant,
atol=atol,
rtol=rtol)

View File

@ -1,126 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from tests.kernels.utils import opcheck
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
if not current_platform.has_device_capability(100):
pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.",
allow_module_level=True)
DTYPES = [torch.float16, torch.bfloat16]
SHAPES = [(128, 64), (128, 128), (256, 64), (256, 128)]
SEEDS = [42]
CUDA_DEVICES = ['cuda:0']
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
BLOCK_SIZE = 16
def ref_impl(silu_and_mul: SiluAndMul, x: torch.Tensor,
global_scale: torch.Tensor,
ref_output_scale: torch.Tensor) -> torch.Tensor:
silu_and_mul_out = silu_and_mul.forward_native(x)
assert not current_platform.is_rocm()
assert silu_and_mul_out.ndim >= 1, (
f'input.ndim needs to be >= 1, but got {silu_and_mul_out.ndim}.')
other_dims = 1 if silu_and_mul_out.ndim == 1 else -1
silu_and_mul_out = silu_and_mul_out.reshape(other_dims,
silu_and_mul_out.shape[-1])
m, n = silu_and_mul_out.shape
device = silu_and_mul_out.device
# Two fp4 values will be packed into an uint8.
out = torch.empty((m, n // 2), device=device, dtype=torch.uint8)
output_scale = ref_output_scale
torch.ops._C.scaled_fp4_quant(out, silu_and_mul_out, output_scale,
global_scale)
return out, output_scale
def ops_impl(x: torch.Tensor, global_scale: torch.Tensor,
ref_output_scale: torch.Tensor) -> torch.Tensor:
out_shape = (x.shape[0], x.shape[1] // 4)
output_scale = ref_output_scale
out = torch.empty(out_shape, dtype=torch.uint8, device=x.device)
torch.ops._C.silu_and_mul_nvfp4_quant(out, output_scale, x, global_scale)
return out, output_scale
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("shape", SHAPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_quantize_to_fp4(
dtype: torch.dtype,
shape: tuple[int, int],
seed: int,
device: str,
) -> None:
current_platform.seed_everything(seed)
torch.set_default_device(device)
m, n = shape
x = torch.randn((m, n), dtype=dtype)
tensor_amax = torch.abs(x).max().to(torch.float32)
global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax
block_size = 16
assert n % block_size == 0, (
f'last dim has to be multiple of 16, but got {n}.')
assert x.dtype in (torch.float16, torch.bfloat16), (
f'input.dtype needs to be fp16 or bf16 but got {x.dtype}.')
round_up = lambda x, y: (x + y - 1) // y * y
rounded_m = round_up(x.shape[0], 128)
scale_n = x.shape[1] // (2 * block_size)
rounded_n = round_up(scale_n, 4)
output_scale = torch.empty((rounded_m, rounded_n // 4),
device=x.device,
dtype=torch.int32)
layer = SiluAndMul()
ref_out, ref_out_scale = ref_impl(layer, x, global_scale, output_scale)
fusion_out, fusion_out_scale = ops_impl(x, global_scale, output_scale)
assert ref_out.dtype == torch.uint8
assert fusion_out.dtype == torch.uint8
assert ref_out.shape == fusion_out.shape
assert ref_out_scale.dtype == torch.int32
assert fusion_out_scale.dtype == torch.int32
assert ref_out_scale.shape == fusion_out_scale.shape
# Allow up to 2% of mismatched values since BF16 has accuracy issues.
mis_threshold = 0.02
atol = 0.4
rtol = 0.4
ref_logits = ref_out[-1]
fusion_logits = fusion_out[-1]
mis_count = torch.sum(
torch.abs(fusion_logits - ref_logits) > (atol +
rtol * torch.abs(ref_logits)))
mis_ratio = mis_count / fusion_logits.numel()
assert mis_ratio < mis_threshold, \
f"Mismatch ratio {mis_ratio} exceeds threshold {mis_threshold}"
torch.testing.assert_close(ref_out_scale, fusion_out_scale)
opcheck(torch.ops._C.silu_and_mul_nvfp4_quant,
(fusion_out, fusion_out_scale, x, global_scale))