mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-31 05:43:06 +08:00
Upstream triton fp4 weight preshuffle (#28888)
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>
This commit is contained in:
parent
30b44a1598
commit
b7f1f490a6
@ -948,6 +948,31 @@ class rocm_aiter_ops:
|
||||
(8192, 32768),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def is_triton_gemm_afp4wfp4_presh_ws_tuned(n: int, k: int) -> bool:
|
||||
return (n, k) in [
|
||||
(8192, 4096),
|
||||
(1280, 8192),
|
||||
(16384, 53248),
|
||||
(106496, 16384),
|
||||
(57344, 8192),
|
||||
(8192, 2048),
|
||||
(2560, 8192),
|
||||
(10240, 8192),
|
||||
(16384, 16384),
|
||||
(8192, 28672),
|
||||
(28672, 8192),
|
||||
(18432, 16384),
|
||||
(8192, 1024),
|
||||
(7168, 8192),
|
||||
(5120, 8192),
|
||||
(8192, 8192),
|
||||
(8192, 7168),
|
||||
(14336, 8192),
|
||||
(8192, 14336),
|
||||
(8192, 3584),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def shuffle_weight(
|
||||
self, tensor: torch.Tensor, layout: tuple[int, int] = (16, 16)
|
||||
|
||||
@ -10,6 +10,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm import envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||
dequant_mxfp4,
|
||||
@ -49,7 +50,10 @@ def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool:
|
||||
|
||||
try:
|
||||
from aiter.ops.shuffle import shuffle_weight
|
||||
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
|
||||
from aiter.ops.triton.gemm_afp4wfp4 import (
|
||||
gemm_afp4wfp4,
|
||||
gemm_afp4wfp4_preshuffled_weight_scales,
|
||||
)
|
||||
from aiter.ops.triton.quant import dynamic_mxfp4_quant
|
||||
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
@ -66,23 +70,56 @@ try:
|
||||
x_scales: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
M = x.shape[0]
|
||||
N = weight.shape[0]
|
||||
K = weight.shape[1]
|
||||
if rocm_use_aiter_fp4_asm_gemm:
|
||||
if x_scales is None:
|
||||
# use hip quant kernel for performance
|
||||
x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=True)
|
||||
if M <= 64 and rocm_aiter_ops.is_triton_gemm_afp4wfp4_presh_ws_tuned(N, K):
|
||||
if x_scales is None:
|
||||
# use hip quant kernel for performance
|
||||
if M >= 32:
|
||||
x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=True)
|
||||
else:
|
||||
x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=False)
|
||||
else:
|
||||
x_q = x
|
||||
x_s = x_scales
|
||||
|
||||
if M >= 32:
|
||||
x_s = x_s.view(torch.uint8).view(x_s.shape[0] // 32, -1)
|
||||
else:
|
||||
x_s = x_s[:M, ...].view(torch.uint8)
|
||||
|
||||
y = torch.empty(M, N, device=x_q.device, dtype=out_dtype)
|
||||
gemm_afp4wfp4_preshuffled_weight_scales(
|
||||
x_q.view(torch.uint8),
|
||||
weight.view(torch.uint8).view(weight.shape[0] // 16, -1),
|
||||
x_s,
|
||||
weight_scale.view(torch.uint8).view(
|
||||
weight_scale.shape[0] // 32, -1
|
||||
),
|
||||
out_dtype,
|
||||
y,
|
||||
)
|
||||
else:
|
||||
x_q = x
|
||||
x_s = x_scales
|
||||
if x_scales is None:
|
||||
# use hip quant kernel for performance
|
||||
x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=True)
|
||||
else:
|
||||
x_q = x
|
||||
x_s = x_scales
|
||||
|
||||
# 32 alignment is enough for dim0 padding of output for
|
||||
# gemm_a4w4 kernel
|
||||
y = torch.empty(
|
||||
(M + 31) // 32 * 32, weight.shape[0], device=x_q.device, dtype=out_dtype
|
||||
)
|
||||
# 32 alignment is enough for dim0 padding of output for
|
||||
# gemm_a4w4 kernel
|
||||
y = torch.empty(
|
||||
(M + 31) // 32 * 32,
|
||||
weight.shape[0],
|
||||
device=x_q.device,
|
||||
dtype=out_dtype,
|
||||
)
|
||||
|
||||
gemm_a4w4(
|
||||
x_q, weight, x_s, weight_scale.view(x_s.dtype), y, bpreshuffle=True
|
||||
)
|
||||
gemm_a4w4(
|
||||
x_q, weight, x_s, weight_scale.view(x_s.dtype), y, bpreshuffle=True
|
||||
)
|
||||
return y[:M]
|
||||
else:
|
||||
if x_scales is None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user