mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-08 04:37:02 +08:00
Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: Lucas Kabela <lucaskabela@meta.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: Andrew Sansom <andrew@protopia.ai> Signed-off-by: Boyuan Feng <boyuan@meta.com> Signed-off-by: Boyuan Feng <fby.1994@gmail.com> Signed-off-by: boyuanfeng <boyuan@meta.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: JartX <sagformas@epdcenter.es> Signed-off-by: Chendi Xue <Chendi.Xue@intel.com> Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Signed-off-by: Roger Wang <hey@rogerw.io> Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: wwl2755 <wangwenlong2755@gmail.com> Signed-off-by: Manoel Marques <manoel.marques@ibm.com> Signed-off-by: Manoel Marques <manoelmrqs@gmail.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: pengdrumli <pengdrumli@tencent.com> Signed-off-by: windsonsea <haifeng.yao@daocloud.io> Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai> Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: Huamin Li <3ericli@gmail.com> Signed-off-by: simondanielsson <simon.danielsson99@hotmail.com> Signed-off-by: Rahul Tuli <rtuli@redhat.com> Signed-off-by: Yang <lymailforjob@gmail.com> Signed-off-by: Debolina Roy <debroy@redhat.com> Signed-off-by: David Chen <530634352@qq.com> Signed-off-by: wangzi <3220100013@zju.edu.cn> Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com> Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com> Signed-off-by: Sara Kokkila Schumacher <saraks@ibm.com> Signed-off-by: Csrayz <jover@cmbchina.com> Signed-off-by: ivyilike <pww123@cmbchina.com> Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com> Signed-off-by: Bowen Wang <abmfy@icloud.com> Signed-off-by: qqma <qqma@amazon.com> Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: Lu Fang <fanglu@fb.com> Signed-off-by: Zhuohan Li <zhuohan123@gmail.com> Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: luka <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: Or Ozeri <oro@il.ibm.com> Signed-off-by: Johnny Yang <johnnyyang@google.com> Signed-off-by: Alec Solder <alecs@fb.com> Signed-off-by: Alec S <10566873+alecsolder@users.noreply.github.com> Signed-off-by: Russell Bryant <rbryant@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Alexander Matveev <amatveev@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: liuye.hj <liuye.hj@alibaba-inc.com> Signed-off-by: Kunshang Ji <kunshang.ji@intel.com> Signed-off-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Signed-off-by: Ming Yang <minos.future@gmail.com> Signed-off-by: Zhikaiiii <1658973216@qq.com> Signed-off-by: Andreas Hartel <andreas.hartel@aleph-alpha.com> Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Signed-off-by: wuxibin <wuxibin@bytedance.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: Peter Pan <Peter.Pan@daocloud.io> Signed-off-by: Peter Pan <peter.pan@daocloud.io> Signed-off-by: Nicolò Lucchesi<nicolo.lucchesi@gmail.com> Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Signed-off-by: Sage Moore <sage@neuralmagic.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Signed-off-by: Bill Nell <bnell@redhat.com> Signed-off-by: Shreeasish Kumar <shreeasish@rivosinc.com> Signed-off-by: Weida Hong <wdhongtw@google.com> Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com> Signed-off-by: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com> Signed-off-by: Amir Samani <asamani@nvidia.com> Signed-off-by: ElizaWszola <elizaw.9289@gmail.com> Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Signed-off-by: ilmarkov <markovilya197@gmail.com> Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Signed-off-by: rouchenzi <ruochenwen@gmail.com> Signed-off-by: rouchenzi <40842833+rouchenzi@users.noreply.github.com> Signed-off-by: Andrew Xia <axia@meta.com> Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com> Signed-off-by: Corey Lowman <clowman1993@gmail.com> Signed-off-by: jpvillam <jpvillam@amd.com> Signed-off-by: dougbtv <dosmith@redhat.com> Signed-off-by: Chenxi Yang <cxyang@fb.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: ahao-anyscale <ahao@anyscale.com> Signed-off-by: Yan Lu <luyan@nvidia.com> Signed-off-by: baxingpiaochong <771405853@qq.com> Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> Signed-off-by: Nikhil Gupta <nikhil.gupta2@arm.com> Signed-off-by: Yong Hoon Shin <yhshin@meta.com> Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai> Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Signed-off-by: Ben Browning <bbrownin@redhat.com> Signed-off-by: Chengji Yao <chengjiyao@google.com> Signed-off-by: jiang1.li <jiang1.li@intel.com> Signed-off-by: Jackmin801 <ongjackm@gmail.com> Signed-off-by: Jonas M. Kübler <44084297+jmkuebler@users.noreply.github.com> Signed-off-by: taohui <taohui3@gmail.com> Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io> Signed-off-by: Shu Wang <shuw@nvidia.com> Signed-off-by: Shu Wang. <shuw@nvidia.com> Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Signed-off-by: Duncan Moss <djm.moss@gmail.com> Signed-off-by: Shiyan Deng <dsy842974287@meta.com> Signed-off-by: Wei Wei <wwei6@meta.com> Signed-off-by: Saman Keon <samanamp@outlook.com> Signed-off-by: yangxurui <yangxurui@meituan.com> Signed-off-by: nicole-lihui <nicole.li@daocloud.io> Signed-off-by: courage17340 <courage17340@163.com> Signed-off-by: Jacob Kahn <jacobkahn1@gmail.com> Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com> Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai> Signed-off-by: zxw <1020938856@qq.com> Signed-off-by: wang.yuqi <noooop@126.com> Signed-off-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Signed-off-by: chenlang <chen.lang5@zte.com.cn> Signed-off-by: Jonas Kuebler <kuebj@amazon.com> Signed-off-by: AlonKejzman <alonkeizman@gmail.com> Signed-off-by: Tao Hui <taohui3@gmail.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Signed-off-by: Eugene Khvedchenia <ekhvedchenia@nvidia.com> Signed-off-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Signed-off-by: yiting.jiang <yiting.jiang@daocloud.io> Signed-off-by: xaguilar <Xavier.AguilarFruto@amd.com> Signed-off-by: Iceber Gu <caiwei95@hotmail.com> Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com> Signed-off-by: Icey <1790571317@qq.com> Signed-off-by: 许文卿 <xwq391974@alibaba-inc.com> Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: Nick Hill <nhill@redhat.com> Co-authored-by: Lucas Kabela <lucasakabela@gmail.com> Co-authored-by: Maximilien de Bayser <mbayser@br.ibm.com> Co-authored-by: Andrew Sansom <andrew@protopia.ai> Co-authored-by: Boyuan Feng <boyuan@meta.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: JartX <sagformas@epdcenter.es> Co-authored-by: Chendi.Xue <chendi.xue@intel.com> Co-authored-by: Chauncey <chaunceyjiang@gmail.com> Co-authored-by: xin.li <xin.li@daocloud.io> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Roger Wang <hey@rogerw.io> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Wenlong Wang <wangwenlong2755@gmail.com> Co-authored-by: Manoel Marques <manoelmrqs@gmail.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: lirong <56789630+lirong-lirong@users.noreply.github.com> Co-authored-by: Michael Yao <haifeng.yao@daocloud.io> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: Huamin Li <3ericli@gmail.com> Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com> Co-authored-by: Simon Danielsson <70206058+simondanielsson@users.noreply.github.com> Co-authored-by: Rahul Tuli <rtuli@redhat.com> Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: Yang Liu <127183760+KKSK-DON@users.noreply.github.com> Co-authored-by: Deboleina <debroy@redhat.com> Co-authored-by: yinz-aizip <yinz@aizip.ai> Co-authored-by: WeiQing Chen <40507679+david6666666@users.noreply.github.com> Co-authored-by: wangzi <3220100013@zju.edu.cn> Co-authored-by: Eldar Kurtić <8884008+eldarkurtic@users.noreply.github.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com> Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com> Co-authored-by: Yizhou <136800916+yiz-liu@users.noreply.github.com> Co-authored-by: Sara-KS <50249410+Sara-KS@users.noreply.github.com> Co-authored-by: Csrayz <jover@cmbchina.com> Co-authored-by: ivyilike <pww123@cmbchina.com> Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com> Co-authored-by: Bowen Wang <abmfy@icloud.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Daisy-Ma-coder <daisy.ma.0117@gmail.com> Co-authored-by: qqma <qqma@amazon.com> Co-authored-by: ElizaWszola <ewszola@redhat.com> Co-authored-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Co-authored-by: Zhuohan Li <zhuohan123@gmail.com> Co-authored-by: Simon Mo <simon.mo@hey.com> Co-authored-by: Or Ozeri <oro@il.ibm.com> Co-authored-by: Johnny Yang <24908445+jcyang43@users.noreply.github.com> Co-authored-by: Chengji Yao <chengjiyao@google.com> Co-authored-by: Alec S <10566873+alecsolder@users.noreply.github.com> Co-authored-by: Alec Solder <alecs@fb.com> Co-authored-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Chris Bamford <chrisbam4d@gmail.com> Co-authored-by: Alexander Matveev <59768536+alexm-redhat@users.noreply.github.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: JJJYmmm <92386084+JJJYmmm@users.noreply.github.com> Co-authored-by: liuye.hj <liuye.hj@alibaba-inc.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com> Co-authored-by: Lucia (Lu) Fang <fanglu@meta.com> Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Ming Yang <yming@meta.com> Co-authored-by: Zhikaiiii <55917203+Zhikaiiii@users.noreply.github.com> Co-authored-by: Andreas Hartel <andreas@hartel.me> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: Joel <wuxibin89@163.com> Co-authored-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Mark McLoughlin <markmc@redhat.com> Co-authored-by: Peter Pan <peter.pan@daocloud.io> Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com> Co-authored-by: Fanli Lin <fanli.lin@intel.com> Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Sage Moore <sage@neuralmagic.com> Co-authored-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: bnellnm <49004751+bnellnm@users.noreply.github.com> Co-authored-by: rivos-shreeasish <shreeasish@rivosinc.com> Co-authored-by: Chih-Chieh Yang <chih.chieh.yang@ibm.com> Co-authored-by: Weida Hong <wdhongtw@gmail.com> Co-authored-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Co-authored-by: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com> Co-authored-by: Amir Samani <samani@ualberta.ca> Co-authored-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Co-authored-by: Ilya Markov <markovilya197@gmail.com> Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Co-authored-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Co-authored-by: rouchenzi <40842833+rouchenzi@users.noreply.github.com> Co-authored-by: Andrew Xia <axia@meta.com> Co-authored-by: kourosh hakhamaneshi <31483498+kouroshHakha@users.noreply.github.com> Co-authored-by: Corey Lowman <clowman1993@gmail.com> Co-authored-by: Juan Villamizar <100237675+jpvillam-amd@users.noreply.github.com> Co-authored-by: jpvillam <jpvillam@amd.com> Co-authored-by: Doug Smith <dosmith@redhat.com> Co-authored-by: Chenxi Yang <cxyang@cs.utexas.edu> Co-authored-by: Chenxi Yang <cxyang@fb.com> Co-authored-by: ahao-anyscale <ahao@anyscale.com> Co-authored-by: 0xNullPath <luyanfcp@foxmail.com> Co-authored-by: baxingpiaochong <771405853@qq.com> Co-authored-by: Benjamin Chislett <bchislett@nvidia.com> Co-authored-by: Kyle Sayers <kylesayrs@gmail.com> Co-authored-by: Nikhil Gupta <nikhil.gupta2@arm.com> Co-authored-by: Yong Hoon Shin <48474650+sarckk@users.noreply.github.com> Co-authored-by: lhsjohn <huashuoli@tencent.com> Co-authored-by: Ben Browning <bbrownin@redhat.com> Co-authored-by: Li, Jiang <jiang1.li@intel.com> Co-authored-by: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> Co-authored-by: Jonas M. Kübler <44084297+jmkuebler@users.noreply.github.com> Co-authored-by: Tao Hui <taohui3@gmail.com> Co-authored-by: rongfu.leng <rongfu.leng@daocloud.io> Co-authored-by: Shu Wang <shuw@nvidia.com> Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Co-authored-by: Duncan Moss <djm.moss@gmail.com> Co-authored-by: Shiyan Deng <dsy842974287@meta.com> Co-authored-by: Wei Wei <wwei6@meta.com> Co-authored-by: Saman A. Pour <samanamp@outlook.com> Co-authored-by: XuruiYang <530534756@qq.com> Co-authored-by: yangxurui <yangxurui@meituan.com> Co-authored-by: Nicole LiHui 🥜 <nicolelihui@outlook.com> Co-authored-by: courage17340 <courage17340@users.noreply.github.com> Co-authored-by: Jacob Kahn <jacobkahn1@gmail.com> Co-authored-by: Nicole LiHui 🥜 <nicole.li@daocloud.io> Co-authored-by: Fadi Arafeh <115173828+fadara01@users.noreply.github.com> Co-authored-by: Agata Dobrzyniewicz <160237065+adobrzyn@users.noreply.github.com> Co-authored-by: yyzxw <34639446+yyzxw@users.noreply.github.com> Co-authored-by: wang.yuqi <noooop@126.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: chenlang <chen.lang5@zte.com.cn> Co-authored-by: chenlang <10346245@zte.com.cn> Co-authored-by: AlonKejzman <alonkeizman@gmail.com> Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> Co-authored-by: Aleksandr Malyshev <164964928+maleksan85@users.noreply.github.com> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Doug Lehr <douglehr@amd.com> Co-authored-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Co-authored-by: yitingdc <59356937+yitingdc@users.noreply.github.com> Co-authored-by: xaguilar-amd <xavier.aguilarfruto@amd.com> Co-authored-by: Iceber Gu <caiwei95@hotmail.com> Co-authored-by: Tao He <linzhu.ht@alibaba-inc.com> Co-authored-by: Icey <1790571317@qq.com> Co-authored-by: Xu Wenqing <121550081+Xu-Wenqing@users.noreply.github.com> Co-authored-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: RishiAstra <40644327+RishiAstra@users.noreply.github.com>
495 lines
17 KiB
Python
495 lines
17 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from dataclasses import dataclass
|
|
from typing import Optional, Union
|
|
|
|
import torch
|
|
|
|
# Fused experts and PrepareFinalize imports
|
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
|
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
|
BatchedDeepGemmExperts)
|
|
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
|
|
BatchedTritonOrDeepGemmExperts)
|
|
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
|
|
FusedMoEQuantConfig)
|
|
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
|
|
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
|
BatchedTritonExperts, NaiveBatchedExperts)
|
|
from vllm.model_executor.layers.fused_moe.layer import (FusedMoEMethodBase,
|
|
TritonExperts)
|
|
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
|
MoEPrepareAndFinalizeNoEP)
|
|
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
|
TritonOrDeepGemmExperts)
|
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|
cutlass_fp4_supported)
|
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|
cutlass_fp8_supported)
|
|
from vllm.platforms import current_platform
|
|
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
|
|
from vllm.utils.deep_gemm import is_deep_gemm_supported
|
|
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
|
|
|
|
|
@dataclass
|
|
class TestMoEQuantConfig:
|
|
quant_dtype: Union[torch.dtype, str, None]
|
|
per_out_ch_quant: bool
|
|
per_act_token_quant: bool
|
|
block_shape: Optional[list[int]]
|
|
|
|
|
|
@dataclass
|
|
class PrepareFinalizeInfo:
|
|
activation_format: mk.FusedMoEActivationFormat
|
|
supported_dtypes: list[Union[torch.dtype, str]]
|
|
blocked_quantization_support: bool
|
|
backend: Optional[str]
|
|
supports_apply_weight_on_input: bool = True
|
|
|
|
|
|
@dataclass
|
|
class ExpertInfo:
|
|
activation_format: mk.FusedMoEActivationFormat
|
|
supported_dtypes: list[Union[torch.dtype, str]]
|
|
blocked_quantization_support: bool
|
|
supports_chunking: bool
|
|
supports_expert_map: bool
|
|
needs_matching_quant: bool = False
|
|
needs_deep_gemm: bool = False
|
|
|
|
|
|
PREPARE_FINALIZE_INFO: dict[mk.FusedMoEPrepareAndFinalize,
|
|
PrepareFinalizeInfo] = {}
|
|
EXPERT_INFO: dict[mk.FusedMoEPermuteExpertsUnpermute, ExpertInfo] = {}
|
|
MK_ALL_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
|
|
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
|
|
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
|
|
MK_FUSED_EXPERT_TYPES: list[mk.FusedMoEPermuteExpertsUnpermute] = []
|
|
|
|
standard_format = mk.FusedMoEActivationFormat.Standard
|
|
batched_format = mk.FusedMoEActivationFormat.BatchedExperts
|
|
common_float_types: list[Union[torch.dtype, str]] = [
|
|
torch.float8_e4m3fn, torch.bfloat16, torch.float16, torch.float32
|
|
]
|
|
common_float_and_int_types = common_float_types + [torch.int8]
|
|
nvfp4_types = ["nvfp4"]
|
|
fp8_types = [torch.float8_e4m3fn]
|
|
|
|
|
|
def register_prepare_and_finalize(
|
|
kind,
|
|
activation_format: mk.FusedMoEActivationFormat,
|
|
supported_dtypes: list[Union[torch.dtype, str]],
|
|
blocked_quantization_support: bool,
|
|
backend: Optional[str],
|
|
force_multigpu: bool = False,
|
|
supports_apply_weight_on_input: bool = True,
|
|
):
|
|
global PREPARE_FINALIZE_INFO
|
|
global MK_ALL_PREPARE_FINALIZE_TYPES
|
|
global MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
|
|
global MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES
|
|
assert kind not in PREPARE_FINALIZE_INFO
|
|
|
|
PREPARE_FINALIZE_INFO[kind] = PrepareFinalizeInfo(
|
|
activation_format,
|
|
supported_dtypes,
|
|
blocked_quantization_support,
|
|
backend,
|
|
supports_apply_weight_on_input,
|
|
)
|
|
MK_ALL_PREPARE_FINALIZE_TYPES.append(kind)
|
|
if backend is not None or force_multigpu:
|
|
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES.append(kind)
|
|
else:
|
|
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES.append(kind)
|
|
|
|
|
|
def register_experts(
|
|
kind,
|
|
activation_format: mk.FusedMoEActivationFormat,
|
|
supported_dtypes: list[Union[torch.dtype, str]],
|
|
blocked_quantization_support: bool,
|
|
supports_chunking: bool,
|
|
supports_expert_map: bool,
|
|
needs_matching_quant: bool = False,
|
|
needs_deep_gemm: bool = False,
|
|
):
|
|
global EXPERT_INFO
|
|
global MK_FUSED_EXPERT_TYPES
|
|
assert kind not in EXPERT_INFO
|
|
|
|
EXPERT_INFO[kind] = ExpertInfo(
|
|
activation_format,
|
|
supported_dtypes,
|
|
blocked_quantization_support,
|
|
supports_chunking,
|
|
supports_expert_map,
|
|
needs_matching_quant,
|
|
needs_deep_gemm,
|
|
)
|
|
|
|
MK_FUSED_EXPERT_TYPES.append(kind)
|
|
|
|
|
|
def prepare_finalize_info(kind) -> PrepareFinalizeInfo:
|
|
info = PREPARE_FINALIZE_INFO.get(kind)
|
|
assert info is not None
|
|
return info
|
|
|
|
|
|
def expert_info(kind) -> ExpertInfo:
|
|
info = EXPERT_INFO.get(kind)
|
|
assert info is not None
|
|
return info
|
|
|
|
|
|
register_prepare_and_finalize(
|
|
MoEPrepareAndFinalizeNoEP,
|
|
standard_format,
|
|
common_float_types,
|
|
blocked_quantization_support=True,
|
|
backend=None,
|
|
)
|
|
|
|
register_experts(
|
|
BatchedTritonExperts,
|
|
batched_format,
|
|
common_float_types,
|
|
blocked_quantization_support=True,
|
|
supports_chunking=False,
|
|
supports_expert_map=False,
|
|
needs_matching_quant=True,
|
|
)
|
|
|
|
register_experts(
|
|
TritonExperts,
|
|
standard_format,
|
|
common_float_and_int_types,
|
|
blocked_quantization_support=True,
|
|
supports_chunking=True,
|
|
supports_expert_map=True,
|
|
needs_matching_quant=True,
|
|
)
|
|
|
|
register_experts(
|
|
NaiveBatchedExperts,
|
|
batched_format,
|
|
common_float_and_int_types,
|
|
blocked_quantization_support=True,
|
|
supports_chunking=False,
|
|
supports_expert_map=True,
|
|
)
|
|
|
|
# Disable on blackwell for now
|
|
if has_deep_ep() and not current_platform.has_device_capability(100):
|
|
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
|
DeepEPHTPrepareAndFinalize)
|
|
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
|
DeepEPLLPrepareAndFinalize)
|
|
|
|
register_prepare_and_finalize(
|
|
DeepEPHTPrepareAndFinalize,
|
|
standard_format,
|
|
common_float_types,
|
|
blocked_quantization_support=True,
|
|
backend="deepep_high_throughput",
|
|
)
|
|
|
|
register_prepare_and_finalize(
|
|
DeepEPLLPrepareAndFinalize,
|
|
batched_format,
|
|
common_float_types,
|
|
blocked_quantization_support=True,
|
|
backend="deepep_low_latency",
|
|
)
|
|
|
|
if has_pplx():
|
|
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
|
PplxPrepareAndFinalize)
|
|
register_prepare_and_finalize(
|
|
PplxPrepareAndFinalize,
|
|
batched_format,
|
|
common_float_and_int_types,
|
|
blocked_quantization_support=True,
|
|
backend="pplx",
|
|
)
|
|
|
|
if (has_flashinfer_cutlass_fused_moe()
|
|
and current_platform.has_device_capability(100)):
|
|
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
|
|
FlashInferExperts)
|
|
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
|
|
FlashInferCutlassMoEPrepareAndFinalize,
|
|
create_flashinfer_prepare_finalize)
|
|
|
|
register_prepare_and_finalize(
|
|
FlashInferCutlassMoEPrepareAndFinalize,
|
|
standard_format,
|
|
nvfp4_types,
|
|
blocked_quantization_support=True,
|
|
backend=None,
|
|
force_multigpu=True,
|
|
supports_apply_weight_on_input=False,
|
|
)
|
|
|
|
register_experts(
|
|
FlashInferExperts,
|
|
standard_format,
|
|
nvfp4_types,
|
|
blocked_quantization_support=True,
|
|
supports_chunking=True,
|
|
# Note: this is a hack to get it to run for now
|
|
supports_expert_map=True,
|
|
)
|
|
else:
|
|
FlashInferCutlassMoEPrepareAndFinalize = None
|
|
|
|
if has_deep_gemm() and is_deep_gemm_supported():
|
|
register_experts(
|
|
BatchedDeepGemmExperts,
|
|
batched_format,
|
|
fp8_types,
|
|
blocked_quantization_support=True,
|
|
supports_chunking=False,
|
|
supports_expert_map=False,
|
|
needs_matching_quant=False,
|
|
needs_deep_gemm=True,
|
|
)
|
|
register_experts(
|
|
DeepGemmExperts,
|
|
standard_format,
|
|
fp8_types,
|
|
blocked_quantization_support=True,
|
|
supports_chunking=True,
|
|
supports_expert_map=True,
|
|
needs_matching_quant=False,
|
|
needs_deep_gemm=True,
|
|
),
|
|
register_experts(
|
|
BatchedTritonOrDeepGemmExperts,
|
|
batched_format,
|
|
common_float_and_int_types,
|
|
blocked_quantization_support=True,
|
|
supports_chunking=False,
|
|
supports_expert_map=False,
|
|
needs_matching_quant=True,
|
|
needs_deep_gemm=True,
|
|
)
|
|
register_experts(
|
|
TritonOrDeepGemmExperts,
|
|
standard_format,
|
|
common_float_and_int_types,
|
|
blocked_quantization_support=True,
|
|
supports_chunking=True,
|
|
supports_expert_map=True,
|
|
needs_matching_quant=True,
|
|
needs_deep_gemm=True,
|
|
)
|
|
|
|
if cutlass_fp8_supported():
|
|
from vllm.model_executor.layers.fused_moe import (CutlassBatchedExpertsFp8,
|
|
CutlassExpertsFp8)
|
|
register_experts(
|
|
CutlassExpertsFp8,
|
|
standard_format,
|
|
fp8_types,
|
|
blocked_quantization_support=False,
|
|
supports_chunking=True,
|
|
supports_expert_map=False,
|
|
)
|
|
register_experts(
|
|
CutlassBatchedExpertsFp8,
|
|
batched_format,
|
|
fp8_types,
|
|
blocked_quantization_support=False,
|
|
supports_chunking=False,
|
|
supports_expert_map=False,
|
|
)
|
|
|
|
if cutlass_fp4_supported():
|
|
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
|
CutlassExpertsFp4)
|
|
register_experts(
|
|
CutlassExpertsFp4,
|
|
standard_format,
|
|
nvfp4_types,
|
|
blocked_quantization_support=True,
|
|
supports_chunking=True,
|
|
supports_expert_map=False,
|
|
)
|
|
|
|
MK_QUANT_CONFIGS: list[Optional[TestMoEQuantConfig]] = [
|
|
None,
|
|
# per-channel / per-column weights and per-tensor activations
|
|
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
|
per_out_ch_quant=True,
|
|
per_act_token_quant=False,
|
|
block_shape=None),
|
|
# per-channel / per-column weights and per-token activations
|
|
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
|
per_out_ch_quant=True,
|
|
per_act_token_quant=True,
|
|
block_shape=None),
|
|
# per-tensor weights and per-tensor activations
|
|
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
|
per_out_ch_quant=False,
|
|
per_act_token_quant=False,
|
|
block_shape=None),
|
|
# per-tensor weights and per-token activations
|
|
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
|
per_out_ch_quant=False,
|
|
per_act_token_quant=True,
|
|
block_shape=None),
|
|
# block-quantized weights and 128 block per-token activations
|
|
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
|
per_out_ch_quant=False,
|
|
per_act_token_quant=False,
|
|
block_shape=[128, 128]),
|
|
# TODO (varun) : Should we test the following combinations ?
|
|
# block-quantized weights and per-token activations
|
|
# block-quantized weights and per-tensor activations
|
|
]
|
|
|
|
if cutlass_fp4_supported() or has_flashinfer_cutlass_fused_moe():
|
|
MK_QUANT_CONFIGS += [
|
|
TestMoEQuantConfig(quant_dtype="nvfp4",
|
|
per_out_ch_quant=False,
|
|
per_act_token_quant=False,
|
|
block_shape=None),
|
|
]
|
|
|
|
|
|
def make_prepare_finalize(
|
|
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
|
|
backend: Optional[str],
|
|
moe: FusedMoEConfig,
|
|
quant_config: FusedMoEQuantConfig,
|
|
) -> mk.FusedMoEPrepareAndFinalize:
|
|
if backend != "naive" and backend is not None:
|
|
prepare_finalize = FusedMoEMethodBase._maybe_make_prepare_finalize(
|
|
moe, quant_config)
|
|
assert prepare_finalize is not None
|
|
return prepare_finalize
|
|
elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize:
|
|
return create_flashinfer_prepare_finalize(
|
|
use_dp=moe.moe_parallel_config.dp_size > 1)
|
|
else:
|
|
return MoEPrepareAndFinalizeNoEP()
|
|
|
|
|
|
def _slice(rank: int, num_local_experts: int, t: torch.Tensor) -> torch.Tensor:
|
|
s = rank * num_local_experts
|
|
e = s + num_local_experts
|
|
return t[s:e]
|
|
|
|
|
|
def make_cutlass_strides(
|
|
e: int,
|
|
n: int,
|
|
k: int,
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
|
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
|
|
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
|
|
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
|
return ab_strides1, ab_strides2, c_strides1, c_strides2
|
|
|
|
|
|
def make_fused_experts(
|
|
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
|
|
moe: FusedMoEConfig,
|
|
quant_config: FusedMoEQuantConfig,
|
|
num_dispatchers: int,
|
|
N: int,
|
|
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
|
|
|
batch_kwargs = {
|
|
"max_num_tokens": moe.max_num_tokens,
|
|
"num_dispatchers": num_dispatchers,
|
|
}
|
|
quant_kwargs = {
|
|
"quant_config": quant_config,
|
|
}
|
|
deepgemm_kwargs = {"allow_deep_gemm": has_deep_gemm()}
|
|
|
|
torch.set_printoptions(threshold=0, edgeitems=0, linewidth=10000)
|
|
|
|
if fused_experts_type == BatchedDeepGemmExperts:
|
|
kwargs = batch_kwargs | quant_kwargs
|
|
print(f"Making BatchedDeepGemmExperts {kwargs} ...")
|
|
experts = BatchedDeepGemmExperts(**kwargs)
|
|
elif fused_experts_type == BatchedTritonExperts:
|
|
kwargs = batch_kwargs | quant_kwargs
|
|
print(f"Making BatchedTritonExperts {kwargs} ...")
|
|
experts = BatchedTritonExperts(**kwargs)
|
|
elif fused_experts_type == BatchedTritonOrDeepGemmExperts:
|
|
kwargs = batch_kwargs | quant_kwargs | deepgemm_kwargs
|
|
print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...")
|
|
experts = BatchedTritonOrDeepGemmExperts(**kwargs)
|
|
elif fused_experts_type == DeepGemmExperts:
|
|
print("Making DeepGemmExperts {quant_config} ...")
|
|
experts = DeepGemmExperts(quant_config)
|
|
elif fused_experts_type == TritonExperts:
|
|
kwargs = quant_kwargs
|
|
print(f"Making TritonExperts {kwargs} ...")
|
|
experts = TritonExperts(**kwargs)
|
|
elif fused_experts_type == TritonOrDeepGemmExperts:
|
|
kwargs = quant_kwargs | deepgemm_kwargs
|
|
print(f"Making TritonOrDeepGemmExperts {kwargs} ...")
|
|
experts = TritonOrDeepGemmExperts(**kwargs)
|
|
elif fused_experts_type == NaiveBatchedExperts:
|
|
kwargs = batch_kwargs | quant_kwargs
|
|
print(f"Making NaiveBatchedExperts {kwargs} ...")
|
|
experts = NaiveBatchedExperts(**kwargs)
|
|
elif fused_experts_type == CutlassExpertsFp8:
|
|
strides = make_cutlass_strides(moe.num_experts, N, moe.hidden_dim)
|
|
kwargs = {
|
|
"out_dtype": moe.in_dtype,
|
|
"ab_strides1": strides[0],
|
|
"ab_strides2": strides[1],
|
|
"c_strides1": strides[2],
|
|
"c_strides2": strides[3],
|
|
} | quant_kwargs
|
|
print(f"Making CutlassExpertsFp8 {kwargs} ...")
|
|
experts = CutlassExpertsFp8(**kwargs)
|
|
elif fused_experts_type == CutlassBatchedExpertsFp8:
|
|
strides = make_cutlass_strides(moe.num_experts, N, moe.hidden_dim)
|
|
kwargs = {
|
|
"max_experts_per_worker": moe.num_local_experts,
|
|
"num_dispatchers": num_dispatchers,
|
|
"out_dtype": moe.in_dtype,
|
|
"ab_strides1": strides[0],
|
|
"ab_strides2": strides[1],
|
|
"c_strides1": strides[2],
|
|
"c_strides2": strides[3],
|
|
} | quant_kwargs
|
|
print(f"Making CutlassBatchedExpertsFp8 {kwargs} ...")
|
|
experts = CutlassBatchedExpertsFp8(**kwargs)
|
|
elif fused_experts_type == CutlassExpertsFp4:
|
|
kwargs = {
|
|
"max_experts_per_worker": moe.num_local_experts,
|
|
"num_dispatchers": num_dispatchers,
|
|
"out_dtype": moe.in_dtype,
|
|
} | quant_kwargs
|
|
print(f"Making CutlassExpertsFp4 {kwargs} ...")
|
|
experts = CutlassExpertsFp4(**kwargs)
|
|
elif fused_experts_type == FlashInferExperts:
|
|
kwargs = {
|
|
"out_dtype": moe.in_dtype,
|
|
"ep_rank": moe.ep_rank,
|
|
"ep_size": moe.ep_size,
|
|
"tp_rank": moe.tp_rank,
|
|
"tp_size": moe.tp_size,
|
|
} | quant_kwargs
|
|
print(f"Making FlashInferExperts {kwargs} ...")
|
|
experts = FlashInferExperts(**kwargs)
|
|
else:
|
|
raise RuntimeError(f"Unknown fused experts type: {fused_experts_type}")
|
|
|
|
torch.set_printoptions(threshold=1000, edgeitems=5, linewidth=80)
|
|
|
|
return experts
|