diff --git a/benchmarks/kernels/benchmark_mixtral_moe.py b/benchmarks/kernels/benchmark_mixtral_moe.py index 5280b214144c..196ec8cfce88 100644 --- a/benchmarks/kernels/benchmark_mixtral_moe.py +++ b/benchmarks/kernels/benchmark_mixtral_moe.py @@ -11,25 +11,36 @@ from tqdm import tqdm from vllm.model_executor.layers.fused_moe import (fused_moe, get_config_file_name) -os.environ['CUDA_VISIBLE_DEVICES'] = '0' - -def main(dtype: str): +def main(model, tp_size, gpu, dtype: str): + os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu) method = fused_moe for bs in [ 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, 2048, 3072, 4096 ]: - run_grid(bs, method=method, dtype=dtype) + run_grid(bs, + model=model, + method=method, + gpu=gpu, + tp_size=tp_size, + dtype=dtype) -def run_grid(bs, method, dtype: str): - d_model = 4096 +def run_grid(bs, model, method, gpu, tp_size, dtype: str): + if model == '8x7B': + d_model = 4096 + model_intermediate_size = 14336 + num_layers = 32 + elif model == '8x22B': + d_model = 6144 + model_intermediate_size = 16384 + num_layers = 56 + else: + raise ValueError(f'Unsupported Mixtral model {model}') num_total_experts = 8 top_k = 2 - tp_size = 2 - model_intermediate_size = 14336 - num_layers = 32 + # tp_size = 2 num_calls = 100 num_warmup_trials = 1 @@ -211,5 +222,18 @@ if __name__ == "__main__": choices=['float8', 'float16'], help='Data type used for fused_moe kernel computations', ) + parser.add_argument('--model', + type=str, + default='8x7B', + choices=['8x7B', '8x22B'], + help='The Mixtral model to benchmark') + parser.add_argument('--tp-size', + type=int, + default=2, + help='Tensor paralleli size') + parser.add_argument('--gpu', + type=int, + default=0, + help="GPU ID for benchmarking") args = parser.parse_args() - sys.exit(main(args.dtype)) + sys.exit(main(args.model, args.tp_size, args.gpu, args.dtype)) diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json new file mode 100644 index 000000000000..3f3ccdafa88f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json @@ -0,0 +1,138 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json new file mode 100644 index 000000000000..0c495e7e290c --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json index 9287808a94d0..5b78c30f08b6 100644 --- a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json @@ -3,61 +3,59 @@ "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1 + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 }, "2": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1 + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 }, "4": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1 + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 }, "8": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, "num_warps": 8, - "num_stages": 5 + "num_stages": 2 }, "16": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 16, - "num_warps": 4, + "GROUP_SIZE_M": 1, + "num_warps": 8, "num_stages": 5 }, "24": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 5 - }, - "32": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, "num_stages": 4 }, "48": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 3 - }, - "64": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, @@ -65,37 +63,45 @@ "num_warps": 4, "num_stages": 4 }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, "96": { "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 32, "num_warps": 8, - "num_stages": 2 - }, - "128": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 3 - }, - "256": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, "num_stages": 5 }, - "512": { + "256": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 2 + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 }, "1024": { "BLOCK_SIZE_M": 128, @@ -109,7 +115,7 @@ "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4 }, @@ -125,7 +131,7 @@ "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 4 }, diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json new file mode 100644 index 000000000000..60a65724d68b --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json index 2ad07bf79a25..75f8b0017b9c 100644 --- a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json @@ -2,104 +2,104 @@ "1": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "2": { - "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, + "GROUP_SIZE_M": 32, + "num_warps": 8, "num_stages": 4 }, "4": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 4 - }, - "8": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 4 + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 }, "16": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 8, + "num_warps": 4, "num_stages": 4 }, "24": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 4 + "num_warps": 4, + "num_stages": 5 }, "32": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, "num_stages": 4 }, "48": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "64": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, "num_stages": 4 }, "96": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4 }, "128": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 4 + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 }, "256": { "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 8, - "num_stages": 4 + "num_stages": 5 }, "512": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4 }, @@ -115,7 +115,7 @@ "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4 }, @@ -139,7 +139,7 @@ "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 4 } diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json new file mode 100644 index 000000000000..34b916e574f8 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index d6dd7fa1fe9e..2f4237339486 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -278,15 +278,6 @@ class MixtralAttention(nn.Module): self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta - if isinstance( - quant_config, - Fp8Config) and not quant_config.is_checkpoint_fp8_serialized: - print_warning_once( - "For Mixtral FP8 quantization, we currently do not quantize " - "the attention layers until their FP8 performance is improved." - ) - quant_config = None - self.qkv_proj = QKVParallelLinear( hidden_size, self.head_dim,