[Bugfix] Fix test fused quant layernorm tests (#27865)

Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Co-authored-by: yewentao256 <zhyanwentao@126.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
ElizaWszola 2025-11-08 23:31:33 +01:00 committed by GitHub
parent 32787d0644
commit 171133f929
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 28 additions and 10 deletions

View File

@ -1,5 +1,6 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <cmath>
@ -275,6 +276,7 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
int const num_tokens = input.numel() / hidden_size;
dim3 const grid(num_tokens);
dim3 const block(std::min(hidden_size, 256));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
@ -306,6 +308,7 @@ void dynamic_scaled_int8_quant(
int const num_tokens = input.numel() / hidden_size;
dim3 const grid(num_tokens);
dim3 const block(std::min(hidden_size, 256));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] {

View File

@ -11,7 +11,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm
DTYPES = [torch.bfloat16, torch.float]
QUANT_DTYPES = [torch.int8, torch.float8_e4m3fn]
VEC_HIDDEN_SIZES = range(1024, 1030)
VEC_HIDDEN_SIZES = [1024, 1025, 1027, 1029]
# Avoid combinatorial explosion with full Cartesian product
NUM_TOKENS_HIDDEN_SIZES = [
*[(1, i) for i in [1, 64, *VEC_HIDDEN_SIZES, 5120, 5137]],
@ -65,7 +65,7 @@ def ref_dynamic_per_token_quant(
)
else:
assert quant_dtype == torch.int8
torch_out, scales = ops.scaled_int8_quant(torch_out)
torch_out, scales, _ = ops.scaled_int8_quant(torch_out)
return torch_out, scales, residual
@ -109,7 +109,7 @@ def ops_impl(
@pytest.mark.parametrize("num_tokens, hidden_size", NUM_TOKENS_HIDDEN_SIZES)
@pytest.mark.parametrize("add_residual", ADD_RESIDUAL)
@pytest.mark.parametrize("scale_ub", SCALE_UBS)
@pytest.mark.parametrize("has_scale_ub", SCALE_UBS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("quant_dtype", QUANT_DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@ -119,7 +119,7 @@ def test_rms_norm(
num_tokens: int,
hidden_size: int,
add_residual: bool,
scale_ub: bool,
has_scale_ub: bool,
dtype: torch.dtype,
quant_dtype: torch.dtype,
seed: int,
@ -130,7 +130,7 @@ def test_rms_norm(
torch.cuda.manual_seed(seed)
torch.set_default_device(device)
if scale_ub is not None and quant_dtype != torch.float8_e4m3fn:
if has_scale_ub and quant_dtype != torch.float8_e4m3fn:
# skip
return
@ -143,9 +143,11 @@ def test_rms_norm(
scale = 1 / (hidden_size)
x = torch.randn(num_tokens, hidden_size, dtype=dtype) * scale
residual = torch.randn_like(x) * scale if add_residual else None
if scale_ub is not None:
if has_scale_ub:
rms_x, _ = ref_rms_norm(layer, x, residual)
scale_ub = torch.mean(rms_x).to(dtype=torch.float32, device="cuda")
else:
scale_ub = None
ref_out, ref_scales, ref_residual = ref_impl(
layer, x, quant_dtype, residual, scale_ub
@ -156,14 +158,27 @@ def test_rms_norm(
assert ref_out.dtype == quant_dtype
assert ops_out.dtype == quant_dtype
assert torch.allclose(ref_scales, ops_scales)
if quant_dtype == torch.int8:
assert torch.allclose(ref_scales, ops_scales, atol=1e-6)
# big atol to account for round-off errors.
assert torch.allclose(ref_out, ops_out, atol=1)
else:
assert torch.allclose(
ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32)
)
assert torch.allclose(ref_scales, ops_scales)
a = ref_out.to(dtype=torch.float32)
b = ops_out.to(dtype=torch.float32)
ok = torch.allclose(a, b)
if not ok:
# fallback: compare dequantized values with relaxed tolerance
a_deq = a * ref_scales.view(-1, 1)
b_deq = b * ops_scales.view(-1, 1)
# NOTE: It is possible that some future test cases trigger this
# max diff due to precision issues. If such an error is
# encountered, it's recommended to inspect the differences between
# all corresponding elements from each tensor (e.g. by looping over
# them) and checking how many the max diff error shows up on (just
# a few bad elements should still be considered acceptable).
ok = torch.allclose(a_deq, b_deq, rtol=5e-2, atol=5e-2)
assert ok
if add_residual:
assert torch.allclose(ref_residual, ops_residual)