mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 17:55:01 +08:00
[Misc] Removed force_fp8_e4m3fnuz from FP8LinearOp (#23725)
Signed-off-by: Julien Lin <jullin@nvidia.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
parent
c9f7081f9c
commit
37241077d5
@ -15,9 +15,10 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape, QuantKey, ScaleDesc)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp, maybe_create_device_identity)
|
||||
Fp8LinearOp, cutlass_fp8_supported, maybe_create_device_identity)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..utils import override_cutlass_fp8_supported
|
||||
from .backend import TestBackend
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
@ -26,9 +27,9 @@ FP8_DTYPE = current_platform.fp8_dtype()
|
||||
class TestModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size: int, eps: float, static: bool,
|
||||
force_fp8_e4m3fnuz: bool, *args, **kwargs):
|
||||
cuda_force_torch: bool, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.force_fp8_e4m3fnuz = force_fp8_e4m3fnuz
|
||||
self.cuda_force_torch = cuda_force_torch
|
||||
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
|
||||
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
|
||||
group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
|
||||
@ -42,11 +43,12 @@ class TestModel(torch.nn.Module):
|
||||
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
||||
for _ in range(2)
|
||||
]
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
force_fp8_e4m3fnuz=force_fp8_e4m3fnuz,
|
||||
act_quant_static=static,
|
||||
act_quant_group_shape=group_shape,
|
||||
)
|
||||
|
||||
with override_cutlass_fp8_supported(not cuda_force_torch):
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=static,
|
||||
act_quant_group_shape=group_shape,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
resid = torch.sqrt(x)
|
||||
@ -81,11 +83,14 @@ class TestModel(torch.nn.Module):
|
||||
@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049])
|
||||
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
|
||||
@pytest.mark.parametrize("static", [True, False])
|
||||
@pytest.mark.parametrize("force_fp8_e4m3fnuz", [True, False])
|
||||
# cuda_force_torch used to test torch code path on platforms that
|
||||
# cutlass_fp8_supported() == True.
|
||||
@pytest.mark.parametrize("cuda_force_torch",
|
||||
[True, False] if cutlass_fp8_supported() else [True])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
|
||||
reason="Only test on CUDA and ROCm")
|
||||
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
|
||||
force_fp8_e4m3fnuz):
|
||||
cuda_force_torch):
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.manual_seed(1)
|
||||
@ -102,7 +107,7 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
|
||||
fusion_pass = FusionPass.instance(vllm_config)
|
||||
|
||||
backend = TestBackend(noop_pass, fusion_pass)
|
||||
model = TestModel(hidden_size, eps, static, force_fp8_e4m3fnuz)
|
||||
model = TestModel(hidden_size, eps, static, cuda_force_torch)
|
||||
|
||||
# First dimension dynamic
|
||||
x = torch.rand(num_tokens, hidden_size)
|
||||
|
||||
@ -17,9 +17,10 @@ from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape, kFp8StaticTensorSym, kNvfp4Quant)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp)
|
||||
Fp8LinearOp, cutlass_fp8_supported)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..utils import override_cutlass_fp8_supported
|
||||
from .backend import TestBackend
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
@ -32,7 +33,7 @@ def is_nvfp4_supported():
|
||||
|
||||
class TestSiluMulFp8QuantModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size: int, force_fp8_e4m3fnuz: bool, **kwargs):
|
||||
def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs):
|
||||
super().__init__()
|
||||
self.silu_and_mul = SiluAndMul()
|
||||
self.wscale = torch.rand(1, dtype=torch.float32)
|
||||
@ -40,11 +41,11 @@ class TestSiluMulFp8QuantModel(torch.nn.Module):
|
||||
|
||||
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
||||
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
force_fp8_e4m3fnuz=force_fp8_e4m3fnuz,
|
||||
act_quant_static=True,
|
||||
act_quant_group_shape=GroupShape.PER_TENSOR,
|
||||
)
|
||||
with override_cutlass_fp8_supported(not cuda_force_torch):
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=True,
|
||||
act_quant_group_shape=GroupShape.PER_TENSOR,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.silu_and_mul(x)
|
||||
@ -96,12 +97,15 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
|
||||
@pytest.mark.parametrize(
|
||||
"model_class", [TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel]
|
||||
if is_nvfp4_supported() else [TestSiluMulFp8QuantModel])
|
||||
@pytest.mark.parametrize("force_fp8_e4m3fnuz", [True, False])
|
||||
# cuda_force_torch used to test torch code path on platforms that
|
||||
# cutlass_fp8_supported() == True.
|
||||
@pytest.mark.parametrize("cuda_force_torch",
|
||||
[True, False] if cutlass_fp8_supported() else [True])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
|
||||
reason="Only test on CUDA and ROCm")
|
||||
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, model_class,
|
||||
force_fp8_e4m3fnuz):
|
||||
if model_class == TestSiluMulNvfp4QuantModel and force_fp8_e4m3fnuz:
|
||||
cuda_force_torch):
|
||||
if model_class == TestSiluMulNvfp4QuantModel and cuda_force_torch:
|
||||
pytest.skip("Duplicate tests for NVFP4")
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
@ -114,8 +118,7 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, model_class,
|
||||
fusion_pass = ActivationQuantFusionPass(config)
|
||||
|
||||
backend = TestBackend(NoOpEliminationPass(config), fusion_pass)
|
||||
model = model_class(hidden_size=hidden_size,
|
||||
force_fp8_e4m3fnuz=force_fp8_e4m3fnuz)
|
||||
model = model_class(hidden_size, cuda_force_torch)
|
||||
|
||||
# First dimension dynamic
|
||||
x = torch.rand(num_tokens, hidden_size * 2)
|
||||
|
||||
@ -17,6 +17,7 @@ from contextlib import contextmanager, suppress
|
||||
from multiprocessing import Process
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Literal, Optional, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import cloudpickle
|
||||
import httpx
|
||||
@ -1077,3 +1078,11 @@ def get_attn_backend_list_based_on_platform() -> list[str]:
|
||||
return attn_backend_list
|
||||
else:
|
||||
raise ValueError("Unsupported platform")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def override_cutlass_fp8_supported(value: bool):
|
||||
with patch(
|
||||
"vllm.model_executor.layers.quantization.utils.w8a8_utils.cutlass_fp8_supported",
|
||||
return_value=value):
|
||||
yield
|
||||
|
||||
@ -92,13 +92,13 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: PTPCFp8Config):
|
||||
assert current_platform.is_rocm(), \
|
||||
"PTPCFp8LinearMethod is only supported on ROCm."
|
||||
super().__init__(quant_config=quant_config)
|
||||
# Force weight quantization
|
||||
self.quant_config.is_checkpoint_fp8_serialized = False
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=False,
|
||||
act_quant_group_shape=GroupShape.PER_TOKEN,
|
||||
force_fp8_e4m3fnuz=True)
|
||||
act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.weight = torch.nn.Parameter(layer.weight.data,
|
||||
|
||||
@ -355,12 +355,10 @@ class Fp8LinearOp:
|
||||
def __init__(self,
|
||||
act_quant_static: bool,
|
||||
act_quant_group_shape: GroupShape = GroupShape.PER_TENSOR,
|
||||
pad_output: Optional[bool] = None,
|
||||
force_fp8_e4m3fnuz: bool = False):
|
||||
pad_output: Optional[bool] = None):
|
||||
if current_platform.is_rocm():
|
||||
self.preferred_backend = "rocm"
|
||||
elif current_platform.is_cuda(
|
||||
) and not force_fp8_e4m3fnuz and cutlass_fp8_supported():
|
||||
elif current_platform.is_cuda() and cutlass_fp8_supported():
|
||||
if has_flashinfer() and current_platform.has_device_capability(
|
||||
100):
|
||||
self.preferred_backend = "flashinfer"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user