[ROCm][MoE] moe tuning support for rocm (#12049)

Signed-off-by: Divakar Verma <divakar.verma@amd.com>
This commit is contained in:
Divakar Verma 2025-01-17 00:49:16 -06:00 committed by GitHub
parent d75ab55f10
commit 8027a72461
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,7 @@
import argparse import argparse
import time import time
from datetime import datetime from datetime import datetime
from itertools import product
from typing import Any, Dict, List, Tuple, TypedDict from typing import Any, Dict, List, Tuple, TypedDict
import ray import ray
@ -11,7 +12,10 @@ from transformers import AutoConfig
from vllm.model_executor.layers.fused_moe.fused_moe import * from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser, is_navi
FP8_DTYPE = torch.float8_e4m3fnuz if current_platform.is_rocm(
) and not is_navi() else torch.float8_e4m3fn
class BenchmarkConfig(TypedDict): class BenchmarkConfig(TypedDict):
@ -80,8 +84,8 @@ def benchmark_config(
a1_scale = torch.randn(1, dtype=torch.float32) a1_scale = torch.randn(1, dtype=torch.float32)
a2_scale = torch.randn(1, dtype=torch.float32) a2_scale = torch.randn(1, dtype=torch.float32)
w1 = w1.to(torch.float8_e4m3fn) w1 = w1.to(FP8_DTYPE)
w2 = w2.to(torch.float8_e4m3fn) w2 = w2.to(FP8_DTYPE)
input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32) input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32)
@ -141,28 +145,172 @@ def benchmark_config(
return avg return avg
def get_configs_compute_bound() -> List[Dict[str, int]]: def get_rocm_tuning_space(use_fp16):
# Reduced search space for faster tuning. block_mn_range = [16, 32, 64, 128, 256]
# TODO(woosuk): Increase the search space and use a performance model to block_k_range = [16, 32, 64, 128, 256]
# prune the search space. if not use_fp16:
block_k_range.remove(16) # BLOCK_K=16 not supported for fp8
num_warps_range = [1, 2, 4, 8]
group_m_range = [1, 4, 8, 16, 32]
num_stage_range = [2]
waves_per_eu_range = [0]
matrix_instr_nonkdim_range = [16, 32] if use_fp16 else []
kpack_range = [1, 2] if use_fp16 else []
param_ranges = {
"BLOCK_SIZE_M": block_mn_range,
"BLOCK_SIZE_N": block_mn_range,
"BLOCK_SIZE_K": block_k_range,
"GROUP_SIZE_M": group_m_range,
"num_warps": num_warps_range,
"num_stages": num_stage_range,
"waves_per_eu": waves_per_eu_range,
}
if use_fp16:
param_ranges["matrix_instr_nonkdim"] = matrix_instr_nonkdim_range
param_ranges["kpack"] = kpack_range
return param_ranges
def get_configs_compute_bound(use_fp16) -> List[Dict[str, int]]:
configs: List[BenchmarkConfig] = [] configs: List[BenchmarkConfig] = []
for num_stages in [2, 3, 4, 5]:
for block_m in [16, 32, 64, 128, 256]: if current_platform.is_rocm():
for block_k in [64, 128, 256]: param_ranges = get_rocm_tuning_space(use_fp16)
for block_n in [32, 64, 128, 256]: else:
for num_warps in [4, 8]: # Reduced search space for faster tuning.
for group_size in [1, 16, 32, 64]: # TODO(woosuk): Increase the search space and use a performance model to
configs.append({ # prune the search space.
"BLOCK_SIZE_M": block_m, block_m_range = [16, 32, 64, 128, 256]
"BLOCK_SIZE_N": block_n, block_n_range = [32, 64, 128, 256]
"BLOCK_SIZE_K": block_k, block_k_range = [64, 128, 256]
"GROUP_SIZE_M": group_size, num_warps_range = [4, 8]
"num_warps": num_warps, group_m_range = [1, 16, 32, 64]
"num_stages": num_stages, num_stage_range = [2, 3, 4, 5]
})
param_ranges = {
"BLOCK_SIZE_M": block_m_range,
"BLOCK_SIZE_N": block_n_range,
"BLOCK_SIZE_K": block_k_range,
"GROUP_SIZE_M": group_m_range,
"num_warps": num_warps_range,
"num_stages": num_stage_range,
}
keys, values = zip(*param_ranges.items())
for config_values in product(*values):
config = dict(zip(keys, config_values))
configs.append(config)
return configs return configs
def prune_rocm_search_space(num_tokens, shard_intermediate_size, hidden_size,
search_space, is_fp16):
N1, K1 = shard_intermediate_size, hidden_size
N2, K2 = hidden_size, shard_intermediate_size // 2
pruned_space_1 = prune_rocm_configs(num_tokens * 2, N1, K1, search_space,
is_fp16)
pruned_space_2 = prune_rocm_configs(num_tokens * 2, N2, K2, search_space,
is_fp16)
search_space = merge_unique_dicts(pruned_space_1, pruned_space_2)
return search_space
# The following code is inspired by ROCm/Triton GEMM tuning script:
# https://github.com/ROCm/triton/blob/triton-mlir/scripts/amd/gemm/tune_gemm.py#L89
def prune_rocm_configs(M, N, K, configs, is_fp16=True):
pruned_configs = []
elemBytes_a = 2 if is_fp16 else 1
elemBytes_b = 2 if is_fp16 else 1
mfma = 16 if M < 32 or N < 32 else 32
# TODO (zhanglx): figure out the boundary between large and small gemms
large_gemm = False
if M >= 2048 and N >= 2048:
large_gemm = True
for config in configs:
BLOCK_SIZE_M = config.get("BLOCK_SIZE_M")
BLOCK_SIZE_N = config.get("BLOCK_SIZE_N")
BLOCK_SIZE_K = config.get("BLOCK_SIZE_K")
num_warps = config.get("num_warps")
if is_fp16:
matrix_instr_nonkdim = config.get("matrix_instr_nonkdim")
if matrix_instr_nonkdim > mfma:
continue
if mfma == 4 and BLOCK_SIZE_K < 64:
continue
# some layouts could not work properly in case
# number elements per thread is less 1
if BLOCK_SIZE_M * BLOCK_SIZE_N < 64:
continue
SPLIT_K = config.get("SPLIT_K", 1)
GROUP_M = config.get("GROUP_SIZE_M")
if is_fp16:
if (matrix_instr_nonkdim > BLOCK_SIZE_M
or matrix_instr_nonkdim > BLOCK_SIZE_N):
continue
if (matrix_instr_nonkdim >= M
and matrix_instr_nonkdim != BLOCK_SIZE_M):
continue
if (matrix_instr_nonkdim >= N
and matrix_instr_nonkdim != BLOCK_SIZE_N):
continue
# Skip BLOCK_SIZE that is too large compare to M/N
# unless BLOCK_SIZE is already small enough
if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16:
continue
if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16:
continue
# skip large split_k when not necessary
if SPLIT_K != 1 and not need_split_k(M, N, K):
continue
# skip split_k that leads to EVEN_K = false
leap = SPLIT_K * BLOCK_SIZE_K
modv = K % leap
if modv != 0:
continue
# skip large GROUP_M
if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1:
continue
# out of shared memory resource
# TODO (zhanglx): This does not consider the LDS usage in the epilogue
LDS = (BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a +
BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b)
if LDS > 65536:
continue
# Skip small block sizes and num_warps for large gemm
# For fp16 and f8, we want to only use BLOCK_SIZE >= 64
if large_gemm:
if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64:
continue
if BLOCK_SIZE_K < 64:
continue
if num_warps < 4:
continue
pruned_configs.append(config)
return pruned_configs
def need_split_k(SIZE_M, SIZE_N, SIZE_K):
return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024
def merge_unique_dicts(list1, list2):
result = []
combined_list = list1.copy()
combined_list.extend(list2)
for dictionary in combined_list:
if dictionary not in result:
result.append(dictionary)
return result
@ray.remote(num_gpus=1) @ray.remote(num_gpus=1)
class BenchmarkWorker: class BenchmarkWorker:
@ -170,6 +318,10 @@ class BenchmarkWorker:
torch.set_default_device("cuda") torch.set_default_device("cuda")
current_platform.seed_everything(seed) current_platform.seed_everything(seed)
self.seed = seed self.seed = seed
# Get the device ID to allocate tensors and kernels
# on the respective GPU. This is required for Ray to work
# correctly with multi-GPU tuning on the ROCm platform.
self.device_id = int(ray.get_gpu_ids()[0])
def benchmark( def benchmark(
self, self,
@ -217,25 +369,33 @@ class BenchmarkWorker:
) -> Dict[str, int]: ) -> Dict[str, int]:
best_config = None best_config = None
best_time = float("inf") best_time = float("inf")
for config in tqdm(search_space): if current_platform.is_rocm():
try: is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
kernel_time = benchmark_config(config, search_space = prune_rocm_search_space(num_tokens,
num_tokens, shard_intermediate_size,
num_experts, hidden_size, search_space,
shard_intermediate_size, is_fp16)
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a16,
num_iters=10)
except triton.runtime.autotuner.OutOfResources:
# Some configurations may be invalid and fail to compile.
continue
if kernel_time < best_time: with torch.cuda.device(self.device_id):
best_time = kernel_time for config in tqdm(search_space):
best_config = config try:
kernel_time = benchmark_config(config,
num_tokens,
num_experts,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a16,
num_iters=20)
except triton.runtime.autotuner.OutOfResources:
# Some configurations may be invalid and fail to compile.
continue
if kernel_time < best_time:
best_time = kernel_time
best_config = config
now = datetime.now() now = datetime.now()
print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}") print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
assert best_config is not None assert best_config is not None
@ -244,12 +404,27 @@ class BenchmarkWorker:
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig: def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
return { return {
"BLOCK_SIZE_M": config["BLOCK_SIZE_M"], "BLOCK_SIZE_M":
"BLOCK_SIZE_N": config["BLOCK_SIZE_N"], config["BLOCK_SIZE_M"],
"BLOCK_SIZE_K": config["BLOCK_SIZE_K"], "BLOCK_SIZE_N":
"GROUP_SIZE_M": config["GROUP_SIZE_M"], config["BLOCK_SIZE_N"],
"num_warps": config["num_warps"], "BLOCK_SIZE_K":
"num_stages": config["num_stages"], config["BLOCK_SIZE_K"],
"GROUP_SIZE_M":
config["GROUP_SIZE_M"],
"num_warps":
config["num_warps"],
"num_stages":
config["num_stages"],
**({
"waves_per_eu": config["waves_per_eu"]
} if "waves_per_eu" in config else {}),
**({
"matrix_instr_nonkdim": config["matrix_instr_nonkdim"]
} if "matrix_instr_nonkdim" in config else {}),
**({
"kpack": config["kpack"]
} if "kpack" in config else {}),
} }
@ -294,7 +469,7 @@ def main(args: argparse.Namespace):
shard_intermediate_size = 2 * intermediate_size // args.tp_size shard_intermediate_size = 2 * intermediate_size // args.tp_size
hidden_size = config.hidden_size hidden_size = config.hidden_size
dtype = config.torch_dtype dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
use_fp8_w8a8 = args.dtype == "fp8_w8a8" use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16" use_int8_w8a16 = args.dtype == "int8_w8a16"
@ -322,7 +497,8 @@ def main(args: argparse.Namespace):
return ray.get(outputs) return ray.get(outputs)
if args.tune: if args.tune:
search_space = get_configs_compute_bound() is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
search_space = get_configs_compute_bound(is_fp16)
print(f"Start tuning over {len(search_space)} configurations...") print(f"Start tuning over {len(search_space)} configurations...")
start = time.time() start = time.time()