diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index c4229f93464ac..eedb9bdcd5299 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -15,9 +15,10 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, QuantKey, ScaleDesc) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, maybe_create_device_identity) + Fp8LinearOp, cutlass_fp8_supported, maybe_create_device_identity) from vllm.platforms import current_platform +from ..utils import override_cutlass_fp8_supported from .backend import TestBackend FP8_DTYPE = current_platform.fp8_dtype() @@ -26,9 +27,9 @@ FP8_DTYPE = current_platform.fp8_dtype() class TestModel(torch.nn.Module): def __init__(self, hidden_size: int, eps: float, static: bool, - force_fp8_e4m3fnuz: bool, *args, **kwargs): + cuda_force_torch: bool, *args, **kwargs): super().__init__(*args, **kwargs) - self.force_fp8_e4m3fnuz = force_fp8_e4m3fnuz + self.cuda_force_torch = cuda_force_torch self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)] self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN @@ -42,11 +43,12 @@ class TestModel(torch.nn.Module): torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() for _ in range(2) ] - self.fp8_linear = Fp8LinearOp( - force_fp8_e4m3fnuz=force_fp8_e4m3fnuz, - act_quant_static=static, - act_quant_group_shape=group_shape, - ) + + with override_cutlass_fp8_supported(not cuda_force_torch): + self.fp8_linear = Fp8LinearOp( + act_quant_static=static, + act_quant_group_shape=group_shape, + ) def forward(self, x): resid = torch.sqrt(x) @@ -81,11 +83,14 @@ class TestModel(torch.nn.Module): @pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) @pytest.mark.parametrize("static", [True, False]) -@pytest.mark.parametrize("force_fp8_e4m3fnuz", [True, False]) +# cuda_force_torch used to test torch code path on platforms that +# cutlass_fp8_supported() == True. +@pytest.mark.parametrize("cuda_force_torch", + [True, False] if cutlass_fp8_supported() else [True]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], reason="Only test on CUDA and ROCm") def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, - force_fp8_e4m3fnuz): + cuda_force_torch): torch.set_default_device("cuda") torch.set_default_dtype(dtype) torch.manual_seed(1) @@ -102,7 +107,7 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, fusion_pass = FusionPass.instance(vllm_config) backend = TestBackend(noop_pass, fusion_pass) - model = TestModel(hidden_size, eps, static, force_fp8_e4m3fnuz) + model = TestModel(hidden_size, eps, static, cuda_force_torch) # First dimension dynamic x = torch.rand(num_tokens, hidden_size) diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index fcc2589e42116..e16d1725e6add 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -17,9 +17,10 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, kFp8StaticTensorSym, kNvfp4Quant) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp) + Fp8LinearOp, cutlass_fp8_supported) from vllm.platforms import current_platform +from ..utils import override_cutlass_fp8_supported from .backend import TestBackend FP8_DTYPE = current_platform.fp8_dtype() @@ -32,7 +33,7 @@ def is_nvfp4_supported(): class TestSiluMulFp8QuantModel(torch.nn.Module): - def __init__(self, hidden_size: int, force_fp8_e4m3fnuz: bool, **kwargs): + def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs): super().__init__() self.silu_and_mul = SiluAndMul() self.wscale = torch.rand(1, dtype=torch.float32) @@ -40,11 +41,11 @@ class TestSiluMulFp8QuantModel(torch.nn.Module): self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() - self.fp8_linear = Fp8LinearOp( - force_fp8_e4m3fnuz=force_fp8_e4m3fnuz, - act_quant_static=True, - act_quant_group_shape=GroupShape.PER_TENSOR, - ) + with override_cutlass_fp8_supported(not cuda_force_torch): + self.fp8_linear = Fp8LinearOp( + act_quant_static=True, + act_quant_group_shape=GroupShape.PER_TENSOR, + ) def forward(self, x): y = self.silu_and_mul(x) @@ -96,12 +97,15 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module): @pytest.mark.parametrize( "model_class", [TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel] if is_nvfp4_supported() else [TestSiluMulFp8QuantModel]) -@pytest.mark.parametrize("force_fp8_e4m3fnuz", [True, False]) +# cuda_force_torch used to test torch code path on platforms that +# cutlass_fp8_supported() == True. +@pytest.mark.parametrize("cuda_force_torch", + [True, False] if cutlass_fp8_supported() else [True]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], reason="Only test on CUDA and ROCm") def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, model_class, - force_fp8_e4m3fnuz): - if model_class == TestSiluMulNvfp4QuantModel and force_fp8_e4m3fnuz: + cuda_force_torch): + if model_class == TestSiluMulNvfp4QuantModel and cuda_force_torch: pytest.skip("Duplicate tests for NVFP4") torch.set_default_device("cuda") @@ -114,8 +118,7 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, model_class, fusion_pass = ActivationQuantFusionPass(config) backend = TestBackend(NoOpEliminationPass(config), fusion_pass) - model = model_class(hidden_size=hidden_size, - force_fp8_e4m3fnuz=force_fp8_e4m3fnuz) + model = model_class(hidden_size, cuda_force_torch) # First dimension dynamic x = torch.rand(num_tokens, hidden_size * 2) diff --git a/tests/utils.py b/tests/utils.py index 9d2073f3c1036..e47235002657d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -17,6 +17,7 @@ from contextlib import contextmanager, suppress from multiprocessing import Process from pathlib import Path from typing import Any, Callable, Literal, Optional, Union +from unittest.mock import patch import cloudpickle import httpx @@ -1077,3 +1078,11 @@ def get_attn_backend_list_based_on_platform() -> list[str]: return attn_backend_list else: raise ValueError("Unsupported platform") + + +@contextmanager +def override_cutlass_fp8_supported(value: bool): + with patch( + "vllm.model_executor.layers.quantization.utils.w8a8_utils.cutlass_fp8_supported", + return_value=value): + yield diff --git a/vllm/model_executor/layers/quantization/ptpc_fp8.py b/vllm/model_executor/layers/quantization/ptpc_fp8.py index 466fd5fba7685..45ea8e3520f1d 100644 --- a/vllm/model_executor/layers/quantization/ptpc_fp8.py +++ b/vllm/model_executor/layers/quantization/ptpc_fp8.py @@ -92,13 +92,13 @@ class PTPCFp8LinearMethod(Fp8LinearMethod): """ def __init__(self, quant_config: PTPCFp8Config): + assert current_platform.is_rocm(), \ + "PTPCFp8LinearMethod is only supported on ROCm." super().__init__(quant_config=quant_config) # Force weight quantization self.quant_config.is_checkpoint_fp8_serialized = False self.fp8_linear = Fp8LinearOp( - act_quant_static=False, - act_quant_group_shape=GroupShape.PER_TOKEN, - force_fp8_e4m3fnuz=True) + act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.weight = torch.nn.Parameter(layer.weight.data, diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index ecdcc573935c0..8f6b7f83d47f8 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -355,12 +355,10 @@ class Fp8LinearOp: def __init__(self, act_quant_static: bool, act_quant_group_shape: GroupShape = GroupShape.PER_TENSOR, - pad_output: Optional[bool] = None, - force_fp8_e4m3fnuz: bool = False): + pad_output: Optional[bool] = None): if current_platform.is_rocm(): self.preferred_backend = "rocm" - elif current_platform.is_cuda( - ) and not force_fp8_e4m3fnuz and cutlass_fp8_supported(): + elif current_platform.is_cuda() and cutlass_fp8_supported(): if has_flashinfer() and current_platform.has_device_capability( 100): self.preferred_backend = "flashinfer"