From 961a5ab423f548462a01112fb2bf58f34004445f Mon Sep 17 00:00:00 2001 From: c0de128 Date: Sun, 21 Dec 2025 20:13:16 -0600 Subject: [PATCH 1/3] [Bugfix][Hardware][AMD] Consolidate FP8 min/max values into helper function Add get_fp8_min_max() helper in quant_utils.py to centralize the FP8 min/max value logic for ROCm fnuz dtype handling. On ROCm with torch.float8_e4m3fnuz, using PyTorch's default finfo.max (240.0) causes accuracy issues with dynamic quantization. The correct value is 224.0 for fnuz dtype. This change: - Adds get_fp8_min_max(dtype) helper returning (fp8_min, fp8_max) tuple - Updates input_quant_fp8.py to use the helper - Updates fp8_utils.py per_token_group_quant_fp8() to use the helper - Updates deep_gemm.py per_block_cast_to_fp8() to use the helper - Updates tests/kernels/quant_utils.py to use the helper Fixes #30360 Signed-off-by: c0de128 --- tests/kernels/quant_utils.py | 36 ++++++------------- .../layers/quantization/input_quant_fp8.py | 11 +++--- .../layers/quantization/utils/fp8_utils.py | 12 +++---- .../layers/quantization/utils/quant_utils.py | 22 ++++++++++++ vllm/utils/deep_gemm.py | 6 +++- 5 files changed, 48 insertions(+), 39 deletions(-) diff --git a/tests/kernels/quant_utils.py b/tests/kernels/quant_utils.py index 7927bd0d200d8..479338b990a23 100644 --- a/tests/kernels/quant_utils.py +++ b/tests/kernels/quant_utils.py @@ -4,13 +4,13 @@ import torch -from vllm.model_executor.layers.quantization.utils.quant_utils import group_broadcast +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + get_fp8_min_max, + group_broadcast, +) from vllm.platforms import current_platform from vllm.utils.math_utils import round_up -# Using the default value (240.0) from pytorch will cause accuracy -# issue on dynamic quantization models. Here use 224.0 for rocm. -ROCM_FP8FNUZ_MAX = 224.0 FP8_DTYPE = current_platform.fp8_dtype() @@ -25,16 +25,12 @@ def ref_dynamic_per_token_quant( if scale_ub is not None: assert quant_dtype == FP8_DTYPE - qtype_traits = ( - torch.iinfo(quant_dtype) - if quant_dtype == torch.int8 - else torch.finfo(quant_dtype) - ) - use_fp8fnuz = ( - current_platform.is_fp8_fnuz() and quant_dtype == current_platform.fp8_dtype() - ) - qtype_traits_max = ROCM_FP8FNUZ_MAX if use_fp8fnuz else qtype_traits.max - qtype_traits_min = -ROCM_FP8FNUZ_MAX if use_fp8fnuz else qtype_traits.min + if quant_dtype == torch.int8: + qtype_traits = torch.iinfo(quant_dtype) + qtype_traits_min = qtype_traits.min + qtype_traits_max = qtype_traits.max + else: + qtype_traits_min, qtype_traits_max = get_fp8_min_max(quant_dtype) qtype_max = as_float32_tensor(qtype_traits_max) s_1 = as_float32_tensor(1.0) s_512 = as_float32_tensor(512.0) @@ -72,17 +68,7 @@ def ref_dynamic_per_token_quant( def ref_dynamic_per_tensor_fp8_quant( x: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - fp8_traits = torch.finfo(FP8_DTYPE) - fp8_traits_max = ( - ROCM_FP8FNUZ_MAX - if current_platform.is_rocm() and current_platform.is_fp8_fnuz() - else fp8_traits.max - ) - fp8_traits_min = ( - -ROCM_FP8FNUZ_MAX - if current_platform.is_rocm() and current_platform.is_fp8_fnuz() - else fp8_traits.min - ) + fp8_traits_min, fp8_traits_max = get_fp8_min_max(FP8_DTYPE) fp8_max = as_float32_tensor(fp8_traits_max) one = as_float32_tensor(1.0) diff --git a/vllm/model_executor/layers/quantization/input_quant_fp8.py b/vllm/model_executor/layers/quantization/input_quant_fp8.py index 7994c838ad548..f57deceaf6ca5 100644 --- a/vllm/model_executor/layers/quantization/input_quant_fp8.py +++ b/vllm/model_executor/layers/quantization/input_quant_fp8.py @@ -7,15 +7,14 @@ import torch.nn.functional as F from vllm import _custom_ops as ops from vllm._aiter_ops import rocm_aiter_ops from vllm.model_executor.custom_op import CustomOp -from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, + get_fp8_min_max, +) from vllm.platforms import current_platform -# Using the default value (240.0) from pytorch will cause accuracy -# issue on dynamic quantization models. Here use 224.0 for fnuz on ROCm. _FP8_DTYPE = current_platform.fp8_dtype() -_FP8_FINFO = torch.finfo(_FP8_DTYPE) -_FP8_MAX = 224.0 if current_platform.is_fp8_fnuz() else _FP8_FINFO.max -_FP8_MIN = -224.0 if current_platform.is_fp8_fnuz() else _FP8_FINFO.min +_FP8_MIN, _FP8_MAX = get_fp8_min_max(_FP8_DTYPE) _FP8_MIN_SCALING_FACTOR = 1.0 / (_FP8_MAX * 512.0) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index de6a1e8c1aa7d..55df073be444b 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -15,7 +15,10 @@ from vllm import _custom_ops as ops from vllm._aiter_ops import rocm_aiter_ops from vllm.logger import init_logger from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 -from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, + get_fp8_min_max, +) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( CUTLASS_BLOCK_FP8_SUPPORTED, ) @@ -748,12 +751,7 @@ def per_token_group_quant_fp8( ) assert x.stride(-1) == 1, "`x` groups must be contiguous" - # Using the default value (240.0) from pytorch will cause accuracy - # issue on dynamic quantization models. Here use 224.0 for fnuz on ROCm - # platforms that use the torch.float8_e4mefnuz dtype. - finfo = torch.finfo(dtype) - fp8_min = -224.0 if current_platform.is_fp8_fnuz() else finfo.min - fp8_max = 224.0 if current_platform.is_fp8_fnuz() else finfo.max + fp8_min, fp8_max = get_fp8_min_max(dtype) assert out_q is None or out_q.shape == x.shape x_q = out_q diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index d01263f82007d..9829972cc3657 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -19,6 +19,28 @@ FP8_DTYPE = current_platform.fp8_dtype() FP4_DTYPE = torch.uint8 +def get_fp8_min_max(dtype: torch.dtype | None = None) -> tuple[float, float]: + """ + Get the min and max values for FP8 quantization. + + On ROCm with torch.float8_e4m3fnuz (fnuz), the default PyTorch finfo.max + (240.0) causes accuracy issues with dynamic quantization models. + Use 224.0 instead for fnuz dtype. + + Args: + dtype: FP8 dtype (defaults to platform's FP8 dtype if None) + + Returns: + Tuple of (fp8_min, fp8_max) values + """ + if dtype is None: + dtype = FP8_DTYPE + finfo = torch.finfo(dtype) + if current_platform.is_fp8_fnuz(): + return -224.0, 224.0 + return finfo.min, finfo.max + + # Use proxy as NamedTuple direct subclasses cannot have static members class _GroupShape(NamedTuple): row: int diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 56c9ca361eaef..e664442282036 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -16,6 +16,9 @@ import torch import vllm.envs as envs from vllm.logger import logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + get_fp8_min_max, +) from vllm.platforms import current_platform from vllm.utils.import_utils import has_deep_gemm from vllm.utils.math_utils import cdiv @@ -355,7 +358,8 @@ def per_block_cast_to_fp8( x_padded[:m, :n] = x x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n) x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) - sf = x_amax / 224.0 if current_platform.is_fp8_fnuz() else x_amax / 448.0 + _, fp8_max = get_fp8_min_max(fp8_dtype) + sf = x_amax / fp8_max sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf x_scaled = (x_view * (1.0 / sf)).to(fp8_dtype) return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view( From 523dd610cbcf7b5494535f56f2422ede1f797a3c Mon Sep 17 00:00:00 2001 From: c0de128 Date: Sun, 21 Dec 2025 21:06:01 -0600 Subject: [PATCH 2/3] Fix dtype check in get_fp8_min_max helper MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Address review feedback: Only apply the 224.0 override when both: 1. Platform supports fnuz (is_fp8_fnuz()) 2. The dtype is actually torch.float8_e4m3fnuz This prevents incorrect min/max values when a non-fnuz dtype is explicitly passed on a platform that supports fnuz. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Sonnet 4.5 Signed-off-by: c0de128 --- vllm/model_executor/layers/quantization/utils/quant_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 9829972cc3657..d2eb768ede9a3 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -36,7 +36,8 @@ def get_fp8_min_max(dtype: torch.dtype | None = None) -> tuple[float, float]: if dtype is None: dtype = FP8_DTYPE finfo = torch.finfo(dtype) - if current_platform.is_fp8_fnuz(): + # Only apply the 224.0 override for the actual fnuz dtype on fnuz platform + if current_platform.is_fp8_fnuz() and dtype == torch.float8_e4m3fnuz: return -224.0, 224.0 return finfo.min, finfo.max From 719ccfd773d15f1cd494c4b04fa0b727d82954b4 Mon Sep 17 00:00:00 2001 From: c0de128 Date: Tue, 23 Dec 2025 09:37:41 -0600 Subject: [PATCH 3/3] Add unit tests for get_fp8_min_max() helper function MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add test_fp8_min_max_helper.py with mocked unit tests that verify: - Standard FP8 dtype uses PyTorch's finfo values - fnuz dtype on fnuz platform (MI300) returns 224.0, not 240.0 - Standard dtype on fnuz platform uses finfo values - fnuz dtype on non-fnuz platform uses finfo values These tests use mocking to verify the logic without requiring actual ROCm hardware. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 Signed-off-by: c0de128 --- .../quantization/test_fp8_min_max_helper.py | 93 +++++++++++++++++++ 1 file changed, 93 insertions(+) create mode 100644 tests/kernels/quantization/test_fp8_min_max_helper.py diff --git a/tests/kernels/quantization/test_fp8_min_max_helper.py b/tests/kernels/quantization/test_fp8_min_max_helper.py new file mode 100644 index 0000000000000..b19637b3adef5 --- /dev/null +++ b/tests/kernels/quantization/test_fp8_min_max_helper.py @@ -0,0 +1,93 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Unit tests for the get_fp8_min_max() helper function. + +These tests verify the FP8 min/max value logic for both standard +and fnuz (ROCm MI300) dtype handling. +""" + +from unittest.mock import patch + +import pytest +import torch + + +class TestGetFp8MinMax: + """Test cases for get_fp8_min_max() function.""" + + def test_standard_fp8_dtype(self): + """Test that standard FP8 dtype uses PyTorch's finfo values.""" + from vllm.model_executor.layers.quantization.utils.quant_utils import ( + get_fp8_min_max, + ) + + # For standard float8_e4m3fn, should return finfo values + fp8_min, fp8_max = get_fp8_min_max(torch.float8_e4m3fn) + finfo = torch.finfo(torch.float8_e4m3fn) + + # Standard FP8 max is 448.0 for e4m3fn + assert fp8_max == finfo.max, f"Expected finfo.max={finfo.max}, got {fp8_max}" + assert fp8_min == finfo.min, f"Expected finfo.min={finfo.min}, got {fp8_min}" + + @patch("vllm.model_executor.layers.quantization.utils.quant_utils.current_platform") + def test_fnuz_fp8_dtype_on_fnuz_platform(self, mock_platform): + """Test that fnuz dtype on fnuz platform returns 224.0.""" + mock_platform.is_fp8_fnuz.return_value = True + mock_platform.fp8_dtype.return_value = torch.float8_e4m3fnuz + + # Re-import to use mocked platform + from importlib import reload + + import vllm.model_executor.layers.quantization.utils.quant_utils as qu + + reload(qu) + + fp8_min, fp8_max = qu.get_fp8_min_max(torch.float8_e4m3fnuz) + + # fnuz on ROCm MI300 should return 224.0, not 240.0 + assert fp8_max == 224.0, ( + f"Expected 224.0 for fnuz on fnuz platform, got {fp8_max}" + ) + assert fp8_min == -224.0, ( + f"Expected -224.0 for fnuz on fnuz platform, got {fp8_min}" + ) + + @patch("vllm.model_executor.layers.quantization.utils.quant_utils.current_platform") + def test_standard_dtype_on_fnuz_platform(self, mock_platform): + """Test that standard dtype on fnuz platform uses finfo values.""" + mock_platform.is_fp8_fnuz.return_value = True + + from vllm.model_executor.layers.quantization.utils.quant_utils import ( + get_fp8_min_max, + ) + + # Standard e4m3fn dtype should use finfo even on fnuz platform + fp8_min, fp8_max = get_fp8_min_max(torch.float8_e4m3fn) + finfo = torch.finfo(torch.float8_e4m3fn) + + assert fp8_max == finfo.max, ( + f"Standard dtype should use finfo.max={finfo.max}, got {fp8_max}" + ) + + @patch("vllm.model_executor.layers.quantization.utils.quant_utils.current_platform") + def test_fnuz_dtype_on_non_fnuz_platform(self, mock_platform): + """Test that fnuz dtype on non-fnuz platform uses finfo values.""" + mock_platform.is_fp8_fnuz.return_value = False + + from vllm.model_executor.layers.quantization.utils.quant_utils import ( + get_fp8_min_max, + ) + + # fnuz dtype on non-fnuz platform should use finfo + fp8_min, fp8_max = get_fp8_min_max(torch.float8_e4m3fnuz) + finfo = torch.finfo(torch.float8_e4m3fnuz) + + # Should be 240.0, not 224.0 (non-fnuz platform) + assert fp8_max == finfo.max, ( + f"Non-fnuz platform should use finfo.max={finfo.max}, got {fp8_max}" + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])