[Misc] Removed force_fp8_e4m3fnuz from FP8LinearOp (#23725)

Signed-off-by: Julien Lin <jullin@nvidia.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
nvjullin 2025-09-04 21:25:40 +08:00 committed by GitHub
parent c9f7081f9c
commit 37241077d5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 45 additions and 30 deletions

View File

@ -15,9 +15,10 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, QuantKey, ScaleDesc) GroupShape, QuantKey, ScaleDesc)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp, maybe_create_device_identity) Fp8LinearOp, cutlass_fp8_supported, maybe_create_device_identity)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from ..utils import override_cutlass_fp8_supported
from .backend import TestBackend from .backend import TestBackend
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
@ -26,9 +27,9 @@ FP8_DTYPE = current_platform.fp8_dtype()
class TestModel(torch.nn.Module): class TestModel(torch.nn.Module):
def __init__(self, hidden_size: int, eps: float, static: bool, def __init__(self, hidden_size: int, eps: float, static: bool,
force_fp8_e4m3fnuz: bool, *args, **kwargs): cuda_force_torch: bool, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.force_fp8_e4m3fnuz = force_fp8_e4m3fnuz self.cuda_force_torch = cuda_force_torch
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)] self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
@ -42,11 +43,12 @@ class TestModel(torch.nn.Module):
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
for _ in range(2) for _ in range(2)
] ]
self.fp8_linear = Fp8LinearOp(
force_fp8_e4m3fnuz=force_fp8_e4m3fnuz, with override_cutlass_fp8_supported(not cuda_force_torch):
act_quant_static=static, self.fp8_linear = Fp8LinearOp(
act_quant_group_shape=group_shape, act_quant_static=static,
) act_quant_group_shape=group_shape,
)
def forward(self, x): def forward(self, x):
resid = torch.sqrt(x) resid = torch.sqrt(x)
@ -81,11 +83,14 @@ class TestModel(torch.nn.Module):
@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049]) @pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049])
@pytest.mark.parametrize("eps", [1e-5, 1e-6]) @pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize("static", [True, False]) @pytest.mark.parametrize("static", [True, False])
@pytest.mark.parametrize("force_fp8_e4m3fnuz", [True, False]) # cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
@pytest.mark.parametrize("cuda_force_torch",
[True, False] if cutlass_fp8_supported() else [True])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
reason="Only test on CUDA and ROCm") reason="Only test on CUDA and ROCm")
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
force_fp8_e4m3fnuz): cuda_force_torch):
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.set_default_dtype(dtype) torch.set_default_dtype(dtype)
torch.manual_seed(1) torch.manual_seed(1)
@ -102,7 +107,7 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
fusion_pass = FusionPass.instance(vllm_config) fusion_pass = FusionPass.instance(vllm_config)
backend = TestBackend(noop_pass, fusion_pass) backend = TestBackend(noop_pass, fusion_pass)
model = TestModel(hidden_size, eps, static, force_fp8_e4m3fnuz) model = TestModel(hidden_size, eps, static, cuda_force_torch)
# First dimension dynamic # First dimension dynamic
x = torch.rand(num_tokens, hidden_size) x = torch.rand(num_tokens, hidden_size)

View File

@ -17,9 +17,10 @@ from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, kFp8StaticTensorSym, kNvfp4Quant) GroupShape, kFp8StaticTensorSym, kNvfp4Quant)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp) Fp8LinearOp, cutlass_fp8_supported)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from ..utils import override_cutlass_fp8_supported
from .backend import TestBackend from .backend import TestBackend
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
@ -32,7 +33,7 @@ def is_nvfp4_supported():
class TestSiluMulFp8QuantModel(torch.nn.Module): class TestSiluMulFp8QuantModel(torch.nn.Module):
def __init__(self, hidden_size: int, force_fp8_e4m3fnuz: bool, **kwargs): def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs):
super().__init__() super().__init__()
self.silu_and_mul = SiluAndMul() self.silu_and_mul = SiluAndMul()
self.wscale = torch.rand(1, dtype=torch.float32) self.wscale = torch.rand(1, dtype=torch.float32)
@ -40,11 +41,11 @@ class TestSiluMulFp8QuantModel(torch.nn.Module):
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
self.fp8_linear = Fp8LinearOp( with override_cutlass_fp8_supported(not cuda_force_torch):
force_fp8_e4m3fnuz=force_fp8_e4m3fnuz, self.fp8_linear = Fp8LinearOp(
act_quant_static=True, act_quant_static=True,
act_quant_group_shape=GroupShape.PER_TENSOR, act_quant_group_shape=GroupShape.PER_TENSOR,
) )
def forward(self, x): def forward(self, x):
y = self.silu_and_mul(x) y = self.silu_and_mul(x)
@ -96,12 +97,15 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_class", [TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel] "model_class", [TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel]
if is_nvfp4_supported() else [TestSiluMulFp8QuantModel]) if is_nvfp4_supported() else [TestSiluMulFp8QuantModel])
@pytest.mark.parametrize("force_fp8_e4m3fnuz", [True, False]) # cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
@pytest.mark.parametrize("cuda_force_torch",
[True, False] if cutlass_fp8_supported() else [True])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
reason="Only test on CUDA and ROCm") reason="Only test on CUDA and ROCm")
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, model_class, def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, model_class,
force_fp8_e4m3fnuz): cuda_force_torch):
if model_class == TestSiluMulNvfp4QuantModel and force_fp8_e4m3fnuz: if model_class == TestSiluMulNvfp4QuantModel and cuda_force_torch:
pytest.skip("Duplicate tests for NVFP4") pytest.skip("Duplicate tests for NVFP4")
torch.set_default_device("cuda") torch.set_default_device("cuda")
@ -114,8 +118,7 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, model_class,
fusion_pass = ActivationQuantFusionPass(config) fusion_pass = ActivationQuantFusionPass(config)
backend = TestBackend(NoOpEliminationPass(config), fusion_pass) backend = TestBackend(NoOpEliminationPass(config), fusion_pass)
model = model_class(hidden_size=hidden_size, model = model_class(hidden_size, cuda_force_torch)
force_fp8_e4m3fnuz=force_fp8_e4m3fnuz)
# First dimension dynamic # First dimension dynamic
x = torch.rand(num_tokens, hidden_size * 2) x = torch.rand(num_tokens, hidden_size * 2)

View File

@ -17,6 +17,7 @@ from contextlib import contextmanager, suppress
from multiprocessing import Process from multiprocessing import Process
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Literal, Optional, Union from typing import Any, Callable, Literal, Optional, Union
from unittest.mock import patch
import cloudpickle import cloudpickle
import httpx import httpx
@ -1077,3 +1078,11 @@ def get_attn_backend_list_based_on_platform() -> list[str]:
return attn_backend_list return attn_backend_list
else: else:
raise ValueError("Unsupported platform") raise ValueError("Unsupported platform")
@contextmanager
def override_cutlass_fp8_supported(value: bool):
with patch(
"vllm.model_executor.layers.quantization.utils.w8a8_utils.cutlass_fp8_supported",
return_value=value):
yield

View File

@ -92,13 +92,13 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
""" """
def __init__(self, quant_config: PTPCFp8Config): def __init__(self, quant_config: PTPCFp8Config):
assert current_platform.is_rocm(), \
"PTPCFp8LinearMethod is only supported on ROCm."
super().__init__(quant_config=quant_config) super().__init__(quant_config=quant_config)
# Force weight quantization # Force weight quantization
self.quant_config.is_checkpoint_fp8_serialized = False self.quant_config.is_checkpoint_fp8_serialized = False
self.fp8_linear = Fp8LinearOp( self.fp8_linear = Fp8LinearOp(
act_quant_static=False, act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN)
act_quant_group_shape=GroupShape.PER_TOKEN,
force_fp8_e4m3fnuz=True)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.weight = torch.nn.Parameter(layer.weight.data, layer.weight = torch.nn.Parameter(layer.weight.data,

View File

@ -355,12 +355,10 @@ class Fp8LinearOp:
def __init__(self, def __init__(self,
act_quant_static: bool, act_quant_static: bool,
act_quant_group_shape: GroupShape = GroupShape.PER_TENSOR, act_quant_group_shape: GroupShape = GroupShape.PER_TENSOR,
pad_output: Optional[bool] = None, pad_output: Optional[bool] = None):
force_fp8_e4m3fnuz: bool = False):
if current_platform.is_rocm(): if current_platform.is_rocm():
self.preferred_backend = "rocm" self.preferred_backend = "rocm"
elif current_platform.is_cuda( elif current_platform.is_cuda() and cutlass_fp8_supported():
) and not force_fp8_e4m3fnuz and cutlass_fp8_supported():
if has_flashinfer() and current_platform.has_device_capability( if has_flashinfer() and current_platform.has_device_capability(
100): 100):
self.preferred_backend = "flashinfer" self.preferred_backend = "flashinfer"