From 77aec83b8c6e6c15a9b5c333a531c29eff0b61fc Mon Sep 17 00:00:00 2001 From: Jiangyun Zhu Date: Sun, 7 Sep 2025 11:12:05 +0800 Subject: [PATCH] [Benchmark] add benchmark for custom activation op (#23908) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: zjy0516 Signed-off-by: Jiangyun Zhu Co-authored-by: Luka Govedič --- benchmarks/kernels/benchmark_activation.py | 104 +++++++++++++++++++++ 1 file changed, 104 insertions(+) create mode 100644 benchmarks/kernels/benchmark_activation.py diff --git a/benchmarks/kernels/benchmark_activation.py b/benchmarks/kernels/benchmark_activation.py new file mode 100644 index 0000000000000..93edbcc9391fc --- /dev/null +++ b/benchmarks/kernels/benchmark_activation.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# benchmark custom activation op performance +import itertools + +import torch + +import vllm.model_executor.layers.activation # noqa F401 +from vllm.model_executor.custom_op import CustomOp +from vllm.platforms import current_platform +from vllm.triton_utils import triton +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser + +batch_size_range = [1, 16, 32, 64, 128] +seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096] +intermediate_size = [3072, 9728, 12288] +configs = list(itertools.product(batch_size_range, seq_len_range, intermediate_size)) + + +def benchmark_activation( + batch_size: int, + seq_len: int, + intermediate_size: int, + provider: str, + func_name: str, + dtype: torch.dtype, +): + device = "cuda" + num_tokens = batch_size * seq_len + dim = intermediate_size + current_platform.seed_everything(42) + torch.set_default_device(device) + + if func_name == "gelu_and_mul": + layer = CustomOp.op_registry[func_name](approximate="none") + elif func_name == "gelu_and_mul_tanh": + layer = CustomOp.op_registry["gelu_and_mul"](approximate="tanh") + elif func_name == "fatrelu_and_mul": + threshold = 0.5 + layer = CustomOp.op_registry[func_name](threshold) + else: + layer = CustomOp.op_registry[func_name]() + + x = torch.randn(num_tokens, dim, dtype=dtype, device=device) + compiled_layer = torch.compile(layer.forward_native) + + if provider == "custom": + fn = lambda: layer(x) + elif provider == "compiled": + fn = lambda: compiled_layer(x) + + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + fn, quantiles=[0.5, 0.2, 0.8] + ) + return ms, max_ms, min_ms + + +if __name__ == "__main__": + parser = FlexibleArgumentParser(description="Benchmark the custom activation op.") + parser.add_argument( + "--func-name", + type=str, + choices=[ + "mul_and_silu", + "silu_and_mul", + "gelu_and_mul", + "gelu_and_mul_tanh", + "fatrelu_and_mul", + "swigluoai_and_mul", + "gelu_new", + "gelu_fast", + "quick_gelu", + ], + default="silu_and_mul", + ) + parser.add_argument( + "--dtype", type=str, choices=["half", "bfloat16", "float"], default="bfloat16" + ) + args = parser.parse_args() + assert args + + func_name = args.func_name + dtype = STR_DTYPE_TO_TORCH_DTYPE[args.dtype] + + perf_report = triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size", "seq_len", "intermediate_size"], + x_vals=configs, + line_arg="provider", + line_vals=["custom", "compiled"], + line_names=["Custom OP", "Compiled"], + styles=[("blue", "-"), ("green", "-")], + ylabel="ms", + plot_name=f"{func_name}-op-performance", + args={}, + ) + ) + + perf_report( + lambda batch_size, seq_len, intermediate_size, provider: benchmark_activation( + batch_size, seq_len, intermediate_size, provider, func_name, dtype + ) + ).run(print_data=True)