Upstream triton fp4 weight preshuffle (#28888)

Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>
This commit is contained in:
Aleksandr Malyshev 2025-11-21 08:34:46 -08:00 committed by GitHub
parent 30b44a1598
commit b7f1f490a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 76 additions and 14 deletions

View File

@ -948,6 +948,31 @@ class rocm_aiter_ops:
(8192, 32768),
]
@staticmethod
def is_triton_gemm_afp4wfp4_presh_ws_tuned(n: int, k: int) -> bool:
return (n, k) in [
(8192, 4096),
(1280, 8192),
(16384, 53248),
(106496, 16384),
(57344, 8192),
(8192, 2048),
(2560, 8192),
(10240, 8192),
(16384, 16384),
(8192, 28672),
(28672, 8192),
(18432, 16384),
(8192, 1024),
(7168, 8192),
(5120, 8192),
(8192, 8192),
(8192, 7168),
(14336, 8192),
(8192, 14336),
(8192, 3584),
]
@staticmethod
def shuffle_weight(
self, tensor: torch.Tensor, layout: tuple[int, int] = (16, 16)

View File

@ -10,6 +10,7 @@ import torch
import torch.nn.functional as F
from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
dequant_mxfp4,
@ -49,7 +50,10 @@ def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool:
try:
from aiter.ops.shuffle import shuffle_weight
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
from aiter.ops.triton.gemm_afp4wfp4 import (
gemm_afp4wfp4,
gemm_afp4wfp4_preshuffled_weight_scales,
)
from aiter.ops.triton.quant import dynamic_mxfp4_quant
from vllm.utils.torch_utils import direct_register_custom_op
@ -66,23 +70,56 @@ try:
x_scales: torch.Tensor | None = None,
) -> torch.Tensor:
M = x.shape[0]
N = weight.shape[0]
K = weight.shape[1]
if rocm_use_aiter_fp4_asm_gemm:
if x_scales is None:
# use hip quant kernel for performance
x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=True)
if M <= 64 and rocm_aiter_ops.is_triton_gemm_afp4wfp4_presh_ws_tuned(N, K):
if x_scales is None:
# use hip quant kernel for performance
if M >= 32:
x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=True)
else:
x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=False)
else:
x_q = x
x_s = x_scales
if M >= 32:
x_s = x_s.view(torch.uint8).view(x_s.shape[0] // 32, -1)
else:
x_s = x_s[:M, ...].view(torch.uint8)
y = torch.empty(M, N, device=x_q.device, dtype=out_dtype)
gemm_afp4wfp4_preshuffled_weight_scales(
x_q.view(torch.uint8),
weight.view(torch.uint8).view(weight.shape[0] // 16, -1),
x_s,
weight_scale.view(torch.uint8).view(
weight_scale.shape[0] // 32, -1
),
out_dtype,
y,
)
else:
x_q = x
x_s = x_scales
if x_scales is None:
# use hip quant kernel for performance
x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=True)
else:
x_q = x
x_s = x_scales
# 32 alignment is enough for dim0 padding of output for
# gemm_a4w4 kernel
y = torch.empty(
(M + 31) // 32 * 32, weight.shape[0], device=x_q.device, dtype=out_dtype
)
# 32 alignment is enough for dim0 padding of output for
# gemm_a4w4 kernel
y = torch.empty(
(M + 31) // 32 * 32,
weight.shape[0],
device=x_q.device,
dtype=out_dtype,
)
gemm_a4w4(
x_q, weight, x_s, weight_scale.view(x_s.dtype), y, bpreshuffle=True
)
gemm_a4w4(
x_q, weight, x_s, weight_scale.view(x_s.dtype), y, bpreshuffle=True
)
return y[:M]
else:
if x_scales is None: