[Bugfix] Fix benchmark_moe.py (#19016)

Signed-off-by: Tianyu Guo <guoty9@mail2.sysu.edu.cn>
This commit is contained in:
Tianyu Guo 2025-06-10 09:04:36 +08:00 committed by GitHub
parent cc867be19c
commit 4589b94032
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -7,7 +7,6 @@ import time
from contextlib import nullcontext from contextlib import nullcontext
from datetime import datetime from datetime import datetime
from itertools import product from itertools import product
from types import SimpleNamespace
from typing import Any, TypedDict from typing import Any, TypedDict
import ray import ray
@ -43,7 +42,7 @@ def benchmark_config(
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
num_iters: int = 100, num_iters: int = 100,
block_quant_shape: List[int] = None, block_quant_shape: list[int] = None,
use_deep_gemm: bool = False, use_deep_gemm: bool = False,
) -> float: ) -> float:
init_dtype = torch.float16 if use_fp8_w8a8 else dtype init_dtype = torch.float16 if use_fp8_w8a8 else dtype
@ -400,7 +399,7 @@ class BenchmarkWorker:
dtype: torch.dtype, dtype: torch.dtype,
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
block_quant_shape: List[int] = None, block_quant_shape: list[int] = None,
use_deep_gemm: bool = False, use_deep_gemm: bool = False,
) -> tuple[dict[str, int], float]: ) -> tuple[dict[str, int], float]:
current_platform.seed_everything(self.seed) current_platform.seed_everything(self.seed)
@ -532,7 +531,7 @@ def save_configs(
dtype: torch.dtype, dtype: torch.dtype,
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
block_quant_shape: List[int], block_quant_shape: list[int],
) -> None: ) -> None:
dtype_str = get_config_dtype_str( dtype_str = get_config_dtype_str(
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
@ -563,7 +562,6 @@ def main(args: argparse.Namespace):
config = get_config(model=args.model, trust_remote_code=args.trust_remote_code) config = get_config(model=args.model, trust_remote_code=args.trust_remote_code)
if args.model_prefix: if args.model_prefix:
config = getattr(config, args.model_prefix) config = getattr(config, args.model_prefix)
config = SimpleNamespace(**config)
if config.architectures[0] == "DbrxForCausalLM": if config.architectures[0] == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts E = config.ffn_config.moe_num_experts
@ -595,11 +593,7 @@ def main(args: argparse.Namespace):
shard_intermediate_size = 2 * intermediate_size // args.tp_size shard_intermediate_size = 2 * intermediate_size // args.tp_size
hidden_size = config.hidden_size hidden_size = config.hidden_size
dtype = ( dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
torch.float16
if current_platform.is_rocm()
else getattr(torch, config.torch_dtype)
)
use_fp8_w8a8 = args.dtype == "fp8_w8a8" use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16" use_int8_w8a16 = args.dtype == "int8_w8a16"
block_quant_shape = get_weight_block_size_safety(config) block_quant_shape = get_weight_block_size_safety(config)