mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-17 21:21:22 +08:00
Merge 719ccfd773d15f1cd494c4b04fa0b727d82954b4 into 254f6b986720c92ddf97fbb1a6a6465da8e87e29
This commit is contained in:
commit
3a557ea3da
@ -4,13 +4,13 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import group_broadcast
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
get_fp8_min_max,
|
||||||
|
group_broadcast,
|
||||||
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.math_utils import round_up
|
from vllm.utils.math_utils import round_up
|
||||||
|
|
||||||
# Using the default value (240.0) from pytorch will cause accuracy
|
|
||||||
# issue on dynamic quantization models. Here use 224.0 for rocm.
|
|
||||||
ROCM_FP8FNUZ_MAX = 224.0
|
|
||||||
FP8_DTYPE = current_platform.fp8_dtype()
|
FP8_DTYPE = current_platform.fp8_dtype()
|
||||||
|
|
||||||
|
|
||||||
@ -25,16 +25,12 @@ def ref_dynamic_per_token_quant(
|
|||||||
if scale_ub is not None:
|
if scale_ub is not None:
|
||||||
assert quant_dtype == FP8_DTYPE
|
assert quant_dtype == FP8_DTYPE
|
||||||
|
|
||||||
qtype_traits = (
|
if quant_dtype == torch.int8:
|
||||||
torch.iinfo(quant_dtype)
|
qtype_traits = torch.iinfo(quant_dtype)
|
||||||
if quant_dtype == torch.int8
|
qtype_traits_min = qtype_traits.min
|
||||||
else torch.finfo(quant_dtype)
|
qtype_traits_max = qtype_traits.max
|
||||||
)
|
else:
|
||||||
use_fp8fnuz = (
|
qtype_traits_min, qtype_traits_max = get_fp8_min_max(quant_dtype)
|
||||||
current_platform.is_fp8_fnuz() and quant_dtype == current_platform.fp8_dtype()
|
|
||||||
)
|
|
||||||
qtype_traits_max = ROCM_FP8FNUZ_MAX if use_fp8fnuz else qtype_traits.max
|
|
||||||
qtype_traits_min = -ROCM_FP8FNUZ_MAX if use_fp8fnuz else qtype_traits.min
|
|
||||||
qtype_max = as_float32_tensor(qtype_traits_max)
|
qtype_max = as_float32_tensor(qtype_traits_max)
|
||||||
s_1 = as_float32_tensor(1.0)
|
s_1 = as_float32_tensor(1.0)
|
||||||
s_512 = as_float32_tensor(512.0)
|
s_512 = as_float32_tensor(512.0)
|
||||||
@ -72,17 +68,7 @@ def ref_dynamic_per_token_quant(
|
|||||||
def ref_dynamic_per_tensor_fp8_quant(
|
def ref_dynamic_per_tensor_fp8_quant(
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
fp8_traits = torch.finfo(FP8_DTYPE)
|
fp8_traits_min, fp8_traits_max = get_fp8_min_max(FP8_DTYPE)
|
||||||
fp8_traits_max = (
|
|
||||||
ROCM_FP8FNUZ_MAX
|
|
||||||
if current_platform.is_rocm() and current_platform.is_fp8_fnuz()
|
|
||||||
else fp8_traits.max
|
|
||||||
)
|
|
||||||
fp8_traits_min = (
|
|
||||||
-ROCM_FP8FNUZ_MAX
|
|
||||||
if current_platform.is_rocm() and current_platform.is_fp8_fnuz()
|
|
||||||
else fp8_traits.min
|
|
||||||
)
|
|
||||||
fp8_max = as_float32_tensor(fp8_traits_max)
|
fp8_max = as_float32_tensor(fp8_traits_max)
|
||||||
one = as_float32_tensor(1.0)
|
one = as_float32_tensor(1.0)
|
||||||
|
|
||||||
|
|||||||
93
tests/kernels/quantization/test_fp8_min_max_helper.py
Normal file
93
tests/kernels/quantization/test_fp8_min_max_helper.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
Unit tests for the get_fp8_min_max() helper function.
|
||||||
|
|
||||||
|
These tests verify the FP8 min/max value logic for both standard
|
||||||
|
and fnuz (ROCm MI300) dtype handling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetFp8MinMax:
|
||||||
|
"""Test cases for get_fp8_min_max() function."""
|
||||||
|
|
||||||
|
def test_standard_fp8_dtype(self):
|
||||||
|
"""Test that standard FP8 dtype uses PyTorch's finfo values."""
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
get_fp8_min_max,
|
||||||
|
)
|
||||||
|
|
||||||
|
# For standard float8_e4m3fn, should return finfo values
|
||||||
|
fp8_min, fp8_max = get_fp8_min_max(torch.float8_e4m3fn)
|
||||||
|
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
# Standard FP8 max is 448.0 for e4m3fn
|
||||||
|
assert fp8_max == finfo.max, f"Expected finfo.max={finfo.max}, got {fp8_max}"
|
||||||
|
assert fp8_min == finfo.min, f"Expected finfo.min={finfo.min}, got {fp8_min}"
|
||||||
|
|
||||||
|
@patch("vllm.model_executor.layers.quantization.utils.quant_utils.current_platform")
|
||||||
|
def test_fnuz_fp8_dtype_on_fnuz_platform(self, mock_platform):
|
||||||
|
"""Test that fnuz dtype on fnuz platform returns 224.0."""
|
||||||
|
mock_platform.is_fp8_fnuz.return_value = True
|
||||||
|
mock_platform.fp8_dtype.return_value = torch.float8_e4m3fnuz
|
||||||
|
|
||||||
|
# Re-import to use mocked platform
|
||||||
|
from importlib import reload
|
||||||
|
|
||||||
|
import vllm.model_executor.layers.quantization.utils.quant_utils as qu
|
||||||
|
|
||||||
|
reload(qu)
|
||||||
|
|
||||||
|
fp8_min, fp8_max = qu.get_fp8_min_max(torch.float8_e4m3fnuz)
|
||||||
|
|
||||||
|
# fnuz on ROCm MI300 should return 224.0, not 240.0
|
||||||
|
assert fp8_max == 224.0, (
|
||||||
|
f"Expected 224.0 for fnuz on fnuz platform, got {fp8_max}"
|
||||||
|
)
|
||||||
|
assert fp8_min == -224.0, (
|
||||||
|
f"Expected -224.0 for fnuz on fnuz platform, got {fp8_min}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch("vllm.model_executor.layers.quantization.utils.quant_utils.current_platform")
|
||||||
|
def test_standard_dtype_on_fnuz_platform(self, mock_platform):
|
||||||
|
"""Test that standard dtype on fnuz platform uses finfo values."""
|
||||||
|
mock_platform.is_fp8_fnuz.return_value = True
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
get_fp8_min_max,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Standard e4m3fn dtype should use finfo even on fnuz platform
|
||||||
|
fp8_min, fp8_max = get_fp8_min_max(torch.float8_e4m3fn)
|
||||||
|
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
assert fp8_max == finfo.max, (
|
||||||
|
f"Standard dtype should use finfo.max={finfo.max}, got {fp8_max}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch("vllm.model_executor.layers.quantization.utils.quant_utils.current_platform")
|
||||||
|
def test_fnuz_dtype_on_non_fnuz_platform(self, mock_platform):
|
||||||
|
"""Test that fnuz dtype on non-fnuz platform uses finfo values."""
|
||||||
|
mock_platform.is_fp8_fnuz.return_value = False
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
get_fp8_min_max,
|
||||||
|
)
|
||||||
|
|
||||||
|
# fnuz dtype on non-fnuz platform should use finfo
|
||||||
|
fp8_min, fp8_max = get_fp8_min_max(torch.float8_e4m3fnuz)
|
||||||
|
finfo = torch.finfo(torch.float8_e4m3fnuz)
|
||||||
|
|
||||||
|
# Should be 240.0, not 224.0 (non-fnuz platform)
|
||||||
|
assert fp8_max == finfo.max, (
|
||||||
|
f"Non-fnuz platform should use finfo.max={finfo.max}, got {fp8_max}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v"])
|
||||||
@ -7,15 +7,14 @@ import torch.nn.functional as F
|
|||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm._aiter_ops import rocm_aiter_ops
|
from vllm._aiter_ops import rocm_aiter_ops
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
GroupShape,
|
||||||
|
get_fp8_min_max,
|
||||||
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
# Using the default value (240.0) from pytorch will cause accuracy
|
|
||||||
# issue on dynamic quantization models. Here use 224.0 for fnuz on ROCm.
|
|
||||||
_FP8_DTYPE = current_platform.fp8_dtype()
|
_FP8_DTYPE = current_platform.fp8_dtype()
|
||||||
_FP8_FINFO = torch.finfo(_FP8_DTYPE)
|
_FP8_MIN, _FP8_MAX = get_fp8_min_max(_FP8_DTYPE)
|
||||||
_FP8_MAX = 224.0 if current_platform.is_fp8_fnuz() else _FP8_FINFO.max
|
|
||||||
_FP8_MIN = -224.0 if current_platform.is_fp8_fnuz() else _FP8_FINFO.min
|
|
||||||
_FP8_MIN_SCALING_FACTOR = 1.0 / (_FP8_MAX * 512.0)
|
_FP8_MIN_SCALING_FACTOR = 1.0 / (_FP8_MAX * 512.0)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -15,7 +15,10 @@ from vllm import _custom_ops as ops
|
|||||||
from vllm._aiter_ops import rocm_aiter_ops
|
from vllm._aiter_ops import rocm_aiter_ops
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
GroupShape,
|
||||||
|
get_fp8_min_max,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
CUTLASS_BLOCK_FP8_SUPPORTED,
|
CUTLASS_BLOCK_FP8_SUPPORTED,
|
||||||
)
|
)
|
||||||
@ -748,12 +751,7 @@ def per_token_group_quant_fp8(
|
|||||||
)
|
)
|
||||||
assert x.stride(-1) == 1, "`x` groups must be contiguous"
|
assert x.stride(-1) == 1, "`x` groups must be contiguous"
|
||||||
|
|
||||||
# Using the default value (240.0) from pytorch will cause accuracy
|
fp8_min, fp8_max = get_fp8_min_max(dtype)
|
||||||
# issue on dynamic quantization models. Here use 224.0 for fnuz on ROCm
|
|
||||||
# platforms that use the torch.float8_e4mefnuz dtype.
|
|
||||||
finfo = torch.finfo(dtype)
|
|
||||||
fp8_min = -224.0 if current_platform.is_fp8_fnuz() else finfo.min
|
|
||||||
fp8_max = 224.0 if current_platform.is_fp8_fnuz() else finfo.max
|
|
||||||
|
|
||||||
assert out_q is None or out_q.shape == x.shape
|
assert out_q is None or out_q.shape == x.shape
|
||||||
x_q = out_q
|
x_q = out_q
|
||||||
|
|||||||
@ -19,6 +19,29 @@ FP8_DTYPE = current_platform.fp8_dtype()
|
|||||||
FP4_DTYPE = torch.uint8
|
FP4_DTYPE = torch.uint8
|
||||||
|
|
||||||
|
|
||||||
|
def get_fp8_min_max(dtype: torch.dtype | None = None) -> tuple[float, float]:
|
||||||
|
"""
|
||||||
|
Get the min and max values for FP8 quantization.
|
||||||
|
|
||||||
|
On ROCm with torch.float8_e4m3fnuz (fnuz), the default PyTorch finfo.max
|
||||||
|
(240.0) causes accuracy issues with dynamic quantization models.
|
||||||
|
Use 224.0 instead for fnuz dtype.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dtype: FP8 dtype (defaults to platform's FP8 dtype if None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (fp8_min, fp8_max) values
|
||||||
|
"""
|
||||||
|
if dtype is None:
|
||||||
|
dtype = FP8_DTYPE
|
||||||
|
finfo = torch.finfo(dtype)
|
||||||
|
# Only apply the 224.0 override for the actual fnuz dtype on fnuz platform
|
||||||
|
if current_platform.is_fp8_fnuz() and dtype == torch.float8_e4m3fnuz:
|
||||||
|
return -224.0, 224.0
|
||||||
|
return finfo.min, finfo.max
|
||||||
|
|
||||||
|
|
||||||
# Use proxy as NamedTuple direct subclasses cannot have static members
|
# Use proxy as NamedTuple direct subclasses cannot have static members
|
||||||
class _GroupShape(NamedTuple):
|
class _GroupShape(NamedTuple):
|
||||||
row: int
|
row: int
|
||||||
|
|||||||
@ -16,6 +16,9 @@ import torch
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
get_fp8_min_max,
|
||||||
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.import_utils import has_deep_gemm
|
from vllm.utils.import_utils import has_deep_gemm
|
||||||
from vllm.utils.math_utils import cdiv
|
from vllm.utils.math_utils import cdiv
|
||||||
@ -355,7 +358,8 @@ def per_block_cast_to_fp8(
|
|||||||
x_padded[:m, :n] = x
|
x_padded[:m, :n] = x
|
||||||
x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
|
x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
|
||||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||||
sf = x_amax / 224.0 if current_platform.is_fp8_fnuz() else x_amax / 448.0
|
_, fp8_max = get_fp8_min_max(fp8_dtype)
|
||||||
|
sf = x_amax / fp8_max
|
||||||
sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf
|
sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf
|
||||||
x_scaled = (x_view * (1.0 / sf)).to(fp8_dtype)
|
x_scaled = (x_view * (1.0 / sf)).to(fp8_dtype)
|
||||||
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(
|
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user