mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-01 17:57:07 +08:00
[Rocm][torch.compile] Adding layernorm + fp8 block quant and silu + fp8 block quant for Aiter (#25693)
Signed-off-by: charlifu <charlifu@amd.com> Signed-off-by: Micah Williamson <micah.williamson@amd.com> Signed-off-by: Charlie Fu <Charlie.Fu@amd.com> Co-authored-by: Micah Williamson <micah.williamson@amd.com> Co-authored-by: wuhuikx <hattie.wu@amd.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com>
This commit is contained in:
parent
fccd532587
commit
3c680f4a17
@ -1,10 +1,13 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.plugins
|
import vllm.plugins
|
||||||
|
from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops
|
||||||
from vllm.compilation.fusion import FUSED_OPS, FusedRMSQuantKey, RMSNormQuantFusionPass
|
from vllm.compilation.fusion import FUSED_OPS, FusedRMSQuantKey, RMSNormQuantFusionPass
|
||||||
from vllm.compilation.fx_utils import find_op_nodes
|
from vllm.compilation.fx_utils import find_op_nodes
|
||||||
from vllm.compilation.matcher_utils import QUANT_OPS
|
from vllm.compilation.matcher_utils import QUANT_OPS
|
||||||
@ -152,13 +155,79 @@ GROUP_SHAPES = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TestRmsnormGroupFp8QuantModel(torch.nn.Module):
|
||||||
|
def __init__(self, hidden_size: int, eps: float, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
|
||||||
|
weight_group_shape=GroupShape(128, 128),
|
||||||
|
act_quant_group_shape=GroupShape(1, 128),
|
||||||
|
cutlass_block_fp8_supported=False,
|
||||||
|
use_aiter_and_is_supported=True,
|
||||||
|
)
|
||||||
|
self.w = [
|
||||||
|
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
||||||
|
for _ in range(3)
|
||||||
|
]
|
||||||
|
|
||||||
|
scale_hidden_size = (hidden_size + 128 - 1) // 128
|
||||||
|
self.wscale = [
|
||||||
|
torch.rand((scale_hidden_size, scale_hidden_size), dtype=torch.float32)
|
||||||
|
for _ in range(3)
|
||||||
|
]
|
||||||
|
|
||||||
|
self.norm_weight = [torch.ones(hidden_size) for _ in range(4)]
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# avoid having graph input be an arg to a pattern directly
|
||||||
|
x = resid = torch.relu(x)
|
||||||
|
y = rocm_aiter_ops.rms_norm(x, self.norm_weight[0], self.eps)
|
||||||
|
|
||||||
|
x2 = self.w8a8_block_fp8_linear.apply(y, self.w[0], self.wscale[0])
|
||||||
|
# make sure resid is used for replacement to work
|
||||||
|
y2, resid = rocm_aiter_ops.rms_norm2d_with_add(
|
||||||
|
x2, resid, self.norm_weight[1], self.eps
|
||||||
|
)
|
||||||
|
|
||||||
|
x3 = self.w8a8_block_fp8_linear.apply(y2, self.w[1], self.wscale[1])
|
||||||
|
|
||||||
|
y3, resid = rocm_aiter_ops.rms_norm2d_with_add(
|
||||||
|
x3, resid, self.norm_weight[2], self.eps
|
||||||
|
)
|
||||||
|
|
||||||
|
x4 = self.w8a8_block_fp8_linear.apply(y3, self.w[2], self.wscale[2])
|
||||||
|
|
||||||
|
y4, resid = rocm_aiter_ops.rms_norm2d_with_add(
|
||||||
|
x4, resid, self.norm_weight[3], self.eps
|
||||||
|
)
|
||||||
|
return y4
|
||||||
|
|
||||||
|
def ops_in_model_before(self):
|
||||||
|
return [
|
||||||
|
torch.ops.vllm.rocm_aiter_rms_norm,
|
||||||
|
torch.ops.vllm.rocm_aiter_group_fp8_quant,
|
||||||
|
]
|
||||||
|
|
||||||
|
def ops_in_model_before_partial(self):
|
||||||
|
return []
|
||||||
|
|
||||||
|
def ops_in_model_after(self):
|
||||||
|
return [
|
||||||
|
torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant,
|
||||||
|
torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||||
@pytest.mark.parametrize("hidden_size", [256])
|
@pytest.mark.parametrize("hidden_size", [256])
|
||||||
@pytest.mark.parametrize("num_tokens", [257])
|
@pytest.mark.parametrize("num_tokens", [257])
|
||||||
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
|
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
|
||||||
@pytest.mark.parametrize("group_shape", GROUP_SHAPES)
|
@pytest.mark.parametrize("group_shape", GROUP_SHAPES)
|
||||||
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
|
@pytest.mark.parametrize(
|
||||||
@pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False])
|
"model_class, enable_rms_norm_custom_op, enable_quant_fp8_custom_op",
|
||||||
|
list(itertools.product([TestModel], [True, False], [True, False]))
|
||||||
|
+ [(TestRmsnormGroupFp8QuantModel, False, False)],
|
||||||
|
)
|
||||||
# cuda_force_torch used to test torch code path on platforms that
|
# cuda_force_torch used to test torch code path on platforms that
|
||||||
# cutlass_fp8_supported() == True.
|
# cutlass_fp8_supported() == True.
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -173,10 +242,14 @@ def test_fusion_rmsnorm_quant(
|
|||||||
num_tokens,
|
num_tokens,
|
||||||
eps,
|
eps,
|
||||||
group_shape,
|
group_shape,
|
||||||
|
model_class,
|
||||||
enable_rms_norm_custom_op,
|
enable_rms_norm_custom_op,
|
||||||
enable_quant_fp8_custom_op,
|
enable_quant_fp8_custom_op,
|
||||||
cuda_force_torch,
|
cuda_force_torch,
|
||||||
):
|
):
|
||||||
|
if model_class is TestRmsnormGroupFp8QuantModel and not IS_AITER_FOUND:
|
||||||
|
pytest.skip("AITER is not supported on this GPU.")
|
||||||
|
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
torch.set_default_dtype(dtype)
|
torch.set_default_dtype(dtype)
|
||||||
torch.manual_seed(1)
|
torch.manual_seed(1)
|
||||||
@ -209,12 +282,24 @@ def test_fusion_rmsnorm_quant(
|
|||||||
with vllm.config.set_current_vllm_config(vllm_config):
|
with vllm.config.set_current_vllm_config(vllm_config):
|
||||||
# Reshape pass is needed for the fusion pass to work
|
# Reshape pass is needed for the fusion pass to work
|
||||||
noop_pass = NoOpEliminationPass(vllm_config)
|
noop_pass = NoOpEliminationPass(vllm_config)
|
||||||
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
if model_class is TestRmsnormGroupFp8QuantModel:
|
||||||
|
from vllm.compilation.rocm_aiter_fusion import (
|
||||||
|
RocmAiterRMSNormFp8GroupQuantFusionPass,
|
||||||
|
)
|
||||||
|
|
||||||
|
fusion_pass = RocmAiterRMSNormFp8GroupQuantFusionPass(vllm_config)
|
||||||
|
else:
|
||||||
|
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
||||||
cleanup_pass = PostCleanupPass(vllm_config)
|
cleanup_pass = PostCleanupPass(vllm_config)
|
||||||
|
|
||||||
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
|
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
|
||||||
backend2 = TestBackend(noop_pass, cleanup_pass)
|
backend2 = TestBackend(noop_pass, cleanup_pass)
|
||||||
model = TestModel(hidden_size, eps, group_shape, cuda_force_torch)
|
model = model_class(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
eps=eps,
|
||||||
|
group_shape=group_shape,
|
||||||
|
cuda_force_torch=cuda_force_torch,
|
||||||
|
)
|
||||||
# First dimension dynamic
|
# First dimension dynamic
|
||||||
x = torch.rand(num_tokens, hidden_size)
|
x = torch.rand(num_tokens, hidden_size)
|
||||||
torch._dynamo.mark_dynamic(x, 0)
|
torch._dynamo.mark_dynamic(x, 0)
|
||||||
@ -243,7 +328,10 @@ def test_fusion_rmsnorm_quant(
|
|||||||
# there's a risk that the fused add doesn't get included in the
|
# there's a risk that the fused add doesn't get included in the
|
||||||
# replacement and only the rms part gets fused with quant.
|
# replacement and only the rms part gets fused with quant.
|
||||||
# Hence, we check only 2 add nodes are left (final fused rmsnorm add).
|
# Hence, we check only 2 add nodes are left (final fused rmsnorm add).
|
||||||
if not enable_rms_norm_custom_op:
|
if (
|
||||||
|
not enable_rms_norm_custom_op
|
||||||
|
and model_class is not TestRmsnormGroupFp8QuantModel
|
||||||
|
):
|
||||||
n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g))
|
n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g))
|
||||||
# 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each)
|
# 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each)
|
||||||
assert n_add_nodes(backend.graph_pre_pass) == 7
|
assert n_add_nodes(backend.graph_pre_pass) == 7
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import torch
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from tests.kernels.quantization.nvfp4_utils import quant_nvfp4_tensor
|
from tests.kernels.quantization.nvfp4_utils import quant_nvfp4_tensor
|
||||||
|
from vllm._aiter_ops import IS_AITER_FOUND
|
||||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||||
from vllm.compilation.activation_quant_fusion import (
|
from vllm.compilation.activation_quant_fusion import (
|
||||||
FUSED_OPS,
|
FUSED_OPS,
|
||||||
@ -24,6 +25,7 @@ from vllm.config import (
|
|||||||
set_current_vllm_config,
|
set_current_vllm_config,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
GroupShape,
|
GroupShape,
|
||||||
kFp8StaticTensorSym,
|
kFp8StaticTensorSym,
|
||||||
@ -126,6 +128,39 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
|
|||||||
return [FUSED_OPS[kNvfp4Quant]]
|
return [FUSED_OPS[kNvfp4Quant]]
|
||||||
|
|
||||||
|
|
||||||
|
class TestSiluMulGroupFp8QuantModel(torch.nn.Module):
|
||||||
|
def __init__(self, hidden_size: int, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.silu_and_mul = SiluAndMul()
|
||||||
|
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
|
||||||
|
weight_group_shape=GroupShape(128, 128),
|
||||||
|
act_quant_group_shape=GroupShape(1, 128),
|
||||||
|
cutlass_block_fp8_supported=False,
|
||||||
|
use_aiter_and_is_supported=True,
|
||||||
|
)
|
||||||
|
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
||||||
|
|
||||||
|
scale_hidden_size = (hidden_size + 128 - 1) // 128
|
||||||
|
self.wscale = torch.rand(
|
||||||
|
(scale_hidden_size, scale_hidden_size), dtype=torch.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
y = self.silu_and_mul(x)
|
||||||
|
x2 = self.w8a8_block_fp8_linear.apply(y, self.w, self.wscale)
|
||||||
|
return x2
|
||||||
|
|
||||||
|
def ops_in_model_before(self):
|
||||||
|
return [
|
||||||
|
SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul,
|
||||||
|
]
|
||||||
|
|
||||||
|
def ops_in_model_after(self):
|
||||||
|
return [torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_tokens", [32, 64])
|
@pytest.mark.parametrize("num_tokens", [32, 64])
|
||||||
@pytest.mark.parametrize("hidden_size", [128, 256])
|
@pytest.mark.parametrize("hidden_size", [128, 256])
|
||||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||||
@ -133,7 +168,10 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model_class, enable_quant_fp8_custom_op, cuda_force_torch",
|
"model_class, enable_quant_fp8_custom_op, cuda_force_torch",
|
||||||
list(itertools.product([TestSiluMulFp8QuantModel], [True, False], [True, False]))
|
list(itertools.product([TestSiluMulFp8QuantModel], [True, False], [True, False]))
|
||||||
+ [(TestSiluMulNvfp4QuantModel, False, False)],
|
+ [
|
||||||
|
(TestSiluMulNvfp4QuantModel, False, False),
|
||||||
|
(TestSiluMulGroupFp8QuantModel, False, False),
|
||||||
|
],
|
||||||
)
|
)
|
||||||
# cuda_force_torch used to test torch code path on platforms that
|
# cuda_force_torch used to test torch code path on platforms that
|
||||||
# cutlass_fp8_supported() == True.
|
# cutlass_fp8_supported() == True.
|
||||||
@ -144,13 +182,19 @@ def test_fusion_silu_and_mul_quant(
|
|||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
model_class: type[TestSiluMulFp8QuantModel | TestSiluMulNvfp4QuantModel],
|
model_class: type[
|
||||||
|
TestSiluMulFp8QuantModel
|
||||||
|
| TestSiluMulNvfp4QuantModel
|
||||||
|
| TestSiluMulGroupFp8QuantModel
|
||||||
|
],
|
||||||
enable_silu_mul_custom_op: bool,
|
enable_silu_mul_custom_op: bool,
|
||||||
enable_quant_fp8_custom_op: bool,
|
enable_quant_fp8_custom_op: bool,
|
||||||
cuda_force_torch: bool,
|
cuda_force_torch: bool,
|
||||||
):
|
):
|
||||||
if model_class is TestSiluMulNvfp4QuantModel and not is_nvfp4_supported():
|
if model_class is TestSiluMulNvfp4QuantModel and not is_nvfp4_supported():
|
||||||
pytest.skip("NVFP4 is not supported on this GPU.")
|
pytest.skip("NVFP4 is not supported on this GPU.")
|
||||||
|
if model_class is TestSiluMulGroupFp8QuantModel and not IS_AITER_FOUND:
|
||||||
|
pytest.skip("AITER is not supported on this GPU.")
|
||||||
|
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
torch.set_default_dtype(dtype)
|
torch.set_default_dtype(dtype)
|
||||||
@ -173,9 +217,15 @@ def test_fusion_silu_and_mul_quant(
|
|||||||
)
|
)
|
||||||
|
|
||||||
with set_current_vllm_config(config):
|
with set_current_vllm_config(config):
|
||||||
fusion_pass = ActivationQuantFusionPass(config)
|
fusion_passes = [ActivationQuantFusionPass(config)]
|
||||||
|
if IS_AITER_FOUND:
|
||||||
|
from vllm.compilation.rocm_aiter_fusion import (
|
||||||
|
RocmAiterSiluMulFp8GroupQuantFusionPass,
|
||||||
|
)
|
||||||
|
|
||||||
passes = [NoOpEliminationPass(config), fusion_pass, PostCleanupPass(config)]
|
fusion_passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)]
|
||||||
|
|
||||||
|
passes = [NoOpEliminationPass(config), *fusion_passes, PostCleanupPass(config)]
|
||||||
backend = TestBackend(*passes)
|
backend = TestBackend(*passes)
|
||||||
model = model_class(
|
model = model_class(
|
||||||
hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x
|
hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x
|
||||||
@ -194,12 +244,14 @@ def test_fusion_silu_and_mul_quant(
|
|||||||
atol, rtol = 1e-3, 1e-3
|
atol, rtol = 1e-3, 1e-3
|
||||||
elif model_class == TestSiluMulNvfp4QuantModel:
|
elif model_class == TestSiluMulNvfp4QuantModel:
|
||||||
atol, rtol = 1e-1, 1e-1
|
atol, rtol = 1e-1, 1e-1
|
||||||
|
elif model_class == TestSiluMulGroupFp8QuantModel:
|
||||||
|
atol, rtol = 5e-2, 5e-2
|
||||||
|
|
||||||
torch.testing.assert_close(
|
torch.testing.assert_close(
|
||||||
result[0].to(dtype=dtype), result2[0].to(dtype=dtype), atol=atol, rtol=rtol
|
result[0].to(dtype=dtype), result2[0].to(dtype=dtype), atol=atol, rtol=rtol
|
||||||
)
|
)
|
||||||
|
|
||||||
assert fusion_pass.matched_count == 1
|
assert sum([p.matched_count for p in fusion_passes]) == 1
|
||||||
|
|
||||||
# In pre-nodes, quant op should be present and fused kernels should not
|
# In pre-nodes, quant op should be present and fused kernels should not
|
||||||
backend.check_before_ops(model.ops_in_model_before())
|
backend.check_before_ops(model.ops_in_model_before())
|
||||||
|
|||||||
@ -24,6 +24,15 @@ def is_aiter_found() -> bool:
|
|||||||
# we keep this global outside to not cause torch compile breaks.
|
# we keep this global outside to not cause torch compile breaks.
|
||||||
IS_AITER_FOUND = is_aiter_found()
|
IS_AITER_FOUND = is_aiter_found()
|
||||||
|
|
||||||
|
# Can't use dtypes.fp8 directly inside an op
|
||||||
|
# because it returns wrong result on gfx942.
|
||||||
|
# This is a workaround to get the correct FP8 dtype.
|
||||||
|
# This might because that the get_gfx() is wrapped as a custom op.
|
||||||
|
if IS_AITER_FOUND:
|
||||||
|
from aiter import dtypes
|
||||||
|
|
||||||
|
AITER_FP8_DTYPE = dtypes.fp8
|
||||||
|
|
||||||
|
|
||||||
def if_aiter_supported(func: Callable) -> Callable:
|
def if_aiter_supported(func: Callable) -> Callable:
|
||||||
"""Decorator that only executes the function if
|
"""Decorator that only executes the function if
|
||||||
@ -45,36 +54,6 @@ def if_aiter_supported(func: Callable) -> Callable:
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def _rocm_aiter_group_fp8_quant_impl(
|
|
||||||
x: torch.Tensor,
|
|
||||||
group_size: int,
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
assert x.shape[-1] % group_size == 0, "Input shape must be divisible by group size"
|
|
||||||
from aiter import QuantType, dtypes, get_hip_quant
|
|
||||||
|
|
||||||
aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128)
|
|
||||||
return aiter_per1x128_quant(x.contiguous(), quant_dtype=dtypes.fp8)
|
|
||||||
|
|
||||||
|
|
||||||
def _rocm_aiter_group_fp8_quant_fake(
|
|
||||||
x: torch.Tensor,
|
|
||||||
group_size: int,
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
from aiter import dtypes
|
|
||||||
|
|
||||||
M, N = x.shape
|
|
||||||
x_fp8 = torch.empty((M, N), dtype=dtypes.fp8, device=x.device)
|
|
||||||
out_bs = torch.empty(
|
|
||||||
(
|
|
||||||
M,
|
|
||||||
(N + group_size - 1) // group_size,
|
|
||||||
),
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=x.device,
|
|
||||||
)
|
|
||||||
return x_fp8, out_bs
|
|
||||||
|
|
||||||
|
|
||||||
def _rocm_aiter_fused_moe_impl(
|
def _rocm_aiter_fused_moe_impl(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
@ -522,6 +501,142 @@ def _rocm_aiter_per_token_quant_fake(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl(
|
||||||
|
x: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
variance_epsilon: float,
|
||||||
|
group_size: int,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant
|
||||||
|
|
||||||
|
(x_quant, x_quant_scales), _, _, res = fused_rms_fp8_group_quant(
|
||||||
|
x,
|
||||||
|
weight,
|
||||||
|
variance_epsilon,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
group_size=group_size,
|
||||||
|
dtype_quant=AITER_FP8_DTYPE,
|
||||||
|
res1=residual,
|
||||||
|
)
|
||||||
|
return (x_quant, x_quant_scales, res)
|
||||||
|
|
||||||
|
|
||||||
|
def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake(
|
||||||
|
x: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
variance_epsilon: float,
|
||||||
|
group_size: int,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
M, N = x.shape
|
||||||
|
scale_shape = (M, (N + group_size - 1) // group_size)
|
||||||
|
return (
|
||||||
|
torch.empty_like(x, dtype=AITER_FP8_DTYPE, device=x.device),
|
||||||
|
torch.empty(scale_shape, dtype=torch.float32, device=x.device),
|
||||||
|
torch.empty_like(residual, device=residual.device),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _rocm_aiter_rmsnorm_fp8_group_quant_impl(
|
||||||
|
x: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
variance_epsilon: float,
|
||||||
|
group_size: int,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant
|
||||||
|
|
||||||
|
(x_quant, x_quant_scales), _, _, res = fused_rms_fp8_group_quant(
|
||||||
|
x,
|
||||||
|
weight,
|
||||||
|
variance_epsilon,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
group_size=group_size,
|
||||||
|
dtype_quant=AITER_FP8_DTYPE,
|
||||||
|
res1=None,
|
||||||
|
)
|
||||||
|
return (x_quant, x_quant_scales)
|
||||||
|
|
||||||
|
|
||||||
|
def _rocm_aiter_rmsnorm_fp8_group_quant_fake(
|
||||||
|
x: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
variance_epsilon: float,
|
||||||
|
group_size: int,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
M, N = x.shape
|
||||||
|
scale_shape = (M, (N + group_size - 1) // group_size)
|
||||||
|
return (
|
||||||
|
torch.empty_like(x, dtype=AITER_FP8_DTYPE, device=x.device),
|
||||||
|
torch.empty(scale_shape, dtype=torch.float32, device=x.device),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _rocm_aiter_group_fp8_quant_impl(
|
||||||
|
x: torch.Tensor,
|
||||||
|
group_size: int,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
assert x.shape[-1] % group_size == 0, "Input shape must be divisible by group size"
|
||||||
|
from aiter import QuantType, get_hip_quant
|
||||||
|
|
||||||
|
aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128)
|
||||||
|
return aiter_per1x128_quant(x.contiguous(), quant_dtype=AITER_FP8_DTYPE)
|
||||||
|
|
||||||
|
|
||||||
|
def _rocm_aiter_group_fp8_quant_fake(
|
||||||
|
x: torch.Tensor,
|
||||||
|
group_size: int,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
M, N = x.shape
|
||||||
|
x_fp8 = torch.empty((M, N), dtype=AITER_FP8_DTYPE, device=x.device)
|
||||||
|
out_bs = torch.empty(
|
||||||
|
(
|
||||||
|
M,
|
||||||
|
(N + group_size - 1) // group_size,
|
||||||
|
),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=x.device,
|
||||||
|
)
|
||||||
|
return x_fp8, out_bs
|
||||||
|
|
||||||
|
|
||||||
|
def _rocm_aiter_act_mul_and_fp8_group_quant_impl(
|
||||||
|
x: torch.Tensor,
|
||||||
|
group_size: int,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
from aiter.ops.triton.activation import act_mul_and_fp8_group_quant
|
||||||
|
|
||||||
|
return act_mul_and_fp8_group_quant(
|
||||||
|
x,
|
||||||
|
activation="silu",
|
||||||
|
group_size=group_size,
|
||||||
|
dtype_quant=AITER_FP8_DTYPE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _rocm_aiter_act_mul_and_fp8_group_quant_fake(
|
||||||
|
x: torch.Tensor,
|
||||||
|
group_size: int,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
M, N = x.shape
|
||||||
|
assert N % 2 == 0
|
||||||
|
N_half = N // 2
|
||||||
|
x_fp8 = torch.empty((M, N_half), dtype=AITER_FP8_DTYPE, device=x.device)
|
||||||
|
out_bs = torch.empty(
|
||||||
|
(
|
||||||
|
M,
|
||||||
|
(N_half + group_size - 1) // group_size,
|
||||||
|
),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=x.device,
|
||||||
|
)
|
||||||
|
return x_fp8, out_bs
|
||||||
|
|
||||||
|
|
||||||
# Global flag to ensure ops are registered only once
|
# Global flag to ensure ops are registered only once
|
||||||
_OPS_REGISTERED = False
|
_OPS_REGISTERED = False
|
||||||
|
|
||||||
@ -557,7 +672,7 @@ class rocm_aiter_ops:
|
|||||||
@if_aiter_supported
|
@if_aiter_supported
|
||||||
def is_linear_fp8_enaled(cls) -> bool:
|
def is_linear_fp8_enaled(cls) -> bool:
|
||||||
""" "Verifies device specs and availability of env variable."""
|
""" "Verifies device specs and availability of env variable."""
|
||||||
return cls.is_linear_enabled() and current_platform.is_fp8_fnuz()
|
return cls.is_linear_enabled()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@if_aiter_supported
|
@if_aiter_supported
|
||||||
@ -632,14 +747,6 @@ class rocm_aiter_ops:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# register all the custom ops here
|
# register all the custom ops here
|
||||||
direct_register_custom_op(
|
|
||||||
op_name="rocm_aiter_group_fp8_quant",
|
|
||||||
op_func=_rocm_aiter_group_fp8_quant_impl,
|
|
||||||
mutates_args=[],
|
|
||||||
fake_impl=_rocm_aiter_group_fp8_quant_fake,
|
|
||||||
dispatch_key=current_platform.dispatch_key,
|
|
||||||
)
|
|
||||||
|
|
||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="rocm_aiter_asm_moe_tkw1",
|
op_name="rocm_aiter_asm_moe_tkw1",
|
||||||
op_func=_rocm_aiter_asm_moe_tkw1_impl,
|
op_func=_rocm_aiter_asm_moe_tkw1_impl,
|
||||||
@ -699,27 +806,46 @@ class rocm_aiter_ops:
|
|||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="rocm_aiter_gemm_a8w8_blockscale",
|
op_name="rocm_aiter_gemm_a8w8_blockscale",
|
||||||
op_func=_rocm_aiter_gemm_a8w8_blockscale_impl,
|
op_func=_rocm_aiter_gemm_a8w8_blockscale_impl,
|
||||||
mutates_args=[],
|
|
||||||
fake_impl=_rocm_aiter_gemm_a8w8_blockscale_fake,
|
fake_impl=_rocm_aiter_gemm_a8w8_blockscale_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="rocm_aiter_rms_norm",
|
op_name="rocm_aiter_rms_norm",
|
||||||
op_func=_rocm_aiter_rms_norm_impl,
|
op_func=_rocm_aiter_rms_norm_impl,
|
||||||
mutates_args=[],
|
|
||||||
fake_impl=_rocm_aiter_rms_norm_fake,
|
fake_impl=_rocm_aiter_rms_norm_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
|
op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
|
||||||
op_func=_rocm_aiter_rmsnorm2d_fwd_with_add_impl,
|
op_func=_rocm_aiter_rmsnorm2d_fwd_with_add_impl,
|
||||||
mutates_args=[],
|
|
||||||
fake_impl=_rocm_aiter_rmsnorm2d_fwd_with_add_fake,
|
fake_impl=_rocm_aiter_rmsnorm2d_fwd_with_add_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
dispatch_key=current_platform.dispatch_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="rocm_aiter_rmsnorm_fp8_group_quant",
|
||||||
|
op_func=_rocm_aiter_rmsnorm_fp8_group_quant_impl,
|
||||||
|
fake_impl=_rocm_aiter_rmsnorm_fp8_group_quant_fake,
|
||||||
|
)
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="rocm_aiter_rmsnorm_with_add_fp8_group_quant",
|
||||||
|
op_func=_rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl,
|
||||||
|
fake_impl=_rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake,
|
||||||
|
)
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="rocm_aiter_act_mul_and_fp8_group_quant",
|
||||||
|
op_func=_rocm_aiter_act_mul_and_fp8_group_quant_impl,
|
||||||
|
fake_impl=_rocm_aiter_act_mul_and_fp8_group_quant_fake,
|
||||||
|
)
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="rocm_aiter_group_fp8_quant",
|
||||||
|
op_func=_rocm_aiter_group_fp8_quant_impl,
|
||||||
|
fake_impl=_rocm_aiter_group_fp8_quant_fake,
|
||||||
|
)
|
||||||
|
|
||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="rocm_aiter_per_tensor_quant",
|
op_name="rocm_aiter_per_tensor_quant",
|
||||||
op_func=_rocm_aiter_per_tensor_quant_impl,
|
op_func=_rocm_aiter_per_tensor_quant_impl,
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import functools
|
|||||||
from torch import fx as fx
|
from torch import fx as fx
|
||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
|
from vllm._aiter_ops import rocm_aiter_ops
|
||||||
from vllm.config import VllmConfig, set_current_vllm_config
|
from vllm.config import VllmConfig, set_current_vllm_config
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@ -13,6 +14,12 @@ from vllm.utils.system_utils import set_env_var
|
|||||||
from .post_cleanup import PostCleanupPass
|
from .post_cleanup import PostCleanupPass
|
||||||
from .vllm_inductor_pass import VllmInductorPass
|
from .vllm_inductor_pass import VllmInductorPass
|
||||||
|
|
||||||
|
if rocm_aiter_ops.is_enabled():
|
||||||
|
from vllm.compilation.rocm_aiter_fusion import (
|
||||||
|
RocmAiterRMSNormFp8GroupQuantFusionPass,
|
||||||
|
RocmAiterSiluMulFp8GroupQuantFusionPass,
|
||||||
|
)
|
||||||
|
|
||||||
if current_platform.is_cuda_alike():
|
if current_platform.is_cuda_alike():
|
||||||
from .activation_quant_fusion import ActivationQuantFusionPass
|
from .activation_quant_fusion import ActivationQuantFusionPass
|
||||||
from .fusion import RMSNormQuantFusionPass
|
from .fusion import RMSNormQuantFusionPass
|
||||||
@ -109,8 +116,12 @@ class PostGradPassManager(CustomGraphPass):
|
|||||||
|
|
||||||
if self.pass_config.fuse_norm_quant:
|
if self.pass_config.fuse_norm_quant:
|
||||||
self.passes += [RMSNormQuantFusionPass(config)]
|
self.passes += [RMSNormQuantFusionPass(config)]
|
||||||
|
if rocm_aiter_ops.is_enabled():
|
||||||
|
self.passes += [RocmAiterRMSNormFp8GroupQuantFusionPass(config)]
|
||||||
if self.pass_config.fuse_act_quant:
|
if self.pass_config.fuse_act_quant:
|
||||||
self.passes += [ActivationQuantFusionPass(config)]
|
self.passes += [ActivationQuantFusionPass(config)]
|
||||||
|
if rocm_aiter_ops.is_enabled():
|
||||||
|
self.passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)]
|
||||||
|
|
||||||
if self.pass_config.fuse_attn_quant:
|
if self.pass_config.fuse_attn_quant:
|
||||||
self.passes += [AttnFusionPass(config)]
|
self.passes += [AttnFusionPass(config)]
|
||||||
|
|||||||
242
vllm/compilation/rocm_aiter_fusion.py
Normal file
242
vllm/compilation/rocm_aiter_fusion.py
Normal file
@ -0,0 +1,242 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch._inductor.pattern_matcher as pm
|
||||||
|
from torch import fx
|
||||||
|
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||||
|
from torch._ops import OpOverload
|
||||||
|
|
||||||
|
import vllm.model_executor.layers.quantization.utils.fp8_utils # noqa: F401
|
||||||
|
from vllm.compilation.activation_quant_fusion import ActivationQuantPattern
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
from .fusion import empty_bf16
|
||||||
|
from .inductor_pass import enable_fake_mode
|
||||||
|
from .matcher_utils import MatcherSiluAndMul
|
||||||
|
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
FP8_DTYPE = current_platform.fp8_dtype()
|
||||||
|
|
||||||
|
AITER_RMS_GROUP_QUANT_OP = torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant.default
|
||||||
|
AITER_RMS_ADD_GROUP_QUANT_OP = (
|
||||||
|
torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant.default
|
||||||
|
)
|
||||||
|
|
||||||
|
AITER_RMS_OP = torch.ops.vllm.rocm_aiter_rms_norm.default
|
||||||
|
AITER_RMS_ADD_OP = torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add.default
|
||||||
|
|
||||||
|
AITER_GROUP_FP8_QUANT_OP = torch.ops.vllm.rocm_aiter_group_fp8_quant.default
|
||||||
|
TRITON_GROUP_FP8_QUANT_OP = torch.ops.vllm.triton_per_token_group_quant_fp8.default
|
||||||
|
|
||||||
|
FUSED_SILU_MUL_QUANT_OP = torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant.default
|
||||||
|
|
||||||
|
|
||||||
|
class AiterRMSFp8GroupQuantPattern:
|
||||||
|
"""
|
||||||
|
This pattern fuses aiter rms_norm & group fp8 quant custom
|
||||||
|
ops into an aiter rms_norm_group_fp8_quant op.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, epsilon: float, quant_dtype: torch.dtype, quant_op: OpOverload):
|
||||||
|
self.epsilon = epsilon
|
||||||
|
self.quant_dtype = quant_dtype
|
||||||
|
self.quant_op = quant_op
|
||||||
|
|
||||||
|
def register(self, pm_pass: PatternMatcherPass):
|
||||||
|
def pattern(
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
):
|
||||||
|
at1 = AITER_RMS_OP(x=input, weight=weight, variance_epsilon=self.epsilon)
|
||||||
|
|
||||||
|
at2 = self.quant_op(at1, 128)
|
||||||
|
|
||||||
|
return at2[0], at2[1]
|
||||||
|
|
||||||
|
def replacement(
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
):
|
||||||
|
at = AITER_RMS_GROUP_QUANT_OP(
|
||||||
|
x=input,
|
||||||
|
weight=weight,
|
||||||
|
variance_epsilon=self.epsilon,
|
||||||
|
group_size=128,
|
||||||
|
)
|
||||||
|
|
||||||
|
return at[0], at[1]
|
||||||
|
|
||||||
|
inputs = [
|
||||||
|
empty_bf16(5, 4), # input
|
||||||
|
empty_bf16(1, 5), # weight
|
||||||
|
]
|
||||||
|
|
||||||
|
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
|
||||||
|
|
||||||
|
|
||||||
|
class AiterFusedAddRMSFp8GroupQuantPattern:
|
||||||
|
"""
|
||||||
|
This pattern fuses aiter rms_norm_with_add & group fp8 quant custom ops
|
||||||
|
into a aiter rms_norm_with_add_group_fp8_quant op.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, epsilon: float, quant_dtype: torch.dtype, quant_op: OpOverload):
|
||||||
|
self.epsilon = epsilon
|
||||||
|
self.quant_dtype = quant_dtype
|
||||||
|
self.quant_op = quant_op
|
||||||
|
|
||||||
|
def register(self, pm_pass: PatternMatcherPass):
|
||||||
|
def pattern(
|
||||||
|
input: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
):
|
||||||
|
at1 = AITER_RMS_ADD_OP(
|
||||||
|
x=input,
|
||||||
|
residual=residual,
|
||||||
|
weight=weight,
|
||||||
|
variance_epsilon=self.epsilon,
|
||||||
|
)
|
||||||
|
|
||||||
|
at2 = self.quant_op(at1[0], 128)
|
||||||
|
|
||||||
|
# result, scale, residual
|
||||||
|
return at2[0], at2[1], at1[1]
|
||||||
|
|
||||||
|
def replacement(
|
||||||
|
input: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
):
|
||||||
|
at = AITER_RMS_ADD_GROUP_QUANT_OP(
|
||||||
|
x=input,
|
||||||
|
residual=residual,
|
||||||
|
weight=weight,
|
||||||
|
variance_epsilon=self.epsilon,
|
||||||
|
group_size=128,
|
||||||
|
)
|
||||||
|
|
||||||
|
# result, scale, residual
|
||||||
|
return at[0], at[1], at[2]
|
||||||
|
|
||||||
|
inputs = [
|
||||||
|
empty_bf16(5, 4), # input
|
||||||
|
empty_bf16(5, 4), # residual
|
||||||
|
empty_bf16(1, 5), # weight
|
||||||
|
]
|
||||||
|
|
||||||
|
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
|
||||||
|
|
||||||
|
|
||||||
|
class RocmAiterRMSNormFp8GroupQuantFusionPass(VllmPatternMatcherPass):
|
||||||
|
"""
|
||||||
|
This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op.
|
||||||
|
It also supports fused_add_rms_norm.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@enable_fake_mode
|
||||||
|
def __init__(self, config: VllmConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||||
|
pass_name="rocm_aiter_rms_norm_fp8_group_quant_fusion_pass"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make sure fused add patterns are before simple rms norm,
|
||||||
|
# as the latter is a subset of the former in torch ops
|
||||||
|
for epsilon in [1e-5, 1e-6]:
|
||||||
|
# Fuse rms_norm + dynamic group fp8 quant
|
||||||
|
for quant_op in [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]:
|
||||||
|
AiterRMSFp8GroupQuantPattern(epsilon, FP8_DTYPE, quant_op).register(
|
||||||
|
self.patterns
|
||||||
|
)
|
||||||
|
|
||||||
|
AiterFusedAddRMSFp8GroupQuantPattern(
|
||||||
|
epsilon, FP8_DTYPE, quant_op
|
||||||
|
).register(self.patterns)
|
||||||
|
|
||||||
|
self.dump_patterns(config, self.patterns)
|
||||||
|
|
||||||
|
@VllmInductorPass.time_and_log
|
||||||
|
def __call__(self, graph: fx.Graph):
|
||||||
|
self.matched_count = self.patterns.apply(graph)
|
||||||
|
logger.debug("Replaced %s patterns", self.matched_count)
|
||||||
|
|
||||||
|
def uuid(self) -> Any:
|
||||||
|
fusion_patterns = [
|
||||||
|
AiterRMSFp8GroupQuantPattern,
|
||||||
|
AiterFusedAddRMSFp8GroupQuantPattern,
|
||||||
|
]
|
||||||
|
return self.hash_source(self, *fusion_patterns)
|
||||||
|
|
||||||
|
|
||||||
|
class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
|
||||||
|
"""
|
||||||
|
This pattern fuses aiter silu_and_mul & group fp8 quant custom
|
||||||
|
ops into an aiter silu_and_mul_group_fp8_quant op.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, quant_op: OpOverload):
|
||||||
|
self.silu_and_mul_matcher = MatcherSiluAndMul()
|
||||||
|
self.quant_op = quant_op
|
||||||
|
|
||||||
|
def register(self, pm_pass: PatternMatcherPass):
|
||||||
|
def pattern(
|
||||||
|
input: torch.Tensor,
|
||||||
|
):
|
||||||
|
at1 = self.silu_and_mul_matcher(input)
|
||||||
|
at2 = self.quant_op(at1, 128)
|
||||||
|
return at2[0], at2[1]
|
||||||
|
|
||||||
|
def replacement(
|
||||||
|
input: torch.Tensor,
|
||||||
|
):
|
||||||
|
at = FUSED_SILU_MUL_QUANT_OP(x=input, group_size=128)
|
||||||
|
return at[0], at[1]
|
||||||
|
|
||||||
|
inputs = [
|
||||||
|
self.silu_and_mul_matcher.inputs()[0],
|
||||||
|
]
|
||||||
|
|
||||||
|
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
|
||||||
|
|
||||||
|
|
||||||
|
class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
|
||||||
|
"""
|
||||||
|
This pass fuses a pre-defined set of custom ops into fused ops.
|
||||||
|
It uses the torch pattern matcher to find the patterns and replace them.
|
||||||
|
|
||||||
|
Because patterns can only be registered once, the pass is a singleton.
|
||||||
|
This will be addressed in a future version of PyTorch:
|
||||||
|
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
|
||||||
|
"""
|
||||||
|
|
||||||
|
@enable_fake_mode
|
||||||
|
def __init__(self, config: VllmConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||||
|
pass_name="rocm_aiter_silu_mul_fp8_group_quant_fusion_pass"
|
||||||
|
)
|
||||||
|
|
||||||
|
for quant_op in [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]:
|
||||||
|
AiterSiluMulFp8GroupQuantPattern(quant_op).register(self.patterns)
|
||||||
|
|
||||||
|
self.dump_patterns(config, self.patterns)
|
||||||
|
|
||||||
|
@VllmInductorPass.time_and_log
|
||||||
|
def __call__(self, graph: torch.fx.Graph):
|
||||||
|
self.matched_count = self.patterns.apply(graph)
|
||||||
|
logger.debug("Replaced %s patterns", self.matched_count)
|
||||||
|
|
||||||
|
def uuid(self):
|
||||||
|
fusion_patterns = [
|
||||||
|
ActivationQuantPattern,
|
||||||
|
AiterSiluMulFp8GroupQuantPattern,
|
||||||
|
]
|
||||||
|
return VllmInductorPass.hash_source(self, *fusion_patterns)
|
||||||
@ -196,6 +196,39 @@ direct_register_custom_op(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _triton_per_token_group_quant_fp8_impl(
|
||||||
|
x: torch.Tensor,
|
||||||
|
group_size: int,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
return per_token_group_quant_fp8(
|
||||||
|
x, group_size, column_major_scales=False, use_ue8m0=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _triton_per_token_group_quant_fp8_fake(
|
||||||
|
x: torch.Tensor,
|
||||||
|
group_size: int,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
M, N = x.shape
|
||||||
|
x_fp8 = torch.empty((M, N), dtype=current_platform.fp8_dtype(), device=x.device)
|
||||||
|
out_bs = torch.empty(
|
||||||
|
(
|
||||||
|
M,
|
||||||
|
(N + group_size - 1) // group_size,
|
||||||
|
),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=x.device,
|
||||||
|
)
|
||||||
|
return x_fp8, out_bs
|
||||||
|
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
"triton_per_token_group_quant_fp8",
|
||||||
|
_triton_per_token_group_quant_fp8_impl,
|
||||||
|
fake_impl=_triton_per_token_group_quant_fp8_fake,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO fix ROCm->Triton custom path:
|
# TODO fix ROCm->Triton custom path:
|
||||||
# https://github.com/vllm-project/vllm/issues/14397
|
# https://github.com/vllm-project/vllm/issues/14397
|
||||||
class W8A8BlockFp8LinearOp:
|
class W8A8BlockFp8LinearOp:
|
||||||
@ -341,17 +374,15 @@ class W8A8BlockFp8LinearOp:
|
|||||||
|
|
||||||
if input_scale is not None:
|
if input_scale is not None:
|
||||||
q_input = input_2d
|
q_input = input_2d
|
||||||
# MI350 case uses triton kernel
|
|
||||||
elif use_triton:
|
elif use_triton:
|
||||||
q_input, input_scale = per_token_group_quant_fp8(
|
q_input, input_scale = torch.ops.vllm.triton_per_token_group_quant_fp8(
|
||||||
input_2d,
|
input_2d,
|
||||||
self.act_quant_group_shape.col,
|
self.act_quant_group_shape.col,
|
||||||
column_major_scales=False,
|
|
||||||
use_ue8m0=False,
|
|
||||||
)
|
)
|
||||||
# MI300 uses tuned AITER ASM/C++ kernel
|
|
||||||
else:
|
else:
|
||||||
q_input, input_scale = rocm_aiter_ops.group_fp8_quant(input_2d)
|
q_input, input_scale = rocm_aiter_ops.group_fp8_quant(
|
||||||
|
input_2d, self.act_quant_group_shape.col
|
||||||
|
)
|
||||||
|
|
||||||
return gemm_a8w8_blockscale_op(
|
return gemm_a8w8_blockscale_op(
|
||||||
q_input,
|
q_input,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user