From b7f1f490a61c99d0b371e39aefbe5546cba231a9 Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev <164964928+maleksan85@users.noreply.github.com> Date: Fri, 21 Nov 2025 08:34:46 -0800 Subject: [PATCH] Upstream triton fp4 weight preshuffle (#28888) Signed-off-by: Aleksandr Malyshev Co-authored-by: Aleksandr Malyshev --- vllm/_aiter_ops.py | 25 +++++++ .../quark/schemes/quark_ocp_mx.py | 65 +++++++++++++++---- 2 files changed, 76 insertions(+), 14 deletions(-) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index e53e4ae6e5296..db79b3f5e8bcb 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -948,6 +948,31 @@ class rocm_aiter_ops: (8192, 32768), ] + @staticmethod + def is_triton_gemm_afp4wfp4_presh_ws_tuned(n: int, k: int) -> bool: + return (n, k) in [ + (8192, 4096), + (1280, 8192), + (16384, 53248), + (106496, 16384), + (57344, 8192), + (8192, 2048), + (2560, 8192), + (10240, 8192), + (16384, 16384), + (8192, 28672), + (28672, 8192), + (18432, 16384), + (8192, 1024), + (7168, 8192), + (5120, 8192), + (8192, 8192), + (8192, 7168), + (14336, 8192), + (8192, 14336), + (8192, 3584), + ] + @staticmethod def shuffle_weight( self, tensor: torch.Tensor, layout: tuple[int, int] = (16, 16) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py index 007e78e68d5cd..33e9f9806b27e 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py @@ -10,6 +10,7 @@ import torch import torch.nn.functional as F from vllm import envs +from vllm._aiter_ops import rocm_aiter_ops from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( dequant_mxfp4, @@ -49,7 +50,10 @@ def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool: try: from aiter.ops.shuffle import shuffle_weight - from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4 + from aiter.ops.triton.gemm_afp4wfp4 import ( + gemm_afp4wfp4, + gemm_afp4wfp4_preshuffled_weight_scales, + ) from aiter.ops.triton.quant import dynamic_mxfp4_quant from vllm.utils.torch_utils import direct_register_custom_op @@ -66,23 +70,56 @@ try: x_scales: torch.Tensor | None = None, ) -> torch.Tensor: M = x.shape[0] + N = weight.shape[0] + K = weight.shape[1] if rocm_use_aiter_fp4_asm_gemm: - if x_scales is None: - # use hip quant kernel for performance - x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=True) + if M <= 64 and rocm_aiter_ops.is_triton_gemm_afp4wfp4_presh_ws_tuned(N, K): + if x_scales is None: + # use hip quant kernel for performance + if M >= 32: + x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=True) + else: + x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=False) + else: + x_q = x + x_s = x_scales + + if M >= 32: + x_s = x_s.view(torch.uint8).view(x_s.shape[0] // 32, -1) + else: + x_s = x_s[:M, ...].view(torch.uint8) + + y = torch.empty(M, N, device=x_q.device, dtype=out_dtype) + gemm_afp4wfp4_preshuffled_weight_scales( + x_q.view(torch.uint8), + weight.view(torch.uint8).view(weight.shape[0] // 16, -1), + x_s, + weight_scale.view(torch.uint8).view( + weight_scale.shape[0] // 32, -1 + ), + out_dtype, + y, + ) else: - x_q = x - x_s = x_scales + if x_scales is None: + # use hip quant kernel for performance + x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=True) + else: + x_q = x + x_s = x_scales - # 32 alignment is enough for dim0 padding of output for - # gemm_a4w4 kernel - y = torch.empty( - (M + 31) // 32 * 32, weight.shape[0], device=x_q.device, dtype=out_dtype - ) + # 32 alignment is enough for dim0 padding of output for + # gemm_a4w4 kernel + y = torch.empty( + (M + 31) // 32 * 32, + weight.shape[0], + device=x_q.device, + dtype=out_dtype, + ) - gemm_a4w4( - x_q, weight, x_s, weight_scale.view(x_s.dtype), y, bpreshuffle=True - ) + gemm_a4w4( + x_q, weight, x_s, weight_scale.view(x_s.dtype), y, bpreshuffle=True + ) return y[:M] else: if x_scales is None: