Merge 719ccfd773d15f1cd494c4b04fa0b727d82954b4 into 254f6b986720c92ddf97fbb1a6a6465da8e87e29

This commit is contained in:
Kevin McKay 2025-12-25 08:06:38 +08:00 committed by GitHub
commit 3a557ea3da
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 142 additions and 39 deletions

View File

@ -4,13 +4,13 @@
import torch
from vllm.model_executor.layers.quantization.utils.quant_utils import group_broadcast
from vllm.model_executor.layers.quantization.utils.quant_utils import (
get_fp8_min_max,
group_broadcast,
)
from vllm.platforms import current_platform
from vllm.utils.math_utils import round_up
# Using the default value (240.0) from pytorch will cause accuracy
# issue on dynamic quantization models. Here use 224.0 for rocm.
ROCM_FP8FNUZ_MAX = 224.0
FP8_DTYPE = current_platform.fp8_dtype()
@ -25,16 +25,12 @@ def ref_dynamic_per_token_quant(
if scale_ub is not None:
assert quant_dtype == FP8_DTYPE
qtype_traits = (
torch.iinfo(quant_dtype)
if quant_dtype == torch.int8
else torch.finfo(quant_dtype)
)
use_fp8fnuz = (
current_platform.is_fp8_fnuz() and quant_dtype == current_platform.fp8_dtype()
)
qtype_traits_max = ROCM_FP8FNUZ_MAX if use_fp8fnuz else qtype_traits.max
qtype_traits_min = -ROCM_FP8FNUZ_MAX if use_fp8fnuz else qtype_traits.min
if quant_dtype == torch.int8:
qtype_traits = torch.iinfo(quant_dtype)
qtype_traits_min = qtype_traits.min
qtype_traits_max = qtype_traits.max
else:
qtype_traits_min, qtype_traits_max = get_fp8_min_max(quant_dtype)
qtype_max = as_float32_tensor(qtype_traits_max)
s_1 = as_float32_tensor(1.0)
s_512 = as_float32_tensor(512.0)
@ -72,17 +68,7 @@ def ref_dynamic_per_token_quant(
def ref_dynamic_per_tensor_fp8_quant(
x: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
fp8_traits = torch.finfo(FP8_DTYPE)
fp8_traits_max = (
ROCM_FP8FNUZ_MAX
if current_platform.is_rocm() and current_platform.is_fp8_fnuz()
else fp8_traits.max
)
fp8_traits_min = (
-ROCM_FP8FNUZ_MAX
if current_platform.is_rocm() and current_platform.is_fp8_fnuz()
else fp8_traits.min
)
fp8_traits_min, fp8_traits_max = get_fp8_min_max(FP8_DTYPE)
fp8_max = as_float32_tensor(fp8_traits_max)
one = as_float32_tensor(1.0)

View File

@ -0,0 +1,93 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Unit tests for the get_fp8_min_max() helper function.
These tests verify the FP8 min/max value logic for both standard
and fnuz (ROCm MI300) dtype handling.
"""
from unittest.mock import patch
import pytest
import torch
class TestGetFp8MinMax:
"""Test cases for get_fp8_min_max() function."""
def test_standard_fp8_dtype(self):
"""Test that standard FP8 dtype uses PyTorch's finfo values."""
from vllm.model_executor.layers.quantization.utils.quant_utils import (
get_fp8_min_max,
)
# For standard float8_e4m3fn, should return finfo values
fp8_min, fp8_max = get_fp8_min_max(torch.float8_e4m3fn)
finfo = torch.finfo(torch.float8_e4m3fn)
# Standard FP8 max is 448.0 for e4m3fn
assert fp8_max == finfo.max, f"Expected finfo.max={finfo.max}, got {fp8_max}"
assert fp8_min == finfo.min, f"Expected finfo.min={finfo.min}, got {fp8_min}"
@patch("vllm.model_executor.layers.quantization.utils.quant_utils.current_platform")
def test_fnuz_fp8_dtype_on_fnuz_platform(self, mock_platform):
"""Test that fnuz dtype on fnuz platform returns 224.0."""
mock_platform.is_fp8_fnuz.return_value = True
mock_platform.fp8_dtype.return_value = torch.float8_e4m3fnuz
# Re-import to use mocked platform
from importlib import reload
import vllm.model_executor.layers.quantization.utils.quant_utils as qu
reload(qu)
fp8_min, fp8_max = qu.get_fp8_min_max(torch.float8_e4m3fnuz)
# fnuz on ROCm MI300 should return 224.0, not 240.0
assert fp8_max == 224.0, (
f"Expected 224.0 for fnuz on fnuz platform, got {fp8_max}"
)
assert fp8_min == -224.0, (
f"Expected -224.0 for fnuz on fnuz platform, got {fp8_min}"
)
@patch("vllm.model_executor.layers.quantization.utils.quant_utils.current_platform")
def test_standard_dtype_on_fnuz_platform(self, mock_platform):
"""Test that standard dtype on fnuz platform uses finfo values."""
mock_platform.is_fp8_fnuz.return_value = True
from vllm.model_executor.layers.quantization.utils.quant_utils import (
get_fp8_min_max,
)
# Standard e4m3fn dtype should use finfo even on fnuz platform
fp8_min, fp8_max = get_fp8_min_max(torch.float8_e4m3fn)
finfo = torch.finfo(torch.float8_e4m3fn)
assert fp8_max == finfo.max, (
f"Standard dtype should use finfo.max={finfo.max}, got {fp8_max}"
)
@patch("vllm.model_executor.layers.quantization.utils.quant_utils.current_platform")
def test_fnuz_dtype_on_non_fnuz_platform(self, mock_platform):
"""Test that fnuz dtype on non-fnuz platform uses finfo values."""
mock_platform.is_fp8_fnuz.return_value = False
from vllm.model_executor.layers.quantization.utils.quant_utils import (
get_fp8_min_max,
)
# fnuz dtype on non-fnuz platform should use finfo
fp8_min, fp8_max = get_fp8_min_max(torch.float8_e4m3fnuz)
finfo = torch.finfo(torch.float8_e4m3fnuz)
# Should be 240.0, not 224.0 (non-fnuz platform)
assert fp8_max == finfo.max, (
f"Non-fnuz platform should use finfo.max={finfo.max}, got {fp8_max}"
)
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@ -7,15 +7,14 @@ import torch.nn.functional as F
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
get_fp8_min_max,
)
from vllm.platforms import current_platform
# Using the default value (240.0) from pytorch will cause accuracy
# issue on dynamic quantization models. Here use 224.0 for fnuz on ROCm.
_FP8_DTYPE = current_platform.fp8_dtype()
_FP8_FINFO = torch.finfo(_FP8_DTYPE)
_FP8_MAX = 224.0 if current_platform.is_fp8_fnuz() else _FP8_FINFO.max
_FP8_MIN = -224.0 if current_platform.is_fp8_fnuz() else _FP8_FINFO.min
_FP8_MIN, _FP8_MAX = get_fp8_min_max(_FP8_DTYPE)
_FP8_MIN_SCALING_FACTOR = 1.0 / (_FP8_MAX * 512.0)

View File

@ -15,7 +15,10 @@ from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
get_fp8_min_max,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_BLOCK_FP8_SUPPORTED,
)
@ -748,12 +751,7 @@ def per_token_group_quant_fp8(
)
assert x.stride(-1) == 1, "`x` groups must be contiguous"
# Using the default value (240.0) from pytorch will cause accuracy
# issue on dynamic quantization models. Here use 224.0 for fnuz on ROCm
# platforms that use the torch.float8_e4mefnuz dtype.
finfo = torch.finfo(dtype)
fp8_min = -224.0 if current_platform.is_fp8_fnuz() else finfo.min
fp8_max = 224.0 if current_platform.is_fp8_fnuz() else finfo.max
fp8_min, fp8_max = get_fp8_min_max(dtype)
assert out_q is None or out_q.shape == x.shape
x_q = out_q

View File

@ -19,6 +19,29 @@ FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
def get_fp8_min_max(dtype: torch.dtype | None = None) -> tuple[float, float]:
"""
Get the min and max values for FP8 quantization.
On ROCm with torch.float8_e4m3fnuz (fnuz), the default PyTorch finfo.max
(240.0) causes accuracy issues with dynamic quantization models.
Use 224.0 instead for fnuz dtype.
Args:
dtype: FP8 dtype (defaults to platform's FP8 dtype if None)
Returns:
Tuple of (fp8_min, fp8_max) values
"""
if dtype is None:
dtype = FP8_DTYPE
finfo = torch.finfo(dtype)
# Only apply the 224.0 override for the actual fnuz dtype on fnuz platform
if current_platform.is_fp8_fnuz() and dtype == torch.float8_e4m3fnuz:
return -224.0, 224.0
return finfo.min, finfo.max
# Use proxy as NamedTuple direct subclasses cannot have static members
class _GroupShape(NamedTuple):
row: int

View File

@ -16,6 +16,9 @@ import torch
import vllm.envs as envs
from vllm.logger import logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
get_fp8_min_max,
)
from vllm.platforms import current_platform
from vllm.utils.import_utils import has_deep_gemm
from vllm.utils.math_utils import cdiv
@ -355,7 +358,8 @@ def per_block_cast_to_fp8(
x_padded[:m, :n] = x
x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
sf = x_amax / 224.0 if current_platform.is_fp8_fnuz() else x_amax / 448.0
_, fp8_max = get_fp8_min_max(fp8_dtype)
sf = x_amax / fp8_max
sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf
x_scaled = (x_view * (1.0 / sf)).to(fp8_dtype)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(