From 384a052971607f1561e734c87c9216f77f47e0fb Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Mon, 11 Aug 2025 15:13:27 +0800 Subject: [PATCH] [Misc] benchmark_moe supports expert parallel (#22251) Signed-off-by: Jee Jee Li --- benchmarks/kernels/benchmark_moe.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 72250e2fb6d2..13bf1be836f6 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -22,10 +22,10 @@ from vllm.utils import FlexibleArgumentParser FP8_DTYPE = current_platform.fp8_dtype() -def ensure_divisibility(numerator, denominator): +def ensure_divisibility(numerator, denominator, text): """Ensure that numerator is divisible by the denominator.""" - assert numerator % denominator == 0, ( - "intermediate_size {} is not divisible by tp {}.".format(numerator, denominator) + assert numerator % denominator == 0, "{} {} is not divisible by tp {}.".format( + text, numerator, denominator ) @@ -577,12 +577,10 @@ def main(args: argparse.Namespace): E = config.ffn_config.moe_num_experts topk = config.ffn_config.moe_top_k intermediate_size = config.ffn_config.ffn_hidden_size - shard_intermediate_size = 2 * intermediate_size // args.tp_size elif config.architectures[0] == "JambaForCausalLM": E = config.num_experts topk = config.num_experts_per_tok intermediate_size = config.intermediate_size - shard_intermediate_size = 2 * intermediate_size // args.tp_size elif config.architectures[0] in ( "DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM", @@ -591,17 +589,14 @@ def main(args: argparse.Namespace): E = config.n_routed_experts topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size - shard_intermediate_size = 2 * intermediate_size // args.tp_size elif config.architectures[0] in ("Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"): E = config.num_experts topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size - shard_intermediate_size = 2 * intermediate_size // args.tp_size elif config.architectures[0] in ("HunYuanMoEV1ForCausalLM"): E = config.num_experts topk = config.moe_topk[0] intermediate_size = config.moe_intermediate_size[0] - shard_intermediate_size = 2 * intermediate_size // args.tp_size else: # Support for llama4 config = config.get_text_config() @@ -609,8 +604,14 @@ def main(args: argparse.Namespace): E = config.num_local_experts topk = config.num_experts_per_tok intermediate_size = config.intermediate_size + enable_ep = bool(args.enable_expert_parallel) + if enable_ep: + ensure_divisibility(E, args.tp_size, "Number of experts") + E = E // args.tp_size + shard_intermediate_size = 2 * intermediate_size + else: + ensure_divisibility(intermediate_size, args.tp_size, "intermediate_size") shard_intermediate_size = 2 * intermediate_size // args.tp_size - ensure_divisibility(intermediate_size, args.tp_size) hidden_size = config.hidden_size dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype use_fp8_w8a8 = args.dtype == "fp8_w8a8" @@ -742,6 +743,7 @@ if __name__ == "__main__": parser.add_argument( "--tp-size", "-tp", "--tensor-parallel-size", type=int, default=2 ) + parser.add_argument("--enable-expert-parallel", "-enable-ep", action="store_true") parser.add_argument( "--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto" )