[Kernel][Triton][AMD] Use block size heuristic for avg 2.8x speedup for int8 models (#11698)

Signed-off-by: Randall Smith <Randall.Smith@amd.com>
This commit is contained in:
rasmith 2025-01-08 14:23:15 -06:00 committed by GitHub
parent 56fe4c297c
commit 526de822d5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -128,7 +128,8 @@ def triton_scaled_mm(input: torch.Tensor,
bias: Optional[torch.Tensor] = None,
block_size_m: int = 32,
block_size_n: int = 32,
block_size_k: int = 32) -> torch.Tensor:
block_size_k: int = 32,
use_heuristic=True) -> torch.Tensor:
M, K = input.shape
N = weight.shape[1]
@ -152,6 +153,20 @@ def triton_scaled_mm(input: torch.Tensor,
has_scalar = lambda x: x.shape[0] == 1 and x.shape[1] == 1
if use_heuristic:
is_small_N = N < 8192
next_power_of_2_M = max(32, triton.next_power_of_2(M))
if next_power_of_2_M <= 32:
tile_shape = (64, 64, 256) if is_small_N else (64, 128, 256)
elif next_power_of_2_M <= 64:
tile_shape = (64, 64, 256)
elif next_power_of_2_M <= 128:
tile_shape = (64, 128, 128)
else:
tile_shape = (128, 128, 128)
block_size_m, block_size_n, block_size_k = tile_shape
block_size_sa = 1 if has_scalar(scale_a) else block_size_m
block_size_sb = 1 if has_scalar(scale_b) else block_size_n