mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 21:05:01 +08:00
[core] moe fp8 block quant tuning support (#14068)
Signed-off-by: Divakar Verma <divakar.verma@amd.com>
This commit is contained in:
parent
c060b71408
commit
bb5b640359
@ -40,6 +40,7 @@ def benchmark_config(
|
|||||||
use_fp8_w8a8: bool,
|
use_fp8_w8a8: bool,
|
||||||
use_int8_w8a16: bool,
|
use_int8_w8a16: bool,
|
||||||
num_iters: int = 100,
|
num_iters: int = 100,
|
||||||
|
block_quant_shape: List[int] = None,
|
||||||
) -> float:
|
) -> float:
|
||||||
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
||||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||||
@ -81,8 +82,24 @@ def benchmark_config(
|
|||||||
dtype=torch.float32)
|
dtype=torch.float32)
|
||||||
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
|
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
|
||||||
if use_fp8_w8a8:
|
if use_fp8_w8a8:
|
||||||
|
if block_quant_shape:
|
||||||
|
block_n, block_k = block_quant_shape[0], block_quant_shape[1]
|
||||||
|
E = num_experts
|
||||||
|
N = shard_intermediate_size // 2
|
||||||
|
K = hidden_size
|
||||||
|
factor_for_scale = 1e-2
|
||||||
|
n_tiles_w1 = (2 * N + block_n - 1) // block_n
|
||||||
|
n_tiles_w2 = (K + block_n - 1) // block_n
|
||||||
|
k_tiles_w1 = (K + block_k - 1) // block_k
|
||||||
|
k_tiles_w2 = (N + block_k - 1) // block_k
|
||||||
|
w1_scale = torch.rand((E, n_tiles_w1, k_tiles_w1),
|
||||||
|
dtype=torch.float32) * factor_for_scale
|
||||||
|
w2_scale = torch.rand((E, n_tiles_w2, k_tiles_w2),
|
||||||
|
dtype=torch.float32) * factor_for_scale
|
||||||
|
else:
|
||||||
w1_scale = torch.randn(num_experts, dtype=torch.float32)
|
w1_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||||
w2_scale = torch.randn(num_experts, dtype=torch.float32)
|
w2_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||||
|
|
||||||
a1_scale = torch.randn(1, dtype=torch.float32)
|
a1_scale = torch.randn(1, dtype=torch.float32)
|
||||||
a2_scale = torch.randn(1, dtype=torch.float32)
|
a2_scale = torch.randn(1, dtype=torch.float32)
|
||||||
|
|
||||||
@ -111,6 +128,7 @@ def benchmark_config(
|
|||||||
w2_scale=w2_scale,
|
w2_scale=w2_scale,
|
||||||
a1_scale=a1_scale,
|
a1_scale=a1_scale,
|
||||||
a2_scale=a2_scale,
|
a2_scale=a2_scale,
|
||||||
|
block_shape=block_quant_shape,
|
||||||
)
|
)
|
||||||
|
|
||||||
# JIT compilation & warmup
|
# JIT compilation & warmup
|
||||||
@ -175,7 +193,8 @@ def get_rocm_tuning_space(use_fp16):
|
|||||||
return param_ranges
|
return param_ranges
|
||||||
|
|
||||||
|
|
||||||
def get_configs_compute_bound(use_fp16) -> list[dict[str, int]]:
|
def get_configs_compute_bound(use_fp16,
|
||||||
|
block_quant_shape) -> list[dict[str, int]]:
|
||||||
configs: list[BenchmarkConfig] = []
|
configs: list[BenchmarkConfig] = []
|
||||||
|
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_rocm():
|
||||||
@ -204,17 +223,27 @@ def get_configs_compute_bound(use_fp16) -> list[dict[str, int]]:
|
|||||||
for config_values in product(*values):
|
for config_values in product(*values):
|
||||||
config = dict(zip(keys, config_values))
|
config = dict(zip(keys, config_values))
|
||||||
configs.append(config)
|
configs.append(config)
|
||||||
|
|
||||||
|
# Remove configs that are not compatible with fp8 block quantization
|
||||||
|
# BLOCK_SIZE_K must be a multiple of block_k
|
||||||
|
# BLOCK_SIZE_N must be a multiple of block_n
|
||||||
|
if block_quant_shape is not None and not use_fp16:
|
||||||
|
block_n, block_k = block_quant_shape[0], block_quant_shape[1]
|
||||||
|
for config in configs[:]:
|
||||||
|
if config["BLOCK_SIZE_K"] % block_k != 0 or config[
|
||||||
|
"BLOCK_SIZE_N"] % block_n != 0:
|
||||||
|
configs.remove(config)
|
||||||
return configs
|
return configs
|
||||||
|
|
||||||
|
|
||||||
def prune_rocm_search_space(num_tokens, shard_intermediate_size, hidden_size,
|
def prune_rocm_search_space(num_tokens, shard_intermediate_size, hidden_size,
|
||||||
search_space, is_fp16):
|
search_space, is_fp16, topk):
|
||||||
N1, K1 = shard_intermediate_size, hidden_size
|
N1, K1 = shard_intermediate_size, hidden_size
|
||||||
N2, K2 = hidden_size, shard_intermediate_size // 2
|
N2, K2 = hidden_size, shard_intermediate_size // 2
|
||||||
pruned_space_1 = prune_rocm_configs(num_tokens * 2, N1, K1, search_space,
|
pruned_space_1 = prune_rocm_configs(num_tokens * topk, N1, K1,
|
||||||
is_fp16)
|
search_space, is_fp16)
|
||||||
pruned_space_2 = prune_rocm_configs(num_tokens * 2, N2, K2, search_space,
|
pruned_space_2 = prune_rocm_configs(num_tokens * topk, N2, K2,
|
||||||
is_fp16)
|
search_space, is_fp16)
|
||||||
search_space = merge_unique_dicts(pruned_space_1, pruned_space_2)
|
search_space = merge_unique_dicts(pruned_space_1, pruned_space_2)
|
||||||
return search_space
|
return search_space
|
||||||
|
|
||||||
@ -372,6 +401,7 @@ class BenchmarkWorker:
|
|||||||
use_fp8_w8a8: bool,
|
use_fp8_w8a8: bool,
|
||||||
use_int8_w8a16: bool,
|
use_int8_w8a16: bool,
|
||||||
search_space: list[dict[str, int]],
|
search_space: list[dict[str, int]],
|
||||||
|
block_quant_shape: list[int],
|
||||||
) -> dict[str, int]:
|
) -> dict[str, int]:
|
||||||
best_config = None
|
best_config = None
|
||||||
best_time = float("inf")
|
best_time = float("inf")
|
||||||
@ -380,12 +410,13 @@ class BenchmarkWorker:
|
|||||||
search_space = prune_rocm_search_space(num_tokens,
|
search_space = prune_rocm_search_space(num_tokens,
|
||||||
shard_intermediate_size,
|
shard_intermediate_size,
|
||||||
hidden_size, search_space,
|
hidden_size, search_space,
|
||||||
is_fp16)
|
is_fp16, topk)
|
||||||
|
|
||||||
with torch.cuda.device(self.device_id):
|
with torch.cuda.device(self.device_id):
|
||||||
for config in tqdm(search_space):
|
for config in tqdm(search_space):
|
||||||
try:
|
try:
|
||||||
kernel_time = benchmark_config(config,
|
kernel_time = benchmark_config(
|
||||||
|
config,
|
||||||
num_tokens,
|
num_tokens,
|
||||||
num_experts,
|
num_experts,
|
||||||
shard_intermediate_size,
|
shard_intermediate_size,
|
||||||
@ -394,7 +425,8 @@ class BenchmarkWorker:
|
|||||||
dtype,
|
dtype,
|
||||||
use_fp8_w8a8,
|
use_fp8_w8a8,
|
||||||
use_int8_w8a16,
|
use_int8_w8a16,
|
||||||
num_iters=20)
|
num_iters=20,
|
||||||
|
block_quant_shape=block_quant_shape)
|
||||||
except triton.runtime.autotuner.OutOfResources:
|
except triton.runtime.autotuner.OutOfResources:
|
||||||
# Some configurations may be invalid and fail to compile.
|
# Some configurations may be invalid and fail to compile.
|
||||||
continue
|
continue
|
||||||
@ -436,8 +468,8 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
|
|||||||
|
|
||||||
def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int,
|
def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int,
|
||||||
shard_intermediate_size: int, hidden_size: int, topk: int,
|
shard_intermediate_size: int, hidden_size: int, topk: int,
|
||||||
dtype: torch.dtype, use_fp8_w8a8: bool,
|
dtype: torch.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool,
|
||||||
use_int8_w8a16: bool) -> None:
|
block_quant_shape: List[int]) -> None:
|
||||||
dtype_str = get_config_dtype_str(dtype,
|
dtype_str = get_config_dtype_str(dtype,
|
||||||
use_int8_w8a16=use_int8_w8a16,
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
use_fp8_w8a8=use_fp8_w8a8)
|
use_fp8_w8a8=use_fp8_w8a8)
|
||||||
@ -445,7 +477,7 @@ def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int,
|
|||||||
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
||||||
# is the intermediate size after silu_and_mul.
|
# is the intermediate size after silu_and_mul.
|
||||||
filename = get_config_file_name(num_experts, shard_intermediate_size // 2,
|
filename = get_config_file_name(num_experts, shard_intermediate_size // 2,
|
||||||
dtype_str)
|
dtype_str, block_quant_shape)
|
||||||
|
|
||||||
print(f"Writing best config to {filename}...")
|
print(f"Writing best config to {filename}...")
|
||||||
with open(filename, "w") as f:
|
with open(filename, "w") as f:
|
||||||
@ -455,7 +487,7 @@ def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int,
|
|||||||
|
|
||||||
def main(args: argparse.Namespace):
|
def main(args: argparse.Namespace):
|
||||||
print(args)
|
print(args)
|
||||||
|
block_quant_shape = None
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
args.model, trust_remote_code=args.trust_remote_code)
|
args.model, trust_remote_code=args.trust_remote_code)
|
||||||
if config.architectures[0] == "DbrxForCausalLM":
|
if config.architectures[0] == "DbrxForCausalLM":
|
||||||
@ -474,6 +506,7 @@ def main(args: argparse.Namespace):
|
|||||||
topk = config.num_experts_per_tok
|
topk = config.num_experts_per_tok
|
||||||
intermediate_size = config.moe_intermediate_size
|
intermediate_size = config.moe_intermediate_size
|
||||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||||
|
block_quant_shape = config.quantization_config['weight_block_size']
|
||||||
else:
|
else:
|
||||||
# Default: Mixtral.
|
# Default: Mixtral.
|
||||||
E = config.num_local_experts
|
E = config.num_local_experts
|
||||||
@ -511,26 +544,29 @@ def main(args: argparse.Namespace):
|
|||||||
|
|
||||||
if args.tune:
|
if args.tune:
|
||||||
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
|
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
|
||||||
search_space = get_configs_compute_bound(is_fp16)
|
search_space = get_configs_compute_bound(is_fp16, block_quant_shape)
|
||||||
print(f"Start tuning over {len(search_space)} configurations...")
|
print(f"Start tuning over {len(search_space)} configurations...")
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
configs = _distribute(
|
configs = _distribute(
|
||||||
"tune", [(batch_size, E, shard_intermediate_size, hidden_size,
|
"tune",
|
||||||
topk, dtype, use_fp8_w8a8, use_int8_w8a16, search_space)
|
[(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype,
|
||||||
|
use_fp8_w8a8, use_int8_w8a16, search_space, block_quant_shape)
|
||||||
for batch_size in batch_sizes])
|
for batch_size in batch_sizes])
|
||||||
best_configs = {
|
best_configs = {
|
||||||
M: sort_config(config)
|
M: sort_config(config)
|
||||||
for M, config in zip(batch_sizes, configs)
|
for M, config in zip(batch_sizes, configs)
|
||||||
}
|
}
|
||||||
save_configs(best_configs, E, shard_intermediate_size, hidden_size,
|
save_configs(best_configs, E, shard_intermediate_size, hidden_size,
|
||||||
topk, dtype, use_fp8_w8a8, use_int8_w8a16)
|
topk, dtype, use_fp8_w8a8, use_int8_w8a16,
|
||||||
|
block_quant_shape)
|
||||||
end = time.time()
|
end = time.time()
|
||||||
print(f"Tuning took {end - start:.2f} seconds")
|
print(f"Tuning took {end - start:.2f} seconds")
|
||||||
else:
|
else:
|
||||||
outputs = _distribute(
|
outputs = _distribute(
|
||||||
"benchmark", [(batch_size, E, shard_intermediate_size, hidden_size,
|
"benchmark",
|
||||||
topk, dtype, use_fp8_w8a8, use_int8_w8a16)
|
[(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype,
|
||||||
|
use_fp8_w8a8, use_int8_w8a16, block_quant_shape)
|
||||||
for batch_size in batch_sizes])
|
for batch_size in batch_sizes])
|
||||||
|
|
||||||
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
|
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
|
||||||
|
|||||||
@ -1,28 +1,28 @@
|
|||||||
{
|
{
|
||||||
"1": {
|
"1": {
|
||||||
"BLOCK_SIZE_M": 16,
|
"BLOCK_SIZE_M": 16,
|
||||||
"BLOCK_SIZE_N": 32,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 256,
|
"BLOCK_SIZE_K": 256,
|
||||||
"GROUP_SIZE_M": 1,
|
"GROUP_SIZE_M": 1,
|
||||||
"num_warps": 4,
|
"num_warps": 8,
|
||||||
"num_stages": 2,
|
"num_stages": 2,
|
||||||
"waves_per_eu": 0
|
"waves_per_eu": 0
|
||||||
},
|
},
|
||||||
"2": {
|
"2": {
|
||||||
"BLOCK_SIZE_M": 32,
|
"BLOCK_SIZE_M": 16,
|
||||||
"BLOCK_SIZE_N": 16,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 256,
|
"BLOCK_SIZE_K": 256,
|
||||||
"GROUP_SIZE_M": 1,
|
"GROUP_SIZE_M": 1,
|
||||||
"num_warps": 2,
|
"num_warps": 8,
|
||||||
"num_stages": 2,
|
"num_stages": 2,
|
||||||
"waves_per_eu": 0
|
"waves_per_eu": 0
|
||||||
},
|
},
|
||||||
"4": {
|
"4": {
|
||||||
"BLOCK_SIZE_M": 16,
|
"BLOCK_SIZE_M": 16,
|
||||||
"BLOCK_SIZE_N": 64,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 256,
|
"BLOCK_SIZE_K": 256,
|
||||||
"GROUP_SIZE_M": 1,
|
"GROUP_SIZE_M": 1,
|
||||||
"num_warps": 4,
|
"num_warps": 8,
|
||||||
"num_stages": 2,
|
"num_stages": 2,
|
||||||
"waves_per_eu": 0
|
"waves_per_eu": 0
|
||||||
},
|
},
|
||||||
@ -31,15 +31,15 @@
|
|||||||
"BLOCK_SIZE_N": 128,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 1,
|
"GROUP_SIZE_M": 1,
|
||||||
"num_warps": 4,
|
"num_warps": 8,
|
||||||
"num_stages": 2,
|
"num_stages": 2,
|
||||||
"waves_per_eu": 0
|
"waves_per_eu": 0
|
||||||
},
|
},
|
||||||
"16": {
|
"16": {
|
||||||
"BLOCK_SIZE_M": 16,
|
"BLOCK_SIZE_M": 16,
|
||||||
"BLOCK_SIZE_N": 64,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 4,
|
"GROUP_SIZE_M": 1,
|
||||||
"num_warps": 2,
|
"num_warps": 2,
|
||||||
"num_stages": 2,
|
"num_stages": 2,
|
||||||
"waves_per_eu": 0
|
"waves_per_eu": 0
|
||||||
@ -49,13 +49,13 @@
|
|||||||
"BLOCK_SIZE_N": 128,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 1,
|
"GROUP_SIZE_M": 1,
|
||||||
"num_warps": 4,
|
"num_warps": 2,
|
||||||
"num_stages": 2,
|
"num_stages": 2,
|
||||||
"waves_per_eu": 0
|
"waves_per_eu": 0
|
||||||
},
|
},
|
||||||
"32": {
|
"32": {
|
||||||
"BLOCK_SIZE_M": 16,
|
"BLOCK_SIZE_M": 16,
|
||||||
"BLOCK_SIZE_N": 64,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 4,
|
"GROUP_SIZE_M": 4,
|
||||||
"num_warps": 2,
|
"num_warps": 2,
|
||||||
@ -64,7 +64,7 @@
|
|||||||
},
|
},
|
||||||
"48": {
|
"48": {
|
||||||
"BLOCK_SIZE_M": 16,
|
"BLOCK_SIZE_M": 16,
|
||||||
"BLOCK_SIZE_N": 64,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 4,
|
"GROUP_SIZE_M": 4,
|
||||||
"num_warps": 2,
|
"num_warps": 2,
|
||||||
@ -73,7 +73,7 @@
|
|||||||
},
|
},
|
||||||
"64": {
|
"64": {
|
||||||
"BLOCK_SIZE_M": 16,
|
"BLOCK_SIZE_M": 16,
|
||||||
"BLOCK_SIZE_N": 64,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 1,
|
"GROUP_SIZE_M": 1,
|
||||||
"num_warps": 2,
|
"num_warps": 2,
|
||||||
@ -82,46 +82,82 @@
|
|||||||
},
|
},
|
||||||
"96": {
|
"96": {
|
||||||
"BLOCK_SIZE_M": 16,
|
"BLOCK_SIZE_M": 16,
|
||||||
"BLOCK_SIZE_N": 64,
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 8,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 2,
|
||||||
|
"waves_per_eu": 0
|
||||||
|
},
|
||||||
|
"128": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 4,
|
"GROUP_SIZE_M": 4,
|
||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 2,
|
"num_stages": 2,
|
||||||
"waves_per_eu": 0
|
"waves_per_eu": 0
|
||||||
},
|
},
|
||||||
"128": {
|
|
||||||
"BLOCK_SIZE_M": 16,
|
|
||||||
"BLOCK_SIZE_N": 64,
|
|
||||||
"BLOCK_SIZE_K": 256,
|
|
||||||
"GROUP_SIZE_M": 1,
|
|
||||||
"num_warps": 2,
|
|
||||||
"num_stages": 2,
|
|
||||||
"waves_per_eu": 0
|
|
||||||
},
|
|
||||||
"256": {
|
"256": {
|
||||||
"BLOCK_SIZE_M": 16,
|
"BLOCK_SIZE_M": 16,
|
||||||
"BLOCK_SIZE_N": 64,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 4,
|
"GROUP_SIZE_M": 8,
|
||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 2,
|
"num_stages": 2,
|
||||||
"waves_per_eu": 0
|
"waves_per_eu": 0
|
||||||
},
|
},
|
||||||
"512": {
|
"512": {
|
||||||
"BLOCK_SIZE_M": 32,
|
"BLOCK_SIZE_M": 32,
|
||||||
"BLOCK_SIZE_N": 256,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 8,
|
"GROUP_SIZE_M": 8,
|
||||||
"num_warps": 8,
|
"num_warps": 4,
|
||||||
"num_stages": 2,
|
"num_stages": 2,
|
||||||
"waves_per_eu": 0
|
"waves_per_eu": 0
|
||||||
},
|
},
|
||||||
"1024": {
|
"1024": {
|
||||||
"BLOCK_SIZE_M": 64,
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 8,
|
||||||
|
"num_warps": 2,
|
||||||
|
"num_stages": 2,
|
||||||
|
"waves_per_eu": 0
|
||||||
|
},
|
||||||
|
"1536": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 4,
|
||||||
|
"num_warps": 2,
|
||||||
|
"num_stages": 2,
|
||||||
|
"waves_per_eu": 0
|
||||||
|
},
|
||||||
|
"2048": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
"BLOCK_SIZE_N": 256,
|
"BLOCK_SIZE_N": 256,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 8,
|
"GROUP_SIZE_M": 8,
|
||||||
"num_warps": 8,
|
"num_warps": 4,
|
||||||
|
"num_stages": 2,
|
||||||
|
"waves_per_eu": 0
|
||||||
|
},
|
||||||
|
"3072": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 8,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 2,
|
||||||
|
"waves_per_eu": 0
|
||||||
|
},
|
||||||
|
"4096": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 4,
|
||||||
|
"num_warps": 4,
|
||||||
"num_stages": 2,
|
"num_stages": 2,
|
||||||
"waves_per_eu": 0
|
"waves_per_eu": 0
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user