mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 01:49:19 +08:00
Signed-off-by: nicole-lihui <nicole.li@daocloud.io> Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: courage17340 <courage17340@163.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Jacob Kahn <jacobkahn1@gmail.com> Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com> Signed-off-by: Roger Wang <hey@rogerw.io> Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: zxw <1020938856@qq.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: wang.yuqi <noooop@126.com> Signed-off-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Signed-off-by: Kunshang Ji <kunshang.ji@intel.com> Signed-off-by: chenlang <chen.lang5@zte.com.cn> Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: Jonas Kuebler <kuebj@amazon.com> Signed-off-by: jiang1.li <jiang1.li@intel.com> Signed-off-by: Russell Bryant <rbryant@redhat.com> Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Signed-off-by: AlonKejzman <alonkeizman@gmail.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: taohui <taohui3@gmail.com> Signed-off-by: Tao Hui <taohui3@gmail.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Signed-off-by: Zhuohan Li <zhuohan123@gmail.com> Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Signed-off-by: Shu Wang. <shuw@nvidia.com> Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Signed-off-by: Eugene Khvedchenia <ekhvedchenia@nvidia.com> Signed-off-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Signed-off-by: yiting.jiang <yiting.jiang@daocloud.io> Signed-off-by: Andrew Sansom <andrew@protopia.ai> Signed-off-by: xaguilar <Xavier.AguilarFruto@amd.com> Signed-off-by: Iceber Gu <caiwei95@hotmail.com> Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com> Signed-off-by: Icey <1790571317@qq.com> Signed-off-by: Sage Moore <sage@neuralmagic.com> Signed-off-by: 许文卿 <xwq391974@alibaba-inc.com> Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com> Signed-off-by: Seiji Eicher <seiji@anyscale.com> Signed-off-by: Seiji Eicher <58963096+eicherseiji@users.noreply.github.com> Signed-off-by: zjy0516 <riverclouds.zhu@qq.com> Signed-off-by: Kosseila (CloudThrill) <klouddude@gmail.com> Signed-off-by: frankwang28 <frank.wbb@hotmail.com> Signed-off-by: Frank Wang <41319051+frankwang28@users.noreply.github.com> Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> Signed-off-by: zixi-qi <qizixi@meta.com> Signed-off-by: Bram Wasti <bwasti@meta.com> Signed-off-by: Naman Lalit <nl2688@nyu.edu> Signed-off-by: Chenheli Hua <huachenheli@outlook.com> Signed-off-by: Junhong <liujunhong11@huawei.com> Signed-off-by: Junhong Liu <98734602+LJH-LBJ@users.noreply.github.com> Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com> Signed-off-by: rentianyue-jk <rentianyue-jk@360shuke.com> Signed-off-by: Peter Pan <Peter.Pan@daocloud.io> Signed-off-by: Patrick Toulme <ptoulme@meta.com> Signed-off-by: Patrick Toulme <pctoulme+1@gmail.com> Signed-off-by: Jiangyun Zhu <riverclouds.zhu@qq.com> Signed-off-by: Clayton Coleman <smarterclayton@gmail.com> Signed-off-by: Jialin Ouyang <jialino@meta.com> Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Signed-off-by: Weiliang Liu <weiliangl@nvidia.com> Signed-off-by: zRzRzRzRzRzRzR <2448370773@qq.com> Signed-off-by: liuye.hj <liuye.hj@alibaba-inc.com> Signed-off-by: Juechen Liu <jueliu@meta.com> Signed-off-by: simon-mo <simon.mo@hey.com> Signed-off-by: Robert Shaw <robshaw@redhat.com> Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Signed-off-by: isotr0py <2037008807@qq.com> Signed-off-by: yingjun-mou <renzomou@gmail.com> Signed-off-by: zhoukz <me@zhoukz.com> Signed-off-by: Chenxi Yang <cxyang@fb.com> Signed-off-by: Rahul Tuli <rtuli@redhat.com> Signed-off-by: Lee Nau <lnau@nvidia.com> Signed-off-by: adabeyta <aabeyta@redhat.com> Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: simondanielsson <simon.danielsson99@hotmail.com> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> Signed-off-by: Lucia Fang <fanglu@meta.com> Signed-off-by: a120092009 <zhaoty0121@gmail.com> Signed-off-by: sergiopaniego <sergiopaniegoblanco@gmail.com> Signed-off-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com> Signed-off-by: wangyafeng <wangyafeng@baidu.com> Signed-off-by: Lehua Ding <lehuading@tencent.com> Signed-off-by: lyd1992 <liuyudong@iscas.ac.cn> Signed-off-by: ihb2032 <1355790728@qq.com> Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com> Signed-off-by: anion <1005128408@qq.com> Signed-off-by: Anion <123177548+Anionex@users.noreply.github.com> Signed-off-by: Pavani Majety <pmajety@nvidia.com> Signed-off-by: Bill Nell <bnell@redhat.com> Signed-off-by: bnellnm <49004751+bnellnm@users.noreply.github.com> Signed-off-by: Or Ozeri <oro@il.ibm.com> Signed-off-by: cjackal <44624812+cjackal@users.noreply.github.com> Signed-off-by: David Ben-David <davidb@pliops.com> Signed-off-by: Andrew Xia <axia@meta.com> Signed-off-by: Andrew Xia <axia@fb.com> Signed-off-by: Lu Fang <fanglu@fb.com> Signed-off-by: Salvatore Cena <cena@cenas.it> Signed-off-by: padg9912 <phone.and.desktop@gmail.com> Signed-off-by: nadathurv <work.vnadathur@gmail.com> Signed-off-by: WorldExplored <srreyansh.sethi@gmail.com> Signed-off-by: wwl2755 <wangwenlong2755@gmail.com> Signed-off-by: billishyahao <bill.he@amd.com> Signed-off-by: Nathan Scott <nathans@redhat.com> Signed-off-by: Kenichi Maehashi <maehashi@preferred.jp> Signed-off-by: Johnny <johnnynuca14@gmail.com> Signed-off-by: johnnynunez <johnnynuca14@gmail.com> Signed-off-by: Johnny <johnnync13@gmail.com> Signed-off-by: Huamin Li <3ericli@gmail.com> Signed-off-by: Hosang Yoon <hosang.yoon@amd.com> Signed-off-by: Jerry Zhang <jerryzh168@gmail.com> Signed-off-by: Peter Schuurman <psch@google.com> Signed-off-by: Huy Do <huydhn@gmail.com> Signed-off-by: leo-pony <nengjunma@outlook.com> Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: ElizaWszola <elizaw.9289@gmail.com> Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Signed-off-by: zhewenli <zhewenli@meta.com> Signed-off-by: ahao-anyscale <ahao@anyscale.com> Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Signed-off-by: huijjj <huijong.jeong@squeezebits.com> Signed-off-by: Yannick Schnider <yannick.schnider1@ibm.com> Signed-off-by: kyt <eluban4532@gmail.com> Signed-off-by: Egor <e.a.krivov@gmail.com> Signed-off-by: Yang <lymailforjob@gmail.com> Signed-off-by: Paul Pak <paulpak58@gmail.com> Signed-off-by: whx-sjtu <2952154980@qq.com> Signed-off-by: Xiang Si <sixiang@google.com> Signed-off-by: Aleksandr Samarin <astrlrd@nebius.com> Signed-off-by: Jun Jiang <jasl9187@hotmail.com> Signed-off-by: Chendi Xue <Chendi.Xue@intel.com> Signed-off-by: Chendi.Xue <chendi.xue@intel.com> Signed-off-by: Nikhil Ghosh <nikhil@anyscale.com> Co-authored-by: Nicole LiHui 🥜 <nicolelihui@outlook.com> Co-authored-by: courage17340 <courage17340@users.noreply.github.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Co-authored-by: Jacob Kahn <jacobkahn1@gmail.com> Co-authored-by: Roger Wang <hey@rogerw.io> Co-authored-by: Nicole LiHui 🥜 <nicole.li@daocloud.io> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Fadi Arafeh <115173828+fadara01@users.noreply.github.com> Co-authored-by: Agata Dobrzyniewicz <160237065+adobrzyn@users.noreply.github.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: yyzxw <34639446+yyzxw@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: wang.yuqi <noooop@126.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com> Co-authored-by: chenlang <chen.lang5@zte.com.cn> Co-authored-by: chenlang <10346245@zte.com.cn> Co-authored-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Jonas M. Kübler <44084297+jmkuebler@users.noreply.github.com> Co-authored-by: Li, Jiang <jiang1.li@intel.com> Co-authored-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com> Co-authored-by: AlonKejzman <alonkeizman@gmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Tao Hui <taohui3@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Co-authored-by: Nick Hill <nhill@redhat.com> Co-authored-by: Zhuohan Li <zhuohan123@gmail.com> Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com> Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> Co-authored-by: Shu Wang <shuw@nvidia.com> Co-authored-by: Aleksandr Malyshev <164964928+maleksan85@users.noreply.github.com> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Doug Lehr <douglehr@amd.com> Co-authored-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Co-authored-by: yitingdc <59356937+yitingdc@users.noreply.github.com> Co-authored-by: Andrew Sansom <andrew@protopia.ai> Co-authored-by: xaguilar-amd <xavier.aguilarfruto@amd.com> Co-authored-by: Iceber Gu <caiwei95@hotmail.com> Co-authored-by: Tao He <linzhu.ht@alibaba-inc.com> Co-authored-by: Icey <1790571317@qq.com> Co-authored-by: Sage Moore <sage@neuralmagic.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Xu Wenqing <121550081+Xu-Wenqing@users.noreply.github.com> Co-authored-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: RishiAstra <40644327+RishiAstra@users.noreply.github.com> Co-authored-by: Chauncey <chaunceyjiang@gmail.com> Co-authored-by: Seiji Eicher <58963096+eicherseiji@users.noreply.github.com> Co-authored-by: Rui Qiao <161574667+ruisearch42@users.noreply.github.com> Co-authored-by: Jiangyun Zhu <riverclouds.zhu@qq.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: 阿丹(adan) <47373076+LDLINGLINGLING@users.noreply.github.com> Co-authored-by: liudan <adan@minicpm.com> Co-authored-by: liudan <liudan@qq.com> Co-authored-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Co-authored-by: Clouddude <kouss.hd@gmail.com> Co-authored-by: Frank Wang <41319051+frankwang28@users.noreply.github.com> Co-authored-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> Co-authored-by: qizixi <22851944+zixi-qi@users.noreply.github.com> Co-authored-by: Bram Wasti <bwasti@fb.com> Co-authored-by: Naman Lalit <nl2688@nyu.edu> Co-authored-by: Chenheli Hua <huachenheli@outlook.com> Co-authored-by: WeiQing Chen <40507679+david6666666@users.noreply.github.com> Co-authored-by: Junhong <liujunhong11@huawei.com> Co-authored-by: LJH-LBJ <98734602+LJH-LBJ@users.noreply.github.com> Co-authored-by: 22quinn <33176974+22quinn@users.noreply.github.com> Co-authored-by: Xiaohan Zou <renovamenzxh@gmail.com> Co-authored-by: rentianyue-jk <rentianyue-jk@360shuke.com> Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Co-authored-by: Peter Pan <peter.pan@daocloud.io> Co-authored-by: Patrick C. Toulme <135739773+patrick-toulme@users.noreply.github.com> Co-authored-by: Clayton Coleman <smarterclayton@gmail.com> Co-authored-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Co-authored-by: Jialin Ouyang <jialino@meta.com> Co-authored-by: weiliang <weiliangl@nvidia.com> Co-authored-by: Yuxuan Zhang <2448370773@qq.com> Co-authored-by: JJJYmmm <92386084+JJJYmmm@users.noreply.github.com> Co-authored-by: liuye.hj <liuye.hj@alibaba-inc.com> Co-authored-by: Juechen Liu <grinchcoder@gmail.com> Co-authored-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Yingjun Mou <renzomou@gmail.com> Co-authored-by: Zhou Jiahao <me@zhoukz.com> Co-authored-by: Chenxi Yang <cxyang@cs.utexas.edu> Co-authored-by: Chenxi Yang <cxyang@fb.com> Co-authored-by: Rahul Tuli <rtuli@redhat.com> Co-authored-by: Lee Nau <lee.nau@gmail.com> Co-authored-by: Adrian Abeyta <aabeyta@redhat.com> Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Co-authored-by: Aaron Pham <contact@aarnphm.xyz> Co-authored-by: acisseJZhong <40467976+acisseJZhong@users.noreply.github.com> Co-authored-by: Simon Danielsson <70206058+simondanielsson@users.noreply.github.com> Co-authored-by: Yongye Zhu <zyy1102000@gmail.com> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Lucia Fang <fanglu@meta.com> Co-authored-by: Siyuan Fu <siyuanf@nvidia.com> Co-authored-by: Xiaozhu Meng <mxz297@gmail.com> Co-authored-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> Co-authored-by: a120092009 <33205509+a120092009@users.noreply.github.com> Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com> Co-authored-by: CSWYF3634076 <wangyafeng@baidu.com> Co-authored-by: Lehua Ding <lehuading@tencent.com> Co-authored-by: Reza Barazesh <3146276+rzabarazesh@users.noreply.github.com> Co-authored-by: ihb2032 <40718643+ihb2032@users.noreply.github.com> Co-authored-by: Asaf Joseph Gardin <39553475+Josephasafg@users.noreply.github.com> Co-authored-by: Anion <123177548+Anionex@users.noreply.github.com> Co-authored-by: Pavani Majety <pmajety@nvidia.com> Co-authored-by: bnellnm <49004751+bnellnm@users.noreply.github.com> Co-authored-by: Or Ozeri <oro@il.ibm.com> Co-authored-by: cjackal <44624812+cjackal@users.noreply.github.com> Co-authored-by: David Ben-David <sdavidbd@gmail.com> Co-authored-by: David Ben-David <davidb@pliops.com> Co-authored-by: Andrew Xia <axia@mit.edu> Co-authored-by: Andrew Xia <axia@fb.com> Co-authored-by: Salvatore Cena <cena@cenas.it> Co-authored-by: Param <psch@cs.unc.edu> Co-authored-by: Zhewen Li <zhewenli@meta.com> Co-authored-by: nadathurv <work.vnadathur@gmail.com> Co-authored-by: Srreyansh Sethi <107075589+WorldExplored@users.noreply.github.com> Co-authored-by: Wenlong Wang <wangwenlong2755@gmail.com> Co-authored-by: billishyahao <bill.he@amd.com> Co-authored-by: Nathan Scott <natoscott@users.noreply.github.com> Co-authored-by: Kenichi Maehashi <939877+kmaehashi@users.noreply.github.com> Co-authored-by: Johnny <johnnync13@gmail.com> Co-authored-by: Aidyn-A <31858918+Aidyn-A@users.noreply.github.com> Co-authored-by: Huamin Li <3ericli@gmail.com> Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com> Co-authored-by: Hosang <156028780+hyoon1@users.noreply.github.com> Co-authored-by: Jerry Zhang <jerryzh168@gmail.com> Co-authored-by: pwschuurman <psch@google.com> Co-authored-by: Huy Do <huydhn@gmail.com> Co-authored-by: leo-pony <nengjunma@outlook.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: ElizaWszola <ewszola@redhat.com> Co-authored-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: Benjamin Chislett <bchislett@nvidia.com> Co-authored-by: Andrew Xia <axia@meta.com> Co-authored-by: Simon Mo <simon.mo@hey.com> Co-authored-by: TJian <tunjian.tan@embeddedllm.com> Co-authored-by: ahao-anyscale <ahao@anyscale.com> Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Liu-congo <1502632128@qq.com> Co-authored-by: HUIJONG JEONG <64083281+huijjj@users.noreply.github.com> Co-authored-by: Yannick Schnider <Yannick.Schnider1@ibm.com> Co-authored-by: kyt <eluban4532@gmail.com> Co-authored-by: Egor <e.a.krivov@gmail.com> Co-authored-by: Yang Liu <127183760+KKSK-DON@users.noreply.github.com> Co-authored-by: Paul Pak <52512091+paulpak58@users.noreply.github.com> Co-authored-by: whx <56632993+whx-sjtu@users.noreply.github.com> Co-authored-by: Xiang Si <sixiang@google.com> Co-authored-by: Aleksandr Samarin <samarin_ad@mail.ru> Co-authored-by: Jun Jiang <jasl9187@hotmail.com> Co-authored-by: Chendi.Xue <chendi.xue@intel.com> Co-authored-by: Nikhil G <nrghosh@users.noreply.github.com>
591 lines
22 KiB
Plaintext
591 lines
22 KiB
Plaintext
#include <ATen/cuda/CUDAContext.h>
|
|
#include <torch/all.h>
|
|
#include <c10/cuda/CUDAGuard.h>
|
|
|
|
#include <cmath>
|
|
#include "core/math.hpp"
|
|
#include "../cuda_compat.h"
|
|
#include "dispatch_utils.h"
|
|
|
|
#include "quantization/w8a8/fp8/common.cuh"
|
|
|
|
#include <c10/util/Float8_e4m3fn.h>
|
|
|
|
#ifndef USE_ROCM
|
|
#include <cuda_bf16.h>
|
|
#include <cuda_fp16.h>
|
|
#include <cuda_fp8.h>
|
|
#else
|
|
#include <hip/hip_bf16.h>
|
|
#include <hip/hip_fp16.h>
|
|
#include <hip/hip_fp8.h>
|
|
|
|
typedef __hip_bfloat162 __nv_bfloat162;
|
|
typedef __hip_bfloat16 __nv_bfloat16;
|
|
typedef __hip_bfloat16_raw __nv_bfloat16_raw;
|
|
#if defined(HIP_FP8_TYPE_OCP)
|
|
typedef __hip_fp8_e4m3 __nv_fp8_e4m3;
|
|
typedef __hip_fp8x4_e4m3 __nv_fp8x4_e4m3;
|
|
#else
|
|
// ROCm 6.2 fallback: only *_fnuz types exist
|
|
typedef __hip_fp8_e4m3_fnuz __nv_fp8_e4m3;
|
|
typedef __hip_fp8x4_e4m3_fnuz __nv_fp8x4_e4m3;
|
|
#endif
|
|
#endif
|
|
|
|
#include "core/registration.h"
|
|
namespace vllm {
|
|
|
|
template <typename T>
|
|
__device__ __forceinline__ T silu_kernel(const T& x) {
|
|
// x * sigmoid(x)
|
|
return (T)(((float)x) / (1.0f + expf((float)-x)));
|
|
}
|
|
|
|
// Activation and gating kernel template.
|
|
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
|
|
typename fp8_type>
|
|
__global__ void act_and_mul_quant_kernel(
|
|
fp8_type* __restrict__ out, // [..., d]
|
|
const scalar_t* __restrict__ input, // [..., 2, d]
|
|
const float* scale, const int d) {
|
|
const int32_t blocks_per_token = gridDim.y;
|
|
|
|
const int32_t elems_per_128bit_load = (128 / 8) / sizeof(scalar_t);
|
|
|
|
// We don't expect the hidden dimension to exceed 32 bits so int32 should
|
|
// be safe here.
|
|
const int32_t tgt_elems_per_block = div_ceil(d, blocks_per_token);
|
|
const int32_t elems_per_block =
|
|
round_to_next_multiple_of(tgt_elems_per_block, elems_per_128bit_load);
|
|
const int32_t block_start = blockIdx.y * elems_per_block;
|
|
int32_t block_end = block_start + elems_per_block;
|
|
block_end = block_end > d ? d : block_end;
|
|
|
|
// token_idx is 64 bit to prevent 32 bit overflow when the number of tokens
|
|
// is very large
|
|
const int64_t token_idx = blockIdx.x;
|
|
const scalar_t* __restrict__ x_ptr = input + token_idx * 2 * d;
|
|
const scalar_t* __restrict__ y_ptr = input + token_idx * 2 * d + d;
|
|
fp8_type* __restrict__ out_ptr = out + token_idx * d;
|
|
|
|
// 128-bit vectorized code
|
|
const int32_t vec_loop_end =
|
|
round_to_previous_multiple_of(elems_per_128bit_load, block_end);
|
|
const int32_t vec_end_idx = vec_loop_end / elems_per_128bit_load;
|
|
const int32_t vec_start_idx = block_start / elems_per_128bit_load;
|
|
|
|
const int4* __restrict__ x_128bit_ptr = reinterpret_cast<const int4*>(x_ptr);
|
|
const int4* __restrict__ y_128bit_ptr = reinterpret_cast<const int4*>(y_ptr);
|
|
int2* __restrict__ out_128bit_ptr = reinterpret_cast<int2*>(out_ptr);
|
|
|
|
float inverted_scale = 1 / *scale;
|
|
#pragma unroll
|
|
for (int32_t vec_idx = vec_start_idx + threadIdx.x; vec_idx < vec_end_idx;
|
|
vec_idx += blockDim.x) {
|
|
const int4 x_128bit = VLLM_LDG(&x_128bit_ptr[vec_idx]);
|
|
const int4 y_128bit = VLLM_LDG(&y_128bit_ptr[vec_idx]);
|
|
using scalar_128bit_vec_t = std::array<scalar_t, elems_per_128bit_load>;
|
|
using scalar_64bit_vec_t = std::array<fp8_type, elems_per_128bit_load>;
|
|
|
|
scalar_64bit_vec_t out_vec;
|
|
const auto x_vec = reinterpret_cast<scalar_128bit_vec_t const&>(x_128bit);
|
|
const auto y_vec = reinterpret_cast<scalar_128bit_vec_t const&>(y_128bit);
|
|
|
|
#pragma unroll
|
|
for (int i = 0; i < elems_per_128bit_load; i++) {
|
|
out_vec[i] = scaled_fp8_conversion<true, fp8_type>(
|
|
ACT_FN(x_vec[i]) * y_vec[i], inverted_scale);
|
|
}
|
|
|
|
out_128bit_ptr[vec_idx] = reinterpret_cast<const int2&>(out_vec);
|
|
}
|
|
|
|
// Scalar cleanup code
|
|
if (block_end > vec_loop_end) {
|
|
for (int64_t idx = vec_loop_end + threadIdx.x; idx < block_end;
|
|
idx += blockDim.x) {
|
|
const scalar_t x = VLLM_LDG(&x_ptr[idx]);
|
|
const scalar_t y = VLLM_LDG(&y_ptr[idx]);
|
|
out_ptr[idx] =
|
|
scaled_fp8_conversion<true, fp8_type>(ACT_FN(x) * y, inverted_scale);
|
|
}
|
|
}
|
|
}
|
|
|
|
__device__ __forceinline__ float silu(float x) {
|
|
return (__fdividef(x, (1.f + expf(-x))));
|
|
}
|
|
|
|
__device__ __forceinline__ float2 silu2(float2 x) {
|
|
return make_float2(silu(x.x), silu(x.y));
|
|
}
|
|
|
|
#ifndef USE_ROCM
|
|
__device__ __forceinline__ float warp_max(float v) {
|
|
static constexpr unsigned FULL_MASK = 0xffffffffu;
|
|
for (int offset = 1; offset < WARP_SIZE; offset *= 2) {
|
|
v = fmaxf(v, __shfl_xor_sync(FULL_MASK, v, offset));
|
|
}
|
|
return v;
|
|
}
|
|
|
|
__device__ __forceinline__ __nv_bfloat16 warp_max(__nv_bfloat16 v) {
|
|
static constexpr unsigned FULL_MASK = 0xffffffffu;
|
|
for (int offset = 1; offset < WARP_SIZE; offset *= 2) {
|
|
v = __hmax(v, __shfl_xor_sync(FULL_MASK, v, offset));
|
|
}
|
|
return v;
|
|
}
|
|
#endif
|
|
|
|
template <typename T, typename U>
|
|
__device__ __forceinline__ void cp_async4(T* _smem_ptr, const U* _glob_ptr) {
|
|
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
|
|
auto smem_ptr = reinterpret_cast<void*>(_smem_ptr);
|
|
auto glob_ptr = reinterpret_cast<const void*>(_glob_ptr);
|
|
const int BYTES = 16;
|
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
|
asm volatile(
|
|
"{\n"
|
|
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
|
"}\n" ::"r"(smem),
|
|
"l"(glob_ptr), "n"(BYTES));
|
|
#else
|
|
_smem_ptr[0] = _glob_ptr[0];
|
|
#endif
|
|
}
|
|
|
|
__device__ __forceinline__ void cp_async_fence() {
|
|
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
|
|
asm volatile("cp.async.commit_group;\n" ::);
|
|
#else
|
|
#endif
|
|
}
|
|
|
|
template <int N>
|
|
__device__ __forceinline__ void cp_async_wait() {
|
|
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
|
|
asm volatile("cp.async.wait_group %0;\n" ::"n"(N));
|
|
#else
|
|
#endif
|
|
}
|
|
|
|
template <>
|
|
__device__ __forceinline__ void cp_async_wait<0>() {
|
|
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
|
|
asm volatile("cp.async.wait_all;\n" ::);
|
|
#else
|
|
#endif
|
|
}
|
|
|
|
__device__ __forceinline__ float clip(float v, float mmin, float mmax) {
|
|
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
|
|
return fminf(mmax, fmaxf(v, mmin));
|
|
#else
|
|
#endif
|
|
}
|
|
|
|
__device__ __forceinline__ __nv_bfloat16 clip(__nv_bfloat16 v,
|
|
__nv_bfloat16 mmin,
|
|
__nv_bfloat16 mmax) {
|
|
return __hmin(mmax, __hmax(v, mmin));
|
|
}
|
|
|
|
__device__ __forceinline__ __nv_bfloat162 clip(__nv_bfloat162 v,
|
|
__nv_bfloat162 mmin,
|
|
__nv_bfloat162 mmax) {
|
|
return __hmin2(mmax, __hmax2(v, mmin));
|
|
}
|
|
|
|
// We use the following values for fp8 min/max:
|
|
// __nv_fp8_e4m3 = (-448, +448)
|
|
// __nv_fp8_e4m3uz = (-240.0, +240.0)
|
|
// It is currently assumed that only
|
|
template <class T>
|
|
constexpr __nv_bfloat16 get_fp8_max() {
|
|
static_assert(std::is_same_v<T, c10::Float8_e4m3fn> ||
|
|
std::is_same_v<T, c10::Float8_e4m3fnuz>);
|
|
if constexpr (std::is_same_v<T, c10::Float8_e4m3fn>) {
|
|
return __nv_bfloat16(__nv_bfloat16_raw{.x = 17376});
|
|
} else {
|
|
return __nv_bfloat16(__nv_bfloat16_raw{.x = 17264});
|
|
}
|
|
}
|
|
|
|
template <class T>
|
|
constexpr __nv_bfloat16 get_fp8_min() {
|
|
static_assert(std::is_same_v<T, c10::Float8_e4m3fn> ||
|
|
std::is_same_v<T, c10::Float8_e4m3fnuz>);
|
|
if constexpr (std::is_same_v<T, c10::Float8_e4m3fn>) {
|
|
return __nv_bfloat16(__nv_bfloat16_raw{.x = 50144});
|
|
} else {
|
|
return __nv_bfloat16(__nv_bfloat16_raw{.x = 50032});
|
|
}
|
|
}
|
|
#ifndef USE_ROCM
|
|
template <typename fp8_type, int32_t NUM_WARPS, typename Idx_t,
|
|
int NUM_PARALLEL_TOKENS, bool USE_UE8M0, int GROUP_SIZE = 128,
|
|
int NUM_STAGES = 3>
|
|
__global__ void silu_mul_fp8_quant_deep_gemm_kernel(
|
|
const __nv_bfloat16* __restrict__ _input, fp8_type* __restrict__ _y_q,
|
|
float* __restrict__ _y_s, const int32_t* __restrict__ counts,
|
|
|
|
// sizes
|
|
int H, int G,
|
|
|
|
// strides (in elements)
|
|
Idx_t stride_i_e, Idx_t stride_i_t, Idx_t stride_i_h, Idx_t stride_yq_e,
|
|
Idx_t stride_yq_t, Idx_t stride_yq_h, Idx_t stride_ys_e, Idx_t stride_ys_t,
|
|
Idx_t stride_ys_g, Idx_t stride_counts_e) {
|
|
static constexpr __nv_bfloat16 fp8_min = get_fp8_min<fp8_type>();
|
|
static constexpr __nv_bfloat16 fp8_max = get_fp8_max<fp8_type>();
|
|
// We assign EPS with its 16-bit unsigned counterpart to allow constexpr.
|
|
static constexpr __nv_bfloat16 EPS = (__nv_bfloat16_raw{.x = 11996});
|
|
|
|
// We pack 8 16-bit bfloat16 values into a 128-bit __int128_t.
|
|
static constexpr int32_t BFLOAT16_PER_GROUP = 8;
|
|
|
|
// We split the shared memory in half, corresponding to gate and up matrices:
|
|
// [...gate_i, ...up_i] where 0 <= i < stages.
|
|
static constexpr int32_t S_NUM_128 =
|
|
2u * (GROUP_SIZE / BFLOAT16_PER_GROUP) * NUM_WARPS * NUM_STAGES;
|
|
static constexpr auto THREAD_COUNT = NUM_WARPS * WARP_SIZE;
|
|
static constexpr int HALF_THREAD_COUNT = THREAD_COUNT / 2;
|
|
static constexpr int32_t S_NUM_64 = S_NUM_128 * 2;
|
|
__shared__ __int128_t __align__(16) s_buff_128[S_NUM_128];
|
|
|
|
const int32_t tid = threadIdx.x;
|
|
const int32_t warp_id = tid / WARP_SIZE;
|
|
const int32_t lane_id = tid % WARP_SIZE;
|
|
|
|
auto s_buff_compute_32 = reinterpret_cast<__nv_bfloat162*>(s_buff_128);
|
|
|
|
// block handles one (expert e, group g)
|
|
int32_t pid = blockIdx.x;
|
|
int32_t e = pid / G;
|
|
int32_t g = pid % G;
|
|
|
|
const int32_t n_tokens = counts[e * stride_counts_e];
|
|
|
|
if (!n_tokens) {
|
|
return; // Exit ASAP.
|
|
}
|
|
|
|
const Idx_t stride_i_t_128 = stride_i_t / 8u;
|
|
|
|
int32_t n_tokens_lower, n_tokens_upper;
|
|
|
|
// Each block i iterates over tokens of a slice of n_tokens =
|
|
// expert_counts[i], with the size of chunk being
|
|
// (n_tokens / NUM_PARALLEL_TOKENS) + residual, instead of
|
|
// updiv(n_tokens, NUM_PARALLEL_TOKENS) for better scheduling.
|
|
if (n_tokens < NUM_PARALLEL_TOKENS && blockIdx.y < n_tokens) {
|
|
// Specialize this, but can be likely fused.
|
|
if (blockIdx.y >= NUM_PARALLEL_TOKENS) {
|
|
return;
|
|
}
|
|
n_tokens_lower = blockIdx.y;
|
|
n_tokens_upper = blockIdx.y + 1;
|
|
} else {
|
|
auto chunk_size = n_tokens / NUM_PARALLEL_TOKENS;
|
|
auto residual = n_tokens - chunk_size * NUM_PARALLEL_TOKENS;
|
|
auto calc_id = [&](int32_t id) {
|
|
if (id < residual) {
|
|
return min(n_tokens, id * (chunk_size + 1));
|
|
} else {
|
|
return min(n_tokens, id * chunk_size + residual);
|
|
}
|
|
};
|
|
n_tokens_lower = calc_id(blockIdx.y);
|
|
n_tokens_upper = calc_id(blockIdx.y + 1);
|
|
}
|
|
|
|
if (n_tokens_lower >= n_tokens_upper) {
|
|
return;
|
|
}
|
|
|
|
// We do calculations here, using constexpr wherever possible.
|
|
const Idx_t base_i = e * stride_i_e + NUM_WARPS * g * GROUP_SIZE * stride_i_h;
|
|
const Idx_t base_ys = e * stride_ys_e + NUM_WARPS * g * stride_ys_g;
|
|
const Idx_t base_yq =
|
|
e * stride_yq_e + NUM_WARPS * g * GROUP_SIZE * stride_yq_h;
|
|
Idx_t gate_off_128 = (base_i / static_cast<Idx_t>(8u));
|
|
auto input_128_ptr = reinterpret_cast<const __int128_t*>(_input);
|
|
auto gate_128_ptr = input_128_ptr + gate_off_128 + (tid % HALF_THREAD_COUNT) +
|
|
stride_i_t_128 * n_tokens_lower;
|
|
auto up_128_ptr = gate_128_ptr + (H * stride_i_h) / 8u;
|
|
auto y_s_ptr =
|
|
_y_s + base_ys + warp_id * stride_ys_g + n_tokens_lower * stride_ys_t;
|
|
auto y_q_ptr = _y_q + base_yq + warp_id * GROUP_SIZE +
|
|
stride_yq_t * n_tokens_lower + 4 * lane_id;
|
|
int32_t t_load = n_tokens_lower, load_stage_id = 0;
|
|
auto s_buff_gate_load_128 = s_buff_128 + (tid % HALF_THREAD_COUNT);
|
|
auto s_buff_up_load_128 = s_buff_gate_load_128 + S_NUM_128 / 2u;
|
|
int32_t stage_offset{};
|
|
|
|
static constexpr int32_t LOAD_STAGE_SIZE = (NUM_WARPS * WARP_SIZE / 2);
|
|
static constexpr int32_t LOAD_STAGE_MOD =
|
|
NUM_STAGES * (NUM_WARPS * WARP_SIZE / 2);
|
|
|
|
// Two halves of all threads in a block conduct global loads for gate and up,
|
|
// repsectively.
|
|
auto load_and_advance_y_pred = [&] {
|
|
if (t_load < n_tokens_upper) {
|
|
auto s_gate_stage_128_staged_ptr = s_buff_gate_load_128 + stage_offset;
|
|
auto s_up_stage_128_staged_ptr = s_buff_up_load_128 + stage_offset;
|
|
|
|
// It is very important that LOAD_STAGE_SIZE is constexpr to avoid
|
|
// unnecessary ALU ops.
|
|
stage_offset += LOAD_STAGE_SIZE;
|
|
stage_offset %= LOAD_STAGE_MOD;
|
|
|
|
if (tid < HALF_THREAD_COUNT) {
|
|
cp_async4(s_gate_stage_128_staged_ptr, gate_128_ptr);
|
|
gate_128_ptr += stride_i_t_128;
|
|
} else {
|
|
cp_async4(s_up_stage_128_staged_ptr, up_128_ptr);
|
|
up_128_ptr += stride_i_t_128;
|
|
}
|
|
++t_load;
|
|
++load_stage_id;
|
|
}
|
|
// We fence even if there is nothing to load to simplify pipelining.
|
|
cp_async_fence();
|
|
};
|
|
|
|
#pragma unroll
|
|
for (int i = 0; i < NUM_STAGES - 1; i++) {
|
|
load_and_advance_y_pred();
|
|
}
|
|
|
|
__int64_t* s_gate_ptr = reinterpret_cast<__int64_t*>(
|
|
s_buff_compute_32 + warp_id * (GROUP_SIZE / 2)) +
|
|
lane_id;
|
|
__int64_t* s_up_ptr = s_gate_ptr + S_NUM_64 / 2;
|
|
|
|
static constexpr int32_t STAGE_SIZE = (GROUP_SIZE * NUM_WARPS) / 4u;
|
|
static constexpr int32_t STAGE_MOD = STAGE_SIZE * NUM_STAGES;
|
|
|
|
int32_t compute_pipeline_offset_64 = 0;
|
|
|
|
for (int32_t t = n_tokens_lower; t < n_tokens_upper; ++t) {
|
|
__nv_bfloat162 results_bf162[2];
|
|
|
|
cp_async_wait<NUM_STAGES - 2>();
|
|
__syncthreads();
|
|
|
|
// We double-buffer pipelined loads so that the next load will
|
|
// concurrently run with compute without overwrites.
|
|
load_and_advance_y_pred();
|
|
|
|
auto s_gate_compute_64 = s_gate_ptr + compute_pipeline_offset_64;
|
|
auto s_up_compute_64 = s_up_ptr + compute_pipeline_offset_64;
|
|
|
|
// STAGE_SIZE must also be constexpr!
|
|
compute_pipeline_offset_64 += STAGE_SIZE;
|
|
compute_pipeline_offset_64 %= STAGE_MOD;
|
|
|
|
// Each thread loads (gate/up) 2X 4X bfloat16 values into registers.
|
|
__int64_t gate64 = *s_gate_compute_64;
|
|
__nv_bfloat162* s_gate_compute_32 =
|
|
reinterpret_cast<__nv_bfloat162*>(&gate64);
|
|
|
|
__int64_t up64 = *s_up_compute_64;
|
|
__nv_bfloat162* s_up_compute_32 = reinterpret_cast<__nv_bfloat162*>(&up64);
|
|
|
|
#pragma unroll
|
|
for (int i = 0; i < 2; i++) {
|
|
// For silu, we make sure that div is emitted.
|
|
float2 gate = silu2(__bfloat1622float2(s_gate_compute_32[i]));
|
|
results_bf162[i] = __float22bfloat162_rn(gate);
|
|
}
|
|
|
|
#pragma unroll
|
|
for (int i = 0; i < 2; i++) {
|
|
results_bf162[i] = __hmul2(results_bf162[i], s_up_compute_32[i]);
|
|
}
|
|
|
|
auto _y_max2 =
|
|
__hmax2(__habs2(results_bf162[0]), __habs2(results_bf162[1]));
|
|
|
|
__nv_bfloat16 y_max_bf16 = __hmax(EPS, __hmax(_y_max2.x, _y_max2.y));
|
|
|
|
// An entire group is assigned to a single warp, so a simple warp reduce
|
|
// is used.
|
|
__nv_bfloat16 y_s = warp_max(y_max_bf16) / fp8_max;
|
|
|
|
if constexpr (USE_UE8M0) {
|
|
y_s = hexp2(hceil(hlog2(y_s)));
|
|
}
|
|
|
|
auto inv_y = __float2bfloat16_rn(1.f) / y_s;
|
|
|
|
auto y_s2 = make_bfloat162(inv_y, inv_y);
|
|
|
|
#pragma unroll
|
|
for (int32_t i = 0; i < 2; ++i) {
|
|
results_bf162[i] =
|
|
clip(__hmul2(results_bf162[i], y_s2), __bfloat162bfloat162(fp8_min),
|
|
__bfloat162bfloat162(fp8_max));
|
|
}
|
|
|
|
auto fp8x4 = __nv_fp8x4_e4m3(results_bf162[0], results_bf162[1]);
|
|
*reinterpret_cast<__nv_fp8x4_e4m3*>(y_q_ptr) = fp8x4;
|
|
y_q_ptr += stride_yq_t;
|
|
|
|
if (lane_id == 0) {
|
|
*y_s_ptr = y_s;
|
|
y_s_ptr += stride_ys_t;
|
|
}
|
|
}
|
|
}
|
|
#endif
|
|
|
|
} // namespace vllm
|
|
|
|
// Launch activation, gating, and quantize kernel.
|
|
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
|
|
int d = input.size(-1) / 2; \
|
|
int64_t num_tokens = input.numel() / input.size(-1); \
|
|
dim3 grid(num_tokens, num_tokens > 16 ? num_tokens > 32 ? 1 : 2 : 4); \
|
|
dim3 block(std::min(d, 512)); \
|
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
|
VLLM_DISPATCH_FLOATING_TYPES( \
|
|
input.scalar_type(), "act_and_mul_kernel", [&] { \
|
|
VLLM_DISPATCH_FP8_TYPES( \
|
|
out.scalar_type(), "fused_add_rms_norm_kernel_fp8_type", [&] { \
|
|
vllm::act_and_mul_quant_kernel<scalar_t, KERNEL<scalar_t>, \
|
|
fp8_t> \
|
|
<<<grid, block, 0, stream>>>(out.data_ptr<fp8_t>(), \
|
|
input.data_ptr<scalar_t>(), \
|
|
scale.data_ptr<float>(), d); \
|
|
}); \
|
|
});
|
|
|
|
void silu_and_mul_quant(torch::Tensor& out, // [..., d]
|
|
torch::Tensor& input, // [..., 2 * d]
|
|
torch::Tensor& scale) {
|
|
TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn ||
|
|
out.dtype() == torch::kFloat8_e4m3fnuz);
|
|
TORCH_CHECK(input.dtype() == torch::kFloat16 ||
|
|
input.dtype() == torch::kBFloat16);
|
|
TORCH_CHECK(input.size(-1) % 2 == 0);
|
|
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
|
|
}
|
|
|
|
void silu_mul_fp8_quant_deep_gemm_cuda(
|
|
const at::Tensor& input, // (E, T, 2*H)
|
|
const at::Tensor& counts, // (E)
|
|
at::Tensor& y_q, // (E, T, H) [OUT]
|
|
at::Tensor& y_s, // (E, T, H//group_size) [OUT]
|
|
int64_t group_size, bool use_ue8m0, int64_t num_parallel_tokens) {
|
|
#ifndef USE_ROCM
|
|
// This kernel relies heavily on cp.async and fp8 support.
|
|
// This kernel currently only supports H % 128 == 0 and assumes a
|
|
// fixed GROUP_SIZE of 128.
|
|
TORCH_CHECK(input.dtype() == torch::kBFloat16);
|
|
TORCH_CHECK(y_q.dtype() == torch::kFloat8_e4m3fn ||
|
|
y_q.dtype() == torch::kFloat8_e4m3fnuz);
|
|
TORCH_CHECK(y_s.dtype() == torch::kFloat32);
|
|
TORCH_CHECK(input.size(-1) % 256 == 0);
|
|
|
|
// Check that num_parallel_tokens is of power of 2 and between 1 and 64.
|
|
TORCH_CHECK(1 <= num_parallel_tokens && num_parallel_tokens <= 64);
|
|
TORCH_CHECK(!(num_parallel_tokens & (num_parallel_tokens - 1)));
|
|
|
|
using Idx_t = int64_t;
|
|
|
|
Idx_t E = input.size(0);
|
|
Idx_t T = input.size(1);
|
|
Idx_t H = input.size(2) / 2;
|
|
Idx_t stride_i_e = input.stride(0);
|
|
Idx_t stride_i_t = input.stride(1);
|
|
Idx_t stride_i_h = input.stride(2);
|
|
Idx_t stride_yq_e = y_q.stride(0);
|
|
Idx_t stride_yq_t = y_q.stride(1);
|
|
Idx_t stride_yq_h = y_q.stride(2);
|
|
Idx_t stride_ys_e = y_s.stride(0);
|
|
Idx_t stride_ys_t = y_s.stride(1);
|
|
Idx_t stride_ys_g = y_s.stride(2);
|
|
|
|
Idx_t stride_counts_e = counts.stride(0);
|
|
|
|
static constexpr int GROUP_SIZE = 128;
|
|
|
|
#define KERNEL_FN \
|
|
if (use_ue8m0) { \
|
|
vllm::silu_mul_fp8_quant_deep_gemm_kernel<fp8_t, NUM_WARPS, Idx_t, \
|
|
NUM_PARALLEL_TOKENS, true> \
|
|
<<<grid, block, 0, stream>>>( \
|
|
reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \
|
|
(fp8_t*)y_q.data_ptr(), y_s.data_ptr<float>(), \
|
|
reinterpret_cast<int32_t*>(counts.data_ptr<int>()), H, G, \
|
|
stride_i_e, stride_i_t, stride_i_h, stride_yq_e, stride_yq_t, \
|
|
stride_yq_h, stride_ys_e, stride_ys_t, stride_ys_g, \
|
|
stride_counts_e); \
|
|
} else { \
|
|
vllm::silu_mul_fp8_quant_deep_gemm_kernel<fp8_t, NUM_WARPS, Idx_t, \
|
|
NUM_PARALLEL_TOKENS, false> \
|
|
<<<grid, block, 0, stream>>>( \
|
|
reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \
|
|
(fp8_t*)y_q.data_ptr(), y_s.data_ptr<float>(), \
|
|
reinterpret_cast<int32_t*>(counts.data_ptr<int>()), H, G, \
|
|
stride_i_e, stride_i_t, stride_i_h, stride_yq_e, stride_yq_t, \
|
|
stride_yq_h, stride_ys_e, stride_ys_t, stride_ys_g, \
|
|
stride_counts_e); \
|
|
}
|
|
|
|
#define KERNEL_CALL_H \
|
|
if (H % (4 * GROUP_SIZE) == 0) { \
|
|
static constexpr int NUM_WARPS = 4; \
|
|
populate_launch_params(NUM_WARPS, NUM_PARALLEL_TOKENS); \
|
|
KERNEL_FN \
|
|
} else { \
|
|
static constexpr int NUM_WARPS = 1; \
|
|
populate_launch_params(NUM_WARPS, NUM_PARALLEL_TOKENS); \
|
|
KERNEL_FN \
|
|
}
|
|
|
|
#define KERNEL_CALL_TOP_LEVEL \
|
|
if (num_parallel_tokens == 1) { \
|
|
static constexpr int NUM_PARALLEL_TOKENS = 1; \
|
|
KERNEL_CALL_H \
|
|
} else if (num_parallel_tokens == 2) { \
|
|
static constexpr int NUM_PARALLEL_TOKENS = 2; \
|
|
KERNEL_CALL_H \
|
|
} else if (num_parallel_tokens == 4) { \
|
|
static constexpr int NUM_PARALLEL_TOKENS = 4; \
|
|
KERNEL_CALL_H \
|
|
} else if (num_parallel_tokens == 8) { \
|
|
static constexpr int NUM_PARALLEL_TOKENS = 8; \
|
|
KERNEL_CALL_H \
|
|
} else if (num_parallel_tokens == 16) { \
|
|
static constexpr int NUM_PARALLEL_TOKENS = 16; \
|
|
KERNEL_CALL_H \
|
|
} else if (num_parallel_tokens == 32) { \
|
|
static constexpr int NUM_PARALLEL_TOKENS = 32; \
|
|
KERNEL_CALL_H \
|
|
} else if (num_parallel_tokens == 64) { \
|
|
static constexpr int NUM_PARALLEL_TOKENS = 64; \
|
|
KERNEL_CALL_H \
|
|
}
|
|
|
|
Idx_t G;
|
|
dim3 block, grid;
|
|
auto populate_launch_params = [&](int num_warps, int _num_parallel_tokens) {
|
|
G = H / Idx_t(group_size * num_warps);
|
|
grid = dim3(E * G, _num_parallel_tokens);
|
|
block = dim3(num_warps * WARP_SIZE);
|
|
};
|
|
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
|
VLLM_DISPATCH_FP8_TYPES(y_q.scalar_type(),
|
|
"silu_mul_fp8_quant_deep_gemm_kernel",
|
|
[&] { KERNEL_CALL_TOP_LEVEL });
|
|
|
|
#endif
|
|
}
|