From 08f8a5627e6306312202d5c0fda8b7255d2f27fc Mon Sep 17 00:00:00 2001 From: rasmith Date: Fri, 12 Dec 2025 17:41:56 -0600 Subject: [PATCH] [CI/Build][Kernel][BugFix][AMD] Fix per_token_group_quant_fp8 to use correct fp8 min/max values and update atol/rtol in test_quantfp8_group_functionality (#30292) Signed-off-by: Randall Smith Co-authored-by: Randall Smith --- tests/kernels/quantization/test_fp8_quant_group.py | 4 ++-- vllm/model_executor/layers/quantization/utils/fp8_utils.py | 7 +++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/kernels/quantization/test_fp8_quant_group.py b/tests/kernels/quantization/test_fp8_quant_group.py index 6628ac650fd5f..f5e1cde94b6e9 100644 --- a/tests/kernels/quantization/test_fp8_quant_group.py +++ b/tests/kernels/quantization/test_fp8_quant_group.py @@ -62,7 +62,7 @@ def test_quantfp8_group_functionality( assert scales_col.stride(1) == batch_size # Test column-major scales consistency - assert torch.allclose(scales_col, scales_native, rtol=1e-9, atol=1e-8) + torch.testing.assert_close(scales_col, scales_native, rtol=1e-9, atol=1e-8) # 3. Test CUDA implementation (only for divisible dimensions) if is_divisible: @@ -71,7 +71,7 @@ def test_quantfp8_group_functionality( assert scales_cuda.shape == (batch_size, expected_num_groups) # Verify CUDA/native consistency - assert torch.allclose(scales_cuda, scales_native, rtol=1e-9, atol=1e-8) + torch.testing.assert_close(scales_cuda, scales_native, rtol=2e-7, atol=2e-8) # Quantized values should mostly match diff_count = (x_quant_cuda != x_quant_native).sum().item() diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index e12fe61bf3d97..9eeb6e266c34e 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -762,9 +762,12 @@ def per_token_group_quant_fp8( ) assert x.stride(-1) == 1, "`x` groups must be contiguous" + # Using the default value (240.0) from pytorch will cause accuracy + # issue on dynamic quantization models. Here use 224.0 for fnuz on ROCm + # platforms that use the torch.float8_e4mefnuz dtype. finfo = torch.finfo(dtype) - fp8_min = finfo.min - fp8_max = finfo.max + fp8_min = -224.0 if current_platform.is_fp8_fnuz() else finfo.min + fp8_max = 224.0 if current_platform.is_fp8_fnuz() else finfo.max assert out_q is None or out_q.shape == x.shape x_q = out_q