[Misc] Removed force_fp8_e4m3fnuz from FP8LinearOp (#23725)

Signed-off-by: Julien Lin <jullin@nvidia.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
nvjullin 2025-09-04 21:25:40 +08:00 committed by GitHub
parent c9f7081f9c
commit 37241077d5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 45 additions and 30 deletions

View File

@ -15,9 +15,10 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, QuantKey, ScaleDesc)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp, maybe_create_device_identity)
Fp8LinearOp, cutlass_fp8_supported, maybe_create_device_identity)
from vllm.platforms import current_platform
from ..utils import override_cutlass_fp8_supported
from .backend import TestBackend
FP8_DTYPE = current_platform.fp8_dtype()
@ -26,9 +27,9 @@ FP8_DTYPE = current_platform.fp8_dtype()
class TestModel(torch.nn.Module):
def __init__(self, hidden_size: int, eps: float, static: bool,
force_fp8_e4m3fnuz: bool, *args, **kwargs):
cuda_force_torch: bool, *args, **kwargs):
super().__init__(*args, **kwargs)
self.force_fp8_e4m3fnuz = force_fp8_e4m3fnuz
self.cuda_force_torch = cuda_force_torch
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
@ -42,11 +43,12 @@ class TestModel(torch.nn.Module):
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
for _ in range(2)
]
self.fp8_linear = Fp8LinearOp(
force_fp8_e4m3fnuz=force_fp8_e4m3fnuz,
act_quant_static=static,
act_quant_group_shape=group_shape,
)
with override_cutlass_fp8_supported(not cuda_force_torch):
self.fp8_linear = Fp8LinearOp(
act_quant_static=static,
act_quant_group_shape=group_shape,
)
def forward(self, x):
resid = torch.sqrt(x)
@ -81,11 +83,14 @@ class TestModel(torch.nn.Module):
@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize("static", [True, False])
@pytest.mark.parametrize("force_fp8_e4m3fnuz", [True, False])
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
@pytest.mark.parametrize("cuda_force_torch",
[True, False] if cutlass_fp8_supported() else [True])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
reason="Only test on CUDA and ROCm")
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
force_fp8_e4m3fnuz):
cuda_force_torch):
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
torch.manual_seed(1)
@ -102,7 +107,7 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
fusion_pass = FusionPass.instance(vllm_config)
backend = TestBackend(noop_pass, fusion_pass)
model = TestModel(hidden_size, eps, static, force_fp8_e4m3fnuz)
model = TestModel(hidden_size, eps, static, cuda_force_torch)
# First dimension dynamic
x = torch.rand(num_tokens, hidden_size)

View File

@ -17,9 +17,10 @@ from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, kFp8StaticTensorSym, kNvfp4Quant)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp)
Fp8LinearOp, cutlass_fp8_supported)
from vllm.platforms import current_platform
from ..utils import override_cutlass_fp8_supported
from .backend import TestBackend
FP8_DTYPE = current_platform.fp8_dtype()
@ -32,7 +33,7 @@ def is_nvfp4_supported():
class TestSiluMulFp8QuantModel(torch.nn.Module):
def __init__(self, hidden_size: int, force_fp8_e4m3fnuz: bool, **kwargs):
def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs):
super().__init__()
self.silu_and_mul = SiluAndMul()
self.wscale = torch.rand(1, dtype=torch.float32)
@ -40,11 +41,11 @@ class TestSiluMulFp8QuantModel(torch.nn.Module):
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
self.fp8_linear = Fp8LinearOp(
force_fp8_e4m3fnuz=force_fp8_e4m3fnuz,
act_quant_static=True,
act_quant_group_shape=GroupShape.PER_TENSOR,
)
with override_cutlass_fp8_supported(not cuda_force_torch):
self.fp8_linear = Fp8LinearOp(
act_quant_static=True,
act_quant_group_shape=GroupShape.PER_TENSOR,
)
def forward(self, x):
y = self.silu_and_mul(x)
@ -96,12 +97,15 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
@pytest.mark.parametrize(
"model_class", [TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel]
if is_nvfp4_supported() else [TestSiluMulFp8QuantModel])
@pytest.mark.parametrize("force_fp8_e4m3fnuz", [True, False])
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
@pytest.mark.parametrize("cuda_force_torch",
[True, False] if cutlass_fp8_supported() else [True])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
reason="Only test on CUDA and ROCm")
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, model_class,
force_fp8_e4m3fnuz):
if model_class == TestSiluMulNvfp4QuantModel and force_fp8_e4m3fnuz:
cuda_force_torch):
if model_class == TestSiluMulNvfp4QuantModel and cuda_force_torch:
pytest.skip("Duplicate tests for NVFP4")
torch.set_default_device("cuda")
@ -114,8 +118,7 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, model_class,
fusion_pass = ActivationQuantFusionPass(config)
backend = TestBackend(NoOpEliminationPass(config), fusion_pass)
model = model_class(hidden_size=hidden_size,
force_fp8_e4m3fnuz=force_fp8_e4m3fnuz)
model = model_class(hidden_size, cuda_force_torch)
# First dimension dynamic
x = torch.rand(num_tokens, hidden_size * 2)

View File

@ -17,6 +17,7 @@ from contextlib import contextmanager, suppress
from multiprocessing import Process
from pathlib import Path
from typing import Any, Callable, Literal, Optional, Union
from unittest.mock import patch
import cloudpickle
import httpx
@ -1077,3 +1078,11 @@ def get_attn_backend_list_based_on_platform() -> list[str]:
return attn_backend_list
else:
raise ValueError("Unsupported platform")
@contextmanager
def override_cutlass_fp8_supported(value: bool):
with patch(
"vllm.model_executor.layers.quantization.utils.w8a8_utils.cutlass_fp8_supported",
return_value=value):
yield

View File

@ -92,13 +92,13 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
"""
def __init__(self, quant_config: PTPCFp8Config):
assert current_platform.is_rocm(), \
"PTPCFp8LinearMethod is only supported on ROCm."
super().__init__(quant_config=quant_config)
# Force weight quantization
self.quant_config.is_checkpoint_fp8_serialized = False
self.fp8_linear = Fp8LinearOp(
act_quant_static=False,
act_quant_group_shape=GroupShape.PER_TOKEN,
force_fp8_e4m3fnuz=True)
act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.weight = torch.nn.Parameter(layer.weight.data,

View File

@ -355,12 +355,10 @@ class Fp8LinearOp:
def __init__(self,
act_quant_static: bool,
act_quant_group_shape: GroupShape = GroupShape.PER_TENSOR,
pad_output: Optional[bool] = None,
force_fp8_e4m3fnuz: bool = False):
pad_output: Optional[bool] = None):
if current_platform.is_rocm():
self.preferred_backend = "rocm"
elif current_platform.is_cuda(
) and not force_fp8_e4m3fnuz and cutlass_fp8_supported():
elif current_platform.is_cuda() and cutlass_fp8_supported():
if has_flashinfer() and current_platform.has_device_capability(
100):
self.preferred_backend = "flashinfer"