[core] moe fp8 block quant tuning support (#14068)

Signed-off-by: Divakar Verma <divakar.verma@amd.com>
This commit is contained in:
Divakar Verma 2025-03-03 19:30:23 -06:00 committed by GitHub
parent c060b71408
commit bb5b640359
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 132 additions and 60 deletions

View File

@ -40,6 +40,7 @@ def benchmark_config(
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
num_iters: int = 100,
block_quant_shape: List[int] = None,
) -> float:
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
@ -81,8 +82,24 @@ def benchmark_config(
dtype=torch.float32)
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
if use_fp8_w8a8:
w1_scale = torch.randn(num_experts, dtype=torch.float32)
w2_scale = torch.randn(num_experts, dtype=torch.float32)
if block_quant_shape:
block_n, block_k = block_quant_shape[0], block_quant_shape[1]
E = num_experts
N = shard_intermediate_size // 2
K = hidden_size
factor_for_scale = 1e-2
n_tiles_w1 = (2 * N + block_n - 1) // block_n
n_tiles_w2 = (K + block_n - 1) // block_n
k_tiles_w1 = (K + block_k - 1) // block_k
k_tiles_w2 = (N + block_k - 1) // block_k
w1_scale = torch.rand((E, n_tiles_w1, k_tiles_w1),
dtype=torch.float32) * factor_for_scale
w2_scale = torch.rand((E, n_tiles_w2, k_tiles_w2),
dtype=torch.float32) * factor_for_scale
else:
w1_scale = torch.randn(num_experts, dtype=torch.float32)
w2_scale = torch.randn(num_experts, dtype=torch.float32)
a1_scale = torch.randn(1, dtype=torch.float32)
a2_scale = torch.randn(1, dtype=torch.float32)
@ -111,6 +128,7 @@ def benchmark_config(
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_quant_shape,
)
# JIT compilation & warmup
@ -175,7 +193,8 @@ def get_rocm_tuning_space(use_fp16):
return param_ranges
def get_configs_compute_bound(use_fp16) -> list[dict[str, int]]:
def get_configs_compute_bound(use_fp16,
block_quant_shape) -> list[dict[str, int]]:
configs: list[BenchmarkConfig] = []
if current_platform.is_rocm():
@ -204,17 +223,27 @@ def get_configs_compute_bound(use_fp16) -> list[dict[str, int]]:
for config_values in product(*values):
config = dict(zip(keys, config_values))
configs.append(config)
# Remove configs that are not compatible with fp8 block quantization
# BLOCK_SIZE_K must be a multiple of block_k
# BLOCK_SIZE_N must be a multiple of block_n
if block_quant_shape is not None and not use_fp16:
block_n, block_k = block_quant_shape[0], block_quant_shape[1]
for config in configs[:]:
if config["BLOCK_SIZE_K"] % block_k != 0 or config[
"BLOCK_SIZE_N"] % block_n != 0:
configs.remove(config)
return configs
def prune_rocm_search_space(num_tokens, shard_intermediate_size, hidden_size,
search_space, is_fp16):
search_space, is_fp16, topk):
N1, K1 = shard_intermediate_size, hidden_size
N2, K2 = hidden_size, shard_intermediate_size // 2
pruned_space_1 = prune_rocm_configs(num_tokens * 2, N1, K1, search_space,
is_fp16)
pruned_space_2 = prune_rocm_configs(num_tokens * 2, N2, K2, search_space,
is_fp16)
pruned_space_1 = prune_rocm_configs(num_tokens * topk, N1, K1,
search_space, is_fp16)
pruned_space_2 = prune_rocm_configs(num_tokens * topk, N2, K2,
search_space, is_fp16)
search_space = merge_unique_dicts(pruned_space_1, pruned_space_2)
return search_space
@ -372,6 +401,7 @@ class BenchmarkWorker:
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
search_space: list[dict[str, int]],
block_quant_shape: list[int],
) -> dict[str, int]:
best_config = None
best_time = float("inf")
@ -380,21 +410,23 @@ class BenchmarkWorker:
search_space = prune_rocm_search_space(num_tokens,
shard_intermediate_size,
hidden_size, search_space,
is_fp16)
is_fp16, topk)
with torch.cuda.device(self.device_id):
for config in tqdm(search_space):
try:
kernel_time = benchmark_config(config,
num_tokens,
num_experts,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a16,
num_iters=20)
kernel_time = benchmark_config(
config,
num_tokens,
num_experts,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a16,
num_iters=20,
block_quant_shape=block_quant_shape)
except triton.runtime.autotuner.OutOfResources:
# Some configurations may be invalid and fail to compile.
continue
@ -436,8 +468,8 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int,
shard_intermediate_size: int, hidden_size: int, topk: int,
dtype: torch.dtype, use_fp8_w8a8: bool,
use_int8_w8a16: bool) -> None:
dtype: torch.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool,
block_quant_shape: List[int]) -> None:
dtype_str = get_config_dtype_str(dtype,
use_int8_w8a16=use_int8_w8a16,
use_fp8_w8a8=use_fp8_w8a8)
@ -445,7 +477,7 @@ def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int,
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
filename = get_config_file_name(num_experts, shard_intermediate_size // 2,
dtype_str)
dtype_str, block_quant_shape)
print(f"Writing best config to {filename}...")
with open(filename, "w") as f:
@ -455,7 +487,7 @@ def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int,
def main(args: argparse.Namespace):
print(args)
block_quant_shape = None
config = AutoConfig.from_pretrained(
args.model, trust_remote_code=args.trust_remote_code)
if config.architectures[0] == "DbrxForCausalLM":
@ -474,6 +506,7 @@ def main(args: argparse.Namespace):
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
block_quant_shape = config.quantization_config['weight_block_size']
else:
# Default: Mixtral.
E = config.num_local_experts
@ -511,27 +544,30 @@ def main(args: argparse.Namespace):
if args.tune:
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
search_space = get_configs_compute_bound(is_fp16)
search_space = get_configs_compute_bound(is_fp16, block_quant_shape)
print(f"Start tuning over {len(search_space)} configurations...")
start = time.time()
configs = _distribute(
"tune", [(batch_size, E, shard_intermediate_size, hidden_size,
topk, dtype, use_fp8_w8a8, use_int8_w8a16, search_space)
for batch_size in batch_sizes])
"tune",
[(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype,
use_fp8_w8a8, use_int8_w8a16, search_space, block_quant_shape)
for batch_size in batch_sizes])
best_configs = {
M: sort_config(config)
for M, config in zip(batch_sizes, configs)
}
save_configs(best_configs, E, shard_intermediate_size, hidden_size,
topk, dtype, use_fp8_w8a8, use_int8_w8a16)
topk, dtype, use_fp8_w8a8, use_int8_w8a16,
block_quant_shape)
end = time.time()
print(f"Tuning took {end - start:.2f} seconds")
else:
outputs = _distribute(
"benchmark", [(batch_size, E, shard_intermediate_size, hidden_size,
topk, dtype, use_fp8_w8a8, use_int8_w8a16)
for batch_size in batch_sizes])
"benchmark",
[(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype,
use_fp8_w8a8, use_int8_w8a16, block_quant_shape)
for batch_size in batch_sizes])
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
print(f"Batch size: {batch_size}, config: {config}")

View File

@ -1,28 +1,28 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
@ -31,15 +31,15 @@
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
@ -49,13 +49,13 @@
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 2,
@ -64,7 +64,7 @@
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 2,
@ -73,7 +73,7 @@
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 2,
@ -82,46 +82,82 @@
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"512": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
}