From a73122de96431a7d5f86b1dfd2c028834de2722e Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 13 Mar 2025 16:12:42 +0800 Subject: [PATCH] [Bugfix] fix benchmark moe (#14653) Signed-off-by: Jee Jee Li --- benchmarks/kernels/benchmark_moe.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 233fc35d2cf5..491f8c3962f7 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -365,6 +365,7 @@ class BenchmarkWorker: dtype: torch.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool, + block_quant_shape: List[int] = None, ) -> tuple[dict[str, int], float]: current_platform.seed_everything(self.seed) dtype_str = get_config_dtype_str(dtype, @@ -385,10 +386,17 @@ class BenchmarkWorker: else: config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))] - kernel_time = benchmark_config(config, num_tokens, num_experts, - shard_intermediate_size, hidden_size, - topk, dtype, use_fp8_w8a8, - use_int8_w8a16) + kernel_time = benchmark_config(config, + num_tokens, + num_experts, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + num_iters=100, + block_quant_shape=block_quant_shape) return config, kernel_time def tune( @@ -487,6 +495,14 @@ def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int, f.write("\n") +def get_weight_block_size_safety(config, default_value=None): + + quantization_config = getattr(config, 'quantization_config', {}) + if isinstance(quantization_config, dict): + return quantization_config.get('weight_block_size', default_value) + return default_value + + def main(args: argparse.Namespace): print(args) block_quant_shape = None @@ -508,7 +524,7 @@ def main(args: argparse.Namespace): topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size shard_intermediate_size = 2 * intermediate_size // args.tp_size - block_quant_shape = config.quantization_config['weight_block_size'] + block_quant_shape = get_weight_block_size_safety(config) elif config.architectures[0] == "Qwen2MoeForCausalLM": E = config.num_experts topk = config.num_experts_per_tok