diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index c350aaf5d3ad2..72250e2fb6d2b 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -22,6 +22,13 @@ from vllm.utils import FlexibleArgumentParser FP8_DTYPE = current_platform.fp8_dtype() +def ensure_divisibility(numerator, denominator): + """Ensure that numerator is divisible by the denominator.""" + assert numerator % denominator == 0, ( + "intermediate_size {} is not divisible by tp {}.".format(numerator, denominator) + ) + + class BenchmarkConfig(TypedDict): BLOCK_SIZE_M: int BLOCK_SIZE_N: int @@ -603,7 +610,7 @@ def main(args: argparse.Namespace): topk = config.num_experts_per_tok intermediate_size = config.intermediate_size shard_intermediate_size = 2 * intermediate_size // args.tp_size - + ensure_divisibility(intermediate_size, args.tp_size) hidden_size = config.hidden_size dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype use_fp8_w8a8 = args.dtype == "fp8_w8a8"