mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 15:36:29 +08:00
[ROCm][MoE] moe tuning support for rocm (#12049)
Signed-off-by: Divakar Verma <divakar.verma@amd.com>
This commit is contained in:
parent
d75ab55f10
commit
8027a72461
@ -1,6 +1,7 @@
|
||||
import argparse
|
||||
import time
|
||||
from datetime import datetime
|
||||
from itertools import product
|
||||
from typing import Any, Dict, List, Tuple, TypedDict
|
||||
|
||||
import ray
|
||||
@ -11,7 +12,10 @@ from transformers import AutoConfig
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import *
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
from vllm.utils import FlexibleArgumentParser, is_navi
|
||||
|
||||
FP8_DTYPE = torch.float8_e4m3fnuz if current_platform.is_rocm(
|
||||
) and not is_navi() else torch.float8_e4m3fn
|
||||
|
||||
|
||||
class BenchmarkConfig(TypedDict):
|
||||
@ -80,8 +84,8 @@ def benchmark_config(
|
||||
a1_scale = torch.randn(1, dtype=torch.float32)
|
||||
a2_scale = torch.randn(1, dtype=torch.float32)
|
||||
|
||||
w1 = w1.to(torch.float8_e4m3fn)
|
||||
w2 = w2.to(torch.float8_e4m3fn)
|
||||
w1 = w1.to(FP8_DTYPE)
|
||||
w2 = w2.to(FP8_DTYPE)
|
||||
|
||||
input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32)
|
||||
|
||||
@ -141,28 +145,172 @@ def benchmark_config(
|
||||
return avg
|
||||
|
||||
|
||||
def get_configs_compute_bound() -> List[Dict[str, int]]:
|
||||
# Reduced search space for faster tuning.
|
||||
# TODO(woosuk): Increase the search space and use a performance model to
|
||||
# prune the search space.
|
||||
def get_rocm_tuning_space(use_fp16):
|
||||
block_mn_range = [16, 32, 64, 128, 256]
|
||||
block_k_range = [16, 32, 64, 128, 256]
|
||||
if not use_fp16:
|
||||
block_k_range.remove(16) # BLOCK_K=16 not supported for fp8
|
||||
num_warps_range = [1, 2, 4, 8]
|
||||
group_m_range = [1, 4, 8, 16, 32]
|
||||
num_stage_range = [2]
|
||||
waves_per_eu_range = [0]
|
||||
matrix_instr_nonkdim_range = [16, 32] if use_fp16 else []
|
||||
kpack_range = [1, 2] if use_fp16 else []
|
||||
|
||||
param_ranges = {
|
||||
"BLOCK_SIZE_M": block_mn_range,
|
||||
"BLOCK_SIZE_N": block_mn_range,
|
||||
"BLOCK_SIZE_K": block_k_range,
|
||||
"GROUP_SIZE_M": group_m_range,
|
||||
"num_warps": num_warps_range,
|
||||
"num_stages": num_stage_range,
|
||||
"waves_per_eu": waves_per_eu_range,
|
||||
}
|
||||
if use_fp16:
|
||||
param_ranges["matrix_instr_nonkdim"] = matrix_instr_nonkdim_range
|
||||
param_ranges["kpack"] = kpack_range
|
||||
|
||||
return param_ranges
|
||||
|
||||
|
||||
def get_configs_compute_bound(use_fp16) -> List[Dict[str, int]]:
|
||||
configs: List[BenchmarkConfig] = []
|
||||
for num_stages in [2, 3, 4, 5]:
|
||||
for block_m in [16, 32, 64, 128, 256]:
|
||||
for block_k in [64, 128, 256]:
|
||||
for block_n in [32, 64, 128, 256]:
|
||||
for num_warps in [4, 8]:
|
||||
for group_size in [1, 16, 32, 64]:
|
||||
configs.append({
|
||||
"BLOCK_SIZE_M": block_m,
|
||||
"BLOCK_SIZE_N": block_n,
|
||||
"BLOCK_SIZE_K": block_k,
|
||||
"GROUP_SIZE_M": group_size,
|
||||
"num_warps": num_warps,
|
||||
"num_stages": num_stages,
|
||||
})
|
||||
|
||||
if current_platform.is_rocm():
|
||||
param_ranges = get_rocm_tuning_space(use_fp16)
|
||||
else:
|
||||
# Reduced search space for faster tuning.
|
||||
# TODO(woosuk): Increase the search space and use a performance model to
|
||||
# prune the search space.
|
||||
block_m_range = [16, 32, 64, 128, 256]
|
||||
block_n_range = [32, 64, 128, 256]
|
||||
block_k_range = [64, 128, 256]
|
||||
num_warps_range = [4, 8]
|
||||
group_m_range = [1, 16, 32, 64]
|
||||
num_stage_range = [2, 3, 4, 5]
|
||||
|
||||
param_ranges = {
|
||||
"BLOCK_SIZE_M": block_m_range,
|
||||
"BLOCK_SIZE_N": block_n_range,
|
||||
"BLOCK_SIZE_K": block_k_range,
|
||||
"GROUP_SIZE_M": group_m_range,
|
||||
"num_warps": num_warps_range,
|
||||
"num_stages": num_stage_range,
|
||||
}
|
||||
|
||||
keys, values = zip(*param_ranges.items())
|
||||
for config_values in product(*values):
|
||||
config = dict(zip(keys, config_values))
|
||||
configs.append(config)
|
||||
return configs
|
||||
|
||||
|
||||
def prune_rocm_search_space(num_tokens, shard_intermediate_size, hidden_size,
|
||||
search_space, is_fp16):
|
||||
N1, K1 = shard_intermediate_size, hidden_size
|
||||
N2, K2 = hidden_size, shard_intermediate_size // 2
|
||||
pruned_space_1 = prune_rocm_configs(num_tokens * 2, N1, K1, search_space,
|
||||
is_fp16)
|
||||
pruned_space_2 = prune_rocm_configs(num_tokens * 2, N2, K2, search_space,
|
||||
is_fp16)
|
||||
search_space = merge_unique_dicts(pruned_space_1, pruned_space_2)
|
||||
return search_space
|
||||
|
||||
|
||||
# The following code is inspired by ROCm/Triton GEMM tuning script:
|
||||
# https://github.com/ROCm/triton/blob/triton-mlir/scripts/amd/gemm/tune_gemm.py#L89
|
||||
def prune_rocm_configs(M, N, K, configs, is_fp16=True):
|
||||
pruned_configs = []
|
||||
elemBytes_a = 2 if is_fp16 else 1
|
||||
elemBytes_b = 2 if is_fp16 else 1
|
||||
|
||||
mfma = 16 if M < 32 or N < 32 else 32
|
||||
|
||||
# TODO (zhanglx): figure out the boundary between large and small gemms
|
||||
large_gemm = False
|
||||
if M >= 2048 and N >= 2048:
|
||||
large_gemm = True
|
||||
|
||||
for config in configs:
|
||||
BLOCK_SIZE_M = config.get("BLOCK_SIZE_M")
|
||||
BLOCK_SIZE_N = config.get("BLOCK_SIZE_N")
|
||||
BLOCK_SIZE_K = config.get("BLOCK_SIZE_K")
|
||||
num_warps = config.get("num_warps")
|
||||
|
||||
if is_fp16:
|
||||
matrix_instr_nonkdim = config.get("matrix_instr_nonkdim")
|
||||
if matrix_instr_nonkdim > mfma:
|
||||
continue
|
||||
if mfma == 4 and BLOCK_SIZE_K < 64:
|
||||
continue
|
||||
# some layouts could not work properly in case
|
||||
# number elements per thread is less 1
|
||||
if BLOCK_SIZE_M * BLOCK_SIZE_N < 64:
|
||||
continue
|
||||
SPLIT_K = config.get("SPLIT_K", 1)
|
||||
GROUP_M = config.get("GROUP_SIZE_M")
|
||||
if is_fp16:
|
||||
if (matrix_instr_nonkdim > BLOCK_SIZE_M
|
||||
or matrix_instr_nonkdim > BLOCK_SIZE_N):
|
||||
continue
|
||||
if (matrix_instr_nonkdim >= M
|
||||
and matrix_instr_nonkdim != BLOCK_SIZE_M):
|
||||
continue
|
||||
if (matrix_instr_nonkdim >= N
|
||||
and matrix_instr_nonkdim != BLOCK_SIZE_N):
|
||||
continue
|
||||
# Skip BLOCK_SIZE that is too large compare to M/N
|
||||
# unless BLOCK_SIZE is already small enough
|
||||
if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16:
|
||||
continue
|
||||
if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16:
|
||||
continue
|
||||
# skip large split_k when not necessary
|
||||
if SPLIT_K != 1 and not need_split_k(M, N, K):
|
||||
continue
|
||||
# skip split_k that leads to EVEN_K = false
|
||||
leap = SPLIT_K * BLOCK_SIZE_K
|
||||
modv = K % leap
|
||||
if modv != 0:
|
||||
continue
|
||||
# skip large GROUP_M
|
||||
if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1:
|
||||
continue
|
||||
# out of shared memory resource
|
||||
# TODO (zhanglx): This does not consider the LDS usage in the epilogue
|
||||
LDS = (BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a +
|
||||
BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b)
|
||||
if LDS > 65536:
|
||||
continue
|
||||
# Skip small block sizes and num_warps for large gemm
|
||||
# For fp16 and f8, we want to only use BLOCK_SIZE >= 64
|
||||
if large_gemm:
|
||||
if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64:
|
||||
continue
|
||||
if BLOCK_SIZE_K < 64:
|
||||
continue
|
||||
if num_warps < 4:
|
||||
continue
|
||||
|
||||
pruned_configs.append(config)
|
||||
|
||||
return pruned_configs
|
||||
|
||||
|
||||
def need_split_k(SIZE_M, SIZE_N, SIZE_K):
|
||||
return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024
|
||||
|
||||
|
||||
def merge_unique_dicts(list1, list2):
|
||||
result = []
|
||||
combined_list = list1.copy()
|
||||
combined_list.extend(list2)
|
||||
for dictionary in combined_list:
|
||||
if dictionary not in result:
|
||||
result.append(dictionary)
|
||||
return result
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
class BenchmarkWorker:
|
||||
|
||||
@ -170,6 +318,10 @@ class BenchmarkWorker:
|
||||
torch.set_default_device("cuda")
|
||||
current_platform.seed_everything(seed)
|
||||
self.seed = seed
|
||||
# Get the device ID to allocate tensors and kernels
|
||||
# on the respective GPU. This is required for Ray to work
|
||||
# correctly with multi-GPU tuning on the ROCm platform.
|
||||
self.device_id = int(ray.get_gpu_ids()[0])
|
||||
|
||||
def benchmark(
|
||||
self,
|
||||
@ -217,25 +369,33 @@ class BenchmarkWorker:
|
||||
) -> Dict[str, int]:
|
||||
best_config = None
|
||||
best_time = float("inf")
|
||||
for config in tqdm(search_space):
|
||||
try:
|
||||
kernel_time = benchmark_config(config,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
num_iters=10)
|
||||
except triton.runtime.autotuner.OutOfResources:
|
||||
# Some configurations may be invalid and fail to compile.
|
||||
continue
|
||||
if current_platform.is_rocm():
|
||||
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
|
||||
search_space = prune_rocm_search_space(num_tokens,
|
||||
shard_intermediate_size,
|
||||
hidden_size, search_space,
|
||||
is_fp16)
|
||||
|
||||
if kernel_time < best_time:
|
||||
best_time = kernel_time
|
||||
best_config = config
|
||||
with torch.cuda.device(self.device_id):
|
||||
for config in tqdm(search_space):
|
||||
try:
|
||||
kernel_time = benchmark_config(config,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
num_iters=20)
|
||||
except triton.runtime.autotuner.OutOfResources:
|
||||
# Some configurations may be invalid and fail to compile.
|
||||
continue
|
||||
|
||||
if kernel_time < best_time:
|
||||
best_time = kernel_time
|
||||
best_config = config
|
||||
now = datetime.now()
|
||||
print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
|
||||
assert best_config is not None
|
||||
@ -244,12 +404,27 @@ class BenchmarkWorker:
|
||||
|
||||
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
|
||||
return {
|
||||
"BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
|
||||
"BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
|
||||
"BLOCK_SIZE_K": config["BLOCK_SIZE_K"],
|
||||
"GROUP_SIZE_M": config["GROUP_SIZE_M"],
|
||||
"num_warps": config["num_warps"],
|
||||
"num_stages": config["num_stages"],
|
||||
"BLOCK_SIZE_M":
|
||||
config["BLOCK_SIZE_M"],
|
||||
"BLOCK_SIZE_N":
|
||||
config["BLOCK_SIZE_N"],
|
||||
"BLOCK_SIZE_K":
|
||||
config["BLOCK_SIZE_K"],
|
||||
"GROUP_SIZE_M":
|
||||
config["GROUP_SIZE_M"],
|
||||
"num_warps":
|
||||
config["num_warps"],
|
||||
"num_stages":
|
||||
config["num_stages"],
|
||||
**({
|
||||
"waves_per_eu": config["waves_per_eu"]
|
||||
} if "waves_per_eu" in config else {}),
|
||||
**({
|
||||
"matrix_instr_nonkdim": config["matrix_instr_nonkdim"]
|
||||
} if "matrix_instr_nonkdim" in config else {}),
|
||||
**({
|
||||
"kpack": config["kpack"]
|
||||
} if "kpack" in config else {}),
|
||||
}
|
||||
|
||||
|
||||
@ -294,7 +469,7 @@ def main(args: argparse.Namespace):
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
|
||||
hidden_size = config.hidden_size
|
||||
dtype = config.torch_dtype
|
||||
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
|
||||
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
||||
use_int8_w8a16 = args.dtype == "int8_w8a16"
|
||||
|
||||
@ -322,7 +497,8 @@ def main(args: argparse.Namespace):
|
||||
return ray.get(outputs)
|
||||
|
||||
if args.tune:
|
||||
search_space = get_configs_compute_bound()
|
||||
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
|
||||
search_space = get_configs_compute_bound(is_fp16)
|
||||
print(f"Start tuning over {len(search_space)} configurations...")
|
||||
|
||||
start = time.time()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user