mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-04 03:55:42 +08:00
[AMD][torch.compile] Enable silu+fp8_quant fusion for rocm (#18082)
Signed-off-by: charlifu <charlifu@amd.com>
This commit is contained in:
parent
2d912fb66f
commit
7b2f28deba
@ -309,6 +309,7 @@ steps:
|
|||||||
commands:
|
commands:
|
||||||
- pytest -v -s compile/test_pass_manager.py
|
- pytest -v -s compile/test_pass_manager.py
|
||||||
- pytest -v -s compile/test_fusion.py
|
- pytest -v -s compile/test_fusion.py
|
||||||
|
- pytest -v -s compile/test_silu_mul_quant_fusion.py
|
||||||
- pytest -v -s compile/test_sequence_parallelism.py
|
- pytest -v -s compile/test_sequence_parallelism.py
|
||||||
|
|
||||||
- label: PyTorch Fullgraph Smoke Test # 9min
|
- label: PyTorch Fullgraph Smoke Test # 9min
|
||||||
|
|||||||
@ -112,7 +112,8 @@ __global__ void act_and_mul_quant_kernel(
|
|||||||
void silu_and_mul_quant(torch::Tensor& out, // [..., d]
|
void silu_and_mul_quant(torch::Tensor& out, // [..., d]
|
||||||
torch::Tensor& input, // [..., 2 * d]
|
torch::Tensor& input, // [..., 2 * d]
|
||||||
torch::Tensor& scale) {
|
torch::Tensor& scale) {
|
||||||
TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn);
|
TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn ||
|
||||||
|
out.dtype() == torch::kFloat8_e4m3fnuz);
|
||||||
TORCH_CHECK(input.dtype() == torch::kFloat16 ||
|
TORCH_CHECK(input.dtype() == torch::kFloat16 ||
|
||||||
input.dtype() == torch::kBFloat16);
|
input.dtype() == torch::kBFloat16);
|
||||||
TORCH_CHECK(input.size(-1) % 2 == 0);
|
TORCH_CHECK(input.size(-1) % 2 == 0);
|
||||||
|
|||||||
@ -27,8 +27,8 @@ class TestModel(torch.nn.Module):
|
|||||||
|
|
||||||
@pytest.mark.parametrize("num_tokens", [256])
|
@pytest.mark.parametrize("num_tokens", [256])
|
||||||
@pytest.mark.parametrize("hidden_size", [64])
|
@pytest.mark.parametrize("hidden_size", [64])
|
||||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
|
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
|
||||||
reason="Only test on CUDA")
|
reason="Only test on CUDA and ROCm")
|
||||||
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size):
|
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size):
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
torch.set_default_dtype(torch.float16)
|
torch.set_default_dtype(torch.float16)
|
||||||
@ -36,7 +36,7 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size):
|
|||||||
# Reshape pass is needed for the fusion pass to work
|
# Reshape pass is needed for the fusion pass to work
|
||||||
config = VllmConfig()
|
config = VllmConfig()
|
||||||
config.compilation_config = CompilationConfig(
|
config.compilation_config = CompilationConfig(
|
||||||
pass_config=PassConfig(enable_fusion=True, enable_reshape=True))
|
pass_config=PassConfig(enable_fusion=True, enable_noop=True))
|
||||||
fusion_pass = ActivationQuantFusionPass(config)
|
fusion_pass = ActivationQuantFusionPass(config)
|
||||||
|
|
||||||
backend = TestBackend(fusion_pass)
|
backend = TestBackend(fusion_pass)
|
||||||
|
|||||||
@ -58,8 +58,9 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
|
|||||||
@pytest.mark.parametrize("m", M + [28672]) # m >= 16
|
@pytest.mark.parametrize("m", M + [28672]) # m >= 16
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
@pytest.mark.parametrize("seed", SEEDS)
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
@pytest.mark.skipif(not current_platform.is_rocm(),
|
@pytest.mark.skipif(
|
||||||
reason="only test for rocm")
|
not (current_platform.is_rocm() and current_platform.supports_fp8()),
|
||||||
|
reason="only test for rocm fp8")
|
||||||
def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
|
def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
|||||||
@ -5,9 +5,10 @@ import torch
|
|||||||
import vllm._custom_ops as ops
|
import vllm._custom_ops as ops
|
||||||
from tests.kernels.utils import opcheck
|
from tests.kernels.utils import opcheck
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
DTYPES = [torch.bfloat16, torch.float16]
|
DTYPES = [torch.bfloat16, torch.float16]
|
||||||
QUANT_DTYPES = [torch.float8_e4m3fn]
|
QUANT_DTYPES = [current_platform.fp8_dtype()]
|
||||||
NUM_TOKENS = [1, 17, 86, 1234, 3045] # Arbitrary values for testing
|
NUM_TOKENS = [1, 17, 86, 1234, 3045] # Arbitrary values for testing
|
||||||
HIDDEN_SIZES = [16, 48, 128, 1562, 4096] # Arbitrary values for testing
|
HIDDEN_SIZES = [16, 48, 128, 1562, 4096] # Arbitrary values for testing
|
||||||
SEEDS = [0]
|
SEEDS = [0]
|
||||||
@ -26,7 +27,7 @@ def ref_impl(silu_and_mul: SiluAndMul, x: torch.Tensor,
|
|||||||
def ops_impl(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
def ops_impl(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
||||||
out_shape = (x.shape[0], x.shape[1] // 2)
|
out_shape = (x.shape[0], x.shape[1] // 2)
|
||||||
out = torch.empty(out_shape,
|
out = torch.empty(out_shape,
|
||||||
dtype=torch.torch.float8_e4m3fn,
|
dtype=current_platform.fp8_dtype(),
|
||||||
device=x.device)
|
device=x.device)
|
||||||
torch.ops._C.silu_and_mul_quant(out, x, scale)
|
torch.ops._C.silu_and_mul_quant(out, x, scale)
|
||||||
return out
|
return out
|
||||||
|
|||||||
@ -7,6 +7,7 @@ from torch._inductor.pattern_matcher import (PatternMatcherPass, fwd_only,
|
|||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from .vllm_inductor_pass import VllmInductorPass
|
from .vllm_inductor_pass import VllmInductorPass
|
||||||
|
|
||||||
@ -41,7 +42,7 @@ def empty_bf16(*args, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
def empty_fp8(*args, **kwargs):
|
def empty_fp8(*args, **kwargs):
|
||||||
fp8 = torch.float8_e4m3fn
|
fp8 = current_platform.fp8_dtype()
|
||||||
return torch.empty(*args, **kwargs, dtype=fp8, device="cuda")
|
return torch.empty(*args, **kwargs, dtype=fp8, device="cuda")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user