mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 18:35:58 +08:00
Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io> Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com> Signed-off-by: Huzaifa Sidhpurwala <huzaifas@redhat.com> Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Animesh Jain <anijain@umich.edu> Signed-off-by: Rui Qiao <ruisearch42@gmail.com> Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com> Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: kf <kuanfu.liu@embeddedllm.com> Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Dipika Sikka <dipikasikka1@gmail.com> Signed-off-by: Sage Moore <sage@neuralmagic.com> Signed-off-by: tjtanaavllm <tunjian.tan@amd.com> Signed-off-by: Yong Hoon Shin <yhshin@meta.com> Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Signed-off-by: Roger Wang <hey@rogerw.me> Signed-off-by: Vadim Gimpelson <vadim.gimpelson@centml.ai> Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: zRzRzRzRzRzRzR <2448370773@qq.com> Signed-off-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Signed-off-by: yan <yan.ma@intel.com> Signed-off-by: Yan Ma <yan.ma@intel.com> Signed-off-by: Xiao Liu <xiszishu@gmail.com> Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com> Signed-off-by: LopezCastroRoberto <roberto.lopez.castro@udc.es> Signed-off-by: Andy Xie <andy.xning@gmail.com> Signed-off-by: Haibin Lin <haibin.lin@bytedance.com> Signed-off-by: David Ben-David <davidb@pliops.com> Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: jiang1.li <jiang1.li@intel.com> Signed-off-by: Seiji Eicher <seiji@anyscale.com> Signed-off-by: zitian.zhao <zitian.zhao@tencentmusic.com> Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com> Signed-off-by: Abirdcfly <fp544037857@gmail.com> Signed-off-by: Giancarlo Delfin <gdelfin@meta.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Signed-off-by: huangweixiao <huangweixiao@msh.team> Signed-off-by: alyosha-swamy <raghav@arcee.ai> Signed-off-by: Eric Hanley <ericehanley@google.com> Signed-off-by: Abatom <abzhonghua@gmail.com> Signed-off-by: CLFutureX <775523362@qq.com> Signed-off-by: Linkun Chen <github@lkchen.net> Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: tlipoca9 <tlipoca9@gmail.com> Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> Signed-off-by: zitian zhao <zitian.zhao@tencentmusic.com> Signed-off-by: mgoin <michael@neuralmagic.com> Signed-off-by: wang.yuqi <noooop@126.com> Signed-off-by: Benji Beck <benjibeck@meta.com> Signed-off-by: Siyuan Liu <lsiyuan@google.com> Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai> Signed-off-by: isotr0py <2037008807@qq.com> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Signed-off-by: simon-mo <xmo@berkeley.edu> Signed-off-by: LucasWilkinson <lwilkinson@neuralmagic.com> Signed-off-by: Zhang Jason <ning.zhang2@amd.com> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Signed-off-by: asafg <asafg@ai21.com> Signed-off-by: Siyuan Fu <siyuanf@nvidia.com> Signed-off-by: Lain <fusiyuan2000@hotmail.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Kunshang Ji <kunshang.ji@intel.com> Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: QscQ <qscqesze@gmail.com> Signed-off-by: qingjun <qingjun@minimaxi.com> Signed-off-by: Syed Muhammad Bin Asif <syedmba7@connect.hku.hk> Signed-off-by: Lionel Villard <villard@us.ibm.com> Signed-off-by: ycyaw66 <497410282@qq.com> Signed-off-by: David Chen <530634352@qq.com> Signed-off-by: Linkun <github@lkchen.net> Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> Signed-off-by: Ming Yang <minos.future@gmail.com> Signed-off-by: Adrian Garcia <adrian.garcia@inceptionai.ai> Signed-off-by: shaojunqi <shaojunqi.sjq@alibaba-inc.com> Signed-off-by: Ricardo Decal <rdecal@anyscale.com> Signed-off-by: Andrew Chan <andrewkchan.akc@gmail.com> Signed-off-by: Felix Marty <Felix.Marty@amd.com> Signed-off-by: Andrew Sansom <andrew@protopia.ai> Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com> Signed-off-by: Shu Wang <shuw@nvidia.com> Signed-off-by: Po-Han Huang <pohanh@nvidia.com> Signed-off-by: Shu Wang. <shuw@nvidia.com> Signed-off-by: XIn Li <xinli@nvidia.com> Signed-off-by: Junhao Li <junhao@ubicloud.com> Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com> Signed-off-by: iAmir97 <Amir.balwel@embeddedllm.com> Signed-off-by: iAmir97 <71513472+iAmir97@users.noreply.github.com> Signed-off-by: <zyy1102000@gmail.com> Signed-off-by: Guy Stone <guys@spotify.com> Signed-off-by: <yyweiss@gmail.com> Signed-off-by: yyw <yyweiss@gmail.com> Signed-off-by: Russell Bryant <rbryant@redhat.com> Signed-off-by: Pradyun Ramadorai <pradyunr@amazon.com> Signed-off-by: Pradyun92 <142861237+Pradyun92@users.noreply.github.com> Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com> Co-authored-by: rongfu.leng <rongfu.leng@daocloud.io> Co-authored-by: Huzaifa Sidhpurwala <huzaifas@redhat.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Animesh Jain <jainanimesh2305@yahoo.com> Co-authored-by: Rui Qiao <161574667+ruisearch42@users.noreply.github.com> Co-authored-by: XiongfeiWei <isaacwxf23@gmail.com> Co-authored-by: Nick Hill <nhill@redhat.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: JartX <sagformas@gmail.com> Co-authored-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: kf <kuanfu.liu@embeddedllm.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com> Co-authored-by: Dipika Sikka <dipikasikka1@gmail.com> Co-authored-by: Sage Moore <sage@neuralmagic.com> Co-authored-by: tjtanaavllm <tunjian.tan@amd.com> Co-authored-by: Yong Hoon Shin <48474650+sarckk@users.noreply.github.com> Co-authored-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: Roger Wang <hey@rogerw.me> Co-authored-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com> Co-authored-by: Yuxuan Zhang <2448370773@qq.com> Co-authored-by: Isotr0py <2037008807@qq.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Yan Ma <yan.ma@intel.com> Co-authored-by: Xiao <xiszishu@gmail.com> Co-authored-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com> Co-authored-by: Roberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com> Co-authored-by: Ning Xie <andy.xning@gmail.com> Co-authored-by: H <linhaibin.eric@gmail.com> Co-authored-by: David Ben-David <sdavidbd@gmail.com> Co-authored-by: David Ben-David <davidb@pliops.com> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: Li, Jiang <jiang1.li@intel.com> Co-authored-by: TankNee <nee@tanknee.cn> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: Seiji Eicher <58963096+eicherseiji@users.noreply.github.com> Co-authored-by: ZiTian.Zhao <zitian.zhao@tencentmusic.com> Co-authored-by: 22quinn <33176974+22quinn@users.noreply.github.com> Co-authored-by: Abirdcfly <fp544037857@gmail.com> Co-authored-by: Giancarlo Delfin <32987265+TheEpicDolphin@users.noreply.github.com> Co-authored-by: Chenxi Yang <cxyang@cs.utexas.edu> Co-authored-by: Chenxi Yang <cxyang@meta.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Weixiao Huang <hwx.simle@gmail.com> Co-authored-by: Raghav Ravishankar <113712354+alyosha-swamy@users.noreply.github.com> Co-authored-by: ericehanley <ericehanley@google.com> Co-authored-by: Zhonghua Deng <abzhonghua@gmail.com> Co-authored-by: Po-Han Huang (NVIDIA) <53919306+nvpohanh@users.noreply.github.com> Co-authored-by: PiteXChen <44110731+CLFutureX@users.noreply.github.com> Co-authored-by: lkchen <github@lkchen.net> Co-authored-by: TJian <tunjian.tan@embeddedllm.com> Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Co-authored-by: tlipoca9 <160737620+tlipoca9@users.noreply.github.com> Co-authored-by: elvischenv <219235043+elvischenv@users.noreply.github.com> Co-authored-by: wang.yuqi <noooop@126.com> Co-authored-by: Benji Beck <benjibeck@meta.com> Co-authored-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Siyuan Liu <lsiyuan@google.com> Co-authored-by: Benjamin Chislett <chislett.ben@gmail.com> Co-authored-by: LiuXiaoxuanPKU <lilyliupku@gmail.com> Co-authored-by: simon-mo <xmo@berkeley.edu> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com> Co-authored-by: Minseok Lee <47620120+minseokl@users.noreply.github.com> Co-authored-by: Yongye Zhu <zyy1102000@gmail.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Zhang Jason <ning.zhang2@amd.com> Co-authored-by: Asaf Joseph Gardin <39553475+Josephasafg@users.noreply.github.com> Co-authored-by: asafg <asafg@ai21.com> Co-authored-by: Lain <siyuanf@nvidia.com> Co-authored-by: tc-mb <157115220+tc-mb@users.noreply.github.com> Co-authored-by: imning3 <hbning@pku.edu.cn> Co-authored-by: Maximilien de Bayser <mbayser@br.ibm.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com> Co-authored-by: Tao He <linzhu.ht@alibaba-inc.com> Co-authored-by: qscqesze <qingjun@minimaxi.com> Co-authored-by: Syed Muhammad Bin Asif <92625830+syedmba@users.noreply.github.com> Co-authored-by: Lionel Villard <villard@us.ibm.com> Co-authored-by: WeiQing Chen <40507679+david6666666@users.noreply.github.com> Co-authored-by: ycyaw66 <497410282@qq.com> Co-authored-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> Co-authored-by: Ming Yang <minos.future@gmail.com> Co-authored-by: Adrián García García <adrigarvk8@gmail.com> Co-authored-by: Michael Goin <mgoin@redhat.com> Co-authored-by: JaceyShao <65159281+JaceyShao@users.noreply.github.com> Co-authored-by: shaojunqi <shaojunqi.sjq@alibaba-inc.com> Co-authored-by: Ricardo Decal <crypdick@users.noreply.github.com> Co-authored-by: Andrew Chan <andrewkchan.akc@gmail.com> Co-authored-by: fxmarty-amd <felmarty@amd.com> Co-authored-by: Andrew Sansom <andrew@protopia.ai> Co-authored-by: Zhiyu <zhiyuc@nvidia.com> Co-authored-by: Shu Wang <shuw@nvidia.com> Co-authored-by: XIn Li <xinli@nvidia.com> Co-authored-by: Junhao Li <streaver91@gmail.com> Co-authored-by: Chauncey <chaunceyjiang@gmail.com> Co-authored-by: iAmir97 <71513472+iAmir97@users.noreply.github.com> Co-authored-by: iAmir97 <Amir.balwel@embeddedllm.com> Co-authored-by: Hong Hanh <hanh.usth@gmail.com> Co-authored-by: Daniel Serebrenik <74646983+pliops-daniels@users.noreply.github.com> Co-authored-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: Guy Stone <guys@spotify.com> Co-authored-by: yyweiss <70619747+yyweiss@users.noreply.github.com> Co-authored-by: Pradyun92 <142861237+Pradyun92@users.noreply.github.com> Co-authored-by: Pradyun Ramadorai <pradyunr@amazon.com> Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com>
624 lines
20 KiB
Python
624 lines
20 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""Tests for the marlin kernel.
|
|
|
|
Run `pytest tests/kernels/marlin/test_marlin_gemm.py`.
|
|
"""
|
|
import pytest
|
|
import torch
|
|
|
|
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
|
|
from tests.quantization.utils import is_quant_method_supported
|
|
from vllm import _custom_ops as ops
|
|
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
|
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
|
|
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
|
|
from vllm.model_executor.layers.quantization.qqq import (
|
|
MARLIN_QQQ_MAX_PARALLEL, MARLIN_QQQ_MIN_THREAD_N,
|
|
MARLIN_QQQ_SUPPORTED_GROUP_SIZES, MARLIN_QQQ_SUPPORTED_NUM_BITS)
|
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
|
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
|
|
MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx,
|
|
marlin_make_workspace_new, marlin_permute_bias, marlin_permute_scales,
|
|
query_marlin_supported_quant_types)
|
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
|
FP4_MARLIN_SUPPORTED_GROUP_SIZES, rand_marlin_weight_mxfp4_like,
|
|
rand_marlin_weight_nvfp4_like)
|
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
|
marlin_quant_fp8_torch)
|
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
|
MarlinWorkspace, awq_marlin_quantize, get_weight_perm, marlin_quantize,
|
|
marlin_weights)
|
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
|
|
marlin_24_quantize)
|
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_qqq import ( # noqa: E501
|
|
marlin_qqq_quantize)
|
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|
awq_pack, gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights)
|
|
from vllm.scalar_type import scalar_types
|
|
|
|
ACT_ORDER_OPTS = [False, True]
|
|
K_FULL_OPTS = [False, True]
|
|
USE_ATOMIC_ADD_OPTS = [False, True]
|
|
USE_FP32_REDUCE_OPTS = [True]
|
|
|
|
MARLIN_K_CHUNKS = [128]
|
|
MARLIN_N_CHUNKS = [64, 256]
|
|
|
|
MARLIN_24_K_CHUNKS = [128]
|
|
MARLIN_24_N_CHUNKS = [512]
|
|
|
|
HQQ_SUPPORTED_GROUP_SIZES = [64]
|
|
|
|
MNK_FACTORS = [
|
|
(1, 1, 1),
|
|
(1, 4, 8),
|
|
(1, 7, 5),
|
|
(13, 17, 67),
|
|
(26, 37, 13),
|
|
(67, 13, 11),
|
|
(257, 13, 11),
|
|
(658, 13, 11),
|
|
]
|
|
|
|
DTYPES = [torch.float16, torch.bfloat16]
|
|
|
|
|
|
def compute_max_diff(output, output_ref):
|
|
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
|
|
torch.abs(output_ref))
|
|
|
|
|
|
def rand_data(shape, dtype=torch.float16):
|
|
return torch.randn(shape, dtype=dtype, device="cuda")
|
|
|
|
|
|
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
|
reason="Marlin is not supported on this GPU type.")
|
|
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
|
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
|
@pytest.mark.parametrize("quant_type",
|
|
query_marlin_supported_quant_types(False, False))
|
|
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
|
|
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
|
|
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
|
def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
|
|
act_order, mnk_factors):
|
|
m_factor, n_factor, k_factor = mnk_factors
|
|
|
|
size_k = k_chunk * k_factor
|
|
size_n = n_chunk * n_factor
|
|
|
|
# Filter act_order
|
|
if act_order:
|
|
if group_size == -1:
|
|
return
|
|
if group_size == size_k:
|
|
return
|
|
|
|
# Normalize group_size
|
|
if group_size == -1:
|
|
group_size = size_k
|
|
assert group_size <= size_k
|
|
|
|
# Create input
|
|
b_weight = rand_data((size_k, size_n))
|
|
|
|
# Quantize (and apply act_order if provided)
|
|
w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
|
|
b_weight, quant_type, group_size, act_order)
|
|
|
|
# Pack to GPTQ format
|
|
q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
|
|
|
|
# For act_order, sort the "weights" and "g_idx" so that group ids are
|
|
# increasing
|
|
sort_indices = torch.empty(0, dtype=torch.int, device=b_weight.device)
|
|
if act_order:
|
|
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
|
|
|
|
# Pack to Marlin format
|
|
weight_perm = get_weight_perm(quant_type.size_bits)
|
|
marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits,
|
|
weight_perm)
|
|
|
|
opcheck(torch.ops._C.gptq_marlin_repack,
|
|
(q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits))
|
|
|
|
# Run Marlin repack GPU kernel
|
|
marlin_q_w_2 = ops.gptq_marlin_repack(
|
|
q_w_gptq,
|
|
sort_indices,
|
|
size_k,
|
|
size_n,
|
|
quant_type.size_bits,
|
|
)
|
|
torch.cuda.synchronize()
|
|
|
|
torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
|
|
|
|
|
|
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
|
reason="Marlin is not supported on this GPU type.")
|
|
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
|
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
|
@pytest.mark.parametrize("quant_type",
|
|
query_marlin_supported_quant_types(True))
|
|
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
|
|
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
|
def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
|
|
mnk_factors):
|
|
m_factor, n_factor, k_factor = mnk_factors
|
|
|
|
size_k = k_chunk * k_factor
|
|
size_n = n_chunk * n_factor
|
|
|
|
# Normalize group_size
|
|
if group_size == -1:
|
|
group_size = size_k
|
|
assert group_size <= size_k
|
|
|
|
# Create input
|
|
b_weight = rand_data((size_k, size_n))
|
|
|
|
# Quantize
|
|
w_ref, q_w, s, zp = quantize_weights(b_weight,
|
|
quant_type,
|
|
group_size,
|
|
zero_points=True)
|
|
|
|
# Pack to AWQ format
|
|
q_w_awq = awq_pack(q_w, quant_type.size_bits, size_k, size_n)
|
|
|
|
# Pack to Marlin format
|
|
weight_perm = get_weight_perm(quant_type.size_bits)
|
|
marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits,
|
|
weight_perm)
|
|
|
|
opcheck(torch.ops._C.awq_marlin_repack,
|
|
(q_w_awq, size_k, size_n, quant_type.size_bits))
|
|
|
|
# Run Marlin repack GPU kernel
|
|
marlin_q_w_2 = ops.awq_marlin_repack(
|
|
q_w_awq,
|
|
size_k,
|
|
size_n,
|
|
quant_type.size_bits,
|
|
)
|
|
torch.cuda.synchronize()
|
|
|
|
torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
|
|
|
|
|
|
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
|
reason="Marlin is not supported on this GPU type.")
|
|
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
|
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
|
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types())
|
|
@pytest.mark.parametrize(
|
|
"group_size",
|
|
set(MARLIN_SUPPORTED_GROUP_SIZES + FP4_MARLIN_SUPPORTED_GROUP_SIZES))
|
|
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
|
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
|
|
@pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
|
|
@pytest.mark.parametrize("use_atomic_add", USE_ATOMIC_ADD_OPTS)
|
|
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
|
|
@pytest.mark.parametrize("dtype", DTYPES)
|
|
def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size,
|
|
mnk_factors, act_order, is_k_full, use_atomic_add,
|
|
use_fp32_reduce, dtype):
|
|
m_factor, n_factor, k_factor = mnk_factors
|
|
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
|
|
|
|
size_m = m_factor
|
|
size_k = k_chunk * k_factor
|
|
size_n = n_chunk * n_factor
|
|
|
|
if act_order:
|
|
if group_size == -1:
|
|
return
|
|
if group_size == size_k:
|
|
return
|
|
if has_zp:
|
|
return
|
|
|
|
if size_k % group_size != 0:
|
|
return
|
|
|
|
a_input = rand_data((size_m, size_k), dtype)
|
|
b_weight = rand_data((size_k, size_n), dtype)
|
|
|
|
if quant_type == scalar_types.float4_e2m1f:
|
|
if group_size not in [16, 32] or act_order:
|
|
return
|
|
if group_size == 32 and dtype == torch.float16:
|
|
return
|
|
|
|
if group_size == 16:
|
|
w_ref, marlin_q_w, marlin_s, marlin_s2 = \
|
|
rand_marlin_weight_nvfp4_like(b_weight.T, group_size)
|
|
else:
|
|
w_ref, marlin_q_w, marlin_s = \
|
|
rand_marlin_weight_mxfp4_like(b_weight.T, group_size)
|
|
marlin_s2 = None
|
|
|
|
g_idx = None
|
|
sort_indices = None
|
|
marlin_zp = None
|
|
elif quant_type == scalar_types.float8_e4m3fn:
|
|
if group_size not in [-1, 128]:
|
|
return
|
|
if act_order:
|
|
return
|
|
w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(
|
|
b_weight.T, group_size)
|
|
g_idx = None
|
|
sort_indices = None
|
|
marlin_zp = None
|
|
marlin_s2 = None
|
|
elif has_zp:
|
|
if group_size == 16:
|
|
return
|
|
w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
|
|
b_weight, quant_type, group_size)
|
|
g_idx = None
|
|
sort_indices = None
|
|
marlin_s2 = None
|
|
else:
|
|
if group_size == 16:
|
|
return
|
|
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
|
|
b_weight, quant_type, group_size, act_order)
|
|
marlin_zp = None
|
|
marlin_s2 = None
|
|
|
|
workspace = marlin_make_workspace_new(w_ref.device)
|
|
|
|
opcheck(torch.ops._C.gptq_marlin_gemm,
|
|
(a_input, None, marlin_q_w, None, marlin_s, marlin_s2, marlin_zp,
|
|
g_idx, sort_indices, workspace, quant_type.id, a_input.shape[0],
|
|
b_weight.shape[1], a_input.shape[1], is_k_full, use_atomic_add,
|
|
use_fp32_reduce, False),
|
|
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
|
|
|
|
output = ops.gptq_marlin_gemm(
|
|
a_input,
|
|
None,
|
|
marlin_q_w,
|
|
None,
|
|
marlin_s,
|
|
marlin_s2,
|
|
marlin_zp,
|
|
g_idx,
|
|
sort_indices,
|
|
workspace,
|
|
quant_type,
|
|
a_input.shape[0],
|
|
b_weight.shape[1],
|
|
a_input.shape[1],
|
|
is_k_full=is_k_full,
|
|
use_atomic_add=use_atomic_add,
|
|
use_fp32_reduce=use_fp32_reduce,
|
|
is_zp_float=False,
|
|
)
|
|
output_ref = torch.matmul(a_input, w_ref)
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
max_diff = compute_max_diff(output, output_ref)
|
|
|
|
assert max_diff < 0.04
|
|
|
|
|
|
# TODO: find better way to test this?
|
|
@torch.compile(fullgraph=True)
|
|
def marlin_24_gemm_tester(a_input, marlin_24_q_w_comp, marlin_24_meta,
|
|
marlin_24_s, scratch, quant_type, size_m, size_n,
|
|
size_k):
|
|
return ops.gptq_marlin_24_gemm(a_input, marlin_24_q_w_comp, marlin_24_meta,
|
|
marlin_24_s, scratch, quant_type, size_m,
|
|
size_n, size_k)
|
|
|
|
|
|
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
|
reason="Marlin is not supported on this GPU type.")
|
|
@pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS)
|
|
@pytest.mark.parametrize("n_chunk", MARLIN_24_N_CHUNKS)
|
|
@pytest.mark.parametrize("quant_type", GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
|
|
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES)
|
|
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
|
def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
|
|
mnk_factors):
|
|
m_factor, n_factor, k_factor = mnk_factors
|
|
|
|
size_m = m_factor
|
|
size_k = k_chunk * k_factor
|
|
size_n = n_chunk * n_factor
|
|
|
|
a_input = rand_data((size_m, size_k))
|
|
b_weight = rand_data((size_k, size_n))
|
|
|
|
(w_24_ref, marlin_24_q_w_comp, marlin_24_meta,
|
|
marlin_24_s) = marlin_24_quantize(b_weight, quant_type, group_size)
|
|
|
|
workspace_24 = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N,
|
|
GPTQ_MARLIN_24_MAX_PARALLEL)
|
|
|
|
output_ref = torch.matmul(a_input, w_24_ref)
|
|
|
|
opcheck(torch.ops._C.gptq_marlin_24_gemm,
|
|
(a_input, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s,
|
|
workspace_24.scratch, quant_type.id, a_input.shape[0],
|
|
b_weight.shape[1], a_input.shape[1]),
|
|
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
|
|
|
|
output = marlin_24_gemm_tester(
|
|
a_input,
|
|
marlin_24_q_w_comp,
|
|
marlin_24_meta,
|
|
marlin_24_s,
|
|
workspace_24.scratch,
|
|
quant_type,
|
|
a_input.shape[0],
|
|
b_weight.shape[1],
|
|
a_input.shape[1],
|
|
)
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
max_diff = compute_max_diff(output, output_ref)
|
|
|
|
assert max_diff < 0.04
|
|
|
|
|
|
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
|
reason="Marlin is not supported on this GPU type.")
|
|
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
|
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
|
@pytest.mark.parametrize("group_size", HQQ_SUPPORTED_GROUP_SIZES)
|
|
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
|
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
|
|
def test_hqq_marlin_gemm(
|
|
k_chunk,
|
|
n_chunk,
|
|
group_size,
|
|
mnk_factors,
|
|
use_fp32_reduce,
|
|
):
|
|
m_factor, n_factor, k_factor = mnk_factors
|
|
|
|
size_m = m_factor
|
|
size_k = k_chunk * k_factor
|
|
size_n = n_chunk * n_factor
|
|
|
|
quant_type = scalar_types.uint4
|
|
|
|
a_input = rand_data((size_m, size_k))
|
|
dev = a_input.device
|
|
|
|
b_weight = torch.randint(0,
|
|
10, (size_n, size_k),
|
|
dtype=torch.uint8,
|
|
device=dev)
|
|
scale = rand_data((size_n, size_k // group_size))
|
|
zero = rand_data((size_n, size_k // group_size))
|
|
|
|
gptq_w_q = gptq_pack(b_weight.transpose(1, 0), 4, size_k, size_n)
|
|
|
|
sort_indices = torch.empty(0, dtype=torch.int, device=dev)
|
|
marlin_w_q = ops.gptq_marlin_repack(gptq_w_q, sort_indices, size_k, size_n,
|
|
4).to(dev)
|
|
marlin_s = marlin_permute_scales(scale.transpose(1, 0), size_k, size_n,
|
|
group_size).to(dev)
|
|
marlin_zp = marlin_permute_scales(zero.transpose(1, 0), size_k, size_n,
|
|
group_size).to(dev)
|
|
|
|
g_idx = marlin_make_empty_g_idx(dev)
|
|
g_idx_sort_indices = marlin_make_empty_g_idx(dev)
|
|
|
|
workspace = marlin_make_workspace_new(b_weight.device)
|
|
|
|
output = ops.gptq_marlin_gemm(
|
|
a_input,
|
|
None,
|
|
marlin_w_q,
|
|
None,
|
|
marlin_s,
|
|
None,
|
|
marlin_zp,
|
|
g_idx,
|
|
g_idx_sort_indices,
|
|
workspace,
|
|
quant_type,
|
|
a_input.shape[0],
|
|
b_weight.shape[0],
|
|
a_input.shape[1],
|
|
is_k_full=True,
|
|
use_fp32_reduce=use_fp32_reduce,
|
|
is_zp_float=True,
|
|
)
|
|
|
|
b_flat = b_weight.reshape(-1, group_size)
|
|
zp_flat = zero.reshape(-1, 1)
|
|
s_flat = scale.reshape(-1, 1)
|
|
dequant = (b_flat - zp_flat) * s_flat
|
|
|
|
output_ref = torch.matmul(a_input,
|
|
dequant.reshape(b_weight.shape).transpose(1, 0))
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
max_diff = compute_max_diff(output, output_ref)
|
|
|
|
assert max_diff < 0.04
|
|
|
|
|
|
@pytest.mark.skipif(not is_quant_method_supported("qqq"),
|
|
reason="Marlin is not supported on this GPU type.")
|
|
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
|
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
|
@pytest.mark.parametrize("num_bits", MARLIN_QQQ_SUPPORTED_NUM_BITS)
|
|
@pytest.mark.parametrize("group_size", MARLIN_QQQ_SUPPORTED_GROUP_SIZES)
|
|
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
|
def test_marlin_qqq_gemm(
|
|
k_chunk,
|
|
n_chunk,
|
|
num_bits,
|
|
group_size,
|
|
mnk_factors,
|
|
):
|
|
int8_traits = torch.iinfo(torch.int8)
|
|
m_factor, n_factor, k_factor = mnk_factors
|
|
|
|
size_m = m_factor
|
|
size_k = k_chunk * k_factor
|
|
size_n = n_chunk * n_factor
|
|
|
|
a_input = rand_data((size_m, size_k))
|
|
b_weight = rand_data((size_k, size_n))
|
|
|
|
# Quantize activations
|
|
s_a = a_input.abs().max(dim=-1, keepdim=True)[0].div(int8_traits.max).to(
|
|
torch.float)
|
|
q_a = (a_input / s_a).round().clamp(int8_traits.min,
|
|
int8_traits.max).to(torch.int8)
|
|
|
|
# Quantize weights
|
|
w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel = \
|
|
marlin_qqq_quantize(b_weight, num_bits, group_size)
|
|
|
|
workspace = MarlinWorkspace(size_n, MARLIN_QQQ_MIN_THREAD_N,
|
|
MARLIN_QQQ_MAX_PARALLEL)
|
|
|
|
opcheck(torch.ops._C.marlin_qqq_gemm,
|
|
(q_a, marlin_qqq_q_w, s_a, marlin_qqq_s_channel,
|
|
marlin_qqq_s_group, workspace.scratch, a_input.shape[0],
|
|
b_weight.shape[1], a_input.shape[1]))
|
|
|
|
output = ops.marlin_qqq_gemm(
|
|
q_a,
|
|
marlin_qqq_q_w,
|
|
s_a,
|
|
marlin_qqq_s_channel,
|
|
marlin_qqq_s_group,
|
|
workspace.scratch,
|
|
a_input.shape[0],
|
|
b_weight.shape[1],
|
|
a_input.shape[1],
|
|
)
|
|
output_ref = torch.matmul(q_a.half() * s_a.half(), w_ref)
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
max_diff = compute_max_diff(output, output_ref)
|
|
|
|
assert max_diff < 0.04
|
|
|
|
|
|
def test_marlin_gemm_subset_input():
|
|
quant_type = scalar_types.uint4b8
|
|
group_size = 128
|
|
|
|
size_m, size_k, size_n = 32, 1024, 2048
|
|
big_m = size_m * 2
|
|
big_k = size_k * 2
|
|
|
|
a_input = rand_data((big_m, big_k))[8:size_m + 8, 8:size_k + 8]
|
|
b_weight = rand_data((size_k, size_n))
|
|
|
|
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
|
|
b_weight, quant_type, group_size, False)
|
|
|
|
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
|
|
workspace = marlin_make_workspace_new(a_input.device)
|
|
|
|
output = ops.gptq_marlin_gemm(
|
|
a_input,
|
|
None,
|
|
marlin_q_w,
|
|
None,
|
|
marlin_s,
|
|
None,
|
|
marlin_zp,
|
|
g_idx,
|
|
sort_indices,
|
|
workspace,
|
|
quant_type,
|
|
a_input.shape[0],
|
|
b_weight.shape[1],
|
|
a_input.shape[1],
|
|
is_k_full=True,
|
|
use_atomic_add=False,
|
|
use_fp32_reduce=True,
|
|
is_zp_float=False,
|
|
)
|
|
output_ref = torch.matmul(a_input, w_ref)
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
max_diff = compute_max_diff(output, output_ref)
|
|
|
|
assert max_diff < 0.04
|
|
|
|
|
|
@pytest.mark.parametrize("size_m", [1, 256])
|
|
def test_marlin_gemm_with_bias(size_m):
|
|
quant_type = scalar_types.uint4b8
|
|
group_size = 128
|
|
|
|
size_k, size_n = 1024, 2048
|
|
a_input = rand_data((size_m, size_k))
|
|
b_weight = rand_data((size_k, size_n))
|
|
b_bias = rand_data((size_n, )) * 10
|
|
|
|
marlin_bias = marlin_permute_bias(b_bias)
|
|
|
|
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
|
|
b_weight, quant_type, group_size, False)
|
|
|
|
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
|
|
workspace = marlin_make_workspace_new(a_input.device)
|
|
|
|
output = ops.gptq_marlin_gemm(
|
|
a_input,
|
|
None,
|
|
marlin_q_w,
|
|
marlin_bias,
|
|
marlin_s,
|
|
None,
|
|
marlin_zp,
|
|
g_idx,
|
|
sort_indices,
|
|
workspace,
|
|
quant_type,
|
|
a_input.shape[0],
|
|
b_weight.shape[1],
|
|
a_input.shape[1],
|
|
is_k_full=True,
|
|
use_atomic_add=False,
|
|
use_fp32_reduce=True,
|
|
is_zp_float=False,
|
|
)
|
|
output_ref = torch.matmul(a_input, w_ref) + b_bias.view(1, -1)
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
max_diff = compute_max_diff(output, output_ref)
|
|
|
|
assert max_diff < 0.04
|
|
|
|
|
|
def test_marlin_gemm_opcheck():
|
|
size_m = 2048
|
|
size_n = 4096
|
|
size_k = 4096
|
|
a = torch.rand((size_m, size_n), device='cuda', dtype=torch.float16)
|
|
w = torch.randint(-5, 5, (256, 8192), device='cuda', dtype=torch.int32)
|
|
s = torch.full((32, size_k), 0.125, device='cuda', dtype=torch.float16)
|
|
wk = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
|
|
GPTQ_MARLIN_MAX_PARALLEL).scratch
|
|
x = torch.ops._C.marlin_gemm(a, w, s, wk, size_m, size_n, size_k)
|
|
y = torch.ops._C.marlin_gemm(a, w, s, wk, size_m, size_n, size_k)
|
|
torch.testing.assert_close(x, y)
|
|
opcheck(torch.ops._C.marlin_gemm, (a, w, s, wk, size_m, size_n, size_k))
|