vllm/tests/kernels/moe/test_pplx_moe.py
HAIAI aee76334d9
[amd_dev] branch rebase (#25753)
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Signed-off-by: Andrew Sansom <andrew@protopia.ai>
Signed-off-by: Boyuan Feng <boyuan@meta.com>
Signed-off-by: Boyuan Feng <fby.1994@gmail.com>
Signed-off-by: boyuanfeng <boyuan@meta.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: JartX <sagformas@epdcenter.es>
Signed-off-by: Chendi Xue <Chendi.Xue@intel.com>
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Roger Wang <hey@rogerw.io>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
Signed-off-by: Manoel Marques <manoel.marques@ibm.com>
Signed-off-by: Manoel Marques <manoelmrqs@gmail.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: pengdrumli <pengdrumli@tencent.com>
Signed-off-by: windsonsea <haifeng.yao@daocloud.io>
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: Huamin Li <3ericli@gmail.com>
Signed-off-by: simondanielsson <simon.danielsson99@hotmail.com>
Signed-off-by: Rahul Tuli <rtuli@redhat.com>
Signed-off-by: Yang <lymailforjob@gmail.com>
Signed-off-by: Debolina Roy <debroy@redhat.com>
Signed-off-by: David Chen <530634352@qq.com>
Signed-off-by: wangzi <3220100013@zju.edu.cn>
Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Signed-off-by: Sara Kokkila Schumacher <saraks@ibm.com>
Signed-off-by: Csrayz <jover@cmbchina.com>
Signed-off-by: ivyilike <pww123@cmbchina.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Bowen Wang <abmfy@icloud.com>
Signed-off-by: qqma <qqma@amazon.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: Lu Fang <fanglu@fb.com>
Signed-off-by: Zhuohan Li <zhuohan123@gmail.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: luka <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Signed-off-by: Or Ozeri <oro@il.ibm.com>
Signed-off-by: Johnny Yang <johnnyyang@google.com>
Signed-off-by: Alec Solder <alecs@fb.com>
Signed-off-by: Alec S <10566873+alecsolder@users.noreply.github.com>
Signed-off-by: Russell Bryant <rbryant@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: liuye.hj <liuye.hj@alibaba-inc.com>
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Signed-off-by: Lucia Fang <116399278+luccafong@users.noreply.github.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Signed-off-by: Ming Yang <minos.future@gmail.com>
Signed-off-by: Zhikaiiii <1658973216@qq.com>
Signed-off-by: Andreas Hartel <andreas.hartel@aleph-alpha.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: wuxibin <wuxibin@bytedance.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: Peter Pan <Peter.Pan@daocloud.io>
Signed-off-by: Peter Pan <peter.pan@daocloud.io>
Signed-off-by: Nicolò Lucchesi<nicolo.lucchesi@gmail.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Signed-off-by: Bill Nell <bnell@redhat.com>
Signed-off-by: Shreeasish Kumar <shreeasish@rivosinc.com>
Signed-off-by: Weida Hong <wdhongtw@google.com>
Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
Signed-off-by: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com>
Signed-off-by: Amir Samani <asamani@nvidia.com>
Signed-off-by: ElizaWszola <elizaw.9289@gmail.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: ilmarkov <markovilya197@gmail.com>
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com>
Signed-off-by: rouchenzi <ruochenwen@gmail.com>
Signed-off-by: rouchenzi <40842833+rouchenzi@users.noreply.github.com>
Signed-off-by: Andrew Xia <axia@meta.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Corey Lowman <clowman1993@gmail.com>
Signed-off-by: jpvillam <jpvillam@amd.com>
Signed-off-by: dougbtv <dosmith@redhat.com>
Signed-off-by: Chenxi Yang <cxyang@fb.com>
Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Signed-off-by: ahao-anyscale <ahao@anyscale.com>
Signed-off-by: Yan Lu <luyan@nvidia.com>
Signed-off-by: baxingpiaochong <771405853@qq.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Nikhil Gupta <nikhil.gupta2@arm.com>
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Ben Browning <bbrownin@redhat.com>
Signed-off-by: Chengji Yao <chengjiyao@google.com>
Signed-off-by: jiang1.li <jiang1.li@intel.com>
Signed-off-by: Jackmin801 <ongjackm@gmail.com>
Signed-off-by: Jonas M. Kübler <44084297+jmkuebler@users.noreply.github.com>
Signed-off-by: taohui <taohui3@gmail.com>
Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io>
Signed-off-by: Shu Wang <shuw@nvidia.com>
Signed-off-by: Shu Wang. <shuw@nvidia.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
Signed-off-by: Shiyan Deng <dsy842974287@meta.com>
Signed-off-by: Wei Wei <wwei6@meta.com>
Signed-off-by: Saman Keon <samanamp@outlook.com>
Signed-off-by: yangxurui <yangxurui@meituan.com>
Signed-off-by: nicole-lihui <nicole.li@daocloud.io>
Signed-off-by: courage17340 <courage17340@163.com>
Signed-off-by: Jacob Kahn <jacobkahn1@gmail.com>
Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com>
Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai>
Signed-off-by: zxw <1020938856@qq.com>
Signed-off-by: wang.yuqi <noooop@126.com>
Signed-off-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Signed-off-by: chenlang <chen.lang5@zte.com.cn>
Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
Signed-off-by: AlonKejzman <alonkeizman@gmail.com>
Signed-off-by: Tao Hui <taohui3@gmail.com>
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
Signed-off-by: Eugene Khvedchenia <ekhvedchenia@nvidia.com>
Signed-off-by: Eugene Khvedchenya <ekhvedchenya@gmail.com>
Signed-off-by: yiting.jiang <yiting.jiang@daocloud.io>
Signed-off-by: xaguilar <Xavier.AguilarFruto@amd.com>
Signed-off-by: Iceber Gu <caiwei95@hotmail.com>
Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
Signed-off-by: Icey <1790571317@qq.com>
Signed-off-by: 许文卿 <xwq391974@alibaba-inc.com>
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Lucas Kabela <lucasakabela@gmail.com>
Co-authored-by: Maximilien de Bayser <mbayser@br.ibm.com>
Co-authored-by: Andrew Sansom <andrew@protopia.ai>
Co-authored-by: Boyuan Feng <boyuan@meta.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: JartX <sagformas@epdcenter.es>
Co-authored-by: Chendi.Xue <chendi.xue@intel.com>
Co-authored-by: Chauncey <chaunceyjiang@gmail.com>
Co-authored-by: xin.li <xin.li@daocloud.io>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Wenlong Wang <wangwenlong2755@gmail.com>
Co-authored-by: Manoel Marques <manoelmrqs@gmail.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: lirong <56789630+lirong-lirong@users.noreply.github.com>
Co-authored-by: Michael Yao <haifeng.yao@daocloud.io>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: Huamin Li <3ericli@gmail.com>
Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com>
Co-authored-by: Simon Danielsson <70206058+simondanielsson@users.noreply.github.com>
Co-authored-by: Rahul Tuli <rtuli@redhat.com>
Co-authored-by: Claude <noreply@anthropic.com>
Co-authored-by: Yang Liu <127183760+KKSK-DON@users.noreply.github.com>
Co-authored-by: Deboleina <debroy@redhat.com>
Co-authored-by: yinz-aizip <yinz@aizip.ai>
Co-authored-by: WeiQing Chen <40507679+david6666666@users.noreply.github.com>
Co-authored-by: wangzi <3220100013@zju.edu.cn>
Co-authored-by: Eldar Kurtić <8884008+eldarkurtic@users.noreply.github.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com>
Co-authored-by: Yizhou <136800916+yiz-liu@users.noreply.github.com>
Co-authored-by: Sara-KS <50249410+Sara-KS@users.noreply.github.com>
Co-authored-by: Csrayz <jover@cmbchina.com>
Co-authored-by: ivyilike <pww123@cmbchina.com>
Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Co-authored-by: Bowen Wang <abmfy@icloud.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Daisy-Ma-coder <daisy.ma.0117@gmail.com>
Co-authored-by: qqma <qqma@amazon.com>
Co-authored-by: ElizaWszola <ewszola@redhat.com>
Co-authored-by: Lucia Fang <116399278+luccafong@users.noreply.github.com>
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>
Co-authored-by: Or Ozeri <oro@il.ibm.com>
Co-authored-by: Johnny Yang <24908445+jcyang43@users.noreply.github.com>
Co-authored-by: Chengji Yao <chengjiyao@google.com>
Co-authored-by: Alec S <10566873+alecsolder@users.noreply.github.com>
Co-authored-by: Alec Solder <alecs@fb.com>
Co-authored-by: Russell Bryant <rbryant@redhat.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: Chris Bamford <chrisbam4d@gmail.com>
Co-authored-by: Alexander Matveev <59768536+alexm-redhat@users.noreply.github.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: JJJYmmm <92386084+JJJYmmm@users.noreply.github.com>
Co-authored-by: liuye.hj <liuye.hj@alibaba-inc.com>
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
Co-authored-by: Lucia (Lu) Fang <fanglu@meta.com>
Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Ming Yang <yming@meta.com>
Co-authored-by: Zhikaiiii <55917203+Zhikaiiii@users.noreply.github.com>
Co-authored-by: Andreas Hartel <andreas@hartel.me>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com>
Co-authored-by: Joel <wuxibin89@163.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: Mark McLoughlin <markmc@redhat.com>
Co-authored-by: Peter Pan <peter.pan@daocloud.io>
Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com>
Co-authored-by: Fanli Lin <fanli.lin@intel.com>
Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: Sage Moore <sage@neuralmagic.com>
Co-authored-by: yewentao256 <zhyanwentao@126.com>
Co-authored-by: bnellnm <49004751+bnellnm@users.noreply.github.com>
Co-authored-by: rivos-shreeasish <shreeasish@rivosinc.com>
Co-authored-by: Chih-Chieh Yang <chih.chieh.yang@ibm.com>
Co-authored-by: Weida Hong <wdhongtw@gmail.com>
Co-authored-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
Co-authored-by: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com>
Co-authored-by: Amir Samani <samani@ualberta.ca>
Co-authored-by: Luka Govedič <lgovedic@redhat.com>
Co-authored-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Co-authored-by: Ilya Markov <markovilya197@gmail.com>
Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com>
Co-authored-by: Jialin Ouyang <Jialin.Ouyang@gmail.com>
Co-authored-by: rouchenzi <40842833+rouchenzi@users.noreply.github.com>
Co-authored-by: Andrew Xia <axia@meta.com>
Co-authored-by: kourosh hakhamaneshi <31483498+kouroshHakha@users.noreply.github.com>
Co-authored-by: Corey Lowman <clowman1993@gmail.com>
Co-authored-by: Juan Villamizar <100237675+jpvillam-amd@users.noreply.github.com>
Co-authored-by: jpvillam <jpvillam@amd.com>
Co-authored-by: Doug Smith <dosmith@redhat.com>
Co-authored-by: Chenxi Yang <cxyang@cs.utexas.edu>
Co-authored-by: Chenxi Yang <cxyang@fb.com>
Co-authored-by: ahao-anyscale <ahao@anyscale.com>
Co-authored-by: 0xNullPath <luyanfcp@foxmail.com>
Co-authored-by: baxingpiaochong <771405853@qq.com>
Co-authored-by: Benjamin Chislett <bchislett@nvidia.com>
Co-authored-by: Kyle Sayers <kylesayrs@gmail.com>
Co-authored-by: Nikhil Gupta <nikhil.gupta2@arm.com>
Co-authored-by: Yong Hoon Shin <48474650+sarckk@users.noreply.github.com>
Co-authored-by: lhsjohn <huashuoli@tencent.com>
Co-authored-by: Ben Browning <bbrownin@redhat.com>
Co-authored-by: Li, Jiang <jiang1.li@intel.com>
Co-authored-by: Jackmin801 <56836461+Jackmin801@users.noreply.github.com>
Co-authored-by: Jonas M. Kübler <44084297+jmkuebler@users.noreply.github.com>
Co-authored-by: Tao Hui <taohui3@gmail.com>
Co-authored-by: rongfu.leng <rongfu.leng@daocloud.io>
Co-authored-by: Shu Wang <shuw@nvidia.com>
Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: Duncan Moss <djm.moss@gmail.com>
Co-authored-by: Shiyan Deng <dsy842974287@meta.com>
Co-authored-by: Wei Wei <wwei6@meta.com>
Co-authored-by: Saman A. Pour <samanamp@outlook.com>
Co-authored-by: XuruiYang <530534756@qq.com>
Co-authored-by: yangxurui <yangxurui@meituan.com>
Co-authored-by: Nicole LiHui 🥜 <nicolelihui@outlook.com>
Co-authored-by: courage17340 <courage17340@users.noreply.github.com>
Co-authored-by: Jacob Kahn <jacobkahn1@gmail.com>
Co-authored-by: Nicole LiHui 🥜 <nicole.li@daocloud.io>
Co-authored-by: Fadi Arafeh <115173828+fadara01@users.noreply.github.com>
Co-authored-by: Agata Dobrzyniewicz <160237065+adobrzyn@users.noreply.github.com>
Co-authored-by: yyzxw <34639446+yyzxw@users.noreply.github.com>
Co-authored-by: wang.yuqi <noooop@126.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: chenlang <chen.lang5@zte.com.cn>
Co-authored-by: chenlang <10346245@zte.com.cn>
Co-authored-by: AlonKejzman <alonkeizman@gmail.com>
Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com>
Co-authored-by: Aleksandr Malyshev <164964928+maleksan85@users.noreply.github.com>
Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>
Co-authored-by: Doug Lehr <douglehr@amd.com>
Co-authored-by: Eugene Khvedchenya <ekhvedchenya@gmail.com>
Co-authored-by: yitingdc <59356937+yitingdc@users.noreply.github.com>
Co-authored-by: xaguilar-amd <xavier.aguilarfruto@amd.com>
Co-authored-by: Iceber Gu <caiwei95@hotmail.com>
Co-authored-by: Tao He <linzhu.ht@alibaba-inc.com>
Co-authored-by: Icey <1790571317@qq.com>
Co-authored-by: Xu Wenqing <121550081+Xu-Wenqing@users.noreply.github.com>
Co-authored-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com>
Co-authored-by: RishiAstra <40644327+RishiAstra@users.noreply.github.com>
2025-09-26 17:14:31 +01:00

971 lines
30 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for the MOE layers.
Run `pytest tests/kernels/test_pplx_moe.py`.
"""
import copy
import itertools
import textwrap
import traceback
from typing import Callable, Optional, Union
import pytest
import torch
try:
from pplx_kernels import AllToAll
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
nvshmem_finalize, nvshmem_get_unique_id,
nvshmem_init)
has_pplx = True
except ImportError:
has_pplx = False
from tests.kernels.moe.modular_kernel_tools.parallel_utils import (
_set_vllm_config)
from tests.kernels.moe.utils import (make_shared_experts, make_test_weights,
naive_batched_moe)
from tests.kernels.quant_utils import dequant
from tests.kernels.utils import torch_experts
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe import fused_topk, override_config
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts)
from vllm.model_executor.layers.fused_moe.fused_moe import get_default_config
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
from vllm.platforms import current_platform
from vllm.utils import round_up
from ...utils import multi_gpu_test
from .parallel_utils import ProcessGroupInfo, parallel_launch
requires_pplx = pytest.mark.skipif(
not has_pplx,
reason="Requires PPLX kernels",
)
BATCHED_MOE_MNK_FACTORS = [
(1, 128, 128),
(33, 2048, 128),
(64, 128, 2048),
(222, 128, 128),
(222, 2048, 1024),
]
PPLX_COMBOS = [
# TODO(bnell): figure out why this fails, seems to be test problem
#(1, 128, 128),
(2, 128, 512),
(3, 1024, 2048),
(4, 128, 128),
(32, 1024, 512),
(45, 512, 2048),
(64, 1024, 512),
(222, 2048, 1024),
(256, 1408, 2048),
]
NUM_EXPERTS = [8, 64]
TOP_KS = [1, 2, 6]
DTYPES = [torch.float8_e4m3fn, torch.bfloat16]
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
def torch_prepare(
a: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
max_num_tokens: Optional[int] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
assert topk_ids.dim() == 2
assert topk_ids.shape[0] == a.shape[0]
num_tokens, hidden_dim = a.shape
topk = topk_ids.shape[1]
tokens_per_expert = torch.bincount(topk_ids.view(-1),
minlength=num_experts)
assert tokens_per_expert.numel() == num_experts
if max_num_tokens is None:
max_num_tokens = int(tokens_per_expert.max().item())
b_a = torch.zeros((num_experts, max_num_tokens, hidden_dim),
dtype=a.dtype,
device=a.device)
token_counts = torch.zeros(num_experts, dtype=torch.int, device=a.device)
for token in range(num_tokens):
for j in range(topk):
expert_id = topk_ids[token, j]
idx = token_counts[expert_id]
b_a[expert_id, idx:idx + 1, :] = a[token, :]
token_counts[expert_id] = token_counts[expert_id] + 1
return b_a, tokens_per_expert
def torch_finalize(b_out: torch.Tensor, topk_weight: torch.Tensor,
topk_ids: torch.Tensor) -> torch.Tensor:
num_tokens = topk_ids.shape[0]
num_experts = b_out.shape[0]
K = b_out.shape[-1]
out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device)
expert_counts = torch.zeros(num_experts,
dtype=torch.int,
device=b_out.device)
for token in range(num_tokens):
expert_ids = topk_ids[token]
for i in range(expert_ids.numel()):
expert_id = expert_ids[i]
idx = expert_counts[expert_id]
out[token, :] = out[token, :] + b_out[expert_id, idx:idx +
1, :] * topk_weight[token, i]
expert_counts[expert_id] = expert_counts[expert_id] + 1
return out
def torch_batched_moe(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor:
num_experts = w1.shape[0]
b_a, tokens_per_expert = torch_prepare(a, topk_ids, num_experts)
assert b_a.dim() == 3
num_tokens, topk = topk_ids.shape
_, max_num_tokens, K = b_a.shape
assert num_experts == b_a.shape[0] and w2.shape[1] == K
out = torch.zeros((num_experts, max_num_tokens, K),
dtype=b_a.dtype,
device=b_a.device)
tmp = torch.empty((max_num_tokens, w1.shape[1] // 2),
dtype=b_a.dtype,
device=b_a.device)
for expert in range(num_experts):
num = tokens_per_expert[expert]
if num > 0:
torch.ops._C.silu_and_mul(
tmp[:num], b_a[expert, :num, :] @ w1[expert].transpose(0, 1))
out[expert, :num, :] = tmp[:num] @ w2[expert].transpose(0, 1)
return torch_finalize(out, topk_weight, topk_ids)
@pytest.mark.parametrize("m,n,k", BATCHED_MOE_MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
def test_fused_moe_batched_experts(
m: int,
n: int,
k: int,
e: int,
topk: int,
dtype: torch.dtype,
):
current_platform.seed_everything(7)
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
score = torch.randn((m, e), device="cuda", dtype=dtype)
with set_current_vllm_config(vllm_config):
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
baseline_output = torch_experts(a, w1, w2, topk_weight,
topk_ids) # only for baseline
torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids)
batched_output = naive_batched_moe(
a, w1, w2, topk_weight, topk_ids) # pick torch_experts or this
torch.testing.assert_close(baseline_output,
torch_output,
atol=2e-2,
rtol=0)
torch.testing.assert_close(baseline_output,
batched_output,
atol=2e-2,
rtol=0)
def create_pplx_prepare_finalize(
num_tokens: int,
hidden_dim: int,
topk: int,
num_experts: int,
rank: int,
dp_size: int,
world_size: int,
in_dtype: torch.dtype,
quant_dtype: Optional[torch.dtype],
block_shape: Optional[list[int]],
per_act_token_quant: bool,
group_name: Optional[str],
):
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize, pplx_hidden_dim_scale_bytes)
max_num_tokens = max(rank_chunk(num_tokens, 0, world_size), 1)
num_local_experts = rank_chunk(num_experts, 0, world_size)
hidden_dim_bytes, scale_bytes = pplx_hidden_dim_scale_bytes(
max_num_tokens,
hidden_dim,
in_dtype,
quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
)
args = dict(
max_num_tokens=max_num_tokens,
num_experts=num_experts,
experts_per_token=topk,
rank=rank,
world_size=world_size,
dp_size=dp_size,
hidden_dim=hidden_dim,
hidden_dim_bytes=hidden_dim_bytes,
hidden_dim_scale_bytes=scale_bytes,
)
if group_name is None:
ata = AllToAll.internode(**args)
else:
args["group_name"] = group_name
ata = AllToAll.intranode(**args)
prepare_finalize = PplxPrepareAndFinalize(
ata,
max_num_tokens=max_num_tokens,
num_local_experts=num_local_experts,
num_dispatchers=world_size // dp_size,
)
return prepare_finalize, ata
def rank_chunk(num: int, r: int, w: int) -> int:
rem = num % w
return (num // w) + (1 if r < rem else 0)
def chunk_by_rank(t: torch.Tensor, r: int, w: int) -> torch.Tensor:
chunk = rank_chunk(t.shape[0], r, w)
return t[(r * chunk):(r + 1) * chunk]
def maybe_chunk_by_rank(t: Optional[torch.Tensor], r: int,
w: int) -> Optional[torch.Tensor]:
if t is not None:
return chunk_by_rank(t, r, w)
else:
return t
def chunk_scales_by_rank(t: Optional[torch.Tensor], r: int,
w: int) -> Optional[torch.Tensor]:
if t is not None and t.numel() > 1:
chunk = rank_chunk(t.shape[0], r, w)
return t[(r * chunk):(r + 1) * chunk]
else:
return t
def chunk_scales(t: Optional[torch.Tensor], start: int,
end: int) -> Optional[torch.Tensor]:
if t is not None and t.numel() > 1:
return t[start:end]
else:
return t
def dummy_work(a: torch.Tensor) -> torch.Tensor:
return a * 1.1
def pplx_prepare_finalize(
pgi: ProcessGroupInfo,
dp_size: int,
a: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
quant_dtype: Optional[torch.dtype],
block_shape: Optional[list[int]],
per_act_token_quant: bool,
group_name: Optional[str],
) -> torch.Tensor:
assert torch.cuda.current_device() == pgi.local_rank
topk = topk_ids.shape[1]
num_tokens, hidden_dim = a.shape
device = pgi.device
rank = pgi.rank
world_size = pgi.world_size
topk_ids = topk_ids.to(dtype=torch.uint32)
prepare_finalize, ata = create_pplx_prepare_finalize(
num_tokens,
hidden_dim,
topk,
num_experts,
rank,
dp_size,
world_size,
a.dtype,
quant_dtype,
block_shape,
per_act_token_quant,
group_name,
)
assert a.shape[0] == topk_ids.shape[0]
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device)
chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device)
assert a_chunk.shape[0] == chunk_topk_ids.shape[0]
out = torch.full(
a_chunk.shape,
torch.nan,
dtype=a.dtype,
device=device,
)
if (quant_dtype is not None and not per_act_token_quant
and block_shape is None):
a1_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32)
a2_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32)
else:
a1_scale = None
a2_scale = None
b_a, b_a_scale, expert_num_tokens, _, _ = prepare_finalize.prepare(
a_chunk,
chunk_topk_weight,
chunk_topk_ids,
num_experts,
None,
False,
FusedMoEQuantConfig.make(
quant_dtype,
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=False,
block_shape=block_shape,
a1_scale=a1_scale,
a2_scale=a2_scale,
),
)
b_a = dummy_work(
dequant(b_a, b_a_scale, block_shape, per_act_token_quant, a.dtype))
prepare_finalize.finalize(
out,
b_a,
chunk_topk_weight,
chunk_topk_ids,
False,
weight_and_reduce_impl=TopKWeightAndReduceDelegate(),
)
torch.cuda.synchronize()
ata.destroy()
num_tokens = a_chunk.shape[0]
return out[:num_tokens]
def _pplx_prepare_finalize(
pgi: ProcessGroupInfo,
dp_size: int,
a: torch.Tensor,
score: torch.Tensor,
topk: torch.Tensor,
num_experts: int,
quant_dtype: Optional[torch.dtype],
block_shape: Optional[list[int]],
per_act_token_quant: bool,
use_internode: bool,
):
try:
if use_internode:
uid = nvshmem_get_unique_id(
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
torch.distributed.broadcast(uid, src=0)
nvshmem_init(uid, pgi.rank, pgi.world_size)
group_name = None
else:
group_ranks = list(range(pgi.world_size))
cpu_group = torch.distributed.new_group(group_ranks,
backend="gloo")
group_name = cpu_group.group_name
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
m, k = a.shape
a_rep = torch.repeat_interleave(dummy_work(a), topk, dim=0)
torch_output = (a_rep.view(m, topk, k) *
topk_weight.view(m, topk, 1).to(a_rep.dtype)).sum(
dim=1)
pplx_output = pplx_prepare_finalize(pgi, dp_size, a, topk_weight,
topk_ids, num_experts, quant_dtype,
block_shape, per_act_token_quant,
group_name)
torch_output = chunk_by_rank(torch_output, pgi.rank,
pgi.world_size).to(pgi.device)
torch.testing.assert_close(pplx_output,
torch_output,
atol=3e-2,
rtol=3e-2)
finally:
if use_internode:
nvshmem_finalize()
@pytest.mark.parametrize("mnk", PPLX_COMBOS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
@pytest.mark.parametrize("per_act_token_quant", [False, True])
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
@pytest.mark.parametrize("use_internode", [False])
@pytest.mark.optional
@requires_pplx
@multi_gpu_test(num_gpus=2)
def test_pplx_prepare_finalize_slow(
mnk: tuple[int, int, int],
e: int,
topk: int,
dtype: torch.dtype,
world_dp_size: tuple[int, int],
per_act_token_quant: bool,
block_shape: Optional[list[int]],
use_internode: bool,
):
if dtype == torch.float8_e4m3fn:
use_fp8_w8a8 = True
act_dtype = torch.bfloat16
quant_dtype = dtype
else:
use_fp8_w8a8 = False
act_dtype = dtype
quant_dtype = None
if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None):
pytest.skip("Skip quantization test for non-quantized type")
if per_act_token_quant and block_shape is not None:
pytest.skip("Skip illegal quantization combination")
current_platform.seed_everything(7)
m, n, k = mnk
world_size, dp_size = world_dp_size
device = "cuda"
a = torch.randn((m, k), device=device, dtype=act_dtype) / 10
score = torch.randn((m, e), device=device, dtype=act_dtype)
parallel_launch(world_size, _pplx_prepare_finalize, dp_size, a, score,
topk, e, quant_dtype, block_shape, per_act_token_quant,
use_internode)
def pplx_moe(
group_name: Optional[str],
rank: int,
world_size: int,
dp_size: int,
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
quant_dtype: Optional[torch.dtype] = None,
per_act_token_quant=False,
block_shape: Optional[list[int]] = None,
use_compile: bool = False,
use_cudagraphs: bool = True,
shared_experts: Optional[torch.nn.Module] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
num_tokens, hidden_dim = a.shape
num_experts = w1.shape[0]
topk = topk_ids.shape[1]
max_num_tokens = round_up(rank_chunk(a.shape[0], 0, world_size), 16)
prepare_finalize, ata = create_pplx_prepare_finalize(
num_tokens,
hidden_dim,
topk,
num_experts,
rank,
dp_size,
world_size,
a.dtype,
quant_dtype,
block_shape,
per_act_token_quant,
group_name,
)
topk_ids = topk_ids.to(dtype=torch.uint32)
# Note: workers with the same dp_rank must use the exact same inputs.
a_chunk = chunk_by_rank(a, rank, world_size)
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size)
chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size)
# Chunking weights like this only works for batched format
w1_chunk = chunk_by_rank(w1, rank, world_size)
w2_chunk = chunk_by_rank(w2, rank, world_size)
w1_scale_chunk = maybe_chunk_by_rank(w1_scale, rank, world_size)
w2_scale_chunk = maybe_chunk_by_rank(w2_scale, rank, world_size)
a1_scale_chunk = chunk_scales_by_rank(a1_scale, rank, world_size)
a2_scale_chunk = chunk_scales_by_rank(a2_scale, rank, world_size)
quant_config = FusedMoEQuantConfig.make(
quant_dtype,
block_shape=block_shape,
per_act_token_quant=per_act_token_quant,
w1_scale=w1_scale_chunk,
w2_scale=w2_scale_chunk,
a1_scale=a1_scale_chunk,
a2_scale=a2_scale_chunk,
)
experts = BatchedTritonExperts(
max_num_tokens=max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=quant_config,
)
fused_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
shared_experts,
)
# Note: for now use_compile will error out if the problem size is
# large enough to trigger chunking. I'm leaving the flag and
# setup code in case we are able to revisit this later.
if use_compile:
_fused_experts = torch.compile(fused_experts,
backend='inductor',
fullgraph=True)
torch._dynamo.mark_dynamic(a_chunk, 0)
torch._dynamo.mark_dynamic(chunk_topk_weight, 0)
torch._dynamo.mark_dynamic(chunk_topk_ids, 0)
else:
_fused_experts = fused_experts
out = _fused_experts(a_chunk,
w1_chunk,
w2_chunk,
chunk_topk_weight,
chunk_topk_ids,
global_num_experts=num_experts)
if use_cudagraphs:
if isinstance(out, tuple):
out[0].fill_(0)
out[1].fill_(0)
else:
out.fill_(0)
stream = torch.cuda.Stream()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
out = _fused_experts(a_chunk,
w1_chunk,
w2_chunk,
chunk_topk_weight,
chunk_topk_ids,
global_num_experts=num_experts)
torch.cuda.synchronize()
graph.replay()
torch.cuda.synchronize()
ata.destroy()
return out
def _pplx_moe(
pgi: ProcessGroupInfo,
dp_size: int,
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
score: torch.Tensor,
topk: int,
num_experts: int,
w1_s: Optional[torch.Tensor] = None,
w2_s: Optional[torch.Tensor] = None,
quant_dtype: Optional[torch.dtype] = None,
per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None,
use_internode: bool = False,
shared_experts: Optional[torch.nn.Module] = None,
):
try:
if use_internode:
uid = nvshmem_get_unique_id(
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
torch.distributed.broadcast(uid, src=0)
nvshmem_init(uid, pgi.rank, pgi.world_size)
group_name = None
else:
group_ranks = list(range(pgi.world_size))
cpu_group = torch.distributed.new_group(group_ranks,
backend="gloo")
group_name = cpu_group.group_name
m, k = a.shape
e, _, n = w2.shape
moe_config = get_default_config(m, e, n, k, topk, a.dtype, False)
device = torch.device("cuda", pgi.rank)
rank = pgi.rank
world_size = pgi.world_size
a = a.to(device)
w1 = w1.to(device)
w2 = w2.to(device)
w1_s = w1_s.to(device) if w1_s is not None else None
w2_s = w2_s.to(device) if w2_s is not None else None
if (quant_dtype is not None and not per_act_token_quant
and block_shape is None):
a1_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32)
a2_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32)
else:
a1_scale = None
a2_scale = None
with set_current_vllm_config(vllm_config), override_config(moe_config):
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
if shared_experts is not None:
shared_output = shared_experts(a)
else:
shared_output = None
torch_output = torch_experts(
a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_s,
w2_scale=w2_s,
a1_scale=a1_scale,
a2_scale=a2_scale,
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
)
batched_output = naive_batched_moe(
a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_s,
w2_scale=w2_s,
a1_scale=a1_scale,
a2_scale=a2_scale,
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
)
pplx_outputs = pplx_moe(
group_name,
rank,
world_size,
dp_size,
a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_s,
w2_scale=w2_s,
a1_scale=a1_scale,
a2_scale=a2_scale,
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
shared_experts=shared_experts,
)
if shared_experts is None:
pplx_shared_output = None
pplx_output = pplx_outputs
assert isinstance(pplx_output, torch.Tensor)
else:
pplx_shared_output, pplx_output = pplx_outputs
if shared_output is not None:
assert pplx_shared_output is not None
chunked_shared_output = chunk_by_rank(
shared_output, pgi.rank,
pgi.world_size).to(pplx_shared_output.device)
else:
chunked_shared_output = None
chunked_batch_output = chunk_by_rank(
batched_output, pgi.rank, pgi.world_size).to(pplx_output.device)
torch.testing.assert_close(batched_output,
torch_output,
atol=3e-2,
rtol=3e-2)
torch.testing.assert_close(pplx_output,
chunked_batch_output,
atol=3e-2,
rtol=3e-2)
if shared_experts is not None:
assert chunked_shared_output is not None
assert pplx_shared_output is not None
torch.testing.assert_close(pplx_shared_output,
chunked_shared_output,
atol=3e-2,
rtol=3e-2)
finally:
if use_internode:
nvshmem_finalize()
@pytest.mark.parametrize("mnk", PPLX_COMBOS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
@pytest.mark.parametrize("per_act_token_quant", [False, True])
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
@pytest.mark.parametrize("use_internode", [False])
@pytest.mark.optional
@requires_pplx
@multi_gpu_test(num_gpus=2)
def test_pplx_moe_slow(
mnk: tuple[int, int, int],
e: int,
topk: int,
dtype: torch.dtype,
world_dp_size: tuple[int, int],
per_act_token_quant: bool,
block_shape: Optional[list[int]],
use_internode: bool,
):
current_platform.seed_everything(7)
m, n, k = mnk
world_size, dp_size = world_dp_size
if dtype == torch.float8_e4m3fn:
use_fp8_w8a8 = True
quant_dtype = dtype
else:
use_fp8_w8a8 = False
quant_dtype = None
if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None):
pytest.skip("Skip quantization test for non-quantized type")
if per_act_token_quant and block_shape is not None:
pytest.skip("Skip illegal quantization combination")
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
(_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
e,
n,
k,
quant_dtype=quant_dtype,
block_shape=block_shape,
per_out_ch_quant=per_act_token_quant,
)
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, e,
w1_s, w2_s, quant_dtype, per_act_token_quant, block_shape,
use_internode)
def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
use_shared_experts: bool, make_weights: bool,
test_fn: Callable):
def format_result(msg, ex=None):
if ex is not None:
x = str(ex)
newx = x.strip(" \n\t")[:16]
if len(newx) < len(x):
newx = newx + " ..."
prefix = "E\t"
print(f"{textwrap.indent(traceback.format_exc(), prefix)}")
print(f"FAILED {msg} - {newx}\n")
else:
print(f"PASSED {msg}")
if use_shared_experts:
# Note: this config is only needed for the non-naive shared experts.
new_vllm_config = copy.deepcopy(vllm_config)
new_vllm_config.parallel_config.data_parallel_size = pgi.world_size
new_vllm_config.parallel_config.enable_expert_parallel = True
_set_vllm_config(new_vllm_config, pgi.world_size, pgi.rank,
pgi.local_rank)
current_platform.seed_everything(7)
combos = itertools.product(PPLX_COMBOS, NUM_EXPERTS, TOP_KS, DTYPES,
[False, True], [None, [128, 128]])
exceptions = []
count = 0
for mnk, e, topk, dtype, per_act_token_quant, block_shape in combos:
count = count + 1
m, n, k = mnk
if dtype == torch.float8_e4m3fn:
use_fp8_w8a8 = True
quant_dtype = dtype
else:
use_fp8_w8a8 = False
quant_dtype = None
test_desc = (
f"test_pplx_moe[mnk={mnk}, e={e}, topk={topk}, "
f"dtype={dtype}, per_act_token={per_act_token_quant}, "
f"block_shape={block_shape}, use_internode={use_internode}, "
f"use_shared_experts={use_shared_experts}")
if not use_fp8_w8a8 and (per_act_token_quant
or block_shape is not None):
print(
f"{test_desc} - Skip quantization test for non-quantized type."
)
continue
if per_act_token_quant and block_shape is not None:
print(f"{test_desc} - Skip illegal quantization combination.")
continue
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
args = dict()
if make_weights:
(_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
e,
n,
k,
quant_dtype=quant_dtype,
block_shape=block_shape,
per_out_ch_quant=per_act_token_quant,
)
args["w1"] = w1
args["w2"] = w2
args["w1_s"] = w1_s
args["w2_s"] = w2_s
if use_shared_experts:
args["shared_experts"] = make_shared_experts(
n,
k,
in_dtype=a.dtype,
quant_dtype=quant_dtype,
)
try:
test_fn(
pgi=pgi,
dp_size=dp_size,
a=a,
score=score,
topk=topk,
num_experts=e,
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
use_internode=use_internode,
**args,
)
format_result(test_desc)
except Exception as ex:
format_result(test_desc, ex)
exceptions.append(ex)
if len(exceptions) > 0:
raise RuntimeError(
f"{len(exceptions)} of {count} tests failed in child process, "
f"rank={pgi.rank}.")
else:
print(f"{count} of {count} tests passed in child process, "
f"rank={pgi.rank}.")
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
@pytest.mark.parametrize("use_internode", [False])
@requires_pplx
@multi_gpu_test(num_gpus=2)
def test_pplx_prepare_finalize(
world_dp_size: tuple[int, int],
use_internode: bool,
):
current_platform.seed_everything(7)
world_size, dp_size = world_dp_size
parallel_launch(world_size * dp_size, _pplx_test_loop, dp_size,
use_internode, False, False, _pplx_prepare_finalize)
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
@pytest.mark.parametrize("use_internode", [False])
@pytest.mark.parametrize("use_shared_experts", [False, True])
@requires_pplx
@multi_gpu_test(num_gpus=2)
def test_pplx_moe(
world_dp_size: tuple[int, int],
use_internode: bool,
use_shared_experts: bool,
):
current_platform.seed_everything(7)
world_size, dp_size = world_dp_size
parallel_launch(world_size, _pplx_test_loop, dp_size, use_internode,
use_shared_experts, True, _pplx_moe)