[AMD][torch.compile] Enable silu+fp8_quant fusion for rocm (#18082)

Signed-off-by: charlifu <charlifu@amd.com>
This commit is contained in:
Charlie Fu 2025-05-14 00:13:56 -05:00 committed by GitHub
parent 2d912fb66f
commit 7b2f28deba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 14 additions and 9 deletions

View File

@ -309,6 +309,7 @@ steps:
commands:
- pytest -v -s compile/test_pass_manager.py
- pytest -v -s compile/test_fusion.py
- pytest -v -s compile/test_silu_mul_quant_fusion.py
- pytest -v -s compile/test_sequence_parallelism.py
- label: PyTorch Fullgraph Smoke Test # 9min

View File

@ -112,7 +112,8 @@ __global__ void act_and_mul_quant_kernel(
void silu_and_mul_quant(torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., 2 * d]
torch::Tensor& scale) {
TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn ||
out.dtype() == torch::kFloat8_e4m3fnuz);
TORCH_CHECK(input.dtype() == torch::kFloat16 ||
input.dtype() == torch::kBFloat16);
TORCH_CHECK(input.size(-1) % 2 == 0);

View File

@ -27,8 +27,8 @@ class TestModel(torch.nn.Module):
@pytest.mark.parametrize("num_tokens", [256])
@pytest.mark.parametrize("hidden_size", [64])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
reason="Only test on CUDA")
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
reason="Only test on CUDA and ROCm")
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size):
torch.set_default_device("cuda")
torch.set_default_dtype(torch.float16)
@ -36,7 +36,7 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size):
# Reshape pass is needed for the fusion pass to work
config = VllmConfig()
config.compilation_config = CompilationConfig(
pass_config=PassConfig(enable_fusion=True, enable_reshape=True))
pass_config=PassConfig(enable_fusion=True, enable_noop=True))
fusion_pass = ActivationQuantFusionPass(config)
backend = TestBackend(fusion_pass)

View File

@ -58,8 +58,9 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
@pytest.mark.parametrize("m", M + [28672]) # m >= 16
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not current_platform.is_rocm(),
reason="only test for rocm")
@pytest.mark.skipif(
not (current_platform.is_rocm() and current_platform.supports_fp8()),
reason="only test for rocm fp8")
def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed)

View File

@ -5,9 +5,10 @@ import torch
import vllm._custom_ops as ops
from tests.kernels.utils import opcheck
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.platforms import current_platform
DTYPES = [torch.bfloat16, torch.float16]
QUANT_DTYPES = [torch.float8_e4m3fn]
QUANT_DTYPES = [current_platform.fp8_dtype()]
NUM_TOKENS = [1, 17, 86, 1234, 3045] # Arbitrary values for testing
HIDDEN_SIZES = [16, 48, 128, 1562, 4096] # Arbitrary values for testing
SEEDS = [0]
@ -26,7 +27,7 @@ def ref_impl(silu_and_mul: SiluAndMul, x: torch.Tensor,
def ops_impl(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
out_shape = (x.shape[0], x.shape[1] // 2)
out = torch.empty(out_shape,
dtype=torch.torch.float8_e4m3fn,
dtype=current_platform.fp8_dtype(),
device=x.device)
torch.ops._C.silu_and_mul_quant(out, x, scale)
return out

View File

@ -7,6 +7,7 @@ from torch._inductor.pattern_matcher import (PatternMatcherPass, fwd_only,
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from .vllm_inductor_pass import VllmInductorPass
@ -41,7 +42,7 @@ def empty_bf16(*args, **kwargs):
def empty_fp8(*args, **kwargs):
fp8 = torch.float8_e4m3fn
fp8 = current_platform.fp8_dtype()
return torch.empty(*args, **kwargs, dtype=fp8, device="cuda")