mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-16 23:57:09 +08:00
Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: Lucas Kabela <lucaskabela@meta.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: Andrew Sansom <andrew@protopia.ai> Signed-off-by: Boyuan Feng <boyuan@meta.com> Signed-off-by: Boyuan Feng <fby.1994@gmail.com> Signed-off-by: boyuanfeng <boyuan@meta.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: JartX <sagformas@epdcenter.es> Signed-off-by: Chendi Xue <Chendi.Xue@intel.com> Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Signed-off-by: Roger Wang <hey@rogerw.io> Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: wwl2755 <wangwenlong2755@gmail.com> Signed-off-by: Manoel Marques <manoel.marques@ibm.com> Signed-off-by: Manoel Marques <manoelmrqs@gmail.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: pengdrumli <pengdrumli@tencent.com> Signed-off-by: windsonsea <haifeng.yao@daocloud.io> Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai> Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: Huamin Li <3ericli@gmail.com> Signed-off-by: simondanielsson <simon.danielsson99@hotmail.com> Signed-off-by: Rahul Tuli <rtuli@redhat.com> Signed-off-by: Yang <lymailforjob@gmail.com> Signed-off-by: Debolina Roy <debroy@redhat.com> Signed-off-by: David Chen <530634352@qq.com> Signed-off-by: wangzi <3220100013@zju.edu.cn> Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com> Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com> Signed-off-by: Sara Kokkila Schumacher <saraks@ibm.com> Signed-off-by: Csrayz <jover@cmbchina.com> Signed-off-by: ivyilike <pww123@cmbchina.com> Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com> Signed-off-by: Bowen Wang <abmfy@icloud.com> Signed-off-by: qqma <qqma@amazon.com> Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: Lu Fang <fanglu@fb.com> Signed-off-by: Zhuohan Li <zhuohan123@gmail.com> Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: luka <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: Or Ozeri <oro@il.ibm.com> Signed-off-by: Johnny Yang <johnnyyang@google.com> Signed-off-by: Alec Solder <alecs@fb.com> Signed-off-by: Alec S <10566873+alecsolder@users.noreply.github.com> Signed-off-by: Russell Bryant <rbryant@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Alexander Matveev <amatveev@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: liuye.hj <liuye.hj@alibaba-inc.com> Signed-off-by: Kunshang Ji <kunshang.ji@intel.com> Signed-off-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Signed-off-by: Ming Yang <minos.future@gmail.com> Signed-off-by: Zhikaiiii <1658973216@qq.com> Signed-off-by: Andreas Hartel <andreas.hartel@aleph-alpha.com> Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Signed-off-by: wuxibin <wuxibin@bytedance.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: Peter Pan <Peter.Pan@daocloud.io> Signed-off-by: Peter Pan <peter.pan@daocloud.io> Signed-off-by: Nicolò Lucchesi<nicolo.lucchesi@gmail.com> Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Signed-off-by: Sage Moore <sage@neuralmagic.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Signed-off-by: Bill Nell <bnell@redhat.com> Signed-off-by: Shreeasish Kumar <shreeasish@rivosinc.com> Signed-off-by: Weida Hong <wdhongtw@google.com> Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com> Signed-off-by: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com> Signed-off-by: Amir Samani <asamani@nvidia.com> Signed-off-by: ElizaWszola <elizaw.9289@gmail.com> Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Signed-off-by: ilmarkov <markovilya197@gmail.com> Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Signed-off-by: rouchenzi <ruochenwen@gmail.com> Signed-off-by: rouchenzi <40842833+rouchenzi@users.noreply.github.com> Signed-off-by: Andrew Xia <axia@meta.com> Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com> Signed-off-by: Corey Lowman <clowman1993@gmail.com> Signed-off-by: jpvillam <jpvillam@amd.com> Signed-off-by: dougbtv <dosmith@redhat.com> Signed-off-by: Chenxi Yang <cxyang@fb.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: ahao-anyscale <ahao@anyscale.com> Signed-off-by: Yan Lu <luyan@nvidia.com> Signed-off-by: baxingpiaochong <771405853@qq.com> Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> Signed-off-by: Nikhil Gupta <nikhil.gupta2@arm.com> Signed-off-by: Yong Hoon Shin <yhshin@meta.com> Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai> Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Signed-off-by: Ben Browning <bbrownin@redhat.com> Signed-off-by: Chengji Yao <chengjiyao@google.com> Signed-off-by: jiang1.li <jiang1.li@intel.com> Signed-off-by: Jackmin801 <ongjackm@gmail.com> Signed-off-by: Jonas M. Kübler <44084297+jmkuebler@users.noreply.github.com> Signed-off-by: taohui <taohui3@gmail.com> Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io> Signed-off-by: Shu Wang <shuw@nvidia.com> Signed-off-by: Shu Wang. <shuw@nvidia.com> Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Signed-off-by: Duncan Moss <djm.moss@gmail.com> Signed-off-by: Shiyan Deng <dsy842974287@meta.com> Signed-off-by: Wei Wei <wwei6@meta.com> Signed-off-by: Saman Keon <samanamp@outlook.com> Signed-off-by: yangxurui <yangxurui@meituan.com> Signed-off-by: nicole-lihui <nicole.li@daocloud.io> Signed-off-by: courage17340 <courage17340@163.com> Signed-off-by: Jacob Kahn <jacobkahn1@gmail.com> Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com> Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai> Signed-off-by: zxw <1020938856@qq.com> Signed-off-by: wang.yuqi <noooop@126.com> Signed-off-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Signed-off-by: chenlang <chen.lang5@zte.com.cn> Signed-off-by: Jonas Kuebler <kuebj@amazon.com> Signed-off-by: AlonKejzman <alonkeizman@gmail.com> Signed-off-by: Tao Hui <taohui3@gmail.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Signed-off-by: Eugene Khvedchenia <ekhvedchenia@nvidia.com> Signed-off-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Signed-off-by: yiting.jiang <yiting.jiang@daocloud.io> Signed-off-by: xaguilar <Xavier.AguilarFruto@amd.com> Signed-off-by: Iceber Gu <caiwei95@hotmail.com> Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com> Signed-off-by: Icey <1790571317@qq.com> Signed-off-by: 许文卿 <xwq391974@alibaba-inc.com> Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: Nick Hill <nhill@redhat.com> Co-authored-by: Lucas Kabela <lucasakabela@gmail.com> Co-authored-by: Maximilien de Bayser <mbayser@br.ibm.com> Co-authored-by: Andrew Sansom <andrew@protopia.ai> Co-authored-by: Boyuan Feng <boyuan@meta.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: JartX <sagformas@epdcenter.es> Co-authored-by: Chendi.Xue <chendi.xue@intel.com> Co-authored-by: Chauncey <chaunceyjiang@gmail.com> Co-authored-by: xin.li <xin.li@daocloud.io> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Roger Wang <hey@rogerw.io> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Wenlong Wang <wangwenlong2755@gmail.com> Co-authored-by: Manoel Marques <manoelmrqs@gmail.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: lirong <56789630+lirong-lirong@users.noreply.github.com> Co-authored-by: Michael Yao <haifeng.yao@daocloud.io> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: Huamin Li <3ericli@gmail.com> Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com> Co-authored-by: Simon Danielsson <70206058+simondanielsson@users.noreply.github.com> Co-authored-by: Rahul Tuli <rtuli@redhat.com> Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: Yang Liu <127183760+KKSK-DON@users.noreply.github.com> Co-authored-by: Deboleina <debroy@redhat.com> Co-authored-by: yinz-aizip <yinz@aizip.ai> Co-authored-by: WeiQing Chen <40507679+david6666666@users.noreply.github.com> Co-authored-by: wangzi <3220100013@zju.edu.cn> Co-authored-by: Eldar Kurtić <8884008+eldarkurtic@users.noreply.github.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com> Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com> Co-authored-by: Yizhou <136800916+yiz-liu@users.noreply.github.com> Co-authored-by: Sara-KS <50249410+Sara-KS@users.noreply.github.com> Co-authored-by: Csrayz <jover@cmbchina.com> Co-authored-by: ivyilike <pww123@cmbchina.com> Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com> Co-authored-by: Bowen Wang <abmfy@icloud.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Daisy-Ma-coder <daisy.ma.0117@gmail.com> Co-authored-by: qqma <qqma@amazon.com> Co-authored-by: ElizaWszola <ewszola@redhat.com> Co-authored-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Co-authored-by: Zhuohan Li <zhuohan123@gmail.com> Co-authored-by: Simon Mo <simon.mo@hey.com> Co-authored-by: Or Ozeri <oro@il.ibm.com> Co-authored-by: Johnny Yang <24908445+jcyang43@users.noreply.github.com> Co-authored-by: Chengji Yao <chengjiyao@google.com> Co-authored-by: Alec S <10566873+alecsolder@users.noreply.github.com> Co-authored-by: Alec Solder <alecs@fb.com> Co-authored-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Chris Bamford <chrisbam4d@gmail.com> Co-authored-by: Alexander Matveev <59768536+alexm-redhat@users.noreply.github.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: JJJYmmm <92386084+JJJYmmm@users.noreply.github.com> Co-authored-by: liuye.hj <liuye.hj@alibaba-inc.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com> Co-authored-by: Lucia (Lu) Fang <fanglu@meta.com> Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Ming Yang <yming@meta.com> Co-authored-by: Zhikaiiii <55917203+Zhikaiiii@users.noreply.github.com> Co-authored-by: Andreas Hartel <andreas@hartel.me> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: Joel <wuxibin89@163.com> Co-authored-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Mark McLoughlin <markmc@redhat.com> Co-authored-by: Peter Pan <peter.pan@daocloud.io> Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com> Co-authored-by: Fanli Lin <fanli.lin@intel.com> Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Sage Moore <sage@neuralmagic.com> Co-authored-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: bnellnm <49004751+bnellnm@users.noreply.github.com> Co-authored-by: rivos-shreeasish <shreeasish@rivosinc.com> Co-authored-by: Chih-Chieh Yang <chih.chieh.yang@ibm.com> Co-authored-by: Weida Hong <wdhongtw@gmail.com> Co-authored-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Co-authored-by: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com> Co-authored-by: Amir Samani <samani@ualberta.ca> Co-authored-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Co-authored-by: Ilya Markov <markovilya197@gmail.com> Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Co-authored-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Co-authored-by: rouchenzi <40842833+rouchenzi@users.noreply.github.com> Co-authored-by: Andrew Xia <axia@meta.com> Co-authored-by: kourosh hakhamaneshi <31483498+kouroshHakha@users.noreply.github.com> Co-authored-by: Corey Lowman <clowman1993@gmail.com> Co-authored-by: Juan Villamizar <100237675+jpvillam-amd@users.noreply.github.com> Co-authored-by: jpvillam <jpvillam@amd.com> Co-authored-by: Doug Smith <dosmith@redhat.com> Co-authored-by: Chenxi Yang <cxyang@cs.utexas.edu> Co-authored-by: Chenxi Yang <cxyang@fb.com> Co-authored-by: ahao-anyscale <ahao@anyscale.com> Co-authored-by: 0xNullPath <luyanfcp@foxmail.com> Co-authored-by: baxingpiaochong <771405853@qq.com> Co-authored-by: Benjamin Chislett <bchislett@nvidia.com> Co-authored-by: Kyle Sayers <kylesayrs@gmail.com> Co-authored-by: Nikhil Gupta <nikhil.gupta2@arm.com> Co-authored-by: Yong Hoon Shin <48474650+sarckk@users.noreply.github.com> Co-authored-by: lhsjohn <huashuoli@tencent.com> Co-authored-by: Ben Browning <bbrownin@redhat.com> Co-authored-by: Li, Jiang <jiang1.li@intel.com> Co-authored-by: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> Co-authored-by: Jonas M. Kübler <44084297+jmkuebler@users.noreply.github.com> Co-authored-by: Tao Hui <taohui3@gmail.com> Co-authored-by: rongfu.leng <rongfu.leng@daocloud.io> Co-authored-by: Shu Wang <shuw@nvidia.com> Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Co-authored-by: Duncan Moss <djm.moss@gmail.com> Co-authored-by: Shiyan Deng <dsy842974287@meta.com> Co-authored-by: Wei Wei <wwei6@meta.com> Co-authored-by: Saman A. Pour <samanamp@outlook.com> Co-authored-by: XuruiYang <530534756@qq.com> Co-authored-by: yangxurui <yangxurui@meituan.com> Co-authored-by: Nicole LiHui 🥜 <nicolelihui@outlook.com> Co-authored-by: courage17340 <courage17340@users.noreply.github.com> Co-authored-by: Jacob Kahn <jacobkahn1@gmail.com> Co-authored-by: Nicole LiHui 🥜 <nicole.li@daocloud.io> Co-authored-by: Fadi Arafeh <115173828+fadara01@users.noreply.github.com> Co-authored-by: Agata Dobrzyniewicz <160237065+adobrzyn@users.noreply.github.com> Co-authored-by: yyzxw <34639446+yyzxw@users.noreply.github.com> Co-authored-by: wang.yuqi <noooop@126.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: chenlang <chen.lang5@zte.com.cn> Co-authored-by: chenlang <10346245@zte.com.cn> Co-authored-by: AlonKejzman <alonkeizman@gmail.com> Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> Co-authored-by: Aleksandr Malyshev <164964928+maleksan85@users.noreply.github.com> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Doug Lehr <douglehr@amd.com> Co-authored-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Co-authored-by: yitingdc <59356937+yitingdc@users.noreply.github.com> Co-authored-by: xaguilar-amd <xavier.aguilarfruto@amd.com> Co-authored-by: Iceber Gu <caiwei95@hotmail.com> Co-authored-by: Tao He <linzhu.ht@alibaba-inc.com> Co-authored-by: Icey <1790571317@qq.com> Co-authored-by: Xu Wenqing <121550081+Xu-Wenqing@users.noreply.github.com> Co-authored-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: RishiAstra <40644327+RishiAstra@users.noreply.github.com>
320 lines
9.7 KiB
Python
320 lines
9.7 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Optional
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from tests.kernels.moe.utils import (batched_moe,
|
|
make_quantized_test_activations,
|
|
make_test_weights, naive_batched_moe)
|
|
from tests.kernels.quant_utils import native_batched_masked_quant_matmul
|
|
from tests.kernels.utils import torch_experts
|
|
from vllm.config import VllmConfig, set_current_vllm_config
|
|
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
|
invoke_moe_batched_triton_kernel)
|
|
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
|
from vllm.platforms import current_platform
|
|
from vllm.triton_utils import tl
|
|
|
|
MNK_FACTORS = [
|
|
(1, 128, 128),
|
|
(1, 128, 2048),
|
|
(1, 512, 512),
|
|
(1, 1024, 128),
|
|
(1, 1024, 2048),
|
|
(32, 128, 128),
|
|
(32, 512, 512),
|
|
(32, 1024, 2048),
|
|
(45, 128, 128),
|
|
(45, 128, 2048),
|
|
(45, 512, 512),
|
|
(45, 1024, 128),
|
|
(45, 1024, 2048),
|
|
(64, 512, 512),
|
|
(64, 1024, 2048),
|
|
(222, 128, 128),
|
|
(222, 128, 2048),
|
|
(222, 1024, 128),
|
|
(222, 1024, 2048),
|
|
]
|
|
NUM_EXPERTS = [8, 64]
|
|
TOP_KS = [1, 2, 6]
|
|
|
|
vllm_config = VllmConfig()
|
|
vllm_config.scheduler_config.max_num_seqs = 128
|
|
vllm_config.scheduler_config.max_model_len = 8192
|
|
|
|
|
|
@dataclass
|
|
class BatchedMMConfig:
|
|
in_dtype: torch.dtype
|
|
quant_dtype: Optional[torch.dtype]
|
|
out_dtype: torch.dtype
|
|
num_experts: int
|
|
max_tokens_per_expert: int
|
|
K: int
|
|
N: int
|
|
|
|
|
|
@dataclass
|
|
class BatchedMMTensors:
|
|
A: torch.Tensor # [E, max_tokens, K]
|
|
B: torch.Tensor # [E, K, N] - column major
|
|
C: torch.Tensor # [E, max_tokens, N]
|
|
num_expert_tokens: torch.Tensor # [E]
|
|
|
|
@staticmethod
|
|
def make_tensors(config: BatchedMMConfig):
|
|
A = torch.randn(
|
|
(config.num_experts, config.max_tokens_per_expert, config.K),
|
|
device="cuda",
|
|
dtype=config.in_dtype) / 10
|
|
B = torch.randn((config.num_experts, config.N, config.K),
|
|
device="cuda",
|
|
dtype=config.in_dtype)
|
|
C = torch.zeros(
|
|
(config.num_experts, config.max_tokens_per_expert, config.N),
|
|
device="cuda",
|
|
dtype=config.out_dtype)
|
|
|
|
num_expert_tokens = torch.randint(low=0,
|
|
high=config.max_tokens_per_expert,
|
|
size=(config.num_experts, ),
|
|
device="cuda",
|
|
dtype=torch.int32)
|
|
|
|
return BatchedMMTensors(A, B, C, num_expert_tokens)
|
|
|
|
|
|
@pytest.mark.parametrize("num_experts", [8, 32])
|
|
@pytest.mark.parametrize("max_tokens_per_expert", [32, 224, 512])
|
|
@pytest.mark.parametrize("K", [128, 1024])
|
|
@pytest.mark.parametrize("N", [128, 1024])
|
|
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
|
|
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
|
|
@pytest.mark.parametrize("per_act_token_quant", [False, True])
|
|
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
|
N: int, dtype: torch.dtype,
|
|
block_shape: Optional[list[int]],
|
|
per_act_token_quant: bool):
|
|
current_platform.seed_everything(7)
|
|
|
|
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
|
|
|
|
if (per_act_token_quant or block_shape is not None) and not use_fp8_w8a8:
|
|
pytest.skip("Don't test blocking for non-quantized types.")
|
|
|
|
if per_act_token_quant and block_shape is not None:
|
|
pytest.skip("Skip illegal quantization test.")
|
|
|
|
if dtype.itemsize == 1:
|
|
act_dtype = torch.bfloat16
|
|
quant_dtype = dtype
|
|
else:
|
|
act_dtype = dtype
|
|
quant_dtype = None
|
|
|
|
num_expert_tokens = torch.randint(low=0,
|
|
high=max_tokens_per_expert,
|
|
size=(num_experts, ),
|
|
device="cuda",
|
|
dtype=torch.int32)
|
|
|
|
A, A_q, A_scale = make_quantized_test_activations(
|
|
num_experts,
|
|
max_tokens_per_expert,
|
|
K,
|
|
in_dtype=act_dtype,
|
|
quant_dtype=quant_dtype,
|
|
block_shape=block_shape,
|
|
per_act_token_quant=per_act_token_quant,
|
|
)
|
|
|
|
(B, B_q, B_scale, _), _ = make_test_weights(
|
|
num_experts,
|
|
N // 2,
|
|
K,
|
|
in_dtype=act_dtype,
|
|
quant_dtype=quant_dtype,
|
|
block_shape=block_shape,
|
|
per_out_ch_quant=per_act_token_quant,
|
|
)
|
|
|
|
out_shape = (num_experts, max_tokens_per_expert, N)
|
|
test_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
|
|
ref_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
|
|
q_ref_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
|
|
|
|
compute_tl_dtype = {
|
|
torch.float16: tl.float16,
|
|
torch.bfloat16: tl.bfloat16,
|
|
torch.float32: tl.float32
|
|
}[test_output.dtype]
|
|
|
|
assert A_q.dtype == B_q.dtype
|
|
|
|
invoke_moe_batched_triton_kernel(
|
|
A_q,
|
|
B_q,
|
|
test_output,
|
|
num_expert_tokens,
|
|
compute_tl_dtype,
|
|
# Quantization data
|
|
A_scale,
|
|
B_scale,
|
|
None,
|
|
# Quantization schemes
|
|
use_fp8_w8a8,
|
|
False,
|
|
False,
|
|
config={
|
|
"BLOCK_SIZE_M": 16,
|
|
"BLOCK_SIZE_N": 16,
|
|
"BLOCK_SIZE_K": 16 if dtype.itemsize > 1 else 32
|
|
},
|
|
per_act_token_quant=per_act_token_quant,
|
|
block_shape=block_shape,
|
|
)
|
|
|
|
ref_output = native_batched_masked_quant_matmul(
|
|
A,
|
|
B,
|
|
ref_output,
|
|
num_expert_tokens,
|
|
)
|
|
|
|
q_ref_output = native_batched_masked_quant_matmul(A_q, B_q, q_ref_output,
|
|
num_expert_tokens,
|
|
A_scale, B_scale,
|
|
block_shape,
|
|
per_act_token_quant)
|
|
|
|
rtol, atol = {
|
|
torch.float16: (6e-2, 6e-2),
|
|
torch.bfloat16: (6e-2, 6e-2),
|
|
torch.float32: (1e-2, 1e-2),
|
|
}[test_output.dtype]
|
|
|
|
torch.testing.assert_close(ref_output, q_ref_output, atol=atol, rtol=rtol)
|
|
torch.testing.assert_close(test_output, q_ref_output, atol=atol, rtol=rtol)
|
|
|
|
|
|
@pytest.mark.parametrize(("m", "n", "k"), MNK_FACTORS)
|
|
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
|
@pytest.mark.parametrize("topk", TOP_KS)
|
|
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
|
|
@pytest.mark.parametrize("per_act_token_quant", [False, True])
|
|
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
|
|
@pytest.mark.parametrize("input_scales", [False])
|
|
def test_fused_moe_batched_experts(
|
|
m: int,
|
|
n: int,
|
|
k: int,
|
|
e: int,
|
|
topk: int,
|
|
dtype: torch.dtype,
|
|
per_act_token_quant: bool,
|
|
block_shape: Optional[list[int]],
|
|
input_scales: bool,
|
|
):
|
|
current_platform.seed_everything(7)
|
|
|
|
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
|
|
|
|
if topk > e:
|
|
pytest.skip("topk > e")
|
|
|
|
if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None):
|
|
pytest.skip("Skip quantization test for non-quantized type")
|
|
|
|
if per_act_token_quant and block_shape is not None:
|
|
pytest.skip("Skip illegal quantization test.")
|
|
|
|
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
|
|
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
|
|
|
|
if dtype.itemsize == 1:
|
|
act_dtype = torch.bfloat16
|
|
quant_dtype = dtype
|
|
else:
|
|
act_dtype = dtype
|
|
quant_dtype = None
|
|
|
|
(w1_16, w1, w1_s, _), (w2_16, w2, w2_s, _) = make_test_weights(
|
|
e,
|
|
n,
|
|
k,
|
|
block_shape=block_shape,
|
|
in_dtype=act_dtype,
|
|
quant_dtype=quant_dtype,
|
|
per_out_ch_quant=per_act_token_quant,
|
|
)
|
|
|
|
if input_scales and quant_dtype is not None:
|
|
a1_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
|
|
a2_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
|
|
else:
|
|
a1_scale = None
|
|
a2_scale = None
|
|
|
|
with set_current_vllm_config(vllm_config):
|
|
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
|
|
|
baseline_output = torch_experts(
|
|
a,
|
|
w1,
|
|
w2,
|
|
topk_weight,
|
|
topk_ids,
|
|
w1_scale=w1_s,
|
|
w2_scale=w2_s,
|
|
a1_scale=a1_scale,
|
|
a2_scale=a2_scale,
|
|
quant_dtype=quant_dtype,
|
|
per_act_token_quant=per_act_token_quant,
|
|
block_shape=block_shape,
|
|
)
|
|
|
|
batched_output = naive_batched_moe(
|
|
a,
|
|
w1,
|
|
w2,
|
|
topk_weight,
|
|
topk_ids,
|
|
w1_scale=w1_s,
|
|
w2_scale=w2_s,
|
|
a1_scale=a1_scale,
|
|
a2_scale=a2_scale,
|
|
quant_dtype=quant_dtype,
|
|
per_act_token_quant=per_act_token_quant,
|
|
block_shape=block_shape,
|
|
)
|
|
|
|
triton_output = batched_moe(
|
|
a,
|
|
w1,
|
|
w2,
|
|
topk_weight,
|
|
topk_ids,
|
|
w1_scale=w1_s,
|
|
w2_scale=w2_s,
|
|
a1_scale=a1_scale,
|
|
a2_scale=a2_scale,
|
|
quant_dtype=quant_dtype,
|
|
per_act_token_quant=per_act_token_quant,
|
|
block_shape=block_shape,
|
|
)
|
|
|
|
torch.testing.assert_close(batched_output,
|
|
baseline_output,
|
|
atol=3e-2,
|
|
rtol=2e-2)
|
|
|
|
torch.testing.assert_close(triton_output,
|
|
batched_output,
|
|
atol=2e-2,
|
|
rtol=2e-2)
|