mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 14:35:39 +08:00
[Bugfix] Fix benchmark_moe.py (#19016)
Signed-off-by: Tianyu Guo <guoty9@mail2.sysu.edu.cn>
This commit is contained in:
parent
cc867be19c
commit
4589b94032
@ -7,7 +7,6 @@ import time
|
|||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from itertools import product
|
from itertools import product
|
||||||
from types import SimpleNamespace
|
|
||||||
from typing import Any, TypedDict
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
@ -43,7 +42,7 @@ def benchmark_config(
|
|||||||
use_fp8_w8a8: bool,
|
use_fp8_w8a8: bool,
|
||||||
use_int8_w8a16: bool,
|
use_int8_w8a16: bool,
|
||||||
num_iters: int = 100,
|
num_iters: int = 100,
|
||||||
block_quant_shape: List[int] = None,
|
block_quant_shape: list[int] = None,
|
||||||
use_deep_gemm: bool = False,
|
use_deep_gemm: bool = False,
|
||||||
) -> float:
|
) -> float:
|
||||||
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
||||||
@ -400,7 +399,7 @@ class BenchmarkWorker:
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
use_fp8_w8a8: bool,
|
use_fp8_w8a8: bool,
|
||||||
use_int8_w8a16: bool,
|
use_int8_w8a16: bool,
|
||||||
block_quant_shape: List[int] = None,
|
block_quant_shape: list[int] = None,
|
||||||
use_deep_gemm: bool = False,
|
use_deep_gemm: bool = False,
|
||||||
) -> tuple[dict[str, int], float]:
|
) -> tuple[dict[str, int], float]:
|
||||||
current_platform.seed_everything(self.seed)
|
current_platform.seed_everything(self.seed)
|
||||||
@ -532,7 +531,7 @@ def save_configs(
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
use_fp8_w8a8: bool,
|
use_fp8_w8a8: bool,
|
||||||
use_int8_w8a16: bool,
|
use_int8_w8a16: bool,
|
||||||
block_quant_shape: List[int],
|
block_quant_shape: list[int],
|
||||||
) -> None:
|
) -> None:
|
||||||
dtype_str = get_config_dtype_str(
|
dtype_str = get_config_dtype_str(
|
||||||
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
|
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
|
||||||
@ -563,7 +562,6 @@ def main(args: argparse.Namespace):
|
|||||||
config = get_config(model=args.model, trust_remote_code=args.trust_remote_code)
|
config = get_config(model=args.model, trust_remote_code=args.trust_remote_code)
|
||||||
if args.model_prefix:
|
if args.model_prefix:
|
||||||
config = getattr(config, args.model_prefix)
|
config = getattr(config, args.model_prefix)
|
||||||
config = SimpleNamespace(**config)
|
|
||||||
|
|
||||||
if config.architectures[0] == "DbrxForCausalLM":
|
if config.architectures[0] == "DbrxForCausalLM":
|
||||||
E = config.ffn_config.moe_num_experts
|
E = config.ffn_config.moe_num_experts
|
||||||
@ -595,11 +593,7 @@ def main(args: argparse.Namespace):
|
|||||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||||
|
|
||||||
hidden_size = config.hidden_size
|
hidden_size = config.hidden_size
|
||||||
dtype = (
|
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
|
||||||
torch.float16
|
|
||||||
if current_platform.is_rocm()
|
|
||||||
else getattr(torch, config.torch_dtype)
|
|
||||||
)
|
|
||||||
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
||||||
use_int8_w8a16 = args.dtype == "int8_w8a16"
|
use_int8_w8a16 = args.dtype == "int8_w8a16"
|
||||||
block_quant_shape = get_weight_block_size_safety(config)
|
block_quant_shape = get_weight_block_size_safety(config)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user