diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 6cb55b35993ef..cef53b183cef3 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -7,7 +7,6 @@ import time from contextlib import nullcontext from datetime import datetime from itertools import product -from types import SimpleNamespace from typing import Any, TypedDict import ray @@ -43,7 +42,7 @@ def benchmark_config( use_fp8_w8a8: bool, use_int8_w8a16: bool, num_iters: int = 100, - block_quant_shape: List[int] = None, + block_quant_shape: list[int] = None, use_deep_gemm: bool = False, ) -> float: init_dtype = torch.float16 if use_fp8_w8a8 else dtype @@ -400,7 +399,7 @@ class BenchmarkWorker: dtype: torch.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool, - block_quant_shape: List[int] = None, + block_quant_shape: list[int] = None, use_deep_gemm: bool = False, ) -> tuple[dict[str, int], float]: current_platform.seed_everything(self.seed) @@ -532,7 +531,7 @@ def save_configs( dtype: torch.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool, - block_quant_shape: List[int], + block_quant_shape: list[int], ) -> None: dtype_str = get_config_dtype_str( dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 @@ -563,7 +562,6 @@ def main(args: argparse.Namespace): config = get_config(model=args.model, trust_remote_code=args.trust_remote_code) if args.model_prefix: config = getattr(config, args.model_prefix) - config = SimpleNamespace(**config) if config.architectures[0] == "DbrxForCausalLM": E = config.ffn_config.moe_num_experts @@ -595,11 +593,7 @@ def main(args: argparse.Namespace): shard_intermediate_size = 2 * intermediate_size // args.tp_size hidden_size = config.hidden_size - dtype = ( - torch.float16 - if current_platform.is_rocm() - else getattr(torch, config.torch_dtype) - ) + dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype use_fp8_w8a8 = args.dtype == "fp8_w8a8" use_int8_w8a16 = args.dtype == "int8_w8a16" block_quant_shape = get_weight_block_size_safety(config)