mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-07 06:43:05 +08:00
[Kernel][Triton][AMD] Use block size heuristic for avg 2.8x speedup for int8 models (#11698)
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
This commit is contained in:
parent
56fe4c297c
commit
526de822d5
@ -128,7 +128,8 @@ def triton_scaled_mm(input: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
block_size_m: int = 32,
|
||||
block_size_n: int = 32,
|
||||
block_size_k: int = 32) -> torch.Tensor:
|
||||
block_size_k: int = 32,
|
||||
use_heuristic=True) -> torch.Tensor:
|
||||
M, K = input.shape
|
||||
N = weight.shape[1]
|
||||
|
||||
@ -152,6 +153,20 @@ def triton_scaled_mm(input: torch.Tensor,
|
||||
|
||||
has_scalar = lambda x: x.shape[0] == 1 and x.shape[1] == 1
|
||||
|
||||
if use_heuristic:
|
||||
is_small_N = N < 8192
|
||||
next_power_of_2_M = max(32, triton.next_power_of_2(M))
|
||||
if next_power_of_2_M <= 32:
|
||||
tile_shape = (64, 64, 256) if is_small_N else (64, 128, 256)
|
||||
elif next_power_of_2_M <= 64:
|
||||
tile_shape = (64, 64, 256)
|
||||
elif next_power_of_2_M <= 128:
|
||||
tile_shape = (64, 128, 128)
|
||||
else:
|
||||
tile_shape = (128, 128, 128)
|
||||
|
||||
block_size_m, block_size_n, block_size_k = tile_shape
|
||||
|
||||
block_size_sa = 1 if has_scalar(scale_a) else block_size_m
|
||||
block_size_sb = 1 if has_scalar(scale_b) else block_size_n
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user