mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 14:07:13 +08:00
[Bugfix] Fix accuracy issue for silu_mul + nvfp4 quant fusion kernel (#24833)
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
parent
2a4d6412e6
commit
e6585ddb45
@ -796,7 +796,7 @@ steps:
|
||||
# Quantization
|
||||
- pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8'
|
||||
- pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py
|
||||
- pytest -v -s tests/kernels/quantization/test_silu_nvfp4_quant_fusion.py
|
||||
- pytest -v -s tests/kernels/quantization/test_silu_mul_nvfp4_quant.py
|
||||
- pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py
|
||||
- pytest -v -s tests/kernels/quantization/test_flashinfer_scaled_mm.py
|
||||
- pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py
|
||||
|
||||
@ -30,109 +30,41 @@
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// silu in float32
|
||||
__device__ __forceinline__ float silu(float x) {
|
||||
return __fdividef(x, (1.f + __expf(-x)));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float2 silu2(float2 x) {
|
||||
return make_float2(silu(x.x), silu(x.y));
|
||||
}
|
||||
|
||||
template <class Type>
|
||||
__inline__ __device__ PackedVec<Type> compute_silu(PackedVec<Type>& vec,
|
||||
PackedVec<Type>& vec2) {
|
||||
__inline__ __device__ PackedVec<Type> compute_silu_mul(PackedVec<Type>& vec,
|
||||
PackedVec<Type>& vec2) {
|
||||
PackedVec<Type> result;
|
||||
using packed_type = typename TypeConverter<Type>::Type;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) {
|
||||
// silu_mul in float32
|
||||
if constexpr (std::is_same_v<Type, half>) {
|
||||
half2 val(0.5f, 0.5f);
|
||||
half2 t0 = __hmul2(vec.elts[i], val);
|
||||
half2 t1 = __hfma2(h2tanh(t0), val, val);
|
||||
half2 t2 = __hmul2(vec.elts[i], t1);
|
||||
result.elts[i] = __hmul2(t2, vec2.elts[i]);
|
||||
float2 silu_vec = silu2(__half22float2(vec.elts[i]));
|
||||
result.elts[i] =
|
||||
__float22half2_rn(__fmul2_rn(silu_vec, __half22float2(vec2.elts[i])));
|
||||
} else {
|
||||
__nv_bfloat162 val(0.5f, 0.5f);
|
||||
__nv_bfloat162 t0 = __hmul2(vec.elts[i], val);
|
||||
__nv_bfloat162 t1 = __hfma2(h2tanh(t0), val, val);
|
||||
__nv_bfloat162 t2 = __hmul2(vec.elts[i], t1);
|
||||
result.elts[i] = __hmul2(t2, vec2.elts[i]);
|
||||
float2 silu_vec = silu2(__bfloat1622float2(vec.elts[i]));
|
||||
result.elts[i] = __float22bfloat162_rn(
|
||||
__fmul2_rn(silu_vec, __bfloat1622float2(vec2.elts[i])));
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Quantizes the provided PackedVec into the uint32_t output
|
||||
template <class Type, bool UE8M0_SF = false>
|
||||
__device__ uint32_t silu_and_cvt_warp_fp16_to_fp4(PackedVec<Type>& vec,
|
||||
PackedVec<Type>& vec2,
|
||||
float SFScaleVal,
|
||||
uint8_t* SFout) {
|
||||
PackedVec<Type> out_silu = compute_silu(vec, vec2);
|
||||
// Get absolute maximum values among the local 8 values.
|
||||
auto localMax = __habs2(out_silu.elts[0]);
|
||||
|
||||
// Local maximum value.
|
||||
#pragma unroll
|
||||
for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
|
||||
localMax = __hmax2(localMax, __habs2(out_silu.elts[i]));
|
||||
}
|
||||
|
||||
// Get the absolute maximum among all 16 values (two threads).
|
||||
localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax);
|
||||
// Get the final absolute maximum values.
|
||||
float vecMax = float(__hmax(localMax.x, localMax.y));
|
||||
|
||||
// Get the SF (max value of the vector / max value of e2m1).
|
||||
// maximum value of e2m1 = 6.0.
|
||||
// TODO: use half as compute data type.
|
||||
float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f));
|
||||
// 8 bits representation of the SF.
|
||||
uint8_t fp8SFVal;
|
||||
// Write the SF to global memory (STG.8).
|
||||
if constexpr (UE8M0_SF) {
|
||||
// Extract the 8 exponent bits from float32.
|
||||
// float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits.
|
||||
uint32_t tmp = reinterpret_cast<uint32_t&>(SFValue) >> 23;
|
||||
fp8SFVal = tmp & 0xff;
|
||||
// Convert back to fp32.
|
||||
reinterpret_cast<uint32_t&>(SFValue) = tmp << 23;
|
||||
} else {
|
||||
// Here SFValue is always positive, so E4M3 is the same as UE4M3.
|
||||
__nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue);
|
||||
reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp;
|
||||
// Convert back to fp32.
|
||||
SFValue = float(tmp);
|
||||
}
|
||||
// Get the output scale.
|
||||
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) *
|
||||
// reciprocal(SFScaleVal))
|
||||
float outputScale =
|
||||
SFValue != 0 ? reciprocal_approximate_ftz(
|
||||
SFValue * reciprocal_approximate_ftz(SFScaleVal))
|
||||
: 0.0f;
|
||||
|
||||
if (SFout) {
|
||||
// Write the SF to global memory (STG.8).
|
||||
*SFout = fp8SFVal;
|
||||
}
|
||||
|
||||
// Convert the input to float.
|
||||
float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
|
||||
if constexpr (std::is_same_v<Type, half>) {
|
||||
fp2Vals[i] = __half22float2(out_silu.elts[i]);
|
||||
} else {
|
||||
fp2Vals[i] = __bfloat1622float2(out_silu.elts[i]);
|
||||
}
|
||||
fp2Vals[i].x *= outputScale;
|
||||
fp2Vals[i].y *= outputScale;
|
||||
}
|
||||
|
||||
// Convert to e2m1 values.
|
||||
uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals);
|
||||
|
||||
// Write the e2m1 values to global memory.
|
||||
return e2m1Vec;
|
||||
}
|
||||
|
||||
// Use UE4M3 by default.
|
||||
template <class Type, bool UE8M0_SF = false>
|
||||
__global__ void __launch_bounds__(1024, 4)
|
||||
silu_and_cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
|
||||
silu_mul_cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
|
||||
float const* SFScale, uint32_t* out,
|
||||
uint32_t* SFout) {
|
||||
using PackedVec = PackedVec<Type>;
|
||||
@ -160,16 +92,18 @@ __global__ void __launch_bounds__(1024, 4)
|
||||
// Get the output tensor offset.
|
||||
// Same as inOffset because 8 elements are packed into one uint32_t.
|
||||
int64_t outOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx;
|
||||
;
|
||||
auto& out_pos = out[outOffset];
|
||||
|
||||
// Compute silu and mul
|
||||
PackedVec out_silu_mul = compute_silu_mul(in_vec, in_vec2);
|
||||
|
||||
auto sf_out =
|
||||
cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
|
||||
CVT_FP4_NUM_THREADS_PER_SF>(
|
||||
rowIdx, colIdx, numCols, SFout);
|
||||
|
||||
out_pos = silu_and_cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(
|
||||
in_vec, in_vec2, SFScaleVal, sf_out);
|
||||
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(out_silu_mul, SFScaleVal,
|
||||
sf_out);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -204,7 +138,7 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d]
|
||||
input.scalar_type(), "silu_and_mul_nvfp4_quant_kernel", [&] {
|
||||
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
|
||||
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
|
||||
vllm::silu_and_cvt_fp16_to_fp4<cuda_type><<<grid, block, 0, stream>>>(
|
||||
vllm::silu_mul_cvt_fp16_to_fp4<cuda_type><<<grid, block, 0, stream>>>(
|
||||
m, n, input_ptr, input_sf_ptr,
|
||||
reinterpret_cast<uint32_t*>(output_ptr),
|
||||
reinterpret_cast<uint32_t*>(sf_out));
|
||||
|
||||
@ -98,8 +98,9 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
|
||||
return [FUSED_OPS[kNvfp4Quant]]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [64])
|
||||
@pytest.mark.parametrize("hidden_size", [128])
|
||||
@pytest.mark.parametrize("num_tokens", [32, 64])
|
||||
@pytest.mark.parametrize("hidden_size", [128, 256])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize(
|
||||
"model_class",
|
||||
cast(list[type], [TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel]
|
||||
@ -110,13 +111,13 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
|
||||
[True, False] if cutlass_fp8_supported() else [True])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
|
||||
reason="Only test on CUDA and ROCm")
|
||||
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, model_class,
|
||||
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, dtype, model_class,
|
||||
cuda_force_torch):
|
||||
if model_class == TestSiluMulNvfp4QuantModel and cuda_force_torch:
|
||||
pytest.skip("Duplicate tests for NVFP4")
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(torch.float16)
|
||||
torch.set_default_dtype(dtype)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size * 2)
|
||||
|
||||
@ -145,8 +146,8 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, model_class,
|
||||
elif model_class == TestSiluMulNvfp4QuantModel:
|
||||
atol, rtol = 1e-1, 1e-1
|
||||
|
||||
torch.testing.assert_close(result[0].to(dtype=torch.float16),
|
||||
result2[0].to(dtype=torch.float16),
|
||||
torch.testing.assert_close(result[0].to(dtype=dtype),
|
||||
result2[0].to(dtype=dtype),
|
||||
atol=atol,
|
||||
rtol=rtol)
|
||||
|
||||
|
||||
75
tests/kernels/quantization/test_silu_mul_nvfp4_quant.py
Normal file
75
tests/kernels/quantization/test_silu_mul_nvfp4_quant.py
Normal file
@ -0,0 +1,75 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
|
||||
FLOAT8_E4M3_MAX,
|
||||
dequantize_nvfp4_to_dtype)
|
||||
from vllm._custom_ops import scaled_fp4_quant
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not current_platform.has_device_capability(100):
|
||||
pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.",
|
||||
allow_module_level=True)
|
||||
|
||||
FP4_DTYPE = torch.uint8
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
SHAPES = [(128, 256), (128, 128), (256, 256), (256, 128)]
|
||||
BLOCK_SIZE = 16
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("shape", SHAPES)
|
||||
@torch.inference_mode()
|
||||
def test_silu_mul_nvfp4_quant(
|
||||
dtype: torch.dtype,
|
||||
shape: tuple[int, int],
|
||||
) -> None:
|
||||
current_platform.seed_everything(42)
|
||||
device = 'cuda:0'
|
||||
torch.set_default_device(device)
|
||||
|
||||
x = torch.randn(shape, dtype=dtype)
|
||||
|
||||
# ref op
|
||||
ref_output = SiluAndMul().forward_native(x)
|
||||
ref_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
|
||||
torch.abs(ref_output).max().to(torch.float32))
|
||||
ref_output_quant, ref_block_scale = scaled_fp4_quant(
|
||||
ref_output, ref_global_scale)
|
||||
|
||||
# fused op
|
||||
fused_output_quant = torch.empty_like(ref_output_quant)
|
||||
fused_block_scale = torch.empty_like(ref_block_scale)
|
||||
torch.ops._C.silu_and_mul_nvfp4_quant(fused_output_quant,
|
||||
fused_block_scale, x,
|
||||
ref_global_scale)
|
||||
|
||||
# check dtype
|
||||
assert ref_output_quant.dtype == FP4_DTYPE
|
||||
assert fused_output_quant.dtype == FP4_DTYPE
|
||||
assert ref_output_quant.shape == fused_output_quant.shape
|
||||
|
||||
assert ref_block_scale.dtype == FP8_DTYPE
|
||||
assert fused_block_scale.dtype == FP8_DTYPE
|
||||
assert ref_block_scale.shape == fused_block_scale.shape
|
||||
|
||||
# check dequantized output
|
||||
ref_output_dequant = dequantize_nvfp4_to_dtype(ref_output_quant,
|
||||
ref_block_scale,
|
||||
ref_global_scale, dtype,
|
||||
device)
|
||||
fused_output_dequant = dequantize_nvfp4_to_dtype(fused_output_quant,
|
||||
fused_block_scale,
|
||||
ref_global_scale, dtype,
|
||||
device)
|
||||
|
||||
atol, rtol = 3e-1, 3e-1
|
||||
torch.testing.assert_close(ref_output_dequant,
|
||||
fused_output_dequant,
|
||||
atol=atol,
|
||||
rtol=rtol)
|
||||
@ -1,126 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
if not current_platform.has_device_capability(100):
|
||||
pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.",
|
||||
allow_module_level=True)
|
||||
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
SHAPES = [(128, 64), (128, 128), (256, 64), (256, 128)]
|
||||
SEEDS = [42]
|
||||
CUDA_DEVICES = ['cuda:0']
|
||||
|
||||
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
||||
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
|
||||
BLOCK_SIZE = 16
|
||||
|
||||
|
||||
def ref_impl(silu_and_mul: SiluAndMul, x: torch.Tensor,
|
||||
global_scale: torch.Tensor,
|
||||
ref_output_scale: torch.Tensor) -> torch.Tensor:
|
||||
silu_and_mul_out = silu_and_mul.forward_native(x)
|
||||
assert not current_platform.is_rocm()
|
||||
assert silu_and_mul_out.ndim >= 1, (
|
||||
f'input.ndim needs to be >= 1, but got {silu_and_mul_out.ndim}.')
|
||||
other_dims = 1 if silu_and_mul_out.ndim == 1 else -1
|
||||
silu_and_mul_out = silu_and_mul_out.reshape(other_dims,
|
||||
silu_and_mul_out.shape[-1])
|
||||
m, n = silu_and_mul_out.shape
|
||||
device = silu_and_mul_out.device
|
||||
|
||||
# Two fp4 values will be packed into an uint8.
|
||||
out = torch.empty((m, n // 2), device=device, dtype=torch.uint8)
|
||||
|
||||
output_scale = ref_output_scale
|
||||
|
||||
torch.ops._C.scaled_fp4_quant(out, silu_and_mul_out, output_scale,
|
||||
global_scale)
|
||||
|
||||
return out, output_scale
|
||||
|
||||
|
||||
def ops_impl(x: torch.Tensor, global_scale: torch.Tensor,
|
||||
ref_output_scale: torch.Tensor) -> torch.Tensor:
|
||||
out_shape = (x.shape[0], x.shape[1] // 4)
|
||||
output_scale = ref_output_scale
|
||||
out = torch.empty(out_shape, dtype=torch.uint8, device=x.device)
|
||||
torch.ops._C.silu_and_mul_nvfp4_quant(out, output_scale, x, global_scale)
|
||||
return out, output_scale
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("shape", SHAPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_quantize_to_fp4(
|
||||
dtype: torch.dtype,
|
||||
shape: tuple[int, int],
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
m, n = shape
|
||||
|
||||
x = torch.randn((m, n), dtype=dtype)
|
||||
tensor_amax = torch.abs(x).max().to(torch.float32)
|
||||
global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax
|
||||
|
||||
block_size = 16
|
||||
|
||||
assert n % block_size == 0, (
|
||||
f'last dim has to be multiple of 16, but got {n}.')
|
||||
assert x.dtype in (torch.float16, torch.bfloat16), (
|
||||
f'input.dtype needs to be fp16 or bf16 but got {x.dtype}.')
|
||||
|
||||
round_up = lambda x, y: (x + y - 1) // y * y
|
||||
rounded_m = round_up(x.shape[0], 128)
|
||||
scale_n = x.shape[1] // (2 * block_size)
|
||||
rounded_n = round_up(scale_n, 4)
|
||||
output_scale = torch.empty((rounded_m, rounded_n // 4),
|
||||
device=x.device,
|
||||
dtype=torch.int32)
|
||||
|
||||
layer = SiluAndMul()
|
||||
|
||||
ref_out, ref_out_scale = ref_impl(layer, x, global_scale, output_scale)
|
||||
|
||||
fusion_out, fusion_out_scale = ops_impl(x, global_scale, output_scale)
|
||||
|
||||
assert ref_out.dtype == torch.uint8
|
||||
assert fusion_out.dtype == torch.uint8
|
||||
assert ref_out.shape == fusion_out.shape
|
||||
|
||||
assert ref_out_scale.dtype == torch.int32
|
||||
assert fusion_out_scale.dtype == torch.int32
|
||||
assert ref_out_scale.shape == fusion_out_scale.shape
|
||||
|
||||
# Allow up to 2% of mismatched values since BF16 has accuracy issues.
|
||||
mis_threshold = 0.02
|
||||
atol = 0.4
|
||||
rtol = 0.4
|
||||
ref_logits = ref_out[-1]
|
||||
fusion_logits = fusion_out[-1]
|
||||
|
||||
mis_count = torch.sum(
|
||||
torch.abs(fusion_logits - ref_logits) > (atol +
|
||||
rtol * torch.abs(ref_logits)))
|
||||
mis_ratio = mis_count / fusion_logits.numel()
|
||||
|
||||
assert mis_ratio < mis_threshold, \
|
||||
f"Mismatch ratio {mis_ratio} exceeds threshold {mis_threshold}"
|
||||
|
||||
torch.testing.assert_close(ref_out_scale, fusion_out_scale)
|
||||
|
||||
opcheck(torch.ops._C.silu_and_mul_nvfp4_quant,
|
||||
(fusion_out, fusion_out_scale, x, global_scale))
|
||||
Loading…
x
Reference in New Issue
Block a user