diff --git a/benchmarks/kernels/benchmark_mixtral_moe.py b/benchmarks/kernels/benchmark_mixtral_moe.py index 8e976fbcb302..5280b214144c 100644 --- a/benchmarks/kernels/benchmark_mixtral_moe.py +++ b/benchmarks/kernels/benchmark_mixtral_moe.py @@ -1,3 +1,4 @@ +import argparse import json import os import sys @@ -5,6 +6,7 @@ import sys import torch import torch.nn.functional as F import triton +from tqdm import tqdm from vllm.model_executor.layers.fused_moe import (fused_moe, get_config_file_name) @@ -12,16 +14,16 @@ from vllm.model_executor.layers.fused_moe import (fused_moe, os.environ['CUDA_VISIBLE_DEVICES'] = '0' -def main(): +def main(dtype: str): method = fused_moe for bs in [ 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, 2048, 3072, 4096 ]: - run_grid(bs, method=method) + run_grid(bs, method=method, dtype=dtype) -def run_grid(bs, method): +def run_grid(bs, method, dtype: str): d_model = 4096 num_total_experts = 8 top_k = 2 @@ -34,39 +36,29 @@ def run_grid(bs, method): num_trials = 1 configs = [] - if bs <= 16: - BLOCK_SIZES_M = [16] - elif bs <= 32: - BLOCK_SIZES_M = [16, 32] - elif bs <= 64: - BLOCK_SIZES_M = [16, 32, 64] - elif bs <= 128: - BLOCK_SIZES_M = [16, 32, 64, 128] - else: - BLOCK_SIZES_M = [16, 32, 64, 128, 256] for block_size_n in [32, 64, 128, 256]: - for block_size_m in BLOCK_SIZES_M: + for block_size_m in [16, 32, 64, 128, 256]: for block_size_k in [64, 128, 256]: for group_size_m in [1, 16, 32, 64]: for num_warps in [4, 8]: - configs.append({ - "BLOCK_SIZE_M": block_size_m, - "BLOCK_SIZE_N": block_size_n, - "BLOCK_SIZE_K": block_size_k, - "GROUP_SIZE_M": group_size_m, - "num_warps": num_warps, - "num_stages": 4, - }) + for num_stages in [2, 3, 4, 5]: + configs.append({ + "BLOCK_SIZE_M": block_size_m, + "BLOCK_SIZE_N": block_size_n, + "BLOCK_SIZE_K": block_size_k, + "GROUP_SIZE_M": group_size_m, + "num_warps": num_warps, + "num_stages": num_stages, + }) best_config = None best_time_us = 1e20 - for config in configs: - print(f'{tp_size=} {bs=}') - print(f'{config}') + print(f'{tp_size=} {bs=}') + + for config in tqdm(configs): # warmup - print('warming up') try: for _ in range(num_warmup_trials): run_timing( @@ -79,12 +71,12 @@ def run_grid(bs, method): model_intermediate_size=model_intermediate_size, method=method, config=config, + dtype=dtype, ) except triton.runtime.autotuner.OutOfResources: continue # trial - print('benchmarking') for _ in range(num_trials): kernel_dur_ms = run_timing( num_calls=num_calls, @@ -96,6 +88,7 @@ def run_grid(bs, method): model_intermediate_size=model_intermediate_size, method=method, config=config, + dtype=dtype, ) kernel_dur_us = 1000 * kernel_dur_ms @@ -105,16 +98,18 @@ def run_grid(bs, method): best_config = config best_time_us = kernel_dur_us - print(f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f}' - f' {bs=} {tp_size=} {top_k=} {num_total_experts=} ' - f'{d_model=} {model_intermediate_size=} {num_layers=}') + tqdm.write( + f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f}' + f' {bs=} {tp_size=} {top_k=} {num_total_experts=} ' + f'{d_model=} {model_intermediate_size=} {num_layers=}') print("best_time_us", best_time_us) print("best_config", best_config) # holds Dict[str, Dict[str, int]] filename = get_config_file_name(num_total_experts, - model_intermediate_size // tp_size) + model_intermediate_size // tp_size, + "float8" if dtype == "float8" else None) print(f"writing config to file {filename}") existing_content = {} if os.path.exists(filename): @@ -128,27 +123,48 @@ def run_grid(bs, method): def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int, top_k: int, tp_size: int, model_intermediate_size: int, method, - config) -> float: + config, dtype: str) -> float: shard_intermediate_size = model_intermediate_size // tp_size hidden_states = torch.rand( (bs, d_model), device="cuda:0", - dtype=torch.bfloat16, + dtype=torch.float16, ) - ws = torch.rand( + w1 = torch.rand( (num_total_experts, 2 * shard_intermediate_size, d_model), device=hidden_states.device, dtype=hidden_states.dtype, ) - w2s = torch.rand( + w2 = torch.rand( (num_total_experts, d_model, shard_intermediate_size), device=hidden_states.device, dtype=hidden_states.dtype, ) + w1_scale = None + w2_scale = None + a1_scale = None + a2_scale = None + + if dtype == "float8": + w1 = w1.to(torch.float8_e4m3fn) + w2 = w2.to(torch.float8_e4m3fn) + w1_scale = torch.ones(num_total_experts, + device=hidden_states.device, + dtype=torch.float32) + w2_scale = torch.ones(num_total_experts, + device=hidden_states.device, + dtype=torch.float32) + a1_scale = torch.ones(1, + device=hidden_states.device, + dtype=torch.float32) + a2_scale = torch.ones(1, + device=hidden_states.device, + dtype=torch.float32) + gating_output = F.softmax(torch.rand( (num_calls, bs, num_total_experts), device=hidden_states.device, @@ -163,13 +179,18 @@ def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int, for i in range(num_calls): hidden_states = method( hidden_states=hidden_states, - w1=ws, - w2=w2s, + w1=w1, + w2=w2, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, gating_output=gating_output[i], topk=2, renormalize=True, inplace=True, override_config=config, + use_fp8=dtype == "float8", ) end_event.record() end_event.synchronize() @@ -179,4 +200,16 @@ def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int, if __name__ == "__main__": - sys.exit(main()) + parser = argparse.ArgumentParser( + prog='benchmark_mixtral_moe', + description='Benchmark and tune the fused_moe kernel', + ) + parser.add_argument( + '--dtype', + type=str, + default='auto', + choices=['float8', 'float16'], + help='Data type used for fused_moe kernel computations', + ) + args = parser.parse_args() + sys.exit(main(args.dtype)) diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json new file mode 100644 index 000000000000..9287808a94d0 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json @@ -0,0 +1,140 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +}