[Bugfix][Hardware][AMD] Consolidate FP8 min/max values into helper function

Add get_fp8_min_max() helper in quant_utils.py to centralize the
FP8 min/max value logic for ROCm fnuz dtype handling.

On ROCm with torch.float8_e4m3fnuz, using PyTorch's default finfo.max
(240.0) causes accuracy issues with dynamic quantization. The correct
value is 224.0 for fnuz dtype.

This change:
- Adds get_fp8_min_max(dtype) helper returning (fp8_min, fp8_max) tuple
- Updates input_quant_fp8.py to use the helper
- Updates fp8_utils.py per_token_group_quant_fp8() to use the helper
- Updates deep_gemm.py per_block_cast_to_fp8() to use the helper
- Updates tests/kernels/quant_utils.py to use the helper

Fixes #30360

Signed-off-by: c0de128 <kevin.mckay@outlook.com>
This commit is contained in:
c0de128 2025-12-21 20:13:16 -06:00
parent 09dc7c690c
commit 961a5ab423
5 changed files with 48 additions and 39 deletions

View File

@ -4,13 +4,13 @@
import torch
from vllm.model_executor.layers.quantization.utils.quant_utils import group_broadcast
from vllm.model_executor.layers.quantization.utils.quant_utils import (
get_fp8_min_max,
group_broadcast,
)
from vllm.platforms import current_platform
from vllm.utils.math_utils import round_up
# Using the default value (240.0) from pytorch will cause accuracy
# issue on dynamic quantization models. Here use 224.0 for rocm.
ROCM_FP8FNUZ_MAX = 224.0
FP8_DTYPE = current_platform.fp8_dtype()
@ -25,16 +25,12 @@ def ref_dynamic_per_token_quant(
if scale_ub is not None:
assert quant_dtype == FP8_DTYPE
qtype_traits = (
torch.iinfo(quant_dtype)
if quant_dtype == torch.int8
else torch.finfo(quant_dtype)
)
use_fp8fnuz = (
current_platform.is_fp8_fnuz() and quant_dtype == current_platform.fp8_dtype()
)
qtype_traits_max = ROCM_FP8FNUZ_MAX if use_fp8fnuz else qtype_traits.max
qtype_traits_min = -ROCM_FP8FNUZ_MAX if use_fp8fnuz else qtype_traits.min
if quant_dtype == torch.int8:
qtype_traits = torch.iinfo(quant_dtype)
qtype_traits_min = qtype_traits.min
qtype_traits_max = qtype_traits.max
else:
qtype_traits_min, qtype_traits_max = get_fp8_min_max(quant_dtype)
qtype_max = as_float32_tensor(qtype_traits_max)
s_1 = as_float32_tensor(1.0)
s_512 = as_float32_tensor(512.0)
@ -72,17 +68,7 @@ def ref_dynamic_per_token_quant(
def ref_dynamic_per_tensor_fp8_quant(
x: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
fp8_traits = torch.finfo(FP8_DTYPE)
fp8_traits_max = (
ROCM_FP8FNUZ_MAX
if current_platform.is_rocm() and current_platform.is_fp8_fnuz()
else fp8_traits.max
)
fp8_traits_min = (
-ROCM_FP8FNUZ_MAX
if current_platform.is_rocm() and current_platform.is_fp8_fnuz()
else fp8_traits.min
)
fp8_traits_min, fp8_traits_max = get_fp8_min_max(FP8_DTYPE)
fp8_max = as_float32_tensor(fp8_traits_max)
one = as_float32_tensor(1.0)

View File

@ -7,15 +7,14 @@ import torch.nn.functional as F
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
get_fp8_min_max,
)
from vllm.platforms import current_platform
# Using the default value (240.0) from pytorch will cause accuracy
# issue on dynamic quantization models. Here use 224.0 for fnuz on ROCm.
_FP8_DTYPE = current_platform.fp8_dtype()
_FP8_FINFO = torch.finfo(_FP8_DTYPE)
_FP8_MAX = 224.0 if current_platform.is_fp8_fnuz() else _FP8_FINFO.max
_FP8_MIN = -224.0 if current_platform.is_fp8_fnuz() else _FP8_FINFO.min
_FP8_MIN, _FP8_MAX = get_fp8_min_max(_FP8_DTYPE)
_FP8_MIN_SCALING_FACTOR = 1.0 / (_FP8_MAX * 512.0)

View File

@ -15,7 +15,10 @@ from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
get_fp8_min_max,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_BLOCK_FP8_SUPPORTED,
)
@ -748,12 +751,7 @@ def per_token_group_quant_fp8(
)
assert x.stride(-1) == 1, "`x` groups must be contiguous"
# Using the default value (240.0) from pytorch will cause accuracy
# issue on dynamic quantization models. Here use 224.0 for fnuz on ROCm
# platforms that use the torch.float8_e4mefnuz dtype.
finfo = torch.finfo(dtype)
fp8_min = -224.0 if current_platform.is_fp8_fnuz() else finfo.min
fp8_max = 224.0 if current_platform.is_fp8_fnuz() else finfo.max
fp8_min, fp8_max = get_fp8_min_max(dtype)
assert out_q is None or out_q.shape == x.shape
x_q = out_q

View File

@ -19,6 +19,28 @@ FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
def get_fp8_min_max(dtype: torch.dtype | None = None) -> tuple[float, float]:
"""
Get the min and max values for FP8 quantization.
On ROCm with torch.float8_e4m3fnuz (fnuz), the default PyTorch finfo.max
(240.0) causes accuracy issues with dynamic quantization models.
Use 224.0 instead for fnuz dtype.
Args:
dtype: FP8 dtype (defaults to platform's FP8 dtype if None)
Returns:
Tuple of (fp8_min, fp8_max) values
"""
if dtype is None:
dtype = FP8_DTYPE
finfo = torch.finfo(dtype)
if current_platform.is_fp8_fnuz():
return -224.0, 224.0
return finfo.min, finfo.max
# Use proxy as NamedTuple direct subclasses cannot have static members
class _GroupShape(NamedTuple):
row: int

View File

@ -16,6 +16,9 @@ import torch
import vllm.envs as envs
from vllm.logger import logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
get_fp8_min_max,
)
from vllm.platforms import current_platform
from vllm.utils.import_utils import has_deep_gemm
from vllm.utils.math_utils import cdiv
@ -355,7 +358,8 @@ def per_block_cast_to_fp8(
x_padded[:m, :n] = x
x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
sf = x_amax / 224.0 if current_platform.is_fp8_fnuz() else x_amax / 448.0
_, fp8_max = get_fp8_min_max(fp8_dtype)
sf = x_amax / fp8_max
sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf
x_scaled = (x_view * (1.0 / sf)).to(fp8_dtype)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(