mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-20 17:27:03 +08:00
Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: Lucas Kabela <lucaskabela@meta.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: Andrew Sansom <andrew@protopia.ai> Signed-off-by: Boyuan Feng <boyuan@meta.com> Signed-off-by: Boyuan Feng <fby.1994@gmail.com> Signed-off-by: boyuanfeng <boyuan@meta.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: JartX <sagformas@epdcenter.es> Signed-off-by: Chendi Xue <Chendi.Xue@intel.com> Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Signed-off-by: Roger Wang <hey@rogerw.io> Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: wwl2755 <wangwenlong2755@gmail.com> Signed-off-by: Manoel Marques <manoel.marques@ibm.com> Signed-off-by: Manoel Marques <manoelmrqs@gmail.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: pengdrumli <pengdrumli@tencent.com> Signed-off-by: windsonsea <haifeng.yao@daocloud.io> Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai> Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: Huamin Li <3ericli@gmail.com> Signed-off-by: simondanielsson <simon.danielsson99@hotmail.com> Signed-off-by: Rahul Tuli <rtuli@redhat.com> Signed-off-by: Yang <lymailforjob@gmail.com> Signed-off-by: Debolina Roy <debroy@redhat.com> Signed-off-by: David Chen <530634352@qq.com> Signed-off-by: wangzi <3220100013@zju.edu.cn> Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com> Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com> Signed-off-by: Sara Kokkila Schumacher <saraks@ibm.com> Signed-off-by: Csrayz <jover@cmbchina.com> Signed-off-by: ivyilike <pww123@cmbchina.com> Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com> Signed-off-by: Bowen Wang <abmfy@icloud.com> Signed-off-by: qqma <qqma@amazon.com> Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: Lu Fang <fanglu@fb.com> Signed-off-by: Zhuohan Li <zhuohan123@gmail.com> Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: luka <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: Or Ozeri <oro@il.ibm.com> Signed-off-by: Johnny Yang <johnnyyang@google.com> Signed-off-by: Alec Solder <alecs@fb.com> Signed-off-by: Alec S <10566873+alecsolder@users.noreply.github.com> Signed-off-by: Russell Bryant <rbryant@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Alexander Matveev <amatveev@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: liuye.hj <liuye.hj@alibaba-inc.com> Signed-off-by: Kunshang Ji <kunshang.ji@intel.com> Signed-off-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Signed-off-by: Ming Yang <minos.future@gmail.com> Signed-off-by: Zhikaiiii <1658973216@qq.com> Signed-off-by: Andreas Hartel <andreas.hartel@aleph-alpha.com> Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Signed-off-by: wuxibin <wuxibin@bytedance.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: Peter Pan <Peter.Pan@daocloud.io> Signed-off-by: Peter Pan <peter.pan@daocloud.io> Signed-off-by: Nicolò Lucchesi<nicolo.lucchesi@gmail.com> Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Signed-off-by: Sage Moore <sage@neuralmagic.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Signed-off-by: Bill Nell <bnell@redhat.com> Signed-off-by: Shreeasish Kumar <shreeasish@rivosinc.com> Signed-off-by: Weida Hong <wdhongtw@google.com> Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com> Signed-off-by: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com> Signed-off-by: Amir Samani <asamani@nvidia.com> Signed-off-by: ElizaWszola <elizaw.9289@gmail.com> Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Signed-off-by: ilmarkov <markovilya197@gmail.com> Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Signed-off-by: rouchenzi <ruochenwen@gmail.com> Signed-off-by: rouchenzi <40842833+rouchenzi@users.noreply.github.com> Signed-off-by: Andrew Xia <axia@meta.com> Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com> Signed-off-by: Corey Lowman <clowman1993@gmail.com> Signed-off-by: jpvillam <jpvillam@amd.com> Signed-off-by: dougbtv <dosmith@redhat.com> Signed-off-by: Chenxi Yang <cxyang@fb.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: ahao-anyscale <ahao@anyscale.com> Signed-off-by: Yan Lu <luyan@nvidia.com> Signed-off-by: baxingpiaochong <771405853@qq.com> Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> Signed-off-by: Nikhil Gupta <nikhil.gupta2@arm.com> Signed-off-by: Yong Hoon Shin <yhshin@meta.com> Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai> Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Signed-off-by: Ben Browning <bbrownin@redhat.com> Signed-off-by: Chengji Yao <chengjiyao@google.com> Signed-off-by: jiang1.li <jiang1.li@intel.com> Signed-off-by: Jackmin801 <ongjackm@gmail.com> Signed-off-by: Jonas M. Kübler <44084297+jmkuebler@users.noreply.github.com> Signed-off-by: taohui <taohui3@gmail.com> Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io> Signed-off-by: Shu Wang <shuw@nvidia.com> Signed-off-by: Shu Wang. <shuw@nvidia.com> Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Signed-off-by: Duncan Moss <djm.moss@gmail.com> Signed-off-by: Shiyan Deng <dsy842974287@meta.com> Signed-off-by: Wei Wei <wwei6@meta.com> Signed-off-by: Saman Keon <samanamp@outlook.com> Signed-off-by: yangxurui <yangxurui@meituan.com> Signed-off-by: nicole-lihui <nicole.li@daocloud.io> Signed-off-by: courage17340 <courage17340@163.com> Signed-off-by: Jacob Kahn <jacobkahn1@gmail.com> Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com> Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai> Signed-off-by: zxw <1020938856@qq.com> Signed-off-by: wang.yuqi <noooop@126.com> Signed-off-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Signed-off-by: chenlang <chen.lang5@zte.com.cn> Signed-off-by: Jonas Kuebler <kuebj@amazon.com> Signed-off-by: AlonKejzman <alonkeizman@gmail.com> Signed-off-by: Tao Hui <taohui3@gmail.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Signed-off-by: Eugene Khvedchenia <ekhvedchenia@nvidia.com> Signed-off-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Signed-off-by: yiting.jiang <yiting.jiang@daocloud.io> Signed-off-by: xaguilar <Xavier.AguilarFruto@amd.com> Signed-off-by: Iceber Gu <caiwei95@hotmail.com> Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com> Signed-off-by: Icey <1790571317@qq.com> Signed-off-by: 许文卿 <xwq391974@alibaba-inc.com> Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: Nick Hill <nhill@redhat.com> Co-authored-by: Lucas Kabela <lucasakabela@gmail.com> Co-authored-by: Maximilien de Bayser <mbayser@br.ibm.com> Co-authored-by: Andrew Sansom <andrew@protopia.ai> Co-authored-by: Boyuan Feng <boyuan@meta.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: JartX <sagformas@epdcenter.es> Co-authored-by: Chendi.Xue <chendi.xue@intel.com> Co-authored-by: Chauncey <chaunceyjiang@gmail.com> Co-authored-by: xin.li <xin.li@daocloud.io> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Roger Wang <hey@rogerw.io> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Wenlong Wang <wangwenlong2755@gmail.com> Co-authored-by: Manoel Marques <manoelmrqs@gmail.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: lirong <56789630+lirong-lirong@users.noreply.github.com> Co-authored-by: Michael Yao <haifeng.yao@daocloud.io> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: Huamin Li <3ericli@gmail.com> Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com> Co-authored-by: Simon Danielsson <70206058+simondanielsson@users.noreply.github.com> Co-authored-by: Rahul Tuli <rtuli@redhat.com> Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: Yang Liu <127183760+KKSK-DON@users.noreply.github.com> Co-authored-by: Deboleina <debroy@redhat.com> Co-authored-by: yinz-aizip <yinz@aizip.ai> Co-authored-by: WeiQing Chen <40507679+david6666666@users.noreply.github.com> Co-authored-by: wangzi <3220100013@zju.edu.cn> Co-authored-by: Eldar Kurtić <8884008+eldarkurtic@users.noreply.github.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com> Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com> Co-authored-by: Yizhou <136800916+yiz-liu@users.noreply.github.com> Co-authored-by: Sara-KS <50249410+Sara-KS@users.noreply.github.com> Co-authored-by: Csrayz <jover@cmbchina.com> Co-authored-by: ivyilike <pww123@cmbchina.com> Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com> Co-authored-by: Bowen Wang <abmfy@icloud.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Daisy-Ma-coder <daisy.ma.0117@gmail.com> Co-authored-by: qqma <qqma@amazon.com> Co-authored-by: ElizaWszola <ewszola@redhat.com> Co-authored-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Co-authored-by: Zhuohan Li <zhuohan123@gmail.com> Co-authored-by: Simon Mo <simon.mo@hey.com> Co-authored-by: Or Ozeri <oro@il.ibm.com> Co-authored-by: Johnny Yang <24908445+jcyang43@users.noreply.github.com> Co-authored-by: Chengji Yao <chengjiyao@google.com> Co-authored-by: Alec S <10566873+alecsolder@users.noreply.github.com> Co-authored-by: Alec Solder <alecs@fb.com> Co-authored-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Chris Bamford <chrisbam4d@gmail.com> Co-authored-by: Alexander Matveev <59768536+alexm-redhat@users.noreply.github.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: JJJYmmm <92386084+JJJYmmm@users.noreply.github.com> Co-authored-by: liuye.hj <liuye.hj@alibaba-inc.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com> Co-authored-by: Lucia (Lu) Fang <fanglu@meta.com> Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Ming Yang <yming@meta.com> Co-authored-by: Zhikaiiii <55917203+Zhikaiiii@users.noreply.github.com> Co-authored-by: Andreas Hartel <andreas@hartel.me> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: Joel <wuxibin89@163.com> Co-authored-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Mark McLoughlin <markmc@redhat.com> Co-authored-by: Peter Pan <peter.pan@daocloud.io> Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com> Co-authored-by: Fanli Lin <fanli.lin@intel.com> Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Sage Moore <sage@neuralmagic.com> Co-authored-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: bnellnm <49004751+bnellnm@users.noreply.github.com> Co-authored-by: rivos-shreeasish <shreeasish@rivosinc.com> Co-authored-by: Chih-Chieh Yang <chih.chieh.yang@ibm.com> Co-authored-by: Weida Hong <wdhongtw@gmail.com> Co-authored-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Co-authored-by: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com> Co-authored-by: Amir Samani <samani@ualberta.ca> Co-authored-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Co-authored-by: Ilya Markov <markovilya197@gmail.com> Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Co-authored-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Co-authored-by: rouchenzi <40842833+rouchenzi@users.noreply.github.com> Co-authored-by: Andrew Xia <axia@meta.com> Co-authored-by: kourosh hakhamaneshi <31483498+kouroshHakha@users.noreply.github.com> Co-authored-by: Corey Lowman <clowman1993@gmail.com> Co-authored-by: Juan Villamizar <100237675+jpvillam-amd@users.noreply.github.com> Co-authored-by: jpvillam <jpvillam@amd.com> Co-authored-by: Doug Smith <dosmith@redhat.com> Co-authored-by: Chenxi Yang <cxyang@cs.utexas.edu> Co-authored-by: Chenxi Yang <cxyang@fb.com> Co-authored-by: ahao-anyscale <ahao@anyscale.com> Co-authored-by: 0xNullPath <luyanfcp@foxmail.com> Co-authored-by: baxingpiaochong <771405853@qq.com> Co-authored-by: Benjamin Chislett <bchislett@nvidia.com> Co-authored-by: Kyle Sayers <kylesayrs@gmail.com> Co-authored-by: Nikhil Gupta <nikhil.gupta2@arm.com> Co-authored-by: Yong Hoon Shin <48474650+sarckk@users.noreply.github.com> Co-authored-by: lhsjohn <huashuoli@tencent.com> Co-authored-by: Ben Browning <bbrownin@redhat.com> Co-authored-by: Li, Jiang <jiang1.li@intel.com> Co-authored-by: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> Co-authored-by: Jonas M. Kübler <44084297+jmkuebler@users.noreply.github.com> Co-authored-by: Tao Hui <taohui3@gmail.com> Co-authored-by: rongfu.leng <rongfu.leng@daocloud.io> Co-authored-by: Shu Wang <shuw@nvidia.com> Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Co-authored-by: Duncan Moss <djm.moss@gmail.com> Co-authored-by: Shiyan Deng <dsy842974287@meta.com> Co-authored-by: Wei Wei <wwei6@meta.com> Co-authored-by: Saman A. Pour <samanamp@outlook.com> Co-authored-by: XuruiYang <530534756@qq.com> Co-authored-by: yangxurui <yangxurui@meituan.com> Co-authored-by: Nicole LiHui 🥜 <nicolelihui@outlook.com> Co-authored-by: courage17340 <courage17340@users.noreply.github.com> Co-authored-by: Jacob Kahn <jacobkahn1@gmail.com> Co-authored-by: Nicole LiHui 🥜 <nicole.li@daocloud.io> Co-authored-by: Fadi Arafeh <115173828+fadara01@users.noreply.github.com> Co-authored-by: Agata Dobrzyniewicz <160237065+adobrzyn@users.noreply.github.com> Co-authored-by: yyzxw <34639446+yyzxw@users.noreply.github.com> Co-authored-by: wang.yuqi <noooop@126.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: chenlang <chen.lang5@zte.com.cn> Co-authored-by: chenlang <10346245@zte.com.cn> Co-authored-by: AlonKejzman <alonkeizman@gmail.com> Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> Co-authored-by: Aleksandr Malyshev <164964928+maleksan85@users.noreply.github.com> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Doug Lehr <douglehr@amd.com> Co-authored-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Co-authored-by: yitingdc <59356937+yitingdc@users.noreply.github.com> Co-authored-by: xaguilar-amd <xavier.aguilarfruto@amd.com> Co-authored-by: Iceber Gu <caiwei95@hotmail.com> Co-authored-by: Tao He <linzhu.ht@alibaba-inc.com> Co-authored-by: Icey <1790571317@qq.com> Co-authored-by: Xu Wenqing <121550081+Xu-Wenqing@users.noreply.github.com> Co-authored-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: RishiAstra <40644327+RishiAstra@users.noreply.github.com>
771 lines
27 KiB
Plaintext
771 lines
27 KiB
Plaintext
/*
|
|
* Adapted from
|
|
* https://github.com/NVIDIA/TensorRT-LLM/blob/v0.21.0/cpp/tensorrt_llm/kernels/noAuxTcKernels.cu
|
|
* Copyright (c) 2025, The vLLM team.
|
|
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION &
|
|
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
#include <c10/cuda/CUDAStream.h>
|
|
#include <torch/all.h>
|
|
#include <cuda_fp16.h>
|
|
#include <cuda_bf16.h>
|
|
#include <cuda/std/limits>
|
|
#include <cooperative_groups.h>
|
|
#include <cooperative_groups/reduce.h>
|
|
namespace cg = cooperative_groups;
|
|
|
|
namespace vllm {
|
|
namespace moe {
|
|
|
|
constexpr unsigned FULL_WARP_MASK = 0xffffffff;
|
|
constexpr int32_t WARP_SIZE = 32;
|
|
constexpr int32_t BLOCK_SIZE = 512;
|
|
constexpr int32_t NUM_WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE;
|
|
|
|
namespace warp_topk {
|
|
|
|
template <int size, typename T>
|
|
__host__ __device__ constexpr T round_up_to_multiple_of(T len) {
|
|
if (len == 0) {
|
|
return 0;
|
|
}
|
|
return ((len - 1) / size + 1) * size;
|
|
}
|
|
|
|
template <typename T>
|
|
constexpr __host__ __device__ bool isPowerOf2(T v) {
|
|
return (v && !(v & (v - 1)));
|
|
}
|
|
|
|
template <bool greater, typename T>
|
|
__forceinline__ __device__ bool is_better_than(T val, T baseline) {
|
|
return (val > baseline && greater) || (val < baseline && !greater);
|
|
}
|
|
|
|
template <bool greater, typename T, typename idxT>
|
|
__forceinline__ __device__ bool is_better_than(T val, T baseline, idxT index,
|
|
idxT baseline_index) {
|
|
bool res = (val > baseline && greater) || (val < baseline && !greater);
|
|
if (val == baseline) {
|
|
res = (index < baseline_index && greater) ||
|
|
(index < baseline_index && !greater);
|
|
}
|
|
return res;
|
|
}
|
|
|
|
template <typename T, typename idxT>
|
|
int calc_smem_size_for_block_wide(int num_of_warp, int64_t k) {
|
|
int64_t cache_topk = (sizeof(T) + sizeof(idxT)) * num_of_warp * k;
|
|
int64_t n = std::max<int>(num_of_warp / 2 * k, num_of_warp * WARP_SIZE);
|
|
return max(cache_topk,
|
|
round_up_to_multiple_of<256>(n * sizeof(T)) + n * sizeof(idxT));
|
|
}
|
|
|
|
template <int size, bool ascending, bool reverse, typename T, typename idxT,
|
|
bool is_stable>
|
|
struct BitonicMerge {
|
|
// input should be a bitonic sequence, and sort it to be a monotonic sequence
|
|
__device__ static void merge(T* __restrict__ val_arr,
|
|
idxT* __restrict__ idx_arr) {
|
|
static_assert(isPowerOf2(size));
|
|
static_assert(size >= 2 * WARP_SIZE);
|
|
constexpr int arr_len = size / WARP_SIZE;
|
|
|
|
constexpr int stride = arr_len / 2;
|
|
for (int i = 0; i < stride; ++i) {
|
|
int const other_i = i + stride;
|
|
T& val = val_arr[i];
|
|
T& other_val = val_arr[other_i];
|
|
bool is_better;
|
|
if constexpr (is_stable) {
|
|
is_better = is_better_than<ascending>(val, other_val, idx_arr[i],
|
|
idx_arr[other_i]);
|
|
} else {
|
|
is_better = is_better_than<ascending>(val, other_val);
|
|
}
|
|
|
|
if (is_better) {
|
|
T tmp = val;
|
|
val = other_val;
|
|
other_val = tmp;
|
|
|
|
idxT tmp2 = idx_arr[i];
|
|
idx_arr[i] = idx_arr[other_i];
|
|
idx_arr[other_i] = tmp2;
|
|
}
|
|
}
|
|
|
|
BitonicMerge<size / 2, ascending, reverse, T, idxT, is_stable>::merge(
|
|
val_arr, idx_arr);
|
|
BitonicMerge<size / 2, ascending, reverse, T, idxT, is_stable>::merge(
|
|
val_arr + arr_len / 2, idx_arr + arr_len / 2);
|
|
}
|
|
};
|
|
|
|
template <int size, bool ascending, typename T, typename idxT, bool is_stable>
|
|
struct BitonicSort {
|
|
__device__ static void sort(T* __restrict__ val_arr,
|
|
idxT* __restrict__ idx_arr) {
|
|
static_assert(isPowerOf2(size));
|
|
static_assert(size >= 2 * WARP_SIZE);
|
|
constexpr int arr_len = size / WARP_SIZE;
|
|
|
|
BitonicSort<size / 2, true, T, idxT, is_stable>::sort(val_arr, idx_arr);
|
|
BitonicSort<size / 2, false, T, idxT, is_stable>::sort(
|
|
val_arr + arr_len / 2, idx_arr + arr_len / 2);
|
|
BitonicMerge<size, ascending, ascending, T, idxT, is_stable>::merge(
|
|
val_arr, idx_arr);
|
|
}
|
|
};
|
|
|
|
template <bool ascending, typename T, typename idxT, bool is_stable>
|
|
struct BitonicSort<32, ascending, T, idxT, is_stable> {
|
|
__device__ static void sort(T* __restrict__ val_arr,
|
|
idxT* __restrict__ idx_arr) {
|
|
int const lane = threadIdx.x % WARP_SIZE;
|
|
|
|
// ascending doesn't matter before merging since all we need is a bitonic
|
|
// sequence
|
|
for (int stage = 0; stage < 4; ++stage) {
|
|
for (int stride = (1 << stage); stride > 0; stride /= 2) {
|
|
bool reverse = (lane >> stage) & 2;
|
|
bool is_second = lane & stride;
|
|
|
|
T other = __shfl_xor_sync(FULL_WARP_MASK, *val_arr, stride);
|
|
idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, *idx_arr, stride);
|
|
|
|
bool is_better;
|
|
if constexpr (is_stable) {
|
|
if constexpr (ascending) {
|
|
is_better = ((*val_arr > other) ||
|
|
((*val_arr == other) && (*idx_arr < other_idx))) !=
|
|
(reverse != is_second);
|
|
} else {
|
|
is_better = ((*val_arr > other) ||
|
|
((*val_arr == other) && (*idx_arr > other_idx))) !=
|
|
(reverse != is_second);
|
|
}
|
|
} else {
|
|
is_better = (*val_arr != other &&
|
|
(*val_arr > other) != (reverse != is_second));
|
|
}
|
|
if (is_better) {
|
|
*val_arr = other;
|
|
*idx_arr = other_idx;
|
|
}
|
|
}
|
|
}
|
|
|
|
BitonicMerge<32, ascending, ascending, T, idxT, is_stable>::merge(val_arr,
|
|
idx_arr);
|
|
}
|
|
};
|
|
|
|
template <bool ascending, bool reverse, typename T, typename idxT,
|
|
bool is_stable>
|
|
struct BitonicMerge<32, ascending, reverse, T, idxT, is_stable> {
|
|
__device__ static void merge(T* __restrict__ val_arr,
|
|
idxT* __restrict__ idx_arr) {
|
|
int const lane = threadIdx.x % WARP_SIZE;
|
|
for (int stride = WARP_SIZE / 2; stride > 0; stride /= 2) {
|
|
bool is_second = lane & stride;
|
|
T& val = *val_arr;
|
|
T other = __shfl_xor_sync(FULL_WARP_MASK, val, stride);
|
|
idxT& idx = *idx_arr;
|
|
idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, idx, stride);
|
|
|
|
bool is_better;
|
|
if constexpr (is_stable) {
|
|
if constexpr (ascending) {
|
|
is_better = ((*val_arr > other) ||
|
|
((*val_arr == other) && (*idx_arr < other_idx))) ==
|
|
(reverse != is_second); // for min
|
|
} else {
|
|
is_better = ((*val_arr > other) ||
|
|
((*val_arr == other) && (*idx_arr > other_idx))) ==
|
|
(reverse != is_second); // for max
|
|
}
|
|
} else {
|
|
is_better =
|
|
(val != other && ((val > other) == (ascending != is_second)));
|
|
}
|
|
|
|
if (is_better) {
|
|
val = other;
|
|
idx = other_idx;
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
template <int capacity, bool greater, typename T, typename idxT, bool is_stable>
|
|
class WarpSort {
|
|
public:
|
|
__device__ WarpSort(idxT k, T dummy)
|
|
: lane_(threadIdx.x % WARP_SIZE), k_(k), dummy_(dummy) {
|
|
static_assert(capacity >= WARP_SIZE && isPowerOf2(capacity));
|
|
|
|
for (int i = 0; i < max_arr_len_; ++i) {
|
|
val_arr_[i] = dummy_;
|
|
idx_arr_[i] = 0;
|
|
}
|
|
}
|
|
|
|
// load and merge k sorted values
|
|
__device__ void load_sorted(T const* __restrict__ in,
|
|
idxT const* __restrict__ in_idx, idxT start) {
|
|
idxT idx = start + WARP_SIZE - 1 - lane_;
|
|
for (int i = max_arr_len_ - 1; i >= 0; --i, idx += WARP_SIZE) {
|
|
if (idx < start + k_) {
|
|
T t = in[idx];
|
|
bool is_better;
|
|
if constexpr (is_stable) {
|
|
is_better =
|
|
is_better_than<greater>(t, val_arr_[i], in_idx[idx], idx_arr_[i]);
|
|
} else {
|
|
is_better = is_better_than<greater>(t, val_arr_[i]);
|
|
}
|
|
if (is_better) {
|
|
val_arr_[i] = t;
|
|
idx_arr_[i] = in_idx[idx];
|
|
}
|
|
}
|
|
}
|
|
|
|
BitonicMerge<capacity, greater, !greater, T, idxT, is_stable>::merge(
|
|
val_arr_, idx_arr_);
|
|
}
|
|
|
|
__device__ void dump(T* __restrict__ out, idxT* __restrict__ out_idx) const {
|
|
for (int i = 0; i < max_arr_len_; ++i) {
|
|
idxT out_i = i * WARP_SIZE + lane_;
|
|
if (out_i < k_) {
|
|
out[out_i] = val_arr_[i];
|
|
out_idx[out_i] = idx_arr_[i];
|
|
}
|
|
}
|
|
}
|
|
|
|
__device__ void dumpIdx(idxT* __restrict__ out_idx) const {
|
|
for (int i = 0; i < max_arr_len_; ++i) {
|
|
idxT out_i = i * WARP_SIZE + lane_;
|
|
if (out_i < k_) {
|
|
out_idx[out_i] = idx_arr_[i];
|
|
}
|
|
}
|
|
}
|
|
|
|
protected:
|
|
static constexpr int max_arr_len_ = capacity / WARP_SIZE;
|
|
|
|
T val_arr_[max_arr_len_];
|
|
idxT idx_arr_[max_arr_len_];
|
|
|
|
int const lane_;
|
|
idxT const k_;
|
|
T const dummy_;
|
|
|
|
}; // end class WarpSort
|
|
|
|
template <int capacity, bool greater, typename T, typename idxT, bool is_stable>
|
|
class WarpSelect : public WarpSort<capacity, greater, T, idxT, is_stable> {
|
|
public:
|
|
__device__ WarpSelect(idxT k, T dummy)
|
|
: WarpSort<capacity, greater, T, idxT, is_stable>(k, dummy),
|
|
k_th_(dummy),
|
|
k_th_lane_((k - 1) % WARP_SIZE) {
|
|
extern __shared__ char smem_buf[]; // extern __shared__ T smem_buf[];
|
|
|
|
int const num_of_warp = blockDim.x / WARP_SIZE;
|
|
int const warp_id = threadIdx.x / WARP_SIZE;
|
|
val_smem_ = reinterpret_cast<T*>(smem_buf);
|
|
val_smem_ += warp_id * WARP_SIZE;
|
|
idx_smem_ = reinterpret_cast<idxT*>(
|
|
smem_buf +
|
|
round_up_to_multiple_of<256>(num_of_warp * sizeof(T) * WARP_SIZE));
|
|
idx_smem_ += warp_id * WARP_SIZE;
|
|
}
|
|
|
|
__device__ void add(T const* in, idxT start, idxT end) {
|
|
idxT const end_for_fullwarp =
|
|
round_up_to_multiple_of<WARP_SIZE>(end - start) + start;
|
|
for (idxT i = start + lane_; i < end_for_fullwarp; i += WARP_SIZE) {
|
|
T val = (i < end) ? in[i] : dummy_;
|
|
add(val, i);
|
|
}
|
|
}
|
|
|
|
__device__ void add(T val, idxT idx) {
|
|
bool do_add;
|
|
if constexpr (is_stable) {
|
|
do_add = is_better_than<greater>(val, k_th_, idx, k_th_idx_);
|
|
} else {
|
|
do_add = is_better_than<greater>(val, k_th_);
|
|
}
|
|
|
|
uint32_t mask = __ballot_sync(FULL_WARP_MASK, do_add);
|
|
if (mask == 0) {
|
|
return;
|
|
}
|
|
|
|
int pos = smem_buf_len_ + __popc(mask & ((0x1u << lane_) - 1));
|
|
if (do_add && pos < WARP_SIZE) {
|
|
val_smem_[pos] = val;
|
|
idx_smem_[pos] = idx;
|
|
do_add = false;
|
|
}
|
|
smem_buf_len_ += __popc(mask);
|
|
if (smem_buf_len_ >= WARP_SIZE) {
|
|
__syncwarp();
|
|
merge_buf_(val_smem_[lane_], idx_smem_[lane_]);
|
|
smem_buf_len_ -= WARP_SIZE;
|
|
}
|
|
if (do_add) {
|
|
pos -= WARP_SIZE;
|
|
val_smem_[pos] = val;
|
|
idx_smem_[pos] = idx;
|
|
}
|
|
__syncwarp();
|
|
}
|
|
|
|
__device__ void done() {
|
|
if (smem_buf_len_) {
|
|
T val = (lane_ < smem_buf_len_) ? val_smem_[lane_] : dummy_;
|
|
idxT idx = (lane_ < smem_buf_len_) ? idx_smem_[lane_] : 0;
|
|
merge_buf_(val, idx);
|
|
}
|
|
|
|
// after done(), smem is used for merging results among warps
|
|
__syncthreads();
|
|
}
|
|
|
|
private:
|
|
__device__ void set_k_th_() {
|
|
k_th_ = __shfl_sync(FULL_WARP_MASK, val_arr_[max_arr_len_ - 1], k_th_lane_);
|
|
if constexpr (is_stable) {
|
|
k_th_idx_ =
|
|
__shfl_sync(FULL_WARP_MASK, idx_arr_[max_arr_len_ - 1], k_th_lane_);
|
|
}
|
|
}
|
|
|
|
__device__ void merge_buf_(T val, idxT idx) {
|
|
BitonicSort<WARP_SIZE, greater, T, idxT, is_stable>::sort(&val, &idx);
|
|
|
|
T& old = val_arr_[max_arr_len_ - 1];
|
|
|
|
bool is_better;
|
|
if constexpr (is_stable) {
|
|
is_better =
|
|
is_better_than<greater>(val, old, idx, idx_arr_[max_arr_len_ - 1]);
|
|
} else {
|
|
is_better = is_better_than<greater>(val, old);
|
|
}
|
|
|
|
if (is_better) {
|
|
old = val;
|
|
idx_arr_[max_arr_len_ - 1] = idx;
|
|
}
|
|
|
|
BitonicMerge<capacity, greater, !greater, T, idxT, is_stable>::merge(
|
|
val_arr_, idx_arr_);
|
|
|
|
set_k_th_();
|
|
}
|
|
|
|
using WarpSort<capacity, greater, T, idxT, is_stable>::max_arr_len_;
|
|
using WarpSort<capacity, greater, T, idxT, is_stable>::val_arr_;
|
|
using WarpSort<capacity, greater, T, idxT, is_stable>::idx_arr_;
|
|
using WarpSort<capacity, greater, T, idxT, is_stable>::lane_;
|
|
using WarpSort<capacity, greater, T, idxT, is_stable>::k_;
|
|
using WarpSort<capacity, greater, T, idxT, is_stable>::dummy_;
|
|
|
|
T* val_smem_;
|
|
idxT* idx_smem_;
|
|
int smem_buf_len_ = 0;
|
|
|
|
T k_th_;
|
|
idxT k_th_idx_;
|
|
int const k_th_lane_;
|
|
}; // end class WarpSelect
|
|
} // namespace warp_topk
|
|
|
|
template <typename T_OUT, typename T_IN>
|
|
__device__ inline T_OUT cuda_cast(T_IN val) {
|
|
return val;
|
|
}
|
|
|
|
template <>
|
|
__device__ inline float cuda_cast<float, __nv_bfloat16>(__nv_bfloat16 val) {
|
|
return __bfloat162float(val);
|
|
}
|
|
|
|
template <typename T>
|
|
__device__ inline T neg_inf() {
|
|
// cuda::std::numeric_limits<T>::infinity() returns `0` for [T=bf16 or fp16]
|
|
// so we need to cast from fp32
|
|
return cuda_cast<T, float>(-cuda::std::numeric_limits<float>::infinity());
|
|
}
|
|
|
|
template <typename T>
|
|
__device__ inline bool is_finite(const T val) {
|
|
#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120800)
|
|
return cuda::std::isfinite(val);
|
|
#else
|
|
return isfinite(cuda_cast<float, T>(val));
|
|
#endif
|
|
}
|
|
|
|
template <typename T>
|
|
__device__ void topk_with_k2(T* output, T const* input,
|
|
cg::thread_block_tile<32> const& tile,
|
|
int32_t const lane_id,
|
|
int const num_experts_per_group) {
|
|
// Get the top2 per thread
|
|
T largest = neg_inf<T>();
|
|
T second_largest = neg_inf<T>();
|
|
|
|
if (num_experts_per_group > WARP_SIZE) {
|
|
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
|
|
T value = input[i];
|
|
if (value > largest) {
|
|
second_largest = largest;
|
|
largest = value;
|
|
} else if (value > second_largest) {
|
|
second_largest = value;
|
|
}
|
|
}
|
|
} else {
|
|
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
|
|
largest = input[i];
|
|
}
|
|
}
|
|
|
|
__syncwarp(); // Ensure all threads have valid data before reduction
|
|
// Get the top2 warpwise
|
|
T max1 = cg::reduce(tile, largest, cg::greater<T>());
|
|
|
|
T max2 = max1;
|
|
bool equal_to_max1 = (max1 == largest);
|
|
|
|
int count_max1 = __popc(__ballot_sync(FULL_WARP_MASK, equal_to_max1));
|
|
|
|
if (count_max1 == 1) {
|
|
largest = (largest == max1) ? second_largest : largest;
|
|
max2 = cg::reduce(tile, largest, cg::greater<T>());
|
|
}
|
|
|
|
if (lane_id == 0) {
|
|
*output = max1 + max2;
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
__global__ void topk_with_k2_kernel(T* output, T* input,
|
|
int64_t const num_tokens,
|
|
int64_t const num_cases,
|
|
int64_t const n_group,
|
|
int64_t const num_experts_per_group) {
|
|
int32_t warp_id = threadIdx.x / WARP_SIZE;
|
|
int32_t lane_id = threadIdx.x % WARP_SIZE;
|
|
|
|
int32_t case_id = blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id;
|
|
if (case_id < num_cases) {
|
|
input += case_id * num_experts_per_group;
|
|
output += case_id;
|
|
|
|
cg::thread_block block = cg::this_thread_block();
|
|
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block);
|
|
|
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
|
asm volatile("griddepcontrol.wait;");
|
|
#endif
|
|
topk_with_k2(output, input, tile, lane_id, num_experts_per_group);
|
|
}
|
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
|
asm volatile("griddepcontrol.launch_dependents;");
|
|
#endif
|
|
}
|
|
|
|
template <typename T, typename IdxT>
|
|
__global__ void group_idx_and_topk_idx_kernel(
|
|
T* scores, T const* group_scores, T* topk_values, IdxT* topk_indices,
|
|
T* scores_with_bias, int64_t const num_tokens, int64_t const n_group,
|
|
int64_t const topk_group, int64_t const topk, int64_t const num_experts,
|
|
int64_t const num_experts_per_group, bool renormalize,
|
|
double routed_scaling_factor) {
|
|
int32_t warp_id = threadIdx.x / WARP_SIZE;
|
|
int32_t lane_id = threadIdx.x % WARP_SIZE;
|
|
int32_t case_id =
|
|
blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; // one per token
|
|
scores_with_bias += case_id * num_experts;
|
|
scores += case_id * num_experts;
|
|
group_scores += case_id * n_group;
|
|
topk_values += case_id * topk;
|
|
topk_indices += case_id * topk;
|
|
|
|
int32_t align_num_experts_per_group =
|
|
warp_topk::round_up_to_multiple_of<WARP_SIZE>(num_experts_per_group);
|
|
|
|
cg::thread_block block = cg::this_thread_block();
|
|
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block);
|
|
|
|
extern __shared__ char smem_buf[]; // NOTE: reuse the shared memory here to
|
|
// store the target topk idx
|
|
int32_t* s_topk_idx = reinterpret_cast<int32_t*>(smem_buf);
|
|
T* s_topk_value =
|
|
reinterpret_cast<T*>(s_topk_idx + NUM_WARPS_PER_BLOCK * topk) +
|
|
warp_id * topk;
|
|
s_topk_idx += warp_id * topk;
|
|
|
|
T value = neg_inf<T>();
|
|
T topk_group_value = neg_inf<T>();
|
|
int32_t num_equalto_topkth_group;
|
|
|
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
|
asm volatile("griddepcontrol.wait;"); // I think all prolog can be put before
|
|
// acqbulk because it's ptr arithmetic
|
|
#endif
|
|
|
|
if (case_id < num_tokens) {
|
|
// calculate group_idx
|
|
int32_t target_num_min = WARP_SIZE - n_group + topk_group;
|
|
// The check is necessary to avoid abnormal input
|
|
if (lane_id < n_group && is_finite(group_scores[lane_id])) {
|
|
value = group_scores[lane_id];
|
|
}
|
|
|
|
int count_equal_to_top_value = WARP_SIZE - n_group;
|
|
int pre_count_equal_to_top_value = 0;
|
|
// Use loop to find the largset top_group
|
|
while (count_equal_to_top_value < target_num_min) {
|
|
__syncwarp(); // Ensure all threads have valid data before reduction
|
|
topk_group_value = cg::reduce(tile, value, cg::greater<T>());
|
|
if (value == topk_group_value) {
|
|
value = neg_inf<T>();
|
|
}
|
|
pre_count_equal_to_top_value = count_equal_to_top_value;
|
|
count_equal_to_top_value =
|
|
__popc(__ballot_sync(FULL_WARP_MASK, (value == neg_inf<T>())));
|
|
}
|
|
num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value;
|
|
}
|
|
__syncthreads();
|
|
|
|
warp_topk::WarpSelect</*capability*/ WARP_SIZE, /*greater*/ true, T, int32_t,
|
|
/* is_stable */ true>
|
|
queue((int32_t)topk, neg_inf<T>());
|
|
|
|
int count_equalto_topkth_group = 0;
|
|
bool if_proceed_next_topk = topk_group_value != neg_inf<T>();
|
|
if (case_id < num_tokens && if_proceed_next_topk) {
|
|
for (int i_group = 0; i_group < n_group; i_group++) {
|
|
if ((group_scores[i_group] > topk_group_value) ||
|
|
((group_scores[i_group] == topk_group_value) &&
|
|
(count_equalto_topkth_group < num_equalto_topkth_group))) {
|
|
int32_t offset = i_group * num_experts_per_group;
|
|
for (int32_t i = lane_id; i < align_num_experts_per_group;
|
|
i += WARP_SIZE) {
|
|
T candidates = (i < num_experts_per_group) &&
|
|
is_finite(scores_with_bias[offset + i])
|
|
? scores_with_bias[offset + i]
|
|
: neg_inf<T>();
|
|
queue.add(candidates, offset + i);
|
|
}
|
|
if (group_scores[i_group] == topk_group_value) {
|
|
count_equalto_topkth_group++;
|
|
}
|
|
}
|
|
}
|
|
queue.done();
|
|
__syncwarp();
|
|
// Get the topk_idx
|
|
queue.dumpIdx(s_topk_idx);
|
|
__syncwarp();
|
|
}
|
|
|
|
// Load the valid score value
|
|
// Calculate the summation
|
|
float topk_sum = 1e-20;
|
|
if (case_id < num_tokens && if_proceed_next_topk) {
|
|
for (int i = lane_id;
|
|
i < warp_topk::round_up_to_multiple_of<WARP_SIZE>(topk);
|
|
i += WARP_SIZE) {
|
|
T value =
|
|
i < topk
|
|
? scores[s_topk_idx[i]]
|
|
: cuda_cast<T, float>(0.0f); // Load the valid value of expert
|
|
if (i < topk) {
|
|
s_topk_value[i] = value;
|
|
}
|
|
topk_sum +=
|
|
cg::reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
|
|
}
|
|
}
|
|
|
|
__syncthreads();
|
|
|
|
if (case_id < num_tokens) {
|
|
if (if_proceed_next_topk) {
|
|
for (int i = lane_id; i < topk; i += WARP_SIZE) {
|
|
float value;
|
|
if (renormalize) {
|
|
value = cuda_cast<float, T>(s_topk_value[i]) / topk_sum *
|
|
routed_scaling_factor;
|
|
} else {
|
|
value = cuda_cast<float, T>(s_topk_value[i]) * routed_scaling_factor;
|
|
}
|
|
topk_indices[i] = s_topk_idx[i];
|
|
topk_values[i] = cuda_cast<T, float>(value);
|
|
}
|
|
} else {
|
|
for (int i = lane_id; i < topk; i += WARP_SIZE) {
|
|
topk_indices[i] = i;
|
|
topk_values[i] = cuda_cast<T, float>(1.0f / topk);
|
|
}
|
|
}
|
|
// Note: when if_proceed_next_topk==false, choose the first 8 experts as the
|
|
// default result.
|
|
}
|
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
|
asm volatile("griddepcontrol.launch_dependents;");
|
|
#endif
|
|
}
|
|
|
|
template <typename T, typename IdxT>
|
|
void invokeNoAuxTc(T* scores, T* group_scores, T* topk_values,
|
|
IdxT* topk_indices, T* scores_with_bias,
|
|
int64_t const num_tokens, int64_t const num_experts,
|
|
int64_t const n_group, int64_t const topk_group,
|
|
int64_t const topk, bool const renormalize,
|
|
double const routed_scaling_factor, bool enable_pdl = false,
|
|
cudaStream_t const stream = 0) {
|
|
int64_t num_cases = num_tokens * n_group;
|
|
int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1;
|
|
auto* kernel_instance1 = &topk_with_k2_kernel<T>;
|
|
cudaLaunchConfig_t config;
|
|
config.gridDim = topk_with_k2_num_blocks;
|
|
config.blockDim = BLOCK_SIZE;
|
|
config.dynamicSmemBytes = 0;
|
|
config.stream = stream;
|
|
cudaLaunchAttribute attrs[1];
|
|
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
|
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
|
|
config.numAttrs = 1;
|
|
config.attrs = attrs;
|
|
cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores_with_bias,
|
|
num_tokens, num_cases, n_group, num_experts / n_group);
|
|
|
|
int64_t topk_with_k_group_num_blocks =
|
|
(num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1;
|
|
size_t dynamic_smem_in_bytes =
|
|
warp_topk::calc_smem_size_for_block_wide<T, int32_t>(NUM_WARPS_PER_BLOCK,
|
|
topk);
|
|
auto* kernel_instance2 = &group_idx_and_topk_idx_kernel<T, IdxT>;
|
|
config.gridDim = topk_with_k_group_num_blocks;
|
|
config.blockDim = BLOCK_SIZE;
|
|
config.dynamicSmemBytes = dynamic_smem_in_bytes;
|
|
config.stream = stream;
|
|
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
|
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
|
|
config.numAttrs = 1;
|
|
config.attrs = attrs;
|
|
cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores,
|
|
topk_values, topk_indices, scores_with_bias, num_tokens,
|
|
n_group, topk_group, topk, num_experts,
|
|
num_experts / n_group, renormalize, routed_scaling_factor);
|
|
}
|
|
|
|
#define INSTANTIATE_NOAUX_TC(T, IdxT) \
|
|
template void invokeNoAuxTc<T, IdxT>( \
|
|
T * scores, T * group_scores, T * topk_values, IdxT * topk_indices, \
|
|
T * scores_with_bias, int64_t const num_tokens, \
|
|
int64_t const num_experts, int64_t const n_group, \
|
|
int64_t const topk_group, int64_t const topk, bool const renormalize, \
|
|
double const routed_scaling_factor, bool enable_pdl, \
|
|
cudaStream_t const stream);
|
|
|
|
INSTANTIATE_NOAUX_TC(float, int32_t);
|
|
INSTANTIATE_NOAUX_TC(half, int32_t);
|
|
INSTANTIATE_NOAUX_TC(__nv_bfloat16, int32_t);
|
|
} // end namespace moe
|
|
} // namespace vllm
|
|
|
|
std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
|
|
torch::Tensor const& scores, torch::Tensor const& scores_with_bias,
|
|
int64_t n_group, int64_t topk_group, int64_t topk, bool renormalize,
|
|
double routed_scaling_factor) {
|
|
auto data_type = scores_with_bias.scalar_type();
|
|
auto input_size = scores_with_bias.sizes();
|
|
int64_t num_tokens = input_size[0];
|
|
int64_t num_experts = input_size[1];
|
|
TORCH_CHECK(input_size.size() == 2, "scores_with_bias must be a 2D Tensor");
|
|
TORCH_CHECK(num_experts % n_group == 0,
|
|
"num_experts should be divisible by n_group");
|
|
TORCH_CHECK(n_group <= 32,
|
|
"n_group should be smaller than or equal to 32 for now");
|
|
TORCH_CHECK(topk <= 32, "topk should be smaller than or equal to 32 for now");
|
|
|
|
torch::Tensor group_scores = torch::empty(
|
|
{num_tokens, n_group}, torch::dtype(data_type).device(torch::kCUDA));
|
|
torch::Tensor topk_values = torch::empty(
|
|
{num_tokens, topk}, torch::dtype(data_type).device(torch::kCUDA));
|
|
torch::Tensor topk_indices = torch::empty(
|
|
{num_tokens, topk}, torch::dtype(torch::kInt32).device(torch::kCUDA));
|
|
|
|
auto stream = c10::cuda::getCurrentCUDAStream(scores_with_bias.get_device());
|
|
|
|
switch (data_type) {
|
|
case torch::kFloat16:
|
|
// Handle Float16
|
|
vllm::moe::invokeNoAuxTc<half, int32_t>(
|
|
reinterpret_cast<half*>(scores.mutable_data_ptr()),
|
|
reinterpret_cast<half*>(group_scores.mutable_data_ptr()),
|
|
reinterpret_cast<half*>(topk_values.mutable_data_ptr()),
|
|
reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()),
|
|
reinterpret_cast<half*>(scores_with_bias.data_ptr()), num_tokens,
|
|
num_experts, n_group, topk_group, topk, renormalize,
|
|
routed_scaling_factor, false, stream);
|
|
break;
|
|
case torch::kFloat32:
|
|
// Handle Float32
|
|
vllm::moe::invokeNoAuxTc<float, int32_t>(
|
|
reinterpret_cast<float*>(scores.mutable_data_ptr()),
|
|
reinterpret_cast<float*>(group_scores.mutable_data_ptr()),
|
|
reinterpret_cast<float*>(topk_values.mutable_data_ptr()),
|
|
reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()),
|
|
reinterpret_cast<float*>(scores_with_bias.data_ptr()), num_tokens,
|
|
num_experts, n_group, topk_group, topk, renormalize,
|
|
routed_scaling_factor, false, stream);
|
|
break;
|
|
case torch::kBFloat16:
|
|
// Handle BFloat16
|
|
vllm::moe::invokeNoAuxTc<__nv_bfloat16, int32_t>(
|
|
reinterpret_cast<__nv_bfloat16*>(scores.mutable_data_ptr()),
|
|
reinterpret_cast<__nv_bfloat16*>(group_scores.mutable_data_ptr()),
|
|
reinterpret_cast<__nv_bfloat16*>(topk_values.mutable_data_ptr()),
|
|
reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()),
|
|
reinterpret_cast<__nv_bfloat16*>(scores_with_bias.data_ptr()),
|
|
num_tokens, num_experts, n_group, topk_group, topk, renormalize,
|
|
routed_scaling_factor, false, stream);
|
|
break;
|
|
default:
|
|
// Handle other data types
|
|
throw std::invalid_argument(
|
|
"Invalid dtype, only supports float16, float32, and bfloat16");
|
|
break;
|
|
}
|
|
return {topk_values, topk_indices};
|
|
}
|