mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:15:20 +08:00
[Bugfix] fix benchmark moe (#14653)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
bd44b812cb
commit
a73122de96
@ -365,6 +365,7 @@ class BenchmarkWorker:
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
use_fp8_w8a8: bool,
|
use_fp8_w8a8: bool,
|
||||||
use_int8_w8a16: bool,
|
use_int8_w8a16: bool,
|
||||||
|
block_quant_shape: List[int] = None,
|
||||||
) -> tuple[dict[str, int], float]:
|
) -> tuple[dict[str, int], float]:
|
||||||
current_platform.seed_everything(self.seed)
|
current_platform.seed_everything(self.seed)
|
||||||
dtype_str = get_config_dtype_str(dtype,
|
dtype_str = get_config_dtype_str(dtype,
|
||||||
@ -385,10 +386,17 @@ class BenchmarkWorker:
|
|||||||
else:
|
else:
|
||||||
config = op_config[min(op_config.keys(),
|
config = op_config[min(op_config.keys(),
|
||||||
key=lambda x: abs(x - num_tokens))]
|
key=lambda x: abs(x - num_tokens))]
|
||||||
kernel_time = benchmark_config(config, num_tokens, num_experts,
|
kernel_time = benchmark_config(config,
|
||||||
shard_intermediate_size, hidden_size,
|
num_tokens,
|
||||||
topk, dtype, use_fp8_w8a8,
|
num_experts,
|
||||||
use_int8_w8a16)
|
shard_intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
topk,
|
||||||
|
dtype,
|
||||||
|
use_fp8_w8a8,
|
||||||
|
use_int8_w8a16,
|
||||||
|
num_iters=100,
|
||||||
|
block_quant_shape=block_quant_shape)
|
||||||
return config, kernel_time
|
return config, kernel_time
|
||||||
|
|
||||||
def tune(
|
def tune(
|
||||||
@ -487,6 +495,14 @@ def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int,
|
|||||||
f.write("\n")
|
f.write("\n")
|
||||||
|
|
||||||
|
|
||||||
|
def get_weight_block_size_safety(config, default_value=None):
|
||||||
|
|
||||||
|
quantization_config = getattr(config, 'quantization_config', {})
|
||||||
|
if isinstance(quantization_config, dict):
|
||||||
|
return quantization_config.get('weight_block_size', default_value)
|
||||||
|
return default_value
|
||||||
|
|
||||||
|
|
||||||
def main(args: argparse.Namespace):
|
def main(args: argparse.Namespace):
|
||||||
print(args)
|
print(args)
|
||||||
block_quant_shape = None
|
block_quant_shape = None
|
||||||
@ -508,7 +524,7 @@ def main(args: argparse.Namespace):
|
|||||||
topk = config.num_experts_per_tok
|
topk = config.num_experts_per_tok
|
||||||
intermediate_size = config.moe_intermediate_size
|
intermediate_size = config.moe_intermediate_size
|
||||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||||
block_quant_shape = config.quantization_config['weight_block_size']
|
block_quant_shape = get_weight_block_size_safety(config)
|
||||||
elif config.architectures[0] == "Qwen2MoeForCausalLM":
|
elif config.architectures[0] == "Qwen2MoeForCausalLM":
|
||||||
E = config.num_experts
|
E = config.num_experts
|
||||||
topk = config.num_experts_per_tok
|
topk = config.num_experts_per_tok
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user