diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 02c2db674d4b0..d3040e9738f7b 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -579,10 +579,12 @@ def main(args: argparse.Namespace): E = config.ffn_config.moe_num_experts topk = config.ffn_config.moe_top_k intermediate_size = config.ffn_config.ffn_hidden_size + hidden_size = config.hidden_size elif config.architectures[0] == "JambaForCausalLM": E = config.num_experts topk = config.num_experts_per_tok intermediate_size = config.intermediate_size + hidden_size = config.hidden_size elif config.architectures[0] in ( "DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM", @@ -592,6 +594,7 @@ def main(args: argparse.Namespace): E = config.n_routed_experts topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size + hidden_size = config.hidden_size elif config.architectures[0] in ( "Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM", @@ -600,10 +603,18 @@ def main(args: argparse.Namespace): E = config.num_experts topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size + hidden_size = config.hidden_size + elif config.architectures[0] == "Qwen3VLMoeForConditionalGeneration": + text_config = config.get_text_config() + E = text_config.num_experts + topk = text_config.num_experts_per_tok + intermediate_size = text_config.moe_intermediate_size + hidden_size = text_config.hidden_size elif config.architectures[0] in ("HunYuanMoEV1ForCausalLM"): E = config.num_experts topk = config.moe_topk[0] intermediate_size = config.moe_intermediate_size[0] + hidden_size = config.hidden_size else: # Support for llama4 config = config.get_text_config() @@ -611,6 +622,7 @@ def main(args: argparse.Namespace): E = config.num_local_experts topk = config.num_experts_per_tok intermediate_size = config.intermediate_size + hidden_size = config.hidden_size enable_ep = bool(args.enable_expert_parallel) if enable_ep: ensure_divisibility(E, args.tp_size, "Number of experts") @@ -619,7 +631,6 @@ def main(args: argparse.Namespace): else: ensure_divisibility(intermediate_size, args.tp_size, "intermediate_size") shard_intermediate_size = 2 * intermediate_size // args.tp_size - hidden_size = config.hidden_size dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype use_fp8_w8a8 = args.dtype == "fp8_w8a8" use_int8_w8a16 = args.dtype == "int8_w8a16"