mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 04:54:56 +08:00
[AMD][torch.compile] Enable silu+fp8_quant fusion for rocm (#18082)
Signed-off-by: charlifu <charlifu@amd.com>
This commit is contained in:
parent
2d912fb66f
commit
7b2f28deba
@ -309,6 +309,7 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s compile/test_pass_manager.py
|
||||
- pytest -v -s compile/test_fusion.py
|
||||
- pytest -v -s compile/test_silu_mul_quant_fusion.py
|
||||
- pytest -v -s compile/test_sequence_parallelism.py
|
||||
|
||||
- label: PyTorch Fullgraph Smoke Test # 9min
|
||||
|
||||
@ -112,7 +112,8 @@ __global__ void act_and_mul_quant_kernel(
|
||||
void silu_and_mul_quant(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input, // [..., 2 * d]
|
||||
torch::Tensor& scale) {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn ||
|
||||
out.dtype() == torch::kFloat8_e4m3fnuz);
|
||||
TORCH_CHECK(input.dtype() == torch::kFloat16 ||
|
||||
input.dtype() == torch::kBFloat16);
|
||||
TORCH_CHECK(input.size(-1) % 2 == 0);
|
||||
|
||||
@ -27,8 +27,8 @@ class TestModel(torch.nn.Module):
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [256])
|
||||
@pytest.mark.parametrize("hidden_size", [64])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
|
||||
reason="Only test on CUDA")
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
|
||||
reason="Only test on CUDA and ROCm")
|
||||
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size):
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(torch.float16)
|
||||
@ -36,7 +36,7 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size):
|
||||
# Reshape pass is needed for the fusion pass to work
|
||||
config = VllmConfig()
|
||||
config.compilation_config = CompilationConfig(
|
||||
pass_config=PassConfig(enable_fusion=True, enable_reshape=True))
|
||||
pass_config=PassConfig(enable_fusion=True, enable_noop=True))
|
||||
fusion_pass = ActivationQuantFusionPass(config)
|
||||
|
||||
backend = TestBackend(fusion_pass)
|
||||
|
||||
@ -58,8 +58,9 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
|
||||
@pytest.mark.parametrize("m", M + [28672]) # m >= 16
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(),
|
||||
reason="only test for rocm")
|
||||
@pytest.mark.skipif(
|
||||
not (current_platform.is_rocm() and current_platform.supports_fp8()),
|
||||
reason="only test for rocm fp8")
|
||||
def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
|
||||
|
||||
@ -5,9 +5,10 @@ import torch
|
||||
import vllm._custom_ops as ops
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
DTYPES = [torch.bfloat16, torch.float16]
|
||||
QUANT_DTYPES = [torch.float8_e4m3fn]
|
||||
QUANT_DTYPES = [current_platform.fp8_dtype()]
|
||||
NUM_TOKENS = [1, 17, 86, 1234, 3045] # Arbitrary values for testing
|
||||
HIDDEN_SIZES = [16, 48, 128, 1562, 4096] # Arbitrary values for testing
|
||||
SEEDS = [0]
|
||||
@ -26,7 +27,7 @@ def ref_impl(silu_and_mul: SiluAndMul, x: torch.Tensor,
|
||||
def ops_impl(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
||||
out_shape = (x.shape[0], x.shape[1] // 2)
|
||||
out = torch.empty(out_shape,
|
||||
dtype=torch.torch.float8_e4m3fn,
|
||||
dtype=current_platform.fp8_dtype(),
|
||||
device=x.device)
|
||||
torch.ops._C.silu_and_mul_quant(out, x, scale)
|
||||
return out
|
||||
|
||||
@ -7,6 +7,7 @@ from torch._inductor.pattern_matcher import (PatternMatcherPass, fwd_only,
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
@ -41,7 +42,7 @@ def empty_bf16(*args, **kwargs):
|
||||
|
||||
|
||||
def empty_fp8(*args, **kwargs):
|
||||
fp8 = torch.float8_e4m3fn
|
||||
fp8 = current_platform.fp8_dtype()
|
||||
return torch.empty(*args, **kwargs, dtype=fp8, device="cuda")
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user