mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-11 21:17:06 +08:00
Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: Lucas Kabela <lucaskabela@meta.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: Andrew Sansom <andrew@protopia.ai> Signed-off-by: Boyuan Feng <boyuan@meta.com> Signed-off-by: Boyuan Feng <fby.1994@gmail.com> Signed-off-by: boyuanfeng <boyuan@meta.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: JartX <sagformas@epdcenter.es> Signed-off-by: Chendi Xue <Chendi.Xue@intel.com> Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Signed-off-by: Roger Wang <hey@rogerw.io> Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: wwl2755 <wangwenlong2755@gmail.com> Signed-off-by: Manoel Marques <manoel.marques@ibm.com> Signed-off-by: Manoel Marques <manoelmrqs@gmail.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: pengdrumli <pengdrumli@tencent.com> Signed-off-by: windsonsea <haifeng.yao@daocloud.io> Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai> Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: Huamin Li <3ericli@gmail.com> Signed-off-by: simondanielsson <simon.danielsson99@hotmail.com> Signed-off-by: Rahul Tuli <rtuli@redhat.com> Signed-off-by: Yang <lymailforjob@gmail.com> Signed-off-by: Debolina Roy <debroy@redhat.com> Signed-off-by: David Chen <530634352@qq.com> Signed-off-by: wangzi <3220100013@zju.edu.cn> Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com> Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com> Signed-off-by: Sara Kokkila Schumacher <saraks@ibm.com> Signed-off-by: Csrayz <jover@cmbchina.com> Signed-off-by: ivyilike <pww123@cmbchina.com> Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com> Signed-off-by: Bowen Wang <abmfy@icloud.com> Signed-off-by: qqma <qqma@amazon.com> Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: Lu Fang <fanglu@fb.com> Signed-off-by: Zhuohan Li <zhuohan123@gmail.com> Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: luka <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: Or Ozeri <oro@il.ibm.com> Signed-off-by: Johnny Yang <johnnyyang@google.com> Signed-off-by: Alec Solder <alecs@fb.com> Signed-off-by: Alec S <10566873+alecsolder@users.noreply.github.com> Signed-off-by: Russell Bryant <rbryant@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Alexander Matveev <amatveev@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: liuye.hj <liuye.hj@alibaba-inc.com> Signed-off-by: Kunshang Ji <kunshang.ji@intel.com> Signed-off-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Signed-off-by: Ming Yang <minos.future@gmail.com> Signed-off-by: Zhikaiiii <1658973216@qq.com> Signed-off-by: Andreas Hartel <andreas.hartel@aleph-alpha.com> Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Signed-off-by: wuxibin <wuxibin@bytedance.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: Peter Pan <Peter.Pan@daocloud.io> Signed-off-by: Peter Pan <peter.pan@daocloud.io> Signed-off-by: Nicolò Lucchesi<nicolo.lucchesi@gmail.com> Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Signed-off-by: Sage Moore <sage@neuralmagic.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Signed-off-by: Bill Nell <bnell@redhat.com> Signed-off-by: Shreeasish Kumar <shreeasish@rivosinc.com> Signed-off-by: Weida Hong <wdhongtw@google.com> Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com> Signed-off-by: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com> Signed-off-by: Amir Samani <asamani@nvidia.com> Signed-off-by: ElizaWszola <elizaw.9289@gmail.com> Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Signed-off-by: ilmarkov <markovilya197@gmail.com> Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Signed-off-by: rouchenzi <ruochenwen@gmail.com> Signed-off-by: rouchenzi <40842833+rouchenzi@users.noreply.github.com> Signed-off-by: Andrew Xia <axia@meta.com> Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com> Signed-off-by: Corey Lowman <clowman1993@gmail.com> Signed-off-by: jpvillam <jpvillam@amd.com> Signed-off-by: dougbtv <dosmith@redhat.com> Signed-off-by: Chenxi Yang <cxyang@fb.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: ahao-anyscale <ahao@anyscale.com> Signed-off-by: Yan Lu <luyan@nvidia.com> Signed-off-by: baxingpiaochong <771405853@qq.com> Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> Signed-off-by: Nikhil Gupta <nikhil.gupta2@arm.com> Signed-off-by: Yong Hoon Shin <yhshin@meta.com> Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai> Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Signed-off-by: Ben Browning <bbrownin@redhat.com> Signed-off-by: Chengji Yao <chengjiyao@google.com> Signed-off-by: jiang1.li <jiang1.li@intel.com> Signed-off-by: Jackmin801 <ongjackm@gmail.com> Signed-off-by: Jonas M. Kübler <44084297+jmkuebler@users.noreply.github.com> Signed-off-by: taohui <taohui3@gmail.com> Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io> Signed-off-by: Shu Wang <shuw@nvidia.com> Signed-off-by: Shu Wang. <shuw@nvidia.com> Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Signed-off-by: Duncan Moss <djm.moss@gmail.com> Signed-off-by: Shiyan Deng <dsy842974287@meta.com> Signed-off-by: Wei Wei <wwei6@meta.com> Signed-off-by: Saman Keon <samanamp@outlook.com> Signed-off-by: yangxurui <yangxurui@meituan.com> Signed-off-by: nicole-lihui <nicole.li@daocloud.io> Signed-off-by: courage17340 <courage17340@163.com> Signed-off-by: Jacob Kahn <jacobkahn1@gmail.com> Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com> Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai> Signed-off-by: zxw <1020938856@qq.com> Signed-off-by: wang.yuqi <noooop@126.com> Signed-off-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Signed-off-by: chenlang <chen.lang5@zte.com.cn> Signed-off-by: Jonas Kuebler <kuebj@amazon.com> Signed-off-by: AlonKejzman <alonkeizman@gmail.com> Signed-off-by: Tao Hui <taohui3@gmail.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Signed-off-by: Eugene Khvedchenia <ekhvedchenia@nvidia.com> Signed-off-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Signed-off-by: yiting.jiang <yiting.jiang@daocloud.io> Signed-off-by: xaguilar <Xavier.AguilarFruto@amd.com> Signed-off-by: Iceber Gu <caiwei95@hotmail.com> Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com> Signed-off-by: Icey <1790571317@qq.com> Signed-off-by: 许文卿 <xwq391974@alibaba-inc.com> Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: Nick Hill <nhill@redhat.com> Co-authored-by: Lucas Kabela <lucasakabela@gmail.com> Co-authored-by: Maximilien de Bayser <mbayser@br.ibm.com> Co-authored-by: Andrew Sansom <andrew@protopia.ai> Co-authored-by: Boyuan Feng <boyuan@meta.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: JartX <sagformas@epdcenter.es> Co-authored-by: Chendi.Xue <chendi.xue@intel.com> Co-authored-by: Chauncey <chaunceyjiang@gmail.com> Co-authored-by: xin.li <xin.li@daocloud.io> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Roger Wang <hey@rogerw.io> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Wenlong Wang <wangwenlong2755@gmail.com> Co-authored-by: Manoel Marques <manoelmrqs@gmail.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: lirong <56789630+lirong-lirong@users.noreply.github.com> Co-authored-by: Michael Yao <haifeng.yao@daocloud.io> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: Huamin Li <3ericli@gmail.com> Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com> Co-authored-by: Simon Danielsson <70206058+simondanielsson@users.noreply.github.com> Co-authored-by: Rahul Tuli <rtuli@redhat.com> Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: Yang Liu <127183760+KKSK-DON@users.noreply.github.com> Co-authored-by: Deboleina <debroy@redhat.com> Co-authored-by: yinz-aizip <yinz@aizip.ai> Co-authored-by: WeiQing Chen <40507679+david6666666@users.noreply.github.com> Co-authored-by: wangzi <3220100013@zju.edu.cn> Co-authored-by: Eldar Kurtić <8884008+eldarkurtic@users.noreply.github.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com> Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com> Co-authored-by: Yizhou <136800916+yiz-liu@users.noreply.github.com> Co-authored-by: Sara-KS <50249410+Sara-KS@users.noreply.github.com> Co-authored-by: Csrayz <jover@cmbchina.com> Co-authored-by: ivyilike <pww123@cmbchina.com> Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com> Co-authored-by: Bowen Wang <abmfy@icloud.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Daisy-Ma-coder <daisy.ma.0117@gmail.com> Co-authored-by: qqma <qqma@amazon.com> Co-authored-by: ElizaWszola <ewszola@redhat.com> Co-authored-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Co-authored-by: Zhuohan Li <zhuohan123@gmail.com> Co-authored-by: Simon Mo <simon.mo@hey.com> Co-authored-by: Or Ozeri <oro@il.ibm.com> Co-authored-by: Johnny Yang <24908445+jcyang43@users.noreply.github.com> Co-authored-by: Chengji Yao <chengjiyao@google.com> Co-authored-by: Alec S <10566873+alecsolder@users.noreply.github.com> Co-authored-by: Alec Solder <alecs@fb.com> Co-authored-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Chris Bamford <chrisbam4d@gmail.com> Co-authored-by: Alexander Matveev <59768536+alexm-redhat@users.noreply.github.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: JJJYmmm <92386084+JJJYmmm@users.noreply.github.com> Co-authored-by: liuye.hj <liuye.hj@alibaba-inc.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com> Co-authored-by: Lucia (Lu) Fang <fanglu@meta.com> Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Ming Yang <yming@meta.com> Co-authored-by: Zhikaiiii <55917203+Zhikaiiii@users.noreply.github.com> Co-authored-by: Andreas Hartel <andreas@hartel.me> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: Joel <wuxibin89@163.com> Co-authored-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Mark McLoughlin <markmc@redhat.com> Co-authored-by: Peter Pan <peter.pan@daocloud.io> Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com> Co-authored-by: Fanli Lin <fanli.lin@intel.com> Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Sage Moore <sage@neuralmagic.com> Co-authored-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: bnellnm <49004751+bnellnm@users.noreply.github.com> Co-authored-by: rivos-shreeasish <shreeasish@rivosinc.com> Co-authored-by: Chih-Chieh Yang <chih.chieh.yang@ibm.com> Co-authored-by: Weida Hong <wdhongtw@gmail.com> Co-authored-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Co-authored-by: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com> Co-authored-by: Amir Samani <samani@ualberta.ca> Co-authored-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Co-authored-by: Ilya Markov <markovilya197@gmail.com> Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Co-authored-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Co-authored-by: rouchenzi <40842833+rouchenzi@users.noreply.github.com> Co-authored-by: Andrew Xia <axia@meta.com> Co-authored-by: kourosh hakhamaneshi <31483498+kouroshHakha@users.noreply.github.com> Co-authored-by: Corey Lowman <clowman1993@gmail.com> Co-authored-by: Juan Villamizar <100237675+jpvillam-amd@users.noreply.github.com> Co-authored-by: jpvillam <jpvillam@amd.com> Co-authored-by: Doug Smith <dosmith@redhat.com> Co-authored-by: Chenxi Yang <cxyang@cs.utexas.edu> Co-authored-by: Chenxi Yang <cxyang@fb.com> Co-authored-by: ahao-anyscale <ahao@anyscale.com> Co-authored-by: 0xNullPath <luyanfcp@foxmail.com> Co-authored-by: baxingpiaochong <771405853@qq.com> Co-authored-by: Benjamin Chislett <bchislett@nvidia.com> Co-authored-by: Kyle Sayers <kylesayrs@gmail.com> Co-authored-by: Nikhil Gupta <nikhil.gupta2@arm.com> Co-authored-by: Yong Hoon Shin <48474650+sarckk@users.noreply.github.com> Co-authored-by: lhsjohn <huashuoli@tencent.com> Co-authored-by: Ben Browning <bbrownin@redhat.com> Co-authored-by: Li, Jiang <jiang1.li@intel.com> Co-authored-by: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> Co-authored-by: Jonas M. Kübler <44084297+jmkuebler@users.noreply.github.com> Co-authored-by: Tao Hui <taohui3@gmail.com> Co-authored-by: rongfu.leng <rongfu.leng@daocloud.io> Co-authored-by: Shu Wang <shuw@nvidia.com> Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Co-authored-by: Duncan Moss <djm.moss@gmail.com> Co-authored-by: Shiyan Deng <dsy842974287@meta.com> Co-authored-by: Wei Wei <wwei6@meta.com> Co-authored-by: Saman A. Pour <samanamp@outlook.com> Co-authored-by: XuruiYang <530534756@qq.com> Co-authored-by: yangxurui <yangxurui@meituan.com> Co-authored-by: Nicole LiHui 🥜 <nicolelihui@outlook.com> Co-authored-by: courage17340 <courage17340@users.noreply.github.com> Co-authored-by: Jacob Kahn <jacobkahn1@gmail.com> Co-authored-by: Nicole LiHui 🥜 <nicole.li@daocloud.io> Co-authored-by: Fadi Arafeh <115173828+fadara01@users.noreply.github.com> Co-authored-by: Agata Dobrzyniewicz <160237065+adobrzyn@users.noreply.github.com> Co-authored-by: yyzxw <34639446+yyzxw@users.noreply.github.com> Co-authored-by: wang.yuqi <noooop@126.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: chenlang <chen.lang5@zte.com.cn> Co-authored-by: chenlang <10346245@zte.com.cn> Co-authored-by: AlonKejzman <alonkeizman@gmail.com> Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> Co-authored-by: Aleksandr Malyshev <164964928+maleksan85@users.noreply.github.com> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Doug Lehr <douglehr@amd.com> Co-authored-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Co-authored-by: yitingdc <59356937+yitingdc@users.noreply.github.com> Co-authored-by: xaguilar-amd <xavier.aguilarfruto@amd.com> Co-authored-by: Iceber Gu <caiwei95@hotmail.com> Co-authored-by: Tao He <linzhu.ht@alibaba-inc.com> Co-authored-by: Icey <1790571317@qq.com> Co-authored-by: Xu Wenqing <121550081+Xu-Wenqing@users.noreply.github.com> Co-authored-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: RishiAstra <40644327+RishiAstra@users.noreply.github.com>
871 lines
34 KiB
Python
871 lines
34 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import importlib.metadata
|
|
from dataclasses import dataclass
|
|
from importlib.util import find_spec
|
|
from typing import Optional
|
|
|
|
import pytest
|
|
import torch
|
|
from packaging import version
|
|
|
|
from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501
|
|
QuarkLinearMethod, QuarkW4A4MXFP4)
|
|
from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501
|
|
QuarkW4A4MXFp4MoEMethod)
|
|
from vllm.platforms import current_platform
|
|
from vllm.utils.flashinfer import has_flashinfer
|
|
|
|
QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse(
|
|
importlib.metadata.version("amd-quark")) >= version.parse('0.8.99')
|
|
|
|
TRTLLM_GEN_MXFP4_AVAILABLE = current_platform.is_cuda(
|
|
) and current_platform.is_device_capability(100)
|
|
|
|
HOPPER_MXFP4_BF16_AVAILABLE = (current_platform.is_cuda()
|
|
and current_platform.is_device_capability(90)
|
|
and has_flashinfer())
|
|
|
|
if TRTLLM_GEN_MXFP4_AVAILABLE:
|
|
from flashinfer import (fp4_quantize, mxfp8_quantize,
|
|
next_positive_power_of_2,
|
|
reorder_rows_for_gated_act_gemm, shuffle_matrix_a,
|
|
shuffle_matrix_sf_a, trtllm_fp4_block_scale_moe)
|
|
from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
|
|
from flashinfer.fused_moe.core import _maybe_get_cached_w2_permute_indices
|
|
|
|
|
|
@dataclass
|
|
class ModelCase:
|
|
model_id: str
|
|
tp: int
|
|
|
|
|
|
@pytest.fixture(scope="function", autouse=True)
|
|
def enable_pickle(monkeypatch):
|
|
"""`LLM.apply_model` requires pickling a function."""
|
|
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
|
|
|
|
|
@pytest.mark.parametrize('model_case', [
|
|
ModelCase("fxmarty/qwen_1.5-moe-a2.7b-mxfp4", tp=1),
|
|
ModelCase("fxmarty/deepseek_r1_3_layers_mxfp4", tp=8),
|
|
ModelCase("fxmarty/Llama-4-Scout-17B-16E-Instruct-2-layers-mxfp4", tp=1)
|
|
])
|
|
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE,
|
|
reason="amd-quark>=0.9 is not available")
|
|
def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase):
|
|
if torch.cuda.device_count() < model_case.tp:
|
|
pytest.skip(f"This test requires >={model_case.tp} gpus, got only "
|
|
f"{torch.cuda.device_count()}")
|
|
|
|
with vllm_runner(model_case.model_id,
|
|
tensor_parallel_size=model_case.tp,
|
|
load_format="dummy") as llm:
|
|
|
|
def check_model(model):
|
|
layer = model.model.layers[0]
|
|
|
|
qkv_proj = layer.self_attn.qkv_proj
|
|
|
|
assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
|
|
assert isinstance(qkv_proj.scheme, QuarkW4A4MXFP4)
|
|
|
|
assert isinstance(layer.mlp.experts.quant_method,
|
|
QuarkW4A4MXFp4MoEMethod)
|
|
|
|
if model_case.model_id == "fxmarty/qwen_1.5-moe-a2.7b-mxfp4":
|
|
llm.apply_model(check_model)
|
|
|
|
output = llm.generate_greedy("Today I am in the French Alps and",
|
|
max_tokens=20)
|
|
assert output
|
|
|
|
|
|
def swiglu(x,
|
|
alpha: float = 1.702,
|
|
beta: float = 1.0,
|
|
limit: Optional[float] = None):
|
|
# Note we add an extra bias of 1 to the linear layer
|
|
x_glu, x_linear = torch.chunk(x, 2, dim=-1)
|
|
if limit is not None:
|
|
x_glu = x_glu.clamp(max=limit)
|
|
x_linear = x_linear.clamp(min=-limit, max=limit)
|
|
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
|
|
return out_glu * (x_linear + beta)
|
|
|
|
|
|
fp4_lookup_table = [
|
|
0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6
|
|
]
|
|
|
|
|
|
def mxfp4_dequantize(x, scale):
|
|
assert x.dtype == torch.uint8
|
|
x = x.view(torch.uint8).to(torch.int32)
|
|
x_unpacked = torch.zeros(*x.shape[:-1],
|
|
x.shape[-1] * 2,
|
|
dtype=torch.int32,
|
|
device=x.device)
|
|
x_unpacked[..., 0::2].copy_(x & 0xF)
|
|
x_unpacked[..., 1::2].copy_((x >> 4) & 0xF)
|
|
|
|
x_float = torch.zeros(x_unpacked.shape,
|
|
dtype=torch.float32,
|
|
device=x.device)
|
|
for i, val in enumerate(fp4_lookup_table):
|
|
x_float[x_unpacked == i] = val
|
|
|
|
scale = scale.view(torch.uint8).to(torch.int32)
|
|
scale = (scale << 23).view(torch.float32)
|
|
scale = scale.reshape(*x.shape[:-1], -1)
|
|
scale = torch.stack([scale] * 32, dim=-1).reshape(*x_float.shape)
|
|
|
|
return x_float * scale
|
|
|
|
|
|
def mxfp8_dequantize(x, scale):
|
|
assert x.dtype == torch.float8_e4m3fn
|
|
x_float = x.to(torch.float32)
|
|
|
|
scale = scale.view(torch.uint8).to(torch.int32)
|
|
scale = (scale << 23).view(torch.float32)
|
|
scale = scale.reshape(*x.shape[:-1], -1)
|
|
scale = torch.stack([scale] * 32, dim=-1).reshape(*x_float.shape)
|
|
|
|
return x_float * scale
|
|
|
|
|
|
def reference_moe(
|
|
roouting_logits,
|
|
topk,
|
|
num_experts,
|
|
hidden_states,
|
|
w13,
|
|
bias13,
|
|
w2,
|
|
bias2,
|
|
alpha,
|
|
beta,
|
|
limit,
|
|
act_type,
|
|
):
|
|
# renormalize routing
|
|
experts = torch.topk(roouting_logits, k=topk, dim=-1, sorted=True)
|
|
expert_weights = torch.nn.functional.softmax(experts.values, dim=1)
|
|
expert_indices = experts.indices
|
|
t = hidden_states.clone()
|
|
# MLP #1
|
|
mlp1_weight = w13[expert_indices, ...]
|
|
mlp1_bias = bias13[expert_indices, ...]
|
|
t = torch.einsum("beck,bk->bec", mlp1_weight, t) + mlp1_bias
|
|
t = swiglu(t, alpha=alpha, beta=beta, limit=limit)
|
|
|
|
if act_type == 'mxfp8':
|
|
t_quantized, t_scale = mxfp8_quantize(t.to(torch.bfloat16),
|
|
is_sf_swizzled_layout=False)
|
|
t = mxfp8_dequantize(t_quantized, t_scale)
|
|
# MLP #2
|
|
mlp2_weight = w2[expert_indices, ...]
|
|
mlp2_bias = bias2[expert_indices, ...]
|
|
t = torch.einsum("beck,bek->bec", mlp2_weight, t) + mlp2_bias
|
|
# Weighted sum of experts
|
|
t = torch.einsum("bec,be->bc", t, expert_weights)
|
|
assert t.shape == hidden_states.shape
|
|
return t.to(torch.bfloat16)
|
|
|
|
|
|
def get_tile_tokens_dim(x: torch.Tensor, top_k: int, num_experts: int):
|
|
# Number of tokens in the input tensor.
|
|
num_tokens = x.shape[0]
|
|
# Factor to account for the imbalance of the experts.
|
|
# factor equals to the
|
|
# max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
|
|
# - 1.0 means perfect expert distribution.
|
|
# - > 1.0 means some experts have more
|
|
# tokens than the perfect distribution.
|
|
# - < 1.0 does not make sense.
|
|
imbalance_factor = 1.3
|
|
# Calculate the number of tokens per expert
|
|
# assuming perfect distribution.
|
|
num_tokens_per_expert = (num_tokens * top_k) // num_experts
|
|
# Apply the imbalance factor.
|
|
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
|
|
# And pad the number to the next power of 2.
|
|
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
|
|
# Cap to 8-64 tokens per CTA tile
|
|
# as it's the range supported by the kernel.
|
|
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
|
return tile_tokens_dim
|
|
|
|
|
|
def tg_mxfp4_moe(
|
|
router_logits,
|
|
topk,
|
|
num_experts,
|
|
intermediate_size,
|
|
hidden_size,
|
|
hidden_states,
|
|
hidden_states_scale,
|
|
w13_weight,
|
|
w13_weight_scale,
|
|
w13_bias,
|
|
w2_weight,
|
|
w2_weight_scale,
|
|
w2_bias,
|
|
act_type,
|
|
alpha,
|
|
beta,
|
|
limit,
|
|
transpose_optimized: bool = False,
|
|
) -> torch.Tensor:
|
|
sf_block_size = 32
|
|
assert (w13_weight.dim() == 3 and w13_weight.shape[0] == num_experts
|
|
and w13_weight.shape[1] == intermediate_size * 2
|
|
and w13_weight.shape[2] == hidden_size // 2)
|
|
assert (w13_weight_scale.dim() == 3
|
|
and w13_weight_scale.shape[0] == num_experts
|
|
and w13_weight_scale.shape[1] == intermediate_size * 2
|
|
and w13_weight_scale.shape[2] == hidden_size // sf_block_size)
|
|
assert (w2_weight.dim() == 3 and w2_weight.shape[0] == num_experts
|
|
and w2_weight.shape[1] == hidden_size
|
|
and w2_weight.shape[2] == intermediate_size // 2)
|
|
assert (w2_weight_scale.dim() == 3
|
|
and w2_weight_scale.shape[1] == hidden_size
|
|
and w2_weight_scale.shape[2] == intermediate_size // sf_block_size)
|
|
assert (w13_bias.dim() == 2 and w13_bias.shape[0] == num_experts
|
|
and w13_bias.shape[1] == intermediate_size * 2)
|
|
assert (w2_bias.dim() == 2 and w2_bias.shape[0] == num_experts
|
|
and w2_bias.shape[1] == hidden_size)
|
|
|
|
# Swap w1 and w3 as the definition of
|
|
# swiglu is different in the trtllm-gen
|
|
w13_weight_scale_ = w13_weight_scale.clone()
|
|
w13_weight_ = w13_weight.clone()
|
|
w13_bias_ = w13_bias.clone()
|
|
w13_weight[:, :intermediate_size, :].copy_(
|
|
w13_weight_[:, intermediate_size:, :])
|
|
w13_weight[:, intermediate_size:, :].copy_(
|
|
w13_weight_[:, :intermediate_size, :])
|
|
w13_weight_scale[:, :intermediate_size, :].copy_(
|
|
w13_weight_scale_[:, intermediate_size:, :])
|
|
w13_weight_scale[:, intermediate_size:, :].copy_(
|
|
w13_weight_scale_[:, :intermediate_size, :])
|
|
w13_bias[:, :intermediate_size].copy_(w13_bias_[:, intermediate_size:])
|
|
w13_bias[:, intermediate_size:].copy_(w13_bias_[:, :intermediate_size])
|
|
|
|
# Interleave the weights and scaling factors for activation
|
|
w13_weight_interleaved = []
|
|
w13_weight_scale_interleaved = []
|
|
w13_bias_interleaved = []
|
|
for i in range(num_experts):
|
|
w13_weight_interleaved.append(
|
|
reorder_rows_for_gated_act_gemm(w13_weight[i].clone()))
|
|
w13_weight_scale_interleaved.append(
|
|
reorder_rows_for_gated_act_gemm(w13_weight_scale[i].clone()))
|
|
w13_bias_interleaved.append(
|
|
reorder_rows_for_gated_act_gemm(w13_bias[i].clone().reshape(-1,
|
|
1)))
|
|
w13_weight = torch.stack(w13_weight_interleaved).reshape(
|
|
num_experts, 2 * intermediate_size, hidden_size // 2)
|
|
w13_weight_scale = torch.stack(w13_weight_scale_interleaved).reshape(
|
|
num_experts, 2 * intermediate_size, hidden_size // 32)
|
|
w13_bias = torch.stack(w13_bias_interleaved).reshape(
|
|
num_experts, 2 * intermediate_size)
|
|
|
|
# Shuffle weights and scaling factors for transposed mma output
|
|
gemm1_weights_shuffled = []
|
|
gemm1_scales_shuffled = []
|
|
gemm2_weights_shuffled = []
|
|
gemm2_scales_shuffled = []
|
|
gemm1_bias_shuffled = []
|
|
gemm2_bias_shuffled = []
|
|
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
|
|
_cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
|
|
if transpose_optimized:
|
|
for i in range(num_experts):
|
|
# w13 weight shuffling
|
|
permute_indices = _maybe_get_cached_w2_permute_indices(
|
|
_cache_permute_indices,
|
|
w13_weight[i].view(torch.uint8),
|
|
epilogue_tile_m,
|
|
)
|
|
gemm1_weights_shuffled.append(w13_weight[i].view(
|
|
torch.uint8)[permute_indices.to(
|
|
w13_weight.device)].contiguous())
|
|
# w13 scale shuffling
|
|
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
|
|
_cache_permute_indices,
|
|
w13_weight_scale[i].view(torch.uint8),
|
|
epilogue_tile_m,
|
|
num_elts_per_sf=16,
|
|
)
|
|
gemm1_scales_shuffled.append(
|
|
nvfp4_block_scale_interleave(w13_weight_scale[i].view(
|
|
torch.uint8)[permute_sf_indices.to(
|
|
w13_weight_scale.device)].contiguous()))
|
|
# w13 bias shuffling
|
|
permute_bias_indices = _maybe_get_cached_w2_permute_indices(
|
|
_cache_permute_indices,
|
|
w13_bias[i].clone().reshape(-1, 1),
|
|
epilogue_tile_m,
|
|
)
|
|
gemm1_bias_shuffled.append(w13_bias[i].clone().reshape(
|
|
-1, 1)[permute_bias_indices.to(w13_bias.device)].contiguous())
|
|
# w2 weight shuffling
|
|
permute_indices = _maybe_get_cached_w2_permute_indices(
|
|
_cache_permute_indices,
|
|
w2_weight[i].view(torch.uint8),
|
|
epilogue_tile_m,
|
|
)
|
|
gemm2_weights_shuffled.append(w2_weight[i].view(
|
|
torch.uint8)[permute_indices.to(
|
|
w2_weight.device)].contiguous())
|
|
# w2 scale shuffling
|
|
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
|
|
_cache_permute_indices,
|
|
w2_weight_scale[i].view(torch.uint8),
|
|
epilogue_tile_m,
|
|
num_elts_per_sf=16,
|
|
)
|
|
gemm2_scales_shuffled.append(
|
|
nvfp4_block_scale_interleave(w2_weight_scale[i].view(
|
|
torch.uint8)[permute_sf_indices.to(
|
|
w2_weight_scale.device)].contiguous()))
|
|
# w2 bias shuffling
|
|
permute_indices = _maybe_get_cached_w2_permute_indices(
|
|
_cache_permute_indices,
|
|
w2_bias[i].clone().reshape(-1, 1),
|
|
epilogue_tile_m,
|
|
)
|
|
gemm2_bias_shuffled.append(w2_bias[i].clone().reshape(
|
|
-1, 1)[permute_indices.to(w2_bias.device)].contiguous())
|
|
|
|
else:
|
|
for i in range(num_experts):
|
|
gemm1_weights_shuffled.append(
|
|
shuffle_matrix_a(w13_weight[i].view(torch.uint8),
|
|
epilogue_tile_m))
|
|
gemm1_scales_shuffled.append(
|
|
shuffle_matrix_sf_a(w13_weight_scale[i].view(torch.uint8),
|
|
epilogue_tile_m))
|
|
|
|
gemm2_weights_shuffled.append(
|
|
shuffle_matrix_a(w2_weight[i].view(torch.uint8),
|
|
epilogue_tile_m))
|
|
gemm2_scales_shuffled.append(
|
|
shuffle_matrix_sf_a(w2_weight_scale[i].view(torch.uint8),
|
|
epilogue_tile_m))
|
|
gemm1_bias_shuffled.append(
|
|
shuffle_matrix_a(w13_bias[i].reshape(-1, 1), epilogue_tile_m))
|
|
gemm2_bias_shuffled.append(
|
|
shuffle_matrix_a(w2_bias[i].reshape(-1, 1), epilogue_tile_m))
|
|
|
|
w13_weight = torch.stack(gemm1_weights_shuffled)
|
|
w13_weight_scale = torch.stack(gemm1_scales_shuffled).reshape(
|
|
num_experts, 2 * intermediate_size,
|
|
hidden_size // sf_block_size).view(torch.float8_e4m3fn)
|
|
w13_bias = torch.stack(gemm1_bias_shuffled).reshape(num_experts, -1)
|
|
|
|
w2_weight = torch.stack(gemm2_weights_shuffled)
|
|
w2_weight_scale = torch.stack(gemm2_scales_shuffled).reshape(
|
|
num_experts, hidden_size,
|
|
intermediate_size // sf_block_size).view(torch.float8_e4m3fn)
|
|
w2_bias = torch.stack(gemm2_bias_shuffled).reshape(num_experts, -1)
|
|
|
|
tg_result = trtllm_fp4_block_scale_moe(
|
|
routing_logits=router_logits.to(torch.bfloat16),
|
|
routing_bias=None,
|
|
hidden_states=hidden_states,
|
|
hidden_states_scale=hidden_states_scale,
|
|
gemm1_weights=w13_weight,
|
|
gemm1_weights_scale=w13_weight_scale,
|
|
gemm1_bias=w13_bias,
|
|
gemm1_alpha=alpha,
|
|
gemm1_beta=beta,
|
|
gemm1_clamp_limit=limit,
|
|
gemm2_weights=w2_weight,
|
|
gemm2_weights_scale=w2_weight_scale,
|
|
gemm2_bias=w2_bias,
|
|
output1_scale_scalar=None,
|
|
output1_scale_gate_scalar=None,
|
|
output2_scale_scalar=None,
|
|
num_experts=num_experts,
|
|
top_k=topk,
|
|
n_group=None,
|
|
topk_group=None,
|
|
intermediate_size=intermediate_size,
|
|
local_expert_offset=0,
|
|
local_num_experts=num_experts,
|
|
routed_scaling_factor=None,
|
|
tile_tokens_dim=get_tile_tokens_dim(hidden_states, topk, num_experts),
|
|
routing_method_type=1, # renormalize
|
|
do_finalize=True)[0]
|
|
return tg_result
|
|
|
|
|
|
def check_accuracy(a, b, atol, rtol, percent):
|
|
"""Allow a mismatch percentage of 1 - percent."""
|
|
if torch.any(torch.isnan(a)):
|
|
raise Exception("NaN in reference output")
|
|
if torch.any(torch.isnan(b)):
|
|
raise Exception("NaN in actual output")
|
|
if torch.any(torch.isinf(a)):
|
|
raise Exception("Inf in reference output")
|
|
if torch.any(torch.isinf(b)):
|
|
raise Exception("Inf in actual output")
|
|
assert a.shape == b.shape, f"Shape mismatch: {a.shape} vs {b.shape}"
|
|
|
|
left = torch.abs(a - b)
|
|
right = atol + rtol * torch.abs(b)
|
|
count = torch.sum(left > right)
|
|
mismatch_percent = count / a.numel()
|
|
if mismatch_percent > 1 - percent:
|
|
raise Exception(
|
|
f"Mismatch percentage is {mismatch_percent:.4f} for rtol {rtol} "
|
|
f"(threshold: {1-percent:.4f})")
|
|
|
|
|
|
@pytest.mark.parametrize("topk", [1, 4])
|
|
@pytest.mark.parametrize("num_experts", [32, 128])
|
|
@pytest.mark.parametrize("num_tokens", [1, 128, 1024])
|
|
@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)])
|
|
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None),
|
|
(1.702, 1.0, 7.0)])
|
|
@pytest.mark.parametrize("act_type", ['mxfp8', 'bf16'])
|
|
@pytest.mark.parametrize("transpose_optimized", [False, True])
|
|
@pytest.mark.skipif(
|
|
not TRTLLM_GEN_MXFP4_AVAILABLE,
|
|
reason="nvidia gpu and compute capability sm100 is required for this test")
|
|
def test_trtllm_gen_mxfp4_fused_moe(
|
|
topk: int,
|
|
num_experts: int,
|
|
num_tokens: int,
|
|
intermediate_size: int,
|
|
hidden_size: int,
|
|
alpha: float,
|
|
beta: float,
|
|
limit: Optional[float],
|
|
act_type: str,
|
|
transpose_optimized: bool,
|
|
):
|
|
seed = 42
|
|
torch.manual_seed(seed)
|
|
hidden_states = torch.randn(num_tokens,
|
|
hidden_size,
|
|
device="cuda:0",
|
|
dtype=torch.bfloat16)
|
|
w13 = (torch.randn(num_experts,
|
|
intermediate_size * 2,
|
|
hidden_size,
|
|
device="cuda:0",
|
|
dtype=torch.bfloat16))
|
|
w2 = (torch.randn(num_experts,
|
|
hidden_size,
|
|
intermediate_size,
|
|
device="cuda:0",
|
|
dtype=torch.bfloat16))
|
|
bias13 = torch.randn(num_experts, intermediate_size * 2,
|
|
device="cuda:0") * 10
|
|
bias2 = torch.randn(num_experts, hidden_size, device="cuda:0") * 10
|
|
router_logits = torch.rand(num_tokens, num_experts,
|
|
dtype=torch.float32).cuda()
|
|
|
|
w13, w13_scale = fp4_quantize(w13,
|
|
torch.tensor(1.0, device="cuda:0"),
|
|
32,
|
|
sf_use_ue8m0=True,
|
|
is_sf_swizzled_layout=False)
|
|
w13_scale = w13_scale.view(torch.float8_e4m3fn).reshape(
|
|
num_experts, intermediate_size * 2, hidden_size // 32)
|
|
w2, w2_scale = fp4_quantize(w2,
|
|
torch.tensor(1.0, device="cuda:0"),
|
|
32,
|
|
sf_use_ue8m0=True,
|
|
is_sf_swizzled_layout=False)
|
|
w2_scale = w2_scale.view(torch.float8_e4m3fn).reshape(
|
|
num_experts, hidden_size, intermediate_size // 32)
|
|
if act_type == 'mxfp8':
|
|
hidden_states, hidden_states_scale = mxfp8_quantize(
|
|
hidden_states, is_sf_swizzled_layout=False)
|
|
hidden_states_scale = hidden_states_scale.view(
|
|
torch.float8_e4m3fn).reshape(-1)
|
|
else:
|
|
hidden_states_scale = None
|
|
|
|
# reference result
|
|
ref_result = torch.empty_like(hidden_states, dtype=torch.bfloat16)
|
|
w13_ref = mxfp4_dequantize(w13.clone(), w13_scale.clone())
|
|
w2_ref = mxfp4_dequantize(w2.clone(), w2_scale.clone())
|
|
bias13_ref = bias13
|
|
bias2_ref = bias2
|
|
if act_type == 'mxfp8':
|
|
hidden_states_ref = mxfp8_dequantize(
|
|
hidden_states, hidden_states_scale).to(torch.float32)
|
|
else:
|
|
hidden_states_ref = hidden_states.to(torch.float32)
|
|
# Process tokens in chunks of 32 to reduce memory usage
|
|
chunk_size = 32
|
|
num_chunks = (num_tokens + chunk_size - 1) // chunk_size
|
|
for i in range(num_chunks):
|
|
start_idx = i * chunk_size
|
|
end_idx = min(start_idx + chunk_size, num_tokens)
|
|
chunk_result = reference_moe(
|
|
router_logits[start_idx:end_idx].to(torch.float32),
|
|
topk,
|
|
num_experts,
|
|
hidden_states_ref[start_idx:end_idx],
|
|
w13_ref,
|
|
bias13_ref,
|
|
w2_ref,
|
|
bias2_ref,
|
|
alpha,
|
|
beta,
|
|
limit,
|
|
act_type,
|
|
)
|
|
ref_result[start_idx:end_idx].copy_(chunk_result)
|
|
|
|
# trtllm-gen result
|
|
if alpha is not None:
|
|
alpha = torch.full((num_experts, ), alpha, device=hidden_states.device)
|
|
if limit is not None:
|
|
limit = torch.full((num_experts, ), limit, device=hidden_states.device)
|
|
if beta is not None:
|
|
beta = torch.full((num_experts, ), beta, device=hidden_states.device)
|
|
tg_result = tg_mxfp4_moe(router_logits,
|
|
topk,
|
|
num_experts,
|
|
intermediate_size,
|
|
hidden_size,
|
|
hidden_states,
|
|
hidden_states_scale,
|
|
w13,
|
|
w13_scale,
|
|
bias13,
|
|
w2,
|
|
w2_scale,
|
|
bias2,
|
|
act_type,
|
|
alpha=alpha,
|
|
beta=beta,
|
|
limit=limit,
|
|
transpose_optimized=transpose_optimized)
|
|
# relatively loose check since the mxfp4 quantization is less accurate
|
|
check_accuracy(ref_result, tg_result, atol=0, rtol=0.3, percent=0.8)
|
|
|
|
|
|
def _interleave_scales_lastdim_by4(scales: torch.Tensor) -> torch.Tensor:
|
|
"""Interleave scales on the last dimension by groups of 4, matching
|
|
the transformation in mxfp4.py's BF16 (Hopper) path."""
|
|
s = scales.to(torch.uint8)
|
|
s_shape = s.shape
|
|
assert s_shape[-1] % 4 == 0
|
|
s = s.reshape(*s_shape[:-1], s_shape[-1] // 4, 4)
|
|
# Move the 4-group dimension before the row dimension
|
|
permuted = s.permute(0, 2, 1, 3)
|
|
# Merge the row dim with the 4-group dim
|
|
return permuted.reshape(s_shape[0], s_shape[-1] // 4, s_shape[1] * 4)
|
|
|
|
|
|
@pytest.mark.parametrize("topk", [1, 4])
|
|
@pytest.mark.parametrize("num_experts", [32])
|
|
@pytest.mark.parametrize("num_tokens", [1, 128])
|
|
@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)])
|
|
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None),
|
|
(1.702, 1.0, 7.0)])
|
|
@pytest.mark.skipif(
|
|
not HOPPER_MXFP4_BF16_AVAILABLE,
|
|
reason="nvidia gpu sm90 and flashinfer are required for this test",
|
|
)
|
|
def test_flashinfer_cutlass_mxfp4_fused_moe(
|
|
topk: int,
|
|
num_experts: int,
|
|
num_tokens: int,
|
|
intermediate_size: int,
|
|
hidden_size: int,
|
|
alpha: float,
|
|
beta: float,
|
|
limit: Optional[float],
|
|
):
|
|
torch.manual_seed(42)
|
|
device = "cuda:0"
|
|
|
|
# Inputs
|
|
hidden_states = torch.randn(num_tokens,
|
|
hidden_size,
|
|
device=device,
|
|
dtype=torch.bfloat16)
|
|
# Random MXFP4 weights and scales (uint8), contiguous [w1; w3]
|
|
w13_q = torch.randint(
|
|
0,
|
|
256, (num_experts, 2 * intermediate_size, hidden_size // 2),
|
|
device=device,
|
|
dtype=torch.uint8)
|
|
w13_scale = torch.randint(
|
|
118,
|
|
123, (num_experts, 2 * intermediate_size, hidden_size // 32),
|
|
device=device,
|
|
dtype=torch.uint8)
|
|
|
|
w2_q = torch.randint(0,
|
|
256,
|
|
(num_experts, hidden_size, intermediate_size // 2),
|
|
device=device,
|
|
dtype=torch.uint8)
|
|
w2_scale = torch.randint(
|
|
118,
|
|
123, (num_experts, hidden_size, intermediate_size // 32),
|
|
device=device,
|
|
dtype=torch.uint8)
|
|
# Bias contiguous [b1; b3]
|
|
bias13 = (torch.randn(num_experts,
|
|
2 * intermediate_size,
|
|
device=device,
|
|
dtype=torch.bfloat16) * 10)
|
|
bias2 = (torch.randn(
|
|
num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10)
|
|
router_logits = torch.rand(num_tokens,
|
|
num_experts,
|
|
dtype=torch.float32,
|
|
device=device)
|
|
|
|
w13_ref = mxfp4_dequantize(w13_q.clone(), w13_scale.clone()).reshape(
|
|
num_experts, 2 * intermediate_size, hidden_size)
|
|
w2_ref = mxfp4_dequantize(w2_q.clone(), w2_scale.clone()).reshape(
|
|
num_experts, hidden_size, intermediate_size)
|
|
ref = reference_moe(router_logits.to(torch.float32), topk, num_experts,
|
|
hidden_states.to(torch.float32), w13_ref,
|
|
bias13.to(torch.float32), w2_ref,
|
|
bias2.to(torch.float32), alpha, beta, limit, 'bf16')
|
|
|
|
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
|
|
|
|
# Swap halves to arrange as [w3; w1] (kernel expectation)
|
|
w1_w, w3_w = torch.chunk(w13_q, 2, dim=1)
|
|
w13_q_swapped = torch.cat([w3_w, w1_w], dim=1)
|
|
|
|
b1, b3 = torch.chunk(bias13.to(torch.float32), 2, dim=-1)
|
|
w13_b = torch.cat([b3, b1], dim=-1).to(torch.bfloat16)
|
|
|
|
w1_s, w3_s = torch.chunk(w13_scale, 2, dim=1)
|
|
w13_s = torch.cat([w3_s, w1_s], dim=1)
|
|
w13_s_inter = _interleave_scales_lastdim_by4(w13_s)
|
|
w2_s_inter = _interleave_scales_lastdim_by4(w2_scale)
|
|
|
|
routing_weights = torch.nn.functional.softmax(router_logits,
|
|
dim=1,
|
|
dtype=torch.float32)
|
|
token_final_scales, token_selected_experts = torch.topk(routing_weights,
|
|
topk,
|
|
dim=-1)
|
|
token_final_scales = (token_final_scales /
|
|
token_final_scales.sum(dim=-1, keepdim=True))
|
|
token_selected_experts = token_selected_experts.to(torch.int).contiguous()
|
|
|
|
out = torch.empty_like(hidden_states, dtype=torch.bfloat16)
|
|
if alpha is not None:
|
|
alpha = torch.full((num_experts, ), alpha, device=hidden_states.device)
|
|
if beta is not None:
|
|
beta = torch.full((num_experts, ), beta, device=hidden_states.device)
|
|
if limit is not None:
|
|
limit = torch.full((num_experts, ), limit, device=hidden_states.device)
|
|
|
|
_ = flashinfer_cutlass_fused_moe(
|
|
input=hidden_states,
|
|
token_selected_experts=token_selected_experts,
|
|
token_final_scales=token_final_scales,
|
|
fc1_expert_weights=w13_q_swapped,
|
|
fc2_expert_weights=w2_q,
|
|
output_dtype=torch.bfloat16,
|
|
output=out,
|
|
quant_scales=[w13_s_inter.to(torch.uint8),
|
|
w2_s_inter.to(torch.uint8)],
|
|
fc1_expert_biases=w13_b,
|
|
fc2_expert_biases=bias2.to(torch.bfloat16),
|
|
swiglu_alpha=alpha,
|
|
swiglu_beta=beta,
|
|
swiglu_limit=limit,
|
|
tp_size=1,
|
|
tp_rank=0,
|
|
ep_size=1,
|
|
ep_rank=0,
|
|
use_w4_group_scaling=True,
|
|
)
|
|
|
|
# Allow some mismatch due to MXFP4 quantization
|
|
check_accuracy(ref, out, atol=0, rtol=0.3, percent=0.8)
|
|
|
|
|
|
@pytest.mark.parametrize("topk", [1, 4])
|
|
@pytest.mark.parametrize("num_experts", [32])
|
|
@pytest.mark.parametrize("num_tokens", [1, 128])
|
|
@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)])
|
|
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None),
|
|
(1.702, 1.0, 7.0)])
|
|
@pytest.mark.skipif(
|
|
not (current_platform.is_cuda()
|
|
and current_platform.is_device_capability(100) and has_flashinfer()),
|
|
reason="NVIDIA GPU sm100 and flashinfer are required for this test",
|
|
)
|
|
def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe(
|
|
topk: int,
|
|
num_experts: int,
|
|
num_tokens: int,
|
|
intermediate_size: int,
|
|
hidden_size: int,
|
|
alpha: Optional[float],
|
|
beta: Optional[float],
|
|
limit: Optional[float],
|
|
):
|
|
torch.manual_seed(42)
|
|
device = "cuda:0"
|
|
|
|
# Inputs
|
|
hidden_states = torch.randn(num_tokens,
|
|
hidden_size,
|
|
device=device,
|
|
dtype=torch.bfloat16)
|
|
# Float weights in w13 format [w1; w3]
|
|
w13 = (torch.randn(num_experts,
|
|
2 * intermediate_size,
|
|
hidden_size,
|
|
device=device,
|
|
dtype=torch.bfloat16) / 10)
|
|
w2 = (torch.randn(num_experts,
|
|
hidden_size,
|
|
intermediate_size,
|
|
device=device,
|
|
dtype=torch.bfloat16) / 10)
|
|
# Bias contiguous [b1; b3]
|
|
bias13 = (torch.randn(num_experts,
|
|
2 * intermediate_size,
|
|
device=device,
|
|
dtype=torch.bfloat16) * 10)
|
|
bias2 = (torch.randn(
|
|
num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10)
|
|
router_logits = torch.rand(num_tokens,
|
|
num_experts,
|
|
dtype=torch.float32,
|
|
device=device)
|
|
|
|
# Quantize weights to MXFP4 per expert (SM100 path)
|
|
from flashinfer import mxfp4_quantize
|
|
|
|
def quant_mxfp4_batches(a: torch.Tensor, e: int):
|
|
qs, sfs = [], []
|
|
for i in range(e):
|
|
q, sf = mxfp4_quantize(a[i].cuda())
|
|
qs.append(q)
|
|
sfs.append(sf)
|
|
return torch.stack(qs), torch.stack(sfs)
|
|
|
|
def dequant_mxfp4_batches(mat_fp4: torch.Tensor,
|
|
scale_tensor: torch.Tensor):
|
|
num_batches = mat_fp4.size(0)
|
|
scale_tensor = scale_tensor.view(num_batches, -1)
|
|
from flashinfer import mxfp4_dequantize
|
|
return torch.stack([
|
|
mxfp4_dequantize(mat_fp4[b, :, :], scale_tensor[b, :])
|
|
for b in range(num_batches)
|
|
])
|
|
|
|
w13_q, w13_scale = quant_mxfp4_batches(w13, num_experts)
|
|
w2_q, w2_scale = quant_mxfp4_batches(w2, num_experts)
|
|
|
|
# Reference result using dequantized tensors and reference_moe
|
|
w13_ref = dequant_mxfp4_batches(
|
|
w13_q.view(torch.uint8),
|
|
w13_scale.view(torch.uint8).reshape(-1)).to(torch.float32).reshape(
|
|
num_experts, 2 * intermediate_size, hidden_size).to(device)
|
|
w2_ref = dequant_mxfp4_batches(
|
|
w2_q.view(torch.uint8),
|
|
w2_scale.view(torch.uint8).reshape(-1)).to(torch.float32).reshape(
|
|
num_experts, hidden_size, intermediate_size).to(device)
|
|
|
|
# Quantize activations for SM100 path and dequantize for reference
|
|
hidden_states_q, hidden_states_sf = mxfp8_quantize(hidden_states, True, 32)
|
|
# Reference uses BF16 input but quantizes intermediate activation to MXFP8
|
|
ref = reference_moe(router_logits.to(torch.float32), topk, num_experts,
|
|
hidden_states.to(torch.float32), w13_ref,
|
|
bias13.to(torch.float32), w2_ref,
|
|
bias2.to(torch.float32), alpha, beta, limit, 'mxfp8')
|
|
|
|
# Prepare inputs for FlashInfer CUTLASS fused MoE
|
|
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
|
|
|
|
# Swap halves to arrange as [w3; w1] (kernel expectation)
|
|
w1_w, w3_w = torch.chunk(w13_q, 2, dim=1)
|
|
w13_q_swapped = torch.cat([w3_w, w1_w], dim=1)
|
|
|
|
# Swap scales halves to match swapped weights
|
|
s1, s3 = torch.chunk(w13_scale, 2, dim=1)
|
|
w13_scale_swapped = torch.cat([s3, s1], dim=1)
|
|
|
|
b1, b3 = torch.chunk(bias13.to(torch.float32), 2, dim=-1)
|
|
w13_b = torch.cat([b3, b1], dim=-1).to(torch.bfloat16)
|
|
|
|
# Build routing for kernel
|
|
routing_weights = torch.nn.functional.softmax(router_logits,
|
|
dim=1,
|
|
dtype=torch.float32)
|
|
token_final_scales, token_selected_experts = torch.topk(routing_weights,
|
|
topk,
|
|
dim=-1)
|
|
token_final_scales = (token_final_scales /
|
|
token_final_scales.sum(dim=-1, keepdim=True))
|
|
token_selected_experts = token_selected_experts.to(torch.int).contiguous()
|
|
|
|
out = torch.empty_like(hidden_states, dtype=torch.bfloat16)
|
|
if alpha is not None:
|
|
alpha_t = torch.full((num_experts, ),
|
|
alpha,
|
|
device=hidden_states.device)
|
|
else:
|
|
alpha_t = None
|
|
if beta is not None:
|
|
beta_t = torch.full((num_experts, ), beta, device=hidden_states.device)
|
|
else:
|
|
beta_t = None
|
|
if limit is not None:
|
|
limit_t = torch.full((num_experts, ),
|
|
limit,
|
|
device=hidden_states.device)
|
|
else:
|
|
limit_t = None
|
|
|
|
# Quant scales for SM100 MXFP8+MXFP4 path
|
|
fake_input_scale = torch.ones(num_experts, device=device)
|
|
quant_scales = [
|
|
w13_scale_swapped.view(torch.int32),
|
|
fake_input_scale,
|
|
w2_scale.view(torch.int32),
|
|
fake_input_scale,
|
|
]
|
|
|
|
_ = flashinfer_cutlass_fused_moe(
|
|
input=hidden_states_q,
|
|
token_selected_experts=token_selected_experts,
|
|
token_final_scales=token_final_scales,
|
|
fc1_expert_weights=w13_q_swapped.contiguous().view(torch.long),
|
|
fc2_expert_weights=w2_q.contiguous().view(torch.long),
|
|
output_dtype=torch.bfloat16,
|
|
output=out,
|
|
quant_scales=quant_scales,
|
|
fc1_expert_biases=w13_b,
|
|
fc2_expert_biases=bias2.to(torch.bfloat16),
|
|
swiglu_alpha=alpha_t,
|
|
swiglu_beta=beta_t,
|
|
swiglu_limit=limit_t,
|
|
tp_size=1,
|
|
tp_rank=0,
|
|
ep_size=1,
|
|
ep_rank=0,
|
|
use_mxfp8_act_scaling=True,
|
|
input_sf=hidden_states_sf,
|
|
)
|
|
|
|
# Allow some mismatch due to MXFP4 quantization
|
|
check_accuracy(ref, out, atol=0, rtol=0.3, percent=0.8)
|