mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-10 12:09:11 +08:00
[ROCm][FEAT] Support AITER RMSNorm quantization fusion pass (#26575)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
This commit is contained in:
parent
6b16fff01b
commit
f32cfd7d97
@ -1,7 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import itertools
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@ -53,37 +52,61 @@ class TestModel(torch.nn.Module):
|
|||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
eps: float,
|
eps: float,
|
||||||
group_shape: GroupShape,
|
group_shape: GroupShape,
|
||||||
cuda_force_torch: bool,
|
use_aiter: bool = False,
|
||||||
|
cuda_force_torch: bool = False,
|
||||||
|
use_aiter_quant_op: bool = True,
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
self.use_aiter = use_aiter
|
||||||
|
self.use_aiter_quant_op = use_aiter_quant_op
|
||||||
self.cuda_force_torch = cuda_force_torch
|
self.cuda_force_torch = cuda_force_torch
|
||||||
|
self.group_shape = group_shape
|
||||||
|
self.enable_quant_fp8_custom_op = None # Will be set later if applicable
|
||||||
|
|
||||||
self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)]
|
self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)]
|
||||||
if group_shape.is_per_group():
|
|
||||||
self.wscale = [
|
# Setup quantization scale descriptor
|
||||||
torch.rand(
|
static = group_shape == GroupShape.PER_TENSOR and not use_aiter
|
||||||
(hidden_size // group_shape[1], hidden_size // group_shape[1]),
|
|
||||||
dtype=torch.float32,
|
|
||||||
)
|
|
||||||
for _ in range(3)
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
|
||||||
static = group_shape == GroupShape.PER_TENSOR
|
|
||||||
quant_scale = ScaleDesc(torch.float32, static, group_shape)
|
quant_scale = ScaleDesc(torch.float32, static, group_shape)
|
||||||
self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
|
self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
|
||||||
|
|
||||||
|
# Setup scales
|
||||||
if static:
|
if static:
|
||||||
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||||
else:
|
else:
|
||||||
self.scale = [None for _ in range(3)]
|
self.scale = [None for _ in range(3)]
|
||||||
|
|
||||||
|
# Setup weights
|
||||||
self.w = [
|
self.w = [
|
||||||
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE) for _ in range(3)
|
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE) for _ in range(3)
|
||||||
]
|
]
|
||||||
if not group_shape.is_per_group():
|
if not group_shape.is_per_group() or use_aiter:
|
||||||
self.w = [self.w[0].t() for _ in range(3)]
|
self.w = [self.w[0].t() for _ in range(3)]
|
||||||
|
|
||||||
|
# Setup weight scales
|
||||||
if group_shape.is_per_group():
|
if group_shape.is_per_group():
|
||||||
|
scale_size = (
|
||||||
|
(hidden_size + 128 - 1) // 128
|
||||||
|
if use_aiter
|
||||||
|
else hidden_size // group_shape[1]
|
||||||
|
)
|
||||||
|
wscale_shape: tuple[int, ...] = (scale_size, scale_size)
|
||||||
|
else:
|
||||||
|
wscale_shape = (1,)
|
||||||
|
self.wscale = [torch.rand(wscale_shape, dtype=torch.float32) for _ in range(3)]
|
||||||
|
|
||||||
|
# Setup FP8 linear operation
|
||||||
|
is_per_group = group_shape.is_per_group()
|
||||||
|
if is_per_group and use_aiter:
|
||||||
|
self.fp8_linear = W8A8BlockFp8LinearOp(
|
||||||
|
weight_group_shape=GroupShape(128, 128),
|
||||||
|
act_quant_group_shape=group_shape,
|
||||||
|
use_aiter_and_is_supported=use_aiter_quant_op,
|
||||||
|
)
|
||||||
|
# AITER blockwise doesn't use enable_quant_fp8_custom_op
|
||||||
|
elif is_per_group:
|
||||||
self.fp8_linear = W8A8BlockFp8LinearOp(
|
self.fp8_linear = W8A8BlockFp8LinearOp(
|
||||||
weight_group_shape=GroupShape(group_shape[1], group_shape[1]),
|
weight_group_shape=GroupShape(group_shape[1], group_shape[1]),
|
||||||
act_quant_group_shape=group_shape,
|
act_quant_group_shape=group_shape,
|
||||||
@ -91,6 +114,13 @@ class TestModel(torch.nn.Module):
|
|||||||
use_aiter_and_is_supported=False,
|
use_aiter_and_is_supported=False,
|
||||||
)
|
)
|
||||||
self.enable_quant_fp8_custom_op = self.fp8_linear.input_quant_op.enabled()
|
self.enable_quant_fp8_custom_op = self.fp8_linear.input_quant_op.enabled()
|
||||||
|
elif use_aiter:
|
||||||
|
self.fp8_linear = Fp8LinearOp(
|
||||||
|
act_quant_static=False,
|
||||||
|
act_quant_group_shape=group_shape,
|
||||||
|
)
|
||||||
|
self.fp8_linear.quant_fp8.use_aiter = use_aiter_quant_op
|
||||||
|
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
|
||||||
else:
|
else:
|
||||||
with override_cutlass_fp8_supported(not cuda_force_torch):
|
with override_cutlass_fp8_supported(not cuda_force_torch):
|
||||||
self.fp8_linear = Fp8LinearOp(
|
self.fp8_linear = Fp8LinearOp(
|
||||||
@ -100,7 +130,6 @@ class TestModel(torch.nn.Module):
|
|||||||
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
|
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
|
||||||
|
|
||||||
self.enable_rms_norm_custom_op = self.norm[0].enabled()
|
self.enable_rms_norm_custom_op = self.norm[0].enabled()
|
||||||
self.group_shape = group_shape
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# avoid having graph input be an arg to a pattern directly
|
# avoid having graph input be an arg to a pattern directly
|
||||||
@ -126,19 +155,49 @@ class TestModel(torch.nn.Module):
|
|||||||
y4, resid = self.norm[3](x4, resid) # use resid here
|
y4, resid = self.norm[3](x4, resid) # use resid here
|
||||||
return y4
|
return y4
|
||||||
|
|
||||||
|
def ops_in_model_before(self):
|
||||||
|
if (
|
||||||
|
self.use_aiter
|
||||||
|
and self.group_shape.is_per_group()
|
||||||
|
and current_platform.is_fp8_fnuz()
|
||||||
|
):
|
||||||
|
return [rocm_aiter_ops.get_group_quant_op()]
|
||||||
|
if self.use_aiter and self.group_shape.is_per_group():
|
||||||
|
return [torch.ops.vllm.triton_per_token_group_quant_fp8.default]
|
||||||
|
if self.use_aiter and self.use_aiter_quant_op:
|
||||||
|
return [rocm_aiter_ops.get_per_token_quant_op()]
|
||||||
|
if self.use_aiter:
|
||||||
|
return [QUANT_OPS[self.quant_key]]
|
||||||
|
if self.enable_quant_fp8_custom_op:
|
||||||
|
return [QUANT_OPS[self.quant_key]]
|
||||||
|
return [torch.ops.aten.reciprocal]
|
||||||
|
|
||||||
def ops_in_model_after(self):
|
def ops_in_model_after(self):
|
||||||
|
if self.use_aiter and self.group_shape.is_per_group():
|
||||||
|
from vllm.compilation.rocm_aiter_fusion import (
|
||||||
|
AiterFusedAddRMSFp8GroupQuantPattern,
|
||||||
|
AiterRMSFp8GroupQuantPattern,
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
AiterFusedAddRMSFp8GroupQuantPattern.FUSED_OP,
|
||||||
|
AiterRMSFp8GroupQuantPattern.FUSED_OP,
|
||||||
|
]
|
||||||
|
if self.use_aiter:
|
||||||
|
from vllm.compilation.rocm_aiter_fusion import (
|
||||||
|
AiterFusedAddRMSNormDynamicQuantPattern,
|
||||||
|
AiterRMSNormDynamicQuantPattern,
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
AiterFusedAddRMSNormDynamicQuantPattern.FUSED_OP,
|
||||||
|
AiterRMSNormDynamicQuantPattern.FUSED_OP,
|
||||||
|
]
|
||||||
return [
|
return [
|
||||||
FUSED_OPS[FusedRMSQuantKey(self.quant_key, True)],
|
FUSED_OPS[FusedRMSQuantKey(self.quant_key, True)],
|
||||||
FUSED_OPS[FusedRMSQuantKey(self.quant_key, False)],
|
FUSED_OPS[FusedRMSQuantKey(self.quant_key, False)],
|
||||||
]
|
]
|
||||||
|
|
||||||
def ops_in_model_before(self):
|
|
||||||
return (
|
|
||||||
[QUANT_OPS[self.quant_key]]
|
|
||||||
if self.enable_quant_fp8_custom_op
|
|
||||||
else [torch.ops.aten.reciprocal]
|
|
||||||
)
|
|
||||||
|
|
||||||
def ops_in_model_before_partial(self):
|
def ops_in_model_before_partial(self):
|
||||||
return (
|
return (
|
||||||
[RMS_OP, RMS_ADD_OP]
|
[RMS_OP, RMS_ADD_OP]
|
||||||
@ -155,67 +214,45 @@ GROUP_SHAPES = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class TestRmsnormGroupFp8QuantModel(torch.nn.Module):
|
def _run_fusion_test(
|
||||||
def __init__(self, hidden_size: int, eps: float, **kwargs):
|
model,
|
||||||
super().__init__()
|
fusion_pass,
|
||||||
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
|
vllm_config,
|
||||||
weight_group_shape=GroupShape(128, 128),
|
dtype,
|
||||||
act_quant_group_shape=GroupShape(1, 128),
|
hidden_size,
|
||||||
cutlass_block_fp8_supported=False,
|
num_tokens,
|
||||||
use_aiter_and_is_supported=True,
|
):
|
||||||
)
|
"""Helper function for common fusion test logic.
|
||||||
self.w = [
|
|
||||||
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
|
||||||
for _ in range(3)
|
|
||||||
]
|
|
||||||
|
|
||||||
scale_hidden_size = (hidden_size + 128 - 1) // 128
|
Must be called within vllm_config context.
|
||||||
self.wscale = [
|
"""
|
||||||
torch.rand((scale_hidden_size, scale_hidden_size), dtype=torch.float32)
|
noop_pass = NoOpEliminationPass(vllm_config)
|
||||||
for _ in range(3)
|
cleanup_pass = PostCleanupPass(vllm_config)
|
||||||
]
|
|
||||||
|
|
||||||
self.norm_weight = [torch.ones(hidden_size) for _ in range(4)]
|
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
|
||||||
self.eps = eps
|
backend2 = TestBackend(noop_pass, cleanup_pass)
|
||||||
|
|
||||||
def forward(self, x):
|
x = torch.rand(num_tokens, hidden_size)
|
||||||
# avoid having graph input be an arg to a pattern directly
|
torch._dynamo.mark_dynamic(x, 0)
|
||||||
x = resid = torch.relu(x)
|
|
||||||
y = rocm_aiter_ops.rms_norm(x, self.norm_weight[0], self.eps)
|
|
||||||
|
|
||||||
x2 = self.w8a8_block_fp8_linear.apply(y, self.w[0], self.wscale[0])
|
model_fused = torch.compile(model, backend=backend)
|
||||||
# make sure resid is used for replacement to work
|
result_fused = model_fused(x)
|
||||||
y2, resid = rocm_aiter_ops.rms_norm2d_with_add(
|
|
||||||
x2, resid, self.norm_weight[1], self.eps
|
|
||||||
)
|
|
||||||
|
|
||||||
x3 = self.w8a8_block_fp8_linear.apply(y2, self.w[1], self.wscale[1])
|
model_unfused = torch.compile(model, backend=backend2)
|
||||||
|
result_unfused = model_unfused(x)
|
||||||
|
|
||||||
y3, resid = rocm_aiter_ops.rms_norm2d_with_add(
|
if dtype == torch.float16:
|
||||||
x3, resid, self.norm_weight[2], self.eps
|
ATOL, RTOL = (2e-3, 2e-3)
|
||||||
)
|
else:
|
||||||
|
ATOL, RTOL = (1e-2, 1e-2)
|
||||||
|
|
||||||
x4 = self.w8a8_block_fp8_linear.apply(y3, self.w[2], self.wscale[2])
|
torch.testing.assert_close(result_fused, result_unfused, atol=ATOL, rtol=RTOL)
|
||||||
|
|
||||||
y4, resid = rocm_aiter_ops.rms_norm2d_with_add(
|
assert fusion_pass.matched_count == 3
|
||||||
x4, resid, self.norm_weight[3], self.eps
|
backend.check_before_ops(model.ops_in_model_before())
|
||||||
)
|
backend.check_after_ops(model.ops_in_model_after())
|
||||||
return y4
|
|
||||||
|
|
||||||
def ops_in_model_before(self):
|
return backend, backend2
|
||||||
return [
|
|
||||||
torch.ops.vllm.rocm_aiter_rms_norm,
|
|
||||||
torch.ops.vllm.rocm_aiter_group_fp8_quant,
|
|
||||||
]
|
|
||||||
|
|
||||||
def ops_in_model_before_partial(self):
|
|
||||||
return []
|
|
||||||
|
|
||||||
def ops_in_model_after(self):
|
|
||||||
return [
|
|
||||||
torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant,
|
|
||||||
torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant,
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||||
@ -223,11 +260,8 @@ class TestRmsnormGroupFp8QuantModel(torch.nn.Module):
|
|||||||
@pytest.mark.parametrize("num_tokens", [257])
|
@pytest.mark.parametrize("num_tokens", [257])
|
||||||
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
|
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
|
||||||
@pytest.mark.parametrize("group_shape", GROUP_SHAPES)
|
@pytest.mark.parametrize("group_shape", GROUP_SHAPES)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
|
||||||
"model_class, enable_rms_norm_custom_op, enable_quant_fp8_custom_op",
|
@pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False])
|
||||||
list(itertools.product([TestModel], [True, False], [True, False]))
|
|
||||||
+ [(TestRmsnormGroupFp8QuantModel, False, False)],
|
|
||||||
)
|
|
||||||
# cuda_force_torch used to test torch code path on platforms that
|
# cuda_force_torch used to test torch code path on platforms that
|
||||||
# cutlass_fp8_supported() == True.
|
# cutlass_fp8_supported() == True.
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -242,23 +276,13 @@ def test_fusion_rmsnorm_quant(
|
|||||||
num_tokens,
|
num_tokens,
|
||||||
eps,
|
eps,
|
||||||
group_shape,
|
group_shape,
|
||||||
model_class,
|
|
||||||
enable_rms_norm_custom_op,
|
enable_rms_norm_custom_op,
|
||||||
enable_quant_fp8_custom_op,
|
enable_quant_fp8_custom_op,
|
||||||
cuda_force_torch,
|
cuda_force_torch,
|
||||||
):
|
):
|
||||||
if model_class is TestRmsnormGroupFp8QuantModel and not IS_AITER_FOUND:
|
|
||||||
pytest.skip("AITER is not supported on this GPU.")
|
|
||||||
|
|
||||||
torch.set_default_device("cuda")
|
|
||||||
torch.set_default_dtype(dtype)
|
|
||||||
torch.manual_seed(1)
|
|
||||||
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths
|
|
||||||
|
|
||||||
if not enable_quant_fp8_custom_op and group_shape.is_per_group():
|
if not enable_quant_fp8_custom_op and group_shape.is_per_group():
|
||||||
pytest.skip("Unsupported unwrapped quant fp8 op for blockwise quantization")
|
pytest.skip("Unsupported unwrapped quant fp8 op for blockwise quantization")
|
||||||
|
|
||||||
# Skip test for 64-bit group shape when running with cutlass or deepgemm
|
|
||||||
if group_shape == GroupShape(1, 64) and (
|
if group_shape == GroupShape(1, 64) and (
|
||||||
cutlass_block_fp8_supported() or is_deep_gemm_supported()
|
cutlass_block_fp8_supported() or is_deep_gemm_supported()
|
||||||
):
|
):
|
||||||
@ -269,6 +293,7 @@ def test_fusion_rmsnorm_quant(
|
|||||||
custom_ops.append("+rms_norm")
|
custom_ops.append("+rms_norm")
|
||||||
if enable_quant_fp8_custom_op:
|
if enable_quant_fp8_custom_op:
|
||||||
custom_ops.append("+quant_fp8")
|
custom_ops.append("+quant_fp8")
|
||||||
|
|
||||||
vllm_config = VllmConfig(
|
vllm_config = VllmConfig(
|
||||||
model_config=ModelConfig(dtype=dtype),
|
model_config=ModelConfig(dtype=dtype),
|
||||||
compilation_config=CompilationConfig(
|
compilation_config=CompilationConfig(
|
||||||
@ -279,60 +304,97 @@ def test_fusion_rmsnorm_quant(
|
|||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
with vllm.config.set_current_vllm_config(vllm_config):
|
with vllm.config.set_current_vllm_config(vllm_config):
|
||||||
# Reshape pass is needed for the fusion pass to work
|
# Setup device before model creation
|
||||||
noop_pass = NoOpEliminationPass(vllm_config)
|
torch.set_default_device("cuda")
|
||||||
if model_class is TestRmsnormGroupFp8QuantModel:
|
torch.set_default_dtype(dtype)
|
||||||
from vllm.compilation.rocm_aiter_fusion import (
|
torch.manual_seed(1)
|
||||||
RocmAiterRMSNormFp8GroupQuantFusionPass,
|
maybe_create_device_identity()
|
||||||
)
|
|
||||||
|
|
||||||
fusion_pass = RocmAiterRMSNormFp8GroupQuantFusionPass(vllm_config)
|
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
||||||
else:
|
model = TestModel(
|
||||||
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
|
||||||
cleanup_pass = PostCleanupPass(vllm_config)
|
|
||||||
|
|
||||||
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
|
|
||||||
backend2 = TestBackend(noop_pass, cleanup_pass)
|
|
||||||
model = model_class(
|
|
||||||
hidden_size=hidden_size,
|
hidden_size=hidden_size,
|
||||||
eps=eps,
|
eps=eps,
|
||||||
group_shape=group_shape,
|
group_shape=group_shape,
|
||||||
|
use_aiter=False,
|
||||||
cuda_force_torch=cuda_force_torch,
|
cuda_force_torch=cuda_force_torch,
|
||||||
)
|
)
|
||||||
# First dimension dynamic
|
|
||||||
x = torch.rand(num_tokens, hidden_size)
|
|
||||||
torch._dynamo.mark_dynamic(x, 0)
|
|
||||||
|
|
||||||
model_fused = torch.compile(model, backend=backend)
|
backend, _ = _run_fusion_test(
|
||||||
result_fused = model_fused(x)
|
model, fusion_pass, vllm_config, dtype, hidden_size, num_tokens
|
||||||
|
)
|
||||||
model_unfused = torch.compile(model, backend=backend2)
|
|
||||||
result_unfused = model_unfused(x)
|
|
||||||
|
|
||||||
if dtype == torch.float16:
|
|
||||||
ATOL, RTOL = (2e-3, 2e-3)
|
|
||||||
else:
|
|
||||||
ATOL, RTOL = (1e-2, 1e-2)
|
|
||||||
|
|
||||||
torch.testing.assert_close(result_fused, result_unfused, atol=ATOL, rtol=RTOL)
|
|
||||||
|
|
||||||
assert fusion_pass.matched_count == 3
|
|
||||||
backend.check_before_ops(model.ops_in_model_before())
|
|
||||||
backend.check_before_ops(
|
backend.check_before_ops(
|
||||||
model.ops_in_model_before_partial(), fully_replaced=False
|
model.ops_in_model_before_partial(), fully_replaced=False
|
||||||
)
|
)
|
||||||
backend.check_after_ops(model.ops_in_model_after())
|
|
||||||
|
|
||||||
# If RMSNorm custom op is disabled (native/torch impl used),
|
# If RMSNorm custom op is disabled (native/torch impl used),
|
||||||
# there's a risk that the fused add doesn't get included in the
|
# there's a risk that the fused add doesn't get included in the
|
||||||
# replacement and only the rms part gets fused with quant.
|
# replacement and only the rms part gets fused with quant.
|
||||||
# Hence, we check only 2 add nodes are left (final fused rmsnorm add).
|
# Hence, we check only 2 add nodes are left (final fused rmsnorm add).
|
||||||
if (
|
if not enable_rms_norm_custom_op:
|
||||||
not enable_rms_norm_custom_op
|
|
||||||
and model_class is not TestRmsnormGroupFp8QuantModel
|
|
||||||
):
|
|
||||||
n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g))
|
n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g))
|
||||||
# 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each)
|
# 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each)
|
||||||
assert n_add_nodes(backend.graph_pre_pass) == 7
|
assert n_add_nodes(backend.graph_pre_pass) == 7
|
||||||
assert n_add_nodes(backend.graph_post_pass) == 2
|
assert n_add_nodes(backend.graph_post_pass) == 2
|
||||||
|
|
||||||
|
|
||||||
|
GROUP_SHAPE_QUANT_OPS_MATCHS = [
|
||||||
|
(GroupShape.PER_TOKEN, True),
|
||||||
|
(GroupShape.PER_TOKEN, False),
|
||||||
|
(GroupShape(1, 128), True),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||||
|
@pytest.mark.parametrize("hidden_size", [256])
|
||||||
|
@pytest.mark.parametrize("num_tokens", [257])
|
||||||
|
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"group_shape, use_aiter_quant_op", GROUP_SHAPE_QUANT_OPS_MATCHS
|
||||||
|
)
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
(not current_platform.is_rocm() or not IS_AITER_FOUND),
|
||||||
|
reason="Only test on ROCm with aiter package installed",
|
||||||
|
)
|
||||||
|
def test_aiter_fusion_rmsnorm_quant(
|
||||||
|
dtype: torch.dtype,
|
||||||
|
hidden_size: int,
|
||||||
|
num_tokens: int,
|
||||||
|
eps: float,
|
||||||
|
group_shape: GroupShape,
|
||||||
|
use_aiter_quant_op: bool,
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
):
|
||||||
|
vllm_config = VllmConfig(
|
||||||
|
model_config=ModelConfig(dtype=dtype),
|
||||||
|
compilation_config=CompilationConfig(
|
||||||
|
mode=CompilationMode.VLLM_COMPILE,
|
||||||
|
custom_ops=["+rms_norm", "+quant_fp8"],
|
||||||
|
pass_config=PassConfig(fuse_norm_quant=True, eliminate_noops=True),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m:
|
||||||
|
from vllm.compilation.rocm_aiter_fusion import RocmAiterRMSNormFusionPass
|
||||||
|
|
||||||
|
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||||
|
rocm_aiter_ops.refresh_env_variables()
|
||||||
|
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
torch.set_default_dtype(dtype)
|
||||||
|
torch.manual_seed(1)
|
||||||
|
maybe_create_device_identity()
|
||||||
|
|
||||||
|
fusion_pass = RocmAiterRMSNormFusionPass(vllm_config)
|
||||||
|
model = TestModel(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
eps=eps,
|
||||||
|
group_shape=group_shape,
|
||||||
|
use_aiter=True,
|
||||||
|
use_aiter_quant_op=use_aiter_quant_op,
|
||||||
|
)
|
||||||
|
|
||||||
|
_run_fusion_test(
|
||||||
|
model, fusion_pass, vllm_config, dtype, hidden_size, num_tokens
|
||||||
|
)
|
||||||
|
|||||||
@ -4,6 +4,7 @@ import functools
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch._ops import OpOverload
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@ -433,16 +434,16 @@ def _rocm_aiter_rmsnorm2d_fwd_with_add_impl(
|
|||||||
from aiter import rmsnorm2d_fwd_with_add
|
from aiter import rmsnorm2d_fwd_with_add
|
||||||
|
|
||||||
residual_out = torch.empty_like(residual)
|
residual_out = torch.empty_like(residual)
|
||||||
output = torch.empty_like(x)
|
out = torch.empty_like(x)
|
||||||
rmsnorm2d_fwd_with_add(
|
rmsnorm2d_fwd_with_add(
|
||||||
output, # output
|
out, # output
|
||||||
x, # input
|
x, # input
|
||||||
residual, # residual input
|
residual, # residual input
|
||||||
residual_out, # residual output
|
residual_out, # residual output
|
||||||
weight,
|
weight,
|
||||||
variance_epsilon,
|
variance_epsilon,
|
||||||
)
|
)
|
||||||
return output, residual_out
|
return out, residual_out
|
||||||
|
|
||||||
|
|
||||||
def _rocm_aiter_rmsnorm2d_fwd_with_add_fake(
|
def _rocm_aiter_rmsnorm2d_fwd_with_add_fake(
|
||||||
@ -451,7 +452,84 @@ def _rocm_aiter_rmsnorm2d_fwd_with_add_fake(
|
|||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
variance_epsilon: float,
|
variance_epsilon: float,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
return torch.empty_like(x), torch.empty_like(residual)
|
residual_out = torch.empty_like(residual)
|
||||||
|
out = torch.empty_like(x)
|
||||||
|
return out, residual_out
|
||||||
|
|
||||||
|
|
||||||
|
def _rocm_aiter_rmsnorm_fused_add_dynamic_quant_impl(
|
||||||
|
x: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
epsilon: float,
|
||||||
|
quant_dtype: torch.dtype,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
import aiter as rocm_aiter
|
||||||
|
|
||||||
|
assert quant_dtype in [torch.int8, _FP8_DTYPE]
|
||||||
|
|
||||||
|
y_scale = torch.empty(x.shape[0], 1, dtype=torch.float32, device=x.device)
|
||||||
|
out = torch.empty(x.shape, dtype=quant_dtype, device=x.device)
|
||||||
|
residual_out = torch.empty_like(x)
|
||||||
|
|
||||||
|
rocm_aiter.rmsnorm2d_fwd_with_add_dynamicquant(
|
||||||
|
out,
|
||||||
|
x,
|
||||||
|
residual,
|
||||||
|
residual_out,
|
||||||
|
y_scale,
|
||||||
|
weight,
|
||||||
|
epsilon,
|
||||||
|
use_model_sensitive_rmsnorm=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
return out, residual_out, y_scale
|
||||||
|
|
||||||
|
|
||||||
|
def _rocm_aiter_rmsnorm_fused_add_dynamic_quant_fake(
|
||||||
|
x: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
epsilon: float,
|
||||||
|
quant_dtype: torch.dtype,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
y_scale = torch.empty(x.shape[0], 1, dtype=torch.float32, device=x.device)
|
||||||
|
out = torch.empty(x.shape, dtype=quant_dtype, device=x.device)
|
||||||
|
residual_out = torch.empty_like(x)
|
||||||
|
|
||||||
|
return out, residual_out, y_scale
|
||||||
|
|
||||||
|
|
||||||
|
def _rocm_aiter_rmsnorm_fused_dynamic_quant_impl(
|
||||||
|
x: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
epsilon: float,
|
||||||
|
quant_dtype: torch.dtype,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
import aiter as rocm_aiter
|
||||||
|
|
||||||
|
assert quant_dtype in [torch.int8, _FP8_DTYPE]
|
||||||
|
|
||||||
|
y_scale = torch.empty(x.shape[0], 1, dtype=torch.float32, device=x.device)
|
||||||
|
out = torch.empty(x.shape, dtype=quant_dtype, device=x.device)
|
||||||
|
|
||||||
|
rocm_aiter.rmsnorm2d_fwd_with_dynamicquant(
|
||||||
|
out, x, y_scale, weight, epsilon, use_model_sensitive_rmsnorm=0
|
||||||
|
)
|
||||||
|
|
||||||
|
return out, y_scale
|
||||||
|
|
||||||
|
|
||||||
|
def _rocm_aiter_rmsnorm_fused_dynamic_quant_fake(
|
||||||
|
x: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
epsilon: float,
|
||||||
|
quant_dtype: torch.dtype,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
y_scale = torch.empty(x.shape[0], 1, dtype=torch.float32, device=x.device)
|
||||||
|
out = torch.empty(x.shape, dtype=quant_dtype, device=x.device)
|
||||||
|
|
||||||
|
return out, y_scale
|
||||||
|
|
||||||
|
|
||||||
def _rocm_aiter_per_tensor_quant_impl(
|
def _rocm_aiter_per_tensor_quant_impl(
|
||||||
@ -527,7 +605,11 @@ def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl(
|
|||||||
dtype_quant=AITER_FP8_DTYPE,
|
dtype_quant=AITER_FP8_DTYPE,
|
||||||
res1=residual,
|
res1=residual,
|
||||||
)
|
)
|
||||||
return (x_quant, x_quant_scales, res)
|
return (
|
||||||
|
x_quant,
|
||||||
|
res,
|
||||||
|
x_quant_scales,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake(
|
def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake(
|
||||||
@ -541,8 +623,8 @@ def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake(
|
|||||||
scale_shape = (M, (N + group_size - 1) // group_size)
|
scale_shape = (M, (N + group_size - 1) // group_size)
|
||||||
return (
|
return (
|
||||||
torch.empty_like(x, dtype=AITER_FP8_DTYPE, device=x.device),
|
torch.empty_like(x, dtype=AITER_FP8_DTYPE, device=x.device),
|
||||||
torch.empty(scale_shape, dtype=torch.float32, device=x.device),
|
|
||||||
torch.empty_like(residual, device=residual.device),
|
torch.empty_like(residual, device=residual.device),
|
||||||
|
torch.empty(scale_shape, dtype=torch.float32, device=x.device),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -901,6 +983,20 @@ class rocm_aiter_ops:
|
|||||||
dispatch_key=current_platform.dispatch_key,
|
dispatch_key=current_platform.dispatch_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="rocm_aiter_rmsnorm_fused_dynamic_quant",
|
||||||
|
op_func=_rocm_aiter_rmsnorm_fused_dynamic_quant_impl,
|
||||||
|
fake_impl=_rocm_aiter_rmsnorm_fused_dynamic_quant_fake,
|
||||||
|
dispatch_key=current_platform.dispatch_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="rocm_aiter_rmsnorm_fused_add_dynamic_quant",
|
||||||
|
op_func=_rocm_aiter_rmsnorm_fused_add_dynamic_quant_impl,
|
||||||
|
fake_impl=_rocm_aiter_rmsnorm_fused_add_dynamic_quant_fake,
|
||||||
|
dispatch_key=current_platform.dispatch_key,
|
||||||
|
)
|
||||||
|
|
||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="rocm_aiter_rmsnorm_fp8_group_quant",
|
op_name="rocm_aiter_rmsnorm_fp8_group_quant",
|
||||||
op_func=_rocm_aiter_rmsnorm_fp8_group_quant_impl,
|
op_func=_rocm_aiter_rmsnorm_fp8_group_quant_impl,
|
||||||
@ -936,13 +1032,54 @@ class rocm_aiter_ops:
|
|||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="rocm_aiter_per_token_quant",
|
op_name="rocm_aiter_per_token_quant",
|
||||||
op_func=_rocm_aiter_per_token_quant_impl,
|
op_func=_rocm_aiter_per_token_quant_impl,
|
||||||
mutates_args=["scale"],
|
|
||||||
fake_impl=_rocm_aiter_per_token_quant_fake,
|
fake_impl=_rocm_aiter_per_token_quant_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
dispatch_key=current_platform.dispatch_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
_OPS_REGISTERED = True
|
_OPS_REGISTERED = True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_rmsnorm_fused_add_op() -> OpOverload:
|
||||||
|
return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add.default
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_rmsnorm_op() -> OpOverload:
|
||||||
|
return torch.ops.vllm.rocm_aiter_rms_norm.default
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_rmsnorm_fused_add_dynamic_quant_op() -> OpOverload:
|
||||||
|
return torch.ops.vllm.rocm_aiter_rmsnorm_fused_add_dynamic_quant.default
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_rmsnorm_fused_dynamic_quant_op() -> OpOverload:
|
||||||
|
return torch.ops.vllm.rocm_aiter_rmsnorm_fused_dynamic_quant.default
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_rmsnorm_group_fused_quant_op() -> OpOverload:
|
||||||
|
return torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant.default
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_rmsnorm_group_add_fused_quant_op() -> OpOverload:
|
||||||
|
return torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant.default
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_per_token_quant_op() -> OpOverload:
|
||||||
|
return torch.ops.vllm.rocm_aiter_per_token_quant.default
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_group_quant_op() -> OpOverload:
|
||||||
|
return torch.ops.vllm.rocm_aiter_group_fp8_quant.default
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_act_mul_fused_fp8_group_quant_op() -> OpOverload:
|
||||||
|
return torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant.default
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def rms_norm(
|
||||||
|
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.ops.vllm.rocm_aiter_rms_norm(x, weight, variance_epsilon)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def rms_norm2d_with_add(
|
def rms_norm2d_with_add(
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@ -954,12 +1091,6 @@ class rocm_aiter_ops:
|
|||||||
x, residual, weight, variance_epsilon
|
x, residual, weight, variance_epsilon
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def rms_norm(
|
|
||||||
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
|
|
||||||
) -> torch.Tensor:
|
|
||||||
return torch.ops.vllm.rocm_aiter_rms_norm(x, weight, variance_epsilon)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def gemm_a8w8(
|
def gemm_a8w8(
|
||||||
A: torch.Tensor,
|
A: torch.Tensor,
|
||||||
|
|||||||
@ -6,11 +6,13 @@ import torch
|
|||||||
from torch._higher_order_ops import auto_functionalized
|
from torch._higher_order_ops import auto_functionalized
|
||||||
from torch._ops import OpOverload
|
from torch._ops import OpOverload
|
||||||
|
|
||||||
|
from vllm._aiter_ops import rocm_aiter_ops
|
||||||
from vllm.config import get_current_vllm_config
|
from vllm.config import get_current_vllm_config
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
GroupShape,
|
||||||
QuantKey,
|
QuantKey,
|
||||||
_normalize_quant_group_shape,
|
_normalize_quant_group_shape,
|
||||||
kFp8Dynamic64Sym,
|
kFp8Dynamic64Sym,
|
||||||
@ -150,26 +152,50 @@ class MatcherRotaryEmbedding(MatcherCustomOp):
|
|||||||
|
|
||||||
|
|
||||||
class MatcherRMSNorm(MatcherCustomOp):
|
class MatcherRMSNorm(MatcherCustomOp):
|
||||||
def __init__(self, epsilon: float, enabled: bool | None = None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
epsilon: float,
|
||||||
|
enabled: bool | None = None,
|
||||||
|
match_rocm_aiter: bool = False,
|
||||||
|
):
|
||||||
if enabled is None:
|
if enabled is None:
|
||||||
enabled = RMSNorm.enabled()
|
enabled = RMSNorm.enabled()
|
||||||
|
|
||||||
super().__init__(enabled)
|
super().__init__(enabled)
|
||||||
self.epsilon = epsilon
|
self.epsilon = epsilon
|
||||||
|
self._rmsnorm_op = RMS_OP
|
||||||
|
self.match_rocm_aiter = match_rocm_aiter
|
||||||
|
|
||||||
|
if match_rocm_aiter:
|
||||||
|
self._rmsnorm_op = rocm_aiter_ops.get_rmsnorm_op()
|
||||||
|
|
||||||
def inputs(self):
|
def inputs(self):
|
||||||
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
|
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
|
||||||
weight = self.empty(16)
|
weight = self.empty(16)
|
||||||
return [input, weight]
|
return [input, weight]
|
||||||
|
|
||||||
|
def forward_rocm_aiter(
|
||||||
|
self,
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return self._rmsnorm_op(
|
||||||
|
x=input,
|
||||||
|
weight=weight,
|
||||||
|
variance_epsilon=self.epsilon,
|
||||||
|
)
|
||||||
|
|
||||||
def forward_custom(
|
def forward_custom(
|
||||||
self,
|
self,
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if self.match_rocm_aiter:
|
||||||
|
return self.forward_rocm_aiter(input, weight)
|
||||||
|
|
||||||
result = torch.empty_like(input)
|
result = torch.empty_like(input)
|
||||||
_, result = auto_functionalized(
|
_, result = auto_functionalized(
|
||||||
RMS_OP,
|
self._rmsnorm_op,
|
||||||
result=result,
|
result=result,
|
||||||
input=input,
|
input=input,
|
||||||
weight=weight,
|
weight=weight,
|
||||||
@ -189,12 +215,23 @@ class MatcherRMSNorm(MatcherCustomOp):
|
|||||||
|
|
||||||
|
|
||||||
class MatcherFusedAddRMSNorm(MatcherCustomOp):
|
class MatcherFusedAddRMSNorm(MatcherCustomOp):
|
||||||
def __init__(self, epsilon: float, enabled: bool | None = None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
epsilon: float,
|
||||||
|
enabled: bool | None = None,
|
||||||
|
match_rocm_aiter: bool = False,
|
||||||
|
):
|
||||||
if enabled is None:
|
if enabled is None:
|
||||||
enabled = RMSNorm.enabled()
|
enabled = RMSNorm.enabled()
|
||||||
|
|
||||||
super().__init__(enabled)
|
super().__init__(enabled)
|
||||||
self.epsilon = epsilon
|
self.epsilon = epsilon
|
||||||
|
self.match_rocm_aiter = match_rocm_aiter
|
||||||
|
|
||||||
|
self._rmsnorm_op = RMS_ADD_OP
|
||||||
|
|
||||||
|
if match_rocm_aiter:
|
||||||
|
self._rmsnorm_op = rocm_aiter_ops.get_rmsnorm_fused_add_op()
|
||||||
|
|
||||||
def inputs(self):
|
def inputs(self):
|
||||||
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
|
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
|
||||||
@ -202,14 +239,27 @@ class MatcherFusedAddRMSNorm(MatcherCustomOp):
|
|||||||
residual = self.empty(5, 16)
|
residual = self.empty(5, 16)
|
||||||
return [input, weight, residual]
|
return [input, weight, residual]
|
||||||
|
|
||||||
|
def forward_rocm_aiter(
|
||||||
|
self,
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
return self._rmsnorm_op(
|
||||||
|
x=input, residual=residual, weight=weight, variance_epsilon=self.epsilon
|
||||||
|
)
|
||||||
|
|
||||||
def forward_custom(
|
def forward_custom(
|
||||||
self,
|
self,
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
residual: torch.Tensor,
|
residual: torch.Tensor,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if self.match_rocm_aiter:
|
||||||
|
return self.forward_rocm_aiter(input, weight, residual)
|
||||||
|
|
||||||
_, result, residual = auto_functionalized(
|
_, result, residual = auto_functionalized(
|
||||||
RMS_ADD_OP,
|
self._rmsnorm_op,
|
||||||
input=input,
|
input=input,
|
||||||
residual=residual,
|
residual=residual,
|
||||||
weight=weight,
|
weight=weight,
|
||||||
@ -236,22 +286,46 @@ class MatcherQuantFP8(MatcherCustomOp):
|
|||||||
enabled: bool | None = None,
|
enabled: bool | None = None,
|
||||||
has_col_major_scales: bool = False,
|
has_col_major_scales: bool = False,
|
||||||
is_e8m0: bool = False,
|
is_e8m0: bool = False,
|
||||||
|
match_rocm_aiter: bool = False,
|
||||||
):
|
):
|
||||||
if enabled is None:
|
if enabled is None:
|
||||||
enabled = QuantFP8.enabled()
|
enabled = QuantFP8.enabled()
|
||||||
|
|
||||||
super().__init__(enabled)
|
super().__init__(enabled)
|
||||||
self.quant_key = quant_key
|
self.quant_key = quant_key
|
||||||
assert quant_key in QUANT_OPS, f"unsupported quantization scheme {quant_key}"
|
|
||||||
self.QUANT_OP = QUANT_OPS[quant_key]
|
|
||||||
|
|
||||||
self.has_col_major_scales = has_col_major_scales
|
self.has_col_major_scales = has_col_major_scales
|
||||||
self.is_e8m0 = is_e8m0
|
self.is_e8m0 = is_e8m0
|
||||||
|
self.match_rocm_aiter = match_rocm_aiter
|
||||||
|
|
||||||
|
if match_rocm_aiter:
|
||||||
|
assert not quant_key.scale.group_shape.is_per_tensor(), (
|
||||||
|
"ROCm aiter fusion pass does not support per tensor quantization"
|
||||||
|
)
|
||||||
|
if quant_key.scale.group_shape.is_per_token():
|
||||||
|
self.QUANT_OP = rocm_aiter_ops.get_per_token_quant_op()
|
||||||
|
else:
|
||||||
|
assert quant_key.scale.group_shape.col == 128, (
|
||||||
|
"ROCm aiter fusion pass currently supports "
|
||||||
|
"quantization operation with group_size 128"
|
||||||
|
)
|
||||||
|
if current_platform.is_fp8_fnuz():
|
||||||
|
self.QUANT_OP = rocm_aiter_ops.get_group_quant_op()
|
||||||
|
else:
|
||||||
|
self.QUANT_OP = (
|
||||||
|
torch.ops.vllm.triton_per_token_group_quant_fp8.default
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
assert quant_key in QUANT_OPS, (
|
||||||
|
f"unsupported quantization scheme {quant_key}"
|
||||||
|
)
|
||||||
|
self.QUANT_OP = QUANT_OPS[quant_key]
|
||||||
|
|
||||||
|
assert quant_key.dtype == current_platform.fp8_dtype(), (
|
||||||
|
"Only QuantFP8 supported by"
|
||||||
|
)
|
||||||
|
assert quant_key.scale2 is None
|
||||||
|
|
||||||
assert quant_key.dtype == current_platform.fp8_dtype(), (
|
|
||||||
"Only QuantFP8 supported by"
|
|
||||||
)
|
|
||||||
assert quant_key.scale2 is None
|
|
||||||
self.quant_fp8 = QuantFP8(
|
self.quant_fp8 = QuantFP8(
|
||||||
quant_key.scale.static,
|
quant_key.scale.static,
|
||||||
quant_key.scale.group_shape,
|
quant_key.scale.group_shape,
|
||||||
@ -259,11 +333,29 @@ class MatcherQuantFP8(MatcherCustomOp):
|
|||||||
use_ue8m0=is_e8m0,
|
use_ue8m0=is_e8m0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def forward_rocm_aiter(
|
||||||
|
self,
|
||||||
|
input: torch.Tensor,
|
||||||
|
scale: torch.Tensor | None = None,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
quant_key_group_shape = self.quant_key.scale.group_shape
|
||||||
|
if quant_key_group_shape == GroupShape.PER_TOKEN:
|
||||||
|
return self.QUANT_OP(
|
||||||
|
x=input,
|
||||||
|
quant_dtype=self.quant_key.dtype,
|
||||||
|
scale=scale,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self.QUANT_OP(input, quant_key_group_shape.col)
|
||||||
|
|
||||||
def forward_custom(
|
def forward_custom(
|
||||||
self,
|
self,
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
scale: torch.Tensor | None = None,
|
scale: torch.Tensor | None = None,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if self.match_rocm_aiter:
|
||||||
|
return self.forward_rocm_aiter(input, scale)
|
||||||
|
|
||||||
result = torch.empty(
|
result = torch.empty(
|
||||||
input.shape, device=input.device, dtype=self.quant_key.dtype
|
input.shape, device=input.device, dtype=self.quant_key.dtype
|
||||||
)
|
)
|
||||||
|
|||||||
@ -16,7 +16,7 @@ from .vllm_inductor_pass import VllmInductorPass
|
|||||||
|
|
||||||
if rocm_aiter_ops.is_enabled():
|
if rocm_aiter_ops.is_enabled():
|
||||||
from vllm.compilation.rocm_aiter_fusion import (
|
from vllm.compilation.rocm_aiter_fusion import (
|
||||||
RocmAiterRMSNormFp8GroupQuantFusionPass,
|
RocmAiterRMSNormFusionPass,
|
||||||
RocmAiterSiluMulFp8GroupQuantFusionPass,
|
RocmAiterSiluMulFp8GroupQuantFusionPass,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -117,7 +117,9 @@ class PostGradPassManager(CustomGraphPass):
|
|||||||
if self.pass_config.fuse_norm_quant:
|
if self.pass_config.fuse_norm_quant:
|
||||||
self.passes += [RMSNormQuantFusionPass(config)]
|
self.passes += [RMSNormQuantFusionPass(config)]
|
||||||
if rocm_aiter_ops.is_enabled():
|
if rocm_aiter_ops.is_enabled():
|
||||||
self.passes += [RocmAiterRMSNormFp8GroupQuantFusionPass(config)]
|
self.passes += [
|
||||||
|
RocmAiterRMSNormFusionPass(config),
|
||||||
|
]
|
||||||
if self.pass_config.fuse_act_quant:
|
if self.pass_config.fuse_act_quant:
|
||||||
self.passes += [ActivationQuantFusionPass(config)]
|
self.passes += [ActivationQuantFusionPass(config)]
|
||||||
if rocm_aiter_ops.is_enabled():
|
if rocm_aiter_ops.is_enabled():
|
||||||
|
|||||||
@ -9,60 +9,195 @@ from torch._inductor.pattern_matcher import PatternMatcherPass
|
|||||||
from torch._ops import OpOverload
|
from torch._ops import OpOverload
|
||||||
|
|
||||||
import vllm.model_executor.layers.quantization.utils.fp8_utils # noqa: F401
|
import vllm.model_executor.layers.quantization.utils.fp8_utils # noqa: F401
|
||||||
|
from vllm._aiter_ops import rocm_aiter_ops
|
||||||
from vllm.compilation.activation_quant_fusion import ActivationQuantPattern
|
from vllm.compilation.activation_quant_fusion import ActivationQuantPattern
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
GroupShape,
|
||||||
|
QuantKey,
|
||||||
|
ScaleDesc,
|
||||||
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from .fusion import empty_bf16
|
from .fusion import (
|
||||||
|
FusedRMSQuantKey,
|
||||||
|
)
|
||||||
from .inductor_pass import enable_fake_mode
|
from .inductor_pass import enable_fake_mode
|
||||||
from .matcher_utils import MatcherSiluAndMul
|
from .matcher_utils import (
|
||||||
|
MatcherFusedAddRMSNorm,
|
||||||
|
MatcherQuantFP8,
|
||||||
|
MatcherRMSNorm,
|
||||||
|
MatcherSiluAndMul,
|
||||||
|
)
|
||||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
FP8_DTYPE = current_platform.fp8_dtype()
|
FP8_DTYPE = current_platform.fp8_dtype()
|
||||||
|
|
||||||
AITER_RMS_GROUP_QUANT_OP = torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant.default
|
|
||||||
AITER_RMS_ADD_GROUP_QUANT_OP = (
|
|
||||||
torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant.default
|
|
||||||
)
|
|
||||||
|
|
||||||
AITER_RMS_OP = torch.ops.vllm.rocm_aiter_rms_norm.default
|
class AiterRMSNormQuantPattern:
|
||||||
AITER_RMS_ADD_OP = torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add.default
|
def __init__(
|
||||||
|
self, epsilon: float, key: FusedRMSQuantKey, match_aiter_quant: bool = True
|
||||||
|
):
|
||||||
|
self.epsilon = epsilon
|
||||||
|
self.quant_dtype = key.quant.dtype
|
||||||
|
|
||||||
AITER_GROUP_FP8_QUANT_OP = torch.ops.vllm.rocm_aiter_group_fp8_quant.default
|
self.rmsnorm_matcher = (
|
||||||
TRITON_GROUP_FP8_QUANT_OP = torch.ops.vllm.triton_per_token_group_quant_fp8.default
|
MatcherRMSNorm(epsilon, match_rocm_aiter=True)
|
||||||
|
if not key.fused_add
|
||||||
FUSED_SILU_MUL_QUANT_OP = torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant.default
|
else MatcherFusedAddRMSNorm(epsilon, match_rocm_aiter=True)
|
||||||
|
)
|
||||||
|
self.quant_matcher = MatcherQuantFP8(
|
||||||
|
key.quant,
|
||||||
|
match_rocm_aiter=match_aiter_quant,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AiterRMSFp8GroupQuantPattern:
|
class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
|
||||||
|
"""AITER RMSNorm + Dynamic Quantization pattern."""
|
||||||
|
|
||||||
|
FUSED_OP = rocm_aiter_ops.get_rmsnorm_fused_dynamic_quant_op()
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
epsilon: float,
|
||||||
|
quant_dtype: torch.dtype,
|
||||||
|
match_aiter_quant: bool = True,
|
||||||
|
group_shape: GroupShape = GroupShape.PER_TOKEN,
|
||||||
|
symmetric=True,
|
||||||
|
):
|
||||||
|
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||||
|
key = FusedRMSQuantKey(
|
||||||
|
fused_add=False,
|
||||||
|
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(epsilon, key, match_aiter_quant)
|
||||||
|
|
||||||
|
def register(self, pm_pass):
|
||||||
|
def pattern(
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
):
|
||||||
|
result_rms = self.rmsnorm_matcher(input, weight)
|
||||||
|
result, scale = self.quant_matcher(result_rms)
|
||||||
|
return result, scale
|
||||||
|
|
||||||
|
def replacement(
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
):
|
||||||
|
result = self.FUSED_OP(
|
||||||
|
x=input,
|
||||||
|
weight=weight,
|
||||||
|
epsilon=self.epsilon,
|
||||||
|
quant_dtype=self.quant_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
return result[0], result[1]
|
||||||
|
|
||||||
|
pm.register_replacement(
|
||||||
|
pattern,
|
||||||
|
replacement,
|
||||||
|
self.rmsnorm_matcher.inputs(),
|
||||||
|
pm.fwd_only,
|
||||||
|
pm_pass,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AiterFusedAddRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
|
||||||
|
"""AITER RMSNorm Fused Add + Dynamic Quantization pattern."""
|
||||||
|
|
||||||
|
FUSED_OP = rocm_aiter_ops.get_rmsnorm_fused_add_dynamic_quant_op()
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
epsilon: float,
|
||||||
|
quant_dtype: torch.dtype,
|
||||||
|
match_aiter_quant: bool = True,
|
||||||
|
group_shape: GroupShape = GroupShape.PER_TOKEN,
|
||||||
|
symmetric=True,
|
||||||
|
):
|
||||||
|
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||||
|
key = FusedRMSQuantKey(
|
||||||
|
fused_add=True,
|
||||||
|
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(epsilon, key, match_aiter_quant)
|
||||||
|
|
||||||
|
def register(self, pm_pass):
|
||||||
|
def pattern(
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
):
|
||||||
|
result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual)
|
||||||
|
result, scale = self.quant_matcher(result_rms)
|
||||||
|
|
||||||
|
return result, residual_out, scale
|
||||||
|
|
||||||
|
def replacement(
|
||||||
|
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
|
||||||
|
):
|
||||||
|
result = self.FUSED_OP(
|
||||||
|
x=input,
|
||||||
|
residual=residual,
|
||||||
|
weight=weight,
|
||||||
|
epsilon=self.epsilon,
|
||||||
|
quant_dtype=self.quant_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
return result[0], result[1], result[2]
|
||||||
|
|
||||||
|
pm.register_replacement(
|
||||||
|
pattern,
|
||||||
|
replacement,
|
||||||
|
self.rmsnorm_matcher.inputs(),
|
||||||
|
pm.fwd_only,
|
||||||
|
pm_pass,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
|
||||||
"""
|
"""
|
||||||
This pattern fuses aiter rms_norm & group fp8 quant custom
|
This pattern fuses aiter rms_norm & group fp8 quant custom
|
||||||
ops into an aiter rms_norm_group_fp8_quant op.
|
ops into an aiter rms_norm_group_fp8_quant op.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, epsilon: float, quant_dtype: torch.dtype, quant_op: OpOverload):
|
FUSED_OP = rocm_aiter_ops.get_rmsnorm_group_fused_quant_op()
|
||||||
self.epsilon = epsilon
|
|
||||||
self.quant_dtype = quant_dtype
|
def __init__(
|
||||||
self.quant_op = quant_op
|
self,
|
||||||
|
epsilon: float,
|
||||||
|
quant_dtype: torch.dtype,
|
||||||
|
group_shape: GroupShape,
|
||||||
|
match_aiter_quant: bool = True,
|
||||||
|
symmetric=True,
|
||||||
|
):
|
||||||
|
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||||
|
key = FusedRMSQuantKey(
|
||||||
|
fused_add=False,
|
||||||
|
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(epsilon, key, match_aiter_quant)
|
||||||
|
|
||||||
def register(self, pm_pass: PatternMatcherPass):
|
def register(self, pm_pass: PatternMatcherPass):
|
||||||
def pattern(
|
def pattern(
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
):
|
):
|
||||||
at1 = AITER_RMS_OP(x=input, weight=weight, variance_epsilon=self.epsilon)
|
result_rms = self.rmsnorm_matcher(input, weight)
|
||||||
|
result, scale = self.quant_matcher(result_rms)
|
||||||
at2 = self.quant_op(at1, 128)
|
return result, scale
|
||||||
|
|
||||||
return at2[0], at2[1]
|
|
||||||
|
|
||||||
def replacement(
|
def replacement(
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
):
|
):
|
||||||
at = AITER_RMS_GROUP_QUANT_OP(
|
at = self.FUSED_OP(
|
||||||
x=input,
|
x=input,
|
||||||
weight=weight,
|
weight=weight,
|
||||||
variance_epsilon=self.epsilon,
|
variance_epsilon=self.epsilon,
|
||||||
@ -71,49 +206,52 @@ class AiterRMSFp8GroupQuantPattern:
|
|||||||
|
|
||||||
return at[0], at[1]
|
return at[0], at[1]
|
||||||
|
|
||||||
inputs = [
|
pm.register_replacement(
|
||||||
empty_bf16(5, 4), # input
|
pattern, replacement, self.rmsnorm_matcher.inputs(), pm.fwd_only, pm_pass
|
||||||
empty_bf16(1, 5), # weight
|
)
|
||||||
]
|
|
||||||
|
|
||||||
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
|
|
||||||
|
|
||||||
|
|
||||||
class AiterFusedAddRMSFp8GroupQuantPattern:
|
class AiterFusedAddRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
|
||||||
"""
|
"""
|
||||||
This pattern fuses aiter rms_norm_with_add & group fp8 quant custom ops
|
This pattern fuses aiter rms_norm_with_add & group fp8 quant custom ops
|
||||||
into a aiter rms_norm_with_add_group_fp8_quant op.
|
into a aiter rms_norm_with_add_group_fp8_quant op.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, epsilon: float, quant_dtype: torch.dtype, quant_op: OpOverload):
|
FUSED_OP = rocm_aiter_ops.get_rmsnorm_group_add_fused_quant_op()
|
||||||
self.epsilon = epsilon
|
|
||||||
self.quant_dtype = quant_dtype
|
def __init__(
|
||||||
self.quant_op = quant_op
|
self,
|
||||||
|
epsilon: float,
|
||||||
|
quant_dtype: torch.dtype,
|
||||||
|
group_shape: GroupShape,
|
||||||
|
match_aiter_quant: bool = True,
|
||||||
|
symmetric=True,
|
||||||
|
):
|
||||||
|
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||||
|
key = FusedRMSQuantKey(
|
||||||
|
fused_add=True,
|
||||||
|
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(epsilon, key, match_aiter_quant)
|
||||||
|
|
||||||
def register(self, pm_pass: PatternMatcherPass):
|
def register(self, pm_pass: PatternMatcherPass):
|
||||||
def pattern(
|
def pattern(
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
residual: torch.Tensor,
|
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
):
|
):
|
||||||
at1 = AITER_RMS_ADD_OP(
|
result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual)
|
||||||
x=input,
|
result, scale = self.quant_matcher(result_rms)
|
||||||
residual=residual,
|
|
||||||
weight=weight,
|
|
||||||
variance_epsilon=self.epsilon,
|
|
||||||
)
|
|
||||||
|
|
||||||
at2 = self.quant_op(at1[0], 128)
|
return result, residual_out, scale
|
||||||
|
|
||||||
# result, scale, residual
|
|
||||||
return at2[0], at2[1], at1[1]
|
|
||||||
|
|
||||||
def replacement(
|
def replacement(
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
residual: torch.Tensor,
|
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
):
|
):
|
||||||
at = AITER_RMS_ADD_GROUP_QUANT_OP(
|
at = self.FUSED_OP(
|
||||||
x=input,
|
x=input,
|
||||||
residual=residual,
|
residual=residual,
|
||||||
weight=weight,
|
weight=weight,
|
||||||
@ -124,18 +262,15 @@ class AiterFusedAddRMSFp8GroupQuantPattern:
|
|||||||
# result, scale, residual
|
# result, scale, residual
|
||||||
return at[0], at[1], at[2]
|
return at[0], at[1], at[2]
|
||||||
|
|
||||||
inputs = [
|
pm.register_replacement(
|
||||||
empty_bf16(5, 4), # input
|
pattern, replacement, self.rmsnorm_matcher.inputs(), pm.fwd_only, pm_pass
|
||||||
empty_bf16(5, 4), # residual
|
)
|
||||||
empty_bf16(1, 5), # weight
|
|
||||||
]
|
|
||||||
|
|
||||||
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
|
|
||||||
|
|
||||||
|
|
||||||
class RocmAiterRMSNormFp8GroupQuantFusionPass(VllmPatternMatcherPass):
|
class RocmAiterRMSNormFusionPass(VllmPatternMatcherPass):
|
||||||
"""
|
"""
|
||||||
This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op.
|
This pass fuses aiter rms_norm & vllm/aiter quant custom ops
|
||||||
|
into a fused rms_norm_quant op.
|
||||||
It also supports fused_add_rms_norm.
|
It also supports fused_add_rms_norm.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -144,20 +279,33 @@ class RocmAiterRMSNormFp8GroupQuantFusionPass(VllmPatternMatcherPass):
|
|||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||||
pass_name="rocm_aiter_rms_norm_fp8_group_quant_fusion_pass"
|
pass_name="rocm_aiter_rms_norm_quant_fusion_pass"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Make sure fused add patterns are before simple rms norm,
|
# Make sure fused add patterns are before simple rms norm,
|
||||||
# as the latter is a subset of the former in torch ops
|
# as the latter is a subset of the former in torch ops
|
||||||
for epsilon in [1e-5, 1e-6]:
|
for epsilon in [1e-5, 1e-6]:
|
||||||
# Fuse rms_norm + dynamic group fp8 quant
|
# Fuse aiter rms_norm + aiter dynamic group fp8 quant
|
||||||
for quant_op in [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]:
|
AiterRMSFp8GroupQuantPattern(
|
||||||
AiterRMSFp8GroupQuantPattern(epsilon, FP8_DTYPE, quant_op).register(
|
epsilon, FP8_DTYPE, GroupShape(1, 128)
|
||||||
self.patterns
|
).register(self.patterns)
|
||||||
)
|
|
||||||
|
|
||||||
AiterFusedAddRMSFp8GroupQuantPattern(
|
# Fuse aiter fused_add_rms_norm + aiter dynamic group fp8 quant
|
||||||
epsilon, FP8_DTYPE, quant_op
|
AiterFusedAddRMSFp8GroupQuantPattern(
|
||||||
|
epsilon, FP8_DTYPE, GroupShape(1, 128)
|
||||||
|
).register(self.patterns)
|
||||||
|
|
||||||
|
for match_aiter_quant in [True, False]:
|
||||||
|
# Fuse aiter rms_norm + (aiter / vllm built-in)
|
||||||
|
# dynamic per-token fp8 quant
|
||||||
|
AiterRMSNormDynamicQuantPattern(
|
||||||
|
epsilon, FP8_DTYPE, match_aiter_quant=match_aiter_quant
|
||||||
|
).register(self.patterns)
|
||||||
|
|
||||||
|
# Fuse aiter fused_add_rms_norm + (aiter / vllm built-in)
|
||||||
|
# dynamic per-token fp8 quant
|
||||||
|
AiterFusedAddRMSNormDynamicQuantPattern(
|
||||||
|
epsilon, FP8_DTYPE, match_aiter_quant=match_aiter_quant
|
||||||
).register(self.patterns)
|
).register(self.patterns)
|
||||||
|
|
||||||
self.dump_patterns(config, self.patterns)
|
self.dump_patterns(config, self.patterns)
|
||||||
@ -169,6 +317,8 @@ class RocmAiterRMSNormFp8GroupQuantFusionPass(VllmPatternMatcherPass):
|
|||||||
|
|
||||||
def uuid(self) -> Any:
|
def uuid(self) -> Any:
|
||||||
fusion_patterns = [
|
fusion_patterns = [
|
||||||
|
AiterRMSNormDynamicQuantPattern,
|
||||||
|
AiterFusedAddRMSNormDynamicQuantPattern,
|
||||||
AiterRMSFp8GroupQuantPattern,
|
AiterRMSFp8GroupQuantPattern,
|
||||||
AiterFusedAddRMSFp8GroupQuantPattern,
|
AiterFusedAddRMSFp8GroupQuantPattern,
|
||||||
]
|
]
|
||||||
@ -181,6 +331,8 @@ class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
|
|||||||
ops into an aiter silu_and_mul_group_fp8_quant op.
|
ops into an aiter silu_and_mul_group_fp8_quant op.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
FUSED_SILU_MUL_QUANT_OP = rocm_aiter_ops.get_act_mul_fused_fp8_group_quant_op()
|
||||||
|
|
||||||
def __init__(self, quant_op: OpOverload):
|
def __init__(self, quant_op: OpOverload):
|
||||||
self.silu_and_mul_matcher = MatcherSiluAndMul()
|
self.silu_and_mul_matcher = MatcherSiluAndMul()
|
||||||
self.quant_op = quant_op
|
self.quant_op = quant_op
|
||||||
@ -196,7 +348,7 @@ class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
|
|||||||
def replacement(
|
def replacement(
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
):
|
):
|
||||||
at = FUSED_SILU_MUL_QUANT_OP(x=input, group_size=128)
|
at = self.FUSED_SILU_MUL_QUANT_OP(x=input, group_size=128)
|
||||||
return at[0], at[1]
|
return at[0], at[1]
|
||||||
|
|
||||||
inputs = [
|
inputs = [
|
||||||
@ -216,6 +368,11 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
|
|||||||
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
|
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
AITER_GROUP_FP8_QUANT_OP = rocm_aiter_ops.get_group_quant_op()
|
||||||
|
TRITON_GROUP_FP8_QUANT_OP = torch.ops.vllm.triton_per_token_group_quant_fp8.default
|
||||||
|
|
||||||
|
QUANT_OPS = [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]
|
||||||
|
|
||||||
@enable_fake_mode
|
@enable_fake_mode
|
||||||
def __init__(self, config: VllmConfig):
|
def __init__(self, config: VllmConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
@ -224,7 +381,7 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
|
|||||||
pass_name="rocm_aiter_silu_mul_fp8_group_quant_fusion_pass"
|
pass_name="rocm_aiter_silu_mul_fp8_group_quant_fusion_pass"
|
||||||
)
|
)
|
||||||
|
|
||||||
for quant_op in [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]:
|
for quant_op in self.QUANT_OPS:
|
||||||
AiterSiluMulFp8GroupQuantPattern(quant_op).register(self.patterns)
|
AiterSiluMulFp8GroupQuantPattern(quant_op).register(self.patterns)
|
||||||
|
|
||||||
self.dump_patterns(config, self.patterns)
|
self.dump_patterns(config, self.patterns)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user