[Rocm][torch.compile] Adding layernorm + fp8 block quant and silu + fp8 block quant for Aiter (#25693)

Signed-off-by: charlifu <charlifu@amd.com>
Signed-off-by: Micah Williamson <micah.williamson@amd.com>
Signed-off-by: Charlie Fu <Charlie.Fu@amd.com>
Co-authored-by: Micah Williamson <micah.williamson@amd.com>
Co-authored-by: wuhuikx <hattie.wu@amd.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com>
This commit is contained in:
Charlie Fu 2025-12-09 16:39:26 -06:00 committed by GitHub
parent fccd532587
commit 3c680f4a17
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 610 additions and 60 deletions

View File

@ -1,10 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
import pytest
import torch
import vllm.plugins
from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops
from vllm.compilation.fusion import FUSED_OPS, FusedRMSQuantKey, RMSNormQuantFusionPass
from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.matcher_utils import QUANT_OPS
@ -152,13 +155,79 @@ GROUP_SHAPES = [
]
class TestRmsnormGroupFp8QuantModel(torch.nn.Module):
def __init__(self, hidden_size: int, eps: float, **kwargs):
super().__init__()
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(128, 128),
act_quant_group_shape=GroupShape(1, 128),
cutlass_block_fp8_supported=False,
use_aiter_and_is_supported=True,
)
self.w = [
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
for _ in range(3)
]
scale_hidden_size = (hidden_size + 128 - 1) // 128
self.wscale = [
torch.rand((scale_hidden_size, scale_hidden_size), dtype=torch.float32)
for _ in range(3)
]
self.norm_weight = [torch.ones(hidden_size) for _ in range(4)]
self.eps = eps
def forward(self, x):
# avoid having graph input be an arg to a pattern directly
x = resid = torch.relu(x)
y = rocm_aiter_ops.rms_norm(x, self.norm_weight[0], self.eps)
x2 = self.w8a8_block_fp8_linear.apply(y, self.w[0], self.wscale[0])
# make sure resid is used for replacement to work
y2, resid = rocm_aiter_ops.rms_norm2d_with_add(
x2, resid, self.norm_weight[1], self.eps
)
x3 = self.w8a8_block_fp8_linear.apply(y2, self.w[1], self.wscale[1])
y3, resid = rocm_aiter_ops.rms_norm2d_with_add(
x3, resid, self.norm_weight[2], self.eps
)
x4 = self.w8a8_block_fp8_linear.apply(y3, self.w[2], self.wscale[2])
y4, resid = rocm_aiter_ops.rms_norm2d_with_add(
x4, resid, self.norm_weight[3], self.eps
)
return y4
def ops_in_model_before(self):
return [
torch.ops.vllm.rocm_aiter_rms_norm,
torch.ops.vllm.rocm_aiter_group_fp8_quant,
]
def ops_in_model_before_partial(self):
return []
def ops_in_model_after(self):
return [
torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant,
torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant,
]
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [256])
@pytest.mark.parametrize("num_tokens", [257])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize("group_shape", GROUP_SHAPES)
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
@pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False])
@pytest.mark.parametrize(
"model_class, enable_rms_norm_custom_op, enable_quant_fp8_custom_op",
list(itertools.product([TestModel], [True, False], [True, False]))
+ [(TestRmsnormGroupFp8QuantModel, False, False)],
)
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
@pytest.mark.parametrize(
@ -173,10 +242,14 @@ def test_fusion_rmsnorm_quant(
num_tokens,
eps,
group_shape,
model_class,
enable_rms_norm_custom_op,
enable_quant_fp8_custom_op,
cuda_force_torch,
):
if model_class is TestRmsnormGroupFp8QuantModel and not IS_AITER_FOUND:
pytest.skip("AITER is not supported on this GPU.")
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
torch.manual_seed(1)
@ -209,12 +282,24 @@ def test_fusion_rmsnorm_quant(
with vllm.config.set_current_vllm_config(vllm_config):
# Reshape pass is needed for the fusion pass to work
noop_pass = NoOpEliminationPass(vllm_config)
fusion_pass = RMSNormQuantFusionPass(vllm_config)
if model_class is TestRmsnormGroupFp8QuantModel:
from vllm.compilation.rocm_aiter_fusion import (
RocmAiterRMSNormFp8GroupQuantFusionPass,
)
fusion_pass = RocmAiterRMSNormFp8GroupQuantFusionPass(vllm_config)
else:
fusion_pass = RMSNormQuantFusionPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
backend2 = TestBackend(noop_pass, cleanup_pass)
model = TestModel(hidden_size, eps, group_shape, cuda_force_torch)
model = model_class(
hidden_size=hidden_size,
eps=eps,
group_shape=group_shape,
cuda_force_torch=cuda_force_torch,
)
# First dimension dynamic
x = torch.rand(num_tokens, hidden_size)
torch._dynamo.mark_dynamic(x, 0)
@ -243,7 +328,10 @@ def test_fusion_rmsnorm_quant(
# there's a risk that the fused add doesn't get included in the
# replacement and only the rms part gets fused with quant.
# Hence, we check only 2 add nodes are left (final fused rmsnorm add).
if not enable_rms_norm_custom_op:
if (
not enable_rms_norm_custom_op
and model_class is not TestRmsnormGroupFp8QuantModel
):
n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g))
# 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each)
assert n_add_nodes(backend.graph_pre_pass) == 7

View File

@ -7,6 +7,7 @@ import torch
import vllm.envs as envs
from tests.kernels.quantization.nvfp4_utils import quant_nvfp4_tensor
from vllm._aiter_ops import IS_AITER_FOUND
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.compilation.activation_quant_fusion import (
FUSED_OPS,
@ -24,6 +25,7 @@ from vllm.config import (
set_current_vllm_config,
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
kFp8StaticTensorSym,
@ -126,6 +128,39 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
return [FUSED_OPS[kNvfp4Quant]]
class TestSiluMulGroupFp8QuantModel(torch.nn.Module):
def __init__(self, hidden_size: int, **kwargs):
super().__init__()
self.silu_and_mul = SiluAndMul()
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(128, 128),
act_quant_group_shape=GroupShape(1, 128),
cutlass_block_fp8_supported=False,
use_aiter_and_is_supported=True,
)
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
scale_hidden_size = (hidden_size + 128 - 1) // 128
self.wscale = torch.rand(
(scale_hidden_size, scale_hidden_size), dtype=torch.float32
)
self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
def forward(self, x):
y = self.silu_and_mul(x)
x2 = self.w8a8_block_fp8_linear.apply(y, self.w, self.wscale)
return x2
def ops_in_model_before(self):
return [
SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul,
]
def ops_in_model_after(self):
return [torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant]
@pytest.mark.parametrize("num_tokens", [32, 64])
@pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@ -133,7 +168,10 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
@pytest.mark.parametrize(
"model_class, enable_quant_fp8_custom_op, cuda_force_torch",
list(itertools.product([TestSiluMulFp8QuantModel], [True, False], [True, False]))
+ [(TestSiluMulNvfp4QuantModel, False, False)],
+ [
(TestSiluMulNvfp4QuantModel, False, False),
(TestSiluMulGroupFp8QuantModel, False, False),
],
)
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
@ -144,13 +182,19 @@ def test_fusion_silu_and_mul_quant(
num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
model_class: type[TestSiluMulFp8QuantModel | TestSiluMulNvfp4QuantModel],
model_class: type[
TestSiluMulFp8QuantModel
| TestSiluMulNvfp4QuantModel
| TestSiluMulGroupFp8QuantModel
],
enable_silu_mul_custom_op: bool,
enable_quant_fp8_custom_op: bool,
cuda_force_torch: bool,
):
if model_class is TestSiluMulNvfp4QuantModel and not is_nvfp4_supported():
pytest.skip("NVFP4 is not supported on this GPU.")
if model_class is TestSiluMulGroupFp8QuantModel and not IS_AITER_FOUND:
pytest.skip("AITER is not supported on this GPU.")
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
@ -173,9 +217,15 @@ def test_fusion_silu_and_mul_quant(
)
with set_current_vllm_config(config):
fusion_pass = ActivationQuantFusionPass(config)
fusion_passes = [ActivationQuantFusionPass(config)]
if IS_AITER_FOUND:
from vllm.compilation.rocm_aiter_fusion import (
RocmAiterSiluMulFp8GroupQuantFusionPass,
)
passes = [NoOpEliminationPass(config), fusion_pass, PostCleanupPass(config)]
fusion_passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)]
passes = [NoOpEliminationPass(config), *fusion_passes, PostCleanupPass(config)]
backend = TestBackend(*passes)
model = model_class(
hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x
@ -194,12 +244,14 @@ def test_fusion_silu_and_mul_quant(
atol, rtol = 1e-3, 1e-3
elif model_class == TestSiluMulNvfp4QuantModel:
atol, rtol = 1e-1, 1e-1
elif model_class == TestSiluMulGroupFp8QuantModel:
atol, rtol = 5e-2, 5e-2
torch.testing.assert_close(
result[0].to(dtype=dtype), result2[0].to(dtype=dtype), atol=atol, rtol=rtol
)
assert fusion_pass.matched_count == 1
assert sum([p.matched_count for p in fusion_passes]) == 1
# In pre-nodes, quant op should be present and fused kernels should not
backend.check_before_ops(model.ops_in_model_before())

View File

@ -24,6 +24,15 @@ def is_aiter_found() -> bool:
# we keep this global outside to not cause torch compile breaks.
IS_AITER_FOUND = is_aiter_found()
# Can't use dtypes.fp8 directly inside an op
# because it returns wrong result on gfx942.
# This is a workaround to get the correct FP8 dtype.
# This might because that the get_gfx() is wrapped as a custom op.
if IS_AITER_FOUND:
from aiter import dtypes
AITER_FP8_DTYPE = dtypes.fp8
def if_aiter_supported(func: Callable) -> Callable:
"""Decorator that only executes the function if
@ -45,36 +54,6 @@ def if_aiter_supported(func: Callable) -> Callable:
return wrapper
def _rocm_aiter_group_fp8_quant_impl(
x: torch.Tensor,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
assert x.shape[-1] % group_size == 0, "Input shape must be divisible by group size"
from aiter import QuantType, dtypes, get_hip_quant
aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128)
return aiter_per1x128_quant(x.contiguous(), quant_dtype=dtypes.fp8)
def _rocm_aiter_group_fp8_quant_fake(
x: torch.Tensor,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
from aiter import dtypes
M, N = x.shape
x_fp8 = torch.empty((M, N), dtype=dtypes.fp8, device=x.device)
out_bs = torch.empty(
(
M,
(N + group_size - 1) // group_size,
),
dtype=torch.float32,
device=x.device,
)
return x_fp8, out_bs
def _rocm_aiter_fused_moe_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
@ -522,6 +501,142 @@ def _rocm_aiter_per_token_quant_fake(
)
def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl(
x: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant
(x_quant, x_quant_scales), _, _, res = fused_rms_fp8_group_quant(
x,
weight,
variance_epsilon,
None,
None,
None,
group_size=group_size,
dtype_quant=AITER_FP8_DTYPE,
res1=residual,
)
return (x_quant, x_quant_scales, res)
def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake(
x: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
M, N = x.shape
scale_shape = (M, (N + group_size - 1) // group_size)
return (
torch.empty_like(x, dtype=AITER_FP8_DTYPE, device=x.device),
torch.empty(scale_shape, dtype=torch.float32, device=x.device),
torch.empty_like(residual, device=residual.device),
)
def _rocm_aiter_rmsnorm_fp8_group_quant_impl(
x: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant
(x_quant, x_quant_scales), _, _, res = fused_rms_fp8_group_quant(
x,
weight,
variance_epsilon,
None,
None,
None,
group_size=group_size,
dtype_quant=AITER_FP8_DTYPE,
res1=None,
)
return (x_quant, x_quant_scales)
def _rocm_aiter_rmsnorm_fp8_group_quant_fake(
x: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
M, N = x.shape
scale_shape = (M, (N + group_size - 1) // group_size)
return (
torch.empty_like(x, dtype=AITER_FP8_DTYPE, device=x.device),
torch.empty(scale_shape, dtype=torch.float32, device=x.device),
)
def _rocm_aiter_group_fp8_quant_impl(
x: torch.Tensor,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
assert x.shape[-1] % group_size == 0, "Input shape must be divisible by group size"
from aiter import QuantType, get_hip_quant
aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128)
return aiter_per1x128_quant(x.contiguous(), quant_dtype=AITER_FP8_DTYPE)
def _rocm_aiter_group_fp8_quant_fake(
x: torch.Tensor,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
M, N = x.shape
x_fp8 = torch.empty((M, N), dtype=AITER_FP8_DTYPE, device=x.device)
out_bs = torch.empty(
(
M,
(N + group_size - 1) // group_size,
),
dtype=torch.float32,
device=x.device,
)
return x_fp8, out_bs
def _rocm_aiter_act_mul_and_fp8_group_quant_impl(
x: torch.Tensor,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
from aiter.ops.triton.activation import act_mul_and_fp8_group_quant
return act_mul_and_fp8_group_quant(
x,
activation="silu",
group_size=group_size,
dtype_quant=AITER_FP8_DTYPE,
)
def _rocm_aiter_act_mul_and_fp8_group_quant_fake(
x: torch.Tensor,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
M, N = x.shape
assert N % 2 == 0
N_half = N // 2
x_fp8 = torch.empty((M, N_half), dtype=AITER_FP8_DTYPE, device=x.device)
out_bs = torch.empty(
(
M,
(N_half + group_size - 1) // group_size,
),
dtype=torch.float32,
device=x.device,
)
return x_fp8, out_bs
# Global flag to ensure ops are registered only once
_OPS_REGISTERED = False
@ -557,7 +672,7 @@ class rocm_aiter_ops:
@if_aiter_supported
def is_linear_fp8_enaled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls.is_linear_enabled() and current_platform.is_fp8_fnuz()
return cls.is_linear_enabled()
@classmethod
@if_aiter_supported
@ -632,14 +747,6 @@ class rocm_aiter_ops:
)
# register all the custom ops here
direct_register_custom_op(
op_name="rocm_aiter_group_fp8_quant",
op_func=_rocm_aiter_group_fp8_quant_impl,
mutates_args=[],
fake_impl=_rocm_aiter_group_fp8_quant_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_asm_moe_tkw1",
op_func=_rocm_aiter_asm_moe_tkw1_impl,
@ -699,27 +806,46 @@ class rocm_aiter_ops:
direct_register_custom_op(
op_name="rocm_aiter_gemm_a8w8_blockscale",
op_func=_rocm_aiter_gemm_a8w8_blockscale_impl,
mutates_args=[],
fake_impl=_rocm_aiter_gemm_a8w8_blockscale_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_rms_norm",
op_func=_rocm_aiter_rms_norm_impl,
mutates_args=[],
fake_impl=_rocm_aiter_rms_norm_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
op_func=_rocm_aiter_rmsnorm2d_fwd_with_add_impl,
mutates_args=[],
fake_impl=_rocm_aiter_rmsnorm2d_fwd_with_add_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_rmsnorm_fp8_group_quant",
op_func=_rocm_aiter_rmsnorm_fp8_group_quant_impl,
fake_impl=_rocm_aiter_rmsnorm_fp8_group_quant_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_rmsnorm_with_add_fp8_group_quant",
op_func=_rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl,
fake_impl=_rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_act_mul_and_fp8_group_quant",
op_func=_rocm_aiter_act_mul_and_fp8_group_quant_impl,
fake_impl=_rocm_aiter_act_mul_and_fp8_group_quant_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_group_fp8_quant",
op_func=_rocm_aiter_group_fp8_quant_impl,
fake_impl=_rocm_aiter_group_fp8_quant_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_per_tensor_quant",
op_func=_rocm_aiter_per_tensor_quant_impl,

View File

@ -5,6 +5,7 @@ import functools
from torch import fx as fx
from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.logger import init_logger
from vllm.platforms import current_platform
@ -13,6 +14,12 @@ from vllm.utils.system_utils import set_env_var
from .post_cleanup import PostCleanupPass
from .vllm_inductor_pass import VllmInductorPass
if rocm_aiter_ops.is_enabled():
from vllm.compilation.rocm_aiter_fusion import (
RocmAiterRMSNormFp8GroupQuantFusionPass,
RocmAiterSiluMulFp8GroupQuantFusionPass,
)
if current_platform.is_cuda_alike():
from .activation_quant_fusion import ActivationQuantFusionPass
from .fusion import RMSNormQuantFusionPass
@ -109,8 +116,12 @@ class PostGradPassManager(CustomGraphPass):
if self.pass_config.fuse_norm_quant:
self.passes += [RMSNormQuantFusionPass(config)]
if rocm_aiter_ops.is_enabled():
self.passes += [RocmAiterRMSNormFp8GroupQuantFusionPass(config)]
if self.pass_config.fuse_act_quant:
self.passes += [ActivationQuantFusionPass(config)]
if rocm_aiter_ops.is_enabled():
self.passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)]
if self.pass_config.fuse_attn_quant:
self.passes += [AttnFusionPass(config)]

View File

@ -0,0 +1,242 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch
import torch._inductor.pattern_matcher as pm
from torch import fx
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._ops import OpOverload
import vllm.model_executor.layers.quantization.utils.fp8_utils # noqa: F401
from vllm.compilation.activation_quant_fusion import ActivationQuantPattern
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from .fusion import empty_bf16
from .inductor_pass import enable_fake_mode
from .matcher_utils import MatcherSiluAndMul
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__)
FP8_DTYPE = current_platform.fp8_dtype()
AITER_RMS_GROUP_QUANT_OP = torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant.default
AITER_RMS_ADD_GROUP_QUANT_OP = (
torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant.default
)
AITER_RMS_OP = torch.ops.vllm.rocm_aiter_rms_norm.default
AITER_RMS_ADD_OP = torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add.default
AITER_GROUP_FP8_QUANT_OP = torch.ops.vllm.rocm_aiter_group_fp8_quant.default
TRITON_GROUP_FP8_QUANT_OP = torch.ops.vllm.triton_per_token_group_quant_fp8.default
FUSED_SILU_MUL_QUANT_OP = torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant.default
class AiterRMSFp8GroupQuantPattern:
"""
This pattern fuses aiter rms_norm & group fp8 quant custom
ops into an aiter rms_norm_group_fp8_quant op.
"""
def __init__(self, epsilon: float, quant_dtype: torch.dtype, quant_op: OpOverload):
self.epsilon = epsilon
self.quant_dtype = quant_dtype
self.quant_op = quant_op
def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
):
at1 = AITER_RMS_OP(x=input, weight=weight, variance_epsilon=self.epsilon)
at2 = self.quant_op(at1, 128)
return at2[0], at2[1]
def replacement(
input: torch.Tensor,
weight: torch.Tensor,
):
at = AITER_RMS_GROUP_QUANT_OP(
x=input,
weight=weight,
variance_epsilon=self.epsilon,
group_size=128,
)
return at[0], at[1]
inputs = [
empty_bf16(5, 4), # input
empty_bf16(1, 5), # weight
]
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
class AiterFusedAddRMSFp8GroupQuantPattern:
"""
This pattern fuses aiter rms_norm_with_add & group fp8 quant custom ops
into a aiter rms_norm_with_add_group_fp8_quant op.
"""
def __init__(self, epsilon: float, quant_dtype: torch.dtype, quant_op: OpOverload):
self.epsilon = epsilon
self.quant_dtype = quant_dtype
self.quant_op = quant_op
def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
):
at1 = AITER_RMS_ADD_OP(
x=input,
residual=residual,
weight=weight,
variance_epsilon=self.epsilon,
)
at2 = self.quant_op(at1[0], 128)
# result, scale, residual
return at2[0], at2[1], at1[1]
def replacement(
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
):
at = AITER_RMS_ADD_GROUP_QUANT_OP(
x=input,
residual=residual,
weight=weight,
variance_epsilon=self.epsilon,
group_size=128,
)
# result, scale, residual
return at[0], at[1], at[2]
inputs = [
empty_bf16(5, 4), # input
empty_bf16(5, 4), # residual
empty_bf16(1, 5), # weight
]
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
class RocmAiterRMSNormFp8GroupQuantFusionPass(VllmPatternMatcherPass):
"""
This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op.
It also supports fused_add_rms_norm.
"""
@enable_fake_mode
def __init__(self, config: VllmConfig):
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="rocm_aiter_rms_norm_fp8_group_quant_fusion_pass"
)
# Make sure fused add patterns are before simple rms norm,
# as the latter is a subset of the former in torch ops
for epsilon in [1e-5, 1e-6]:
# Fuse rms_norm + dynamic group fp8 quant
for quant_op in [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]:
AiterRMSFp8GroupQuantPattern(epsilon, FP8_DTYPE, quant_op).register(
self.patterns
)
AiterFusedAddRMSFp8GroupQuantPattern(
epsilon, FP8_DTYPE, quant_op
).register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph):
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self) -> Any:
fusion_patterns = [
AiterRMSFp8GroupQuantPattern,
AiterFusedAddRMSFp8GroupQuantPattern,
]
return self.hash_source(self, *fusion_patterns)
class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
"""
This pattern fuses aiter silu_and_mul & group fp8 quant custom
ops into an aiter silu_and_mul_group_fp8_quant op.
"""
def __init__(self, quant_op: OpOverload):
self.silu_and_mul_matcher = MatcherSiluAndMul()
self.quant_op = quant_op
def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor,
):
at1 = self.silu_and_mul_matcher(input)
at2 = self.quant_op(at1, 128)
return at2[0], at2[1]
def replacement(
input: torch.Tensor,
):
at = FUSED_SILU_MUL_QUANT_OP(x=input, group_size=128)
return at[0], at[1]
inputs = [
self.silu_and_mul_matcher.inputs()[0],
]
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
"""
This pass fuses a pre-defined set of custom ops into fused ops.
It uses the torch pattern matcher to find the patterns and replace them.
Because patterns can only be registered once, the pass is a singleton.
This will be addressed in a future version of PyTorch:
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
"""
@enable_fake_mode
def __init__(self, config: VllmConfig):
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="rocm_aiter_silu_mul_fp8_group_quant_fusion_pass"
)
for quant_op in [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]:
AiterSiluMulFp8GroupQuantPattern(quant_op).register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph):
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self):
fusion_patterns = [
ActivationQuantPattern,
AiterSiluMulFp8GroupQuantPattern,
]
return VllmInductorPass.hash_source(self, *fusion_patterns)

View File

@ -196,6 +196,39 @@ direct_register_custom_op(
)
def _triton_per_token_group_quant_fp8_impl(
x: torch.Tensor,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
return per_token_group_quant_fp8(
x, group_size, column_major_scales=False, use_ue8m0=False
)
def _triton_per_token_group_quant_fp8_fake(
x: torch.Tensor,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
M, N = x.shape
x_fp8 = torch.empty((M, N), dtype=current_platform.fp8_dtype(), device=x.device)
out_bs = torch.empty(
(
M,
(N + group_size - 1) // group_size,
),
dtype=torch.float32,
device=x.device,
)
return x_fp8, out_bs
direct_register_custom_op(
"triton_per_token_group_quant_fp8",
_triton_per_token_group_quant_fp8_impl,
fake_impl=_triton_per_token_group_quant_fp8_fake,
)
# TODO fix ROCm->Triton custom path:
# https://github.com/vllm-project/vllm/issues/14397
class W8A8BlockFp8LinearOp:
@ -341,17 +374,15 @@ class W8A8BlockFp8LinearOp:
if input_scale is not None:
q_input = input_2d
# MI350 case uses triton kernel
elif use_triton:
q_input, input_scale = per_token_group_quant_fp8(
q_input, input_scale = torch.ops.vllm.triton_per_token_group_quant_fp8(
input_2d,
self.act_quant_group_shape.col,
column_major_scales=False,
use_ue8m0=False,
)
# MI300 uses tuned AITER ASM/C++ kernel
else:
q_input, input_scale = rocm_aiter_ops.group_fp8_quant(input_2d)
q_input, input_scale = rocm_aiter_ops.group_fp8_quant(
input_2d, self.act_quant_group_shape.col
)
return gemm_a8w8_blockscale_op(
q_input,