From 8d705996dffbb2299750b7b2b50bbcd5ccb4a5ad Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Sat, 2 Aug 2025 01:35:30 +0800 Subject: [PATCH] [Misc] Minor enhancement of benchmark_moe (#22068) Signed-off-by: Jee Jee Li --- benchmarks/kernels/benchmark_moe.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index c350aaf5d3ad2..72250e2fb6d2b 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -22,6 +22,13 @@ from vllm.utils import FlexibleArgumentParser FP8_DTYPE = current_platform.fp8_dtype() +def ensure_divisibility(numerator, denominator): + """Ensure that numerator is divisible by the denominator.""" + assert numerator % denominator == 0, ( + "intermediate_size {} is not divisible by tp {}.".format(numerator, denominator) + ) + + class BenchmarkConfig(TypedDict): BLOCK_SIZE_M: int BLOCK_SIZE_N: int @@ -603,7 +610,7 @@ def main(args: argparse.Namespace): topk = config.num_experts_per_tok intermediate_size = config.intermediate_size shard_intermediate_size = 2 * intermediate_size // args.tp_size - + ensure_divisibility(intermediate_size, args.tp_size) hidden_size = config.hidden_size dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype use_fp8_w8a8 = args.dtype == "fp8_w8a8"