[Bugfix] fix benchmark moe (#14653)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-03-13 16:12:42 +08:00 committed by GitHub
parent bd44b812cb
commit a73122de96
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -365,6 +365,7 @@ class BenchmarkWorker:
dtype: torch.dtype, dtype: torch.dtype,
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
block_quant_shape: List[int] = None,
) -> tuple[dict[str, int], float]: ) -> tuple[dict[str, int], float]:
current_platform.seed_everything(self.seed) current_platform.seed_everything(self.seed)
dtype_str = get_config_dtype_str(dtype, dtype_str = get_config_dtype_str(dtype,
@ -385,10 +386,17 @@ class BenchmarkWorker:
else: else:
config = op_config[min(op_config.keys(), config = op_config[min(op_config.keys(),
key=lambda x: abs(x - num_tokens))] key=lambda x: abs(x - num_tokens))]
kernel_time = benchmark_config(config, num_tokens, num_experts, kernel_time = benchmark_config(config,
shard_intermediate_size, hidden_size, num_tokens,
topk, dtype, use_fp8_w8a8, num_experts,
use_int8_w8a16) shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a16,
num_iters=100,
block_quant_shape=block_quant_shape)
return config, kernel_time return config, kernel_time
def tune( def tune(
@ -487,6 +495,14 @@ def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int,
f.write("\n") f.write("\n")
def get_weight_block_size_safety(config, default_value=None):
quantization_config = getattr(config, 'quantization_config', {})
if isinstance(quantization_config, dict):
return quantization_config.get('weight_block_size', default_value)
return default_value
def main(args: argparse.Namespace): def main(args: argparse.Namespace):
print(args) print(args)
block_quant_shape = None block_quant_shape = None
@ -508,7 +524,7 @@ def main(args: argparse.Namespace):
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size shard_intermediate_size = 2 * intermediate_size // args.tp_size
block_quant_shape = config.quantization_config['weight_block_size'] block_quant_shape = get_weight_block_size_safety(config)
elif config.architectures[0] == "Qwen2MoeForCausalLM": elif config.architectures[0] == "Qwen2MoeForCausalLM":
E = config.num_experts E = config.num_experts
topk = config.num_experts_per_tok topk = config.num_experts_per_tok