mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 13:05:01 +08:00
Signed-off-by: nicole-lihui <nicole.li@daocloud.io> Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: courage17340 <courage17340@163.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Jacob Kahn <jacobkahn1@gmail.com> Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com> Signed-off-by: Roger Wang <hey@rogerw.io> Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: zxw <1020938856@qq.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: wang.yuqi <noooop@126.com> Signed-off-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Signed-off-by: Kunshang Ji <kunshang.ji@intel.com> Signed-off-by: chenlang <chen.lang5@zte.com.cn> Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: Jonas Kuebler <kuebj@amazon.com> Signed-off-by: jiang1.li <jiang1.li@intel.com> Signed-off-by: Russell Bryant <rbryant@redhat.com> Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Signed-off-by: AlonKejzman <alonkeizman@gmail.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: taohui <taohui3@gmail.com> Signed-off-by: Tao Hui <taohui3@gmail.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Signed-off-by: Zhuohan Li <zhuohan123@gmail.com> Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Signed-off-by: Shu Wang. <shuw@nvidia.com> Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Signed-off-by: Eugene Khvedchenia <ekhvedchenia@nvidia.com> Signed-off-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Signed-off-by: yiting.jiang <yiting.jiang@daocloud.io> Signed-off-by: Andrew Sansom <andrew@protopia.ai> Signed-off-by: xaguilar <Xavier.AguilarFruto@amd.com> Signed-off-by: Iceber Gu <caiwei95@hotmail.com> Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com> Signed-off-by: Icey <1790571317@qq.com> Signed-off-by: Sage Moore <sage@neuralmagic.com> Signed-off-by: 许文卿 <xwq391974@alibaba-inc.com> Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com> Signed-off-by: Seiji Eicher <seiji@anyscale.com> Signed-off-by: Seiji Eicher <58963096+eicherseiji@users.noreply.github.com> Signed-off-by: zjy0516 <riverclouds.zhu@qq.com> Signed-off-by: Kosseila (CloudThrill) <klouddude@gmail.com> Signed-off-by: frankwang28 <frank.wbb@hotmail.com> Signed-off-by: Frank Wang <41319051+frankwang28@users.noreply.github.com> Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> Signed-off-by: zixi-qi <qizixi@meta.com> Signed-off-by: Bram Wasti <bwasti@meta.com> Signed-off-by: Naman Lalit <nl2688@nyu.edu> Signed-off-by: Chenheli Hua <huachenheli@outlook.com> Signed-off-by: Junhong <liujunhong11@huawei.com> Signed-off-by: Junhong Liu <98734602+LJH-LBJ@users.noreply.github.com> Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com> Signed-off-by: rentianyue-jk <rentianyue-jk@360shuke.com> Signed-off-by: Peter Pan <Peter.Pan@daocloud.io> Signed-off-by: Patrick Toulme <ptoulme@meta.com> Signed-off-by: Patrick Toulme <pctoulme+1@gmail.com> Signed-off-by: Jiangyun Zhu <riverclouds.zhu@qq.com> Signed-off-by: Clayton Coleman <smarterclayton@gmail.com> Signed-off-by: Jialin Ouyang <jialino@meta.com> Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Signed-off-by: Weiliang Liu <weiliangl@nvidia.com> Signed-off-by: zRzRzRzRzRzRzR <2448370773@qq.com> Signed-off-by: liuye.hj <liuye.hj@alibaba-inc.com> Signed-off-by: Juechen Liu <jueliu@meta.com> Signed-off-by: simon-mo <simon.mo@hey.com> Signed-off-by: Robert Shaw <robshaw@redhat.com> Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Signed-off-by: isotr0py <2037008807@qq.com> Signed-off-by: yingjun-mou <renzomou@gmail.com> Signed-off-by: zhoukz <me@zhoukz.com> Signed-off-by: Chenxi Yang <cxyang@fb.com> Signed-off-by: Rahul Tuli <rtuli@redhat.com> Signed-off-by: Lee Nau <lnau@nvidia.com> Signed-off-by: adabeyta <aabeyta@redhat.com> Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: simondanielsson <simon.danielsson99@hotmail.com> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> Signed-off-by: Lucia Fang <fanglu@meta.com> Signed-off-by: a120092009 <zhaoty0121@gmail.com> Signed-off-by: sergiopaniego <sergiopaniegoblanco@gmail.com> Signed-off-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com> Signed-off-by: wangyafeng <wangyafeng@baidu.com> Signed-off-by: Lehua Ding <lehuading@tencent.com> Signed-off-by: lyd1992 <liuyudong@iscas.ac.cn> Signed-off-by: ihb2032 <1355790728@qq.com> Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com> Signed-off-by: anion <1005128408@qq.com> Signed-off-by: Anion <123177548+Anionex@users.noreply.github.com> Signed-off-by: Pavani Majety <pmajety@nvidia.com> Signed-off-by: Bill Nell <bnell@redhat.com> Signed-off-by: bnellnm <49004751+bnellnm@users.noreply.github.com> Signed-off-by: Or Ozeri <oro@il.ibm.com> Signed-off-by: cjackal <44624812+cjackal@users.noreply.github.com> Signed-off-by: David Ben-David <davidb@pliops.com> Signed-off-by: Andrew Xia <axia@meta.com> Signed-off-by: Andrew Xia <axia@fb.com> Signed-off-by: Lu Fang <fanglu@fb.com> Signed-off-by: Salvatore Cena <cena@cenas.it> Signed-off-by: padg9912 <phone.and.desktop@gmail.com> Signed-off-by: nadathurv <work.vnadathur@gmail.com> Signed-off-by: WorldExplored <srreyansh.sethi@gmail.com> Signed-off-by: wwl2755 <wangwenlong2755@gmail.com> Signed-off-by: billishyahao <bill.he@amd.com> Signed-off-by: Nathan Scott <nathans@redhat.com> Signed-off-by: Kenichi Maehashi <maehashi@preferred.jp> Signed-off-by: Johnny <johnnynuca14@gmail.com> Signed-off-by: johnnynunez <johnnynuca14@gmail.com> Signed-off-by: Johnny <johnnync13@gmail.com> Signed-off-by: Huamin Li <3ericli@gmail.com> Signed-off-by: Hosang Yoon <hosang.yoon@amd.com> Signed-off-by: Jerry Zhang <jerryzh168@gmail.com> Signed-off-by: Peter Schuurman <psch@google.com> Signed-off-by: Huy Do <huydhn@gmail.com> Signed-off-by: leo-pony <nengjunma@outlook.com> Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: ElizaWszola <elizaw.9289@gmail.com> Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Signed-off-by: zhewenli <zhewenli@meta.com> Signed-off-by: ahao-anyscale <ahao@anyscale.com> Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Signed-off-by: huijjj <huijong.jeong@squeezebits.com> Signed-off-by: Yannick Schnider <yannick.schnider1@ibm.com> Signed-off-by: kyt <eluban4532@gmail.com> Signed-off-by: Egor <e.a.krivov@gmail.com> Signed-off-by: Yang <lymailforjob@gmail.com> Signed-off-by: Paul Pak <paulpak58@gmail.com> Signed-off-by: whx-sjtu <2952154980@qq.com> Signed-off-by: Xiang Si <sixiang@google.com> Signed-off-by: Aleksandr Samarin <astrlrd@nebius.com> Signed-off-by: Jun Jiang <jasl9187@hotmail.com> Signed-off-by: Chendi Xue <Chendi.Xue@intel.com> Signed-off-by: Chendi.Xue <chendi.xue@intel.com> Signed-off-by: Nikhil Ghosh <nikhil@anyscale.com> Co-authored-by: Nicole LiHui 🥜 <nicolelihui@outlook.com> Co-authored-by: courage17340 <courage17340@users.noreply.github.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Co-authored-by: Jacob Kahn <jacobkahn1@gmail.com> Co-authored-by: Roger Wang <hey@rogerw.io> Co-authored-by: Nicole LiHui 🥜 <nicole.li@daocloud.io> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Fadi Arafeh <115173828+fadara01@users.noreply.github.com> Co-authored-by: Agata Dobrzyniewicz <160237065+adobrzyn@users.noreply.github.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: yyzxw <34639446+yyzxw@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: wang.yuqi <noooop@126.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com> Co-authored-by: chenlang <chen.lang5@zte.com.cn> Co-authored-by: chenlang <10346245@zte.com.cn> Co-authored-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Jonas M. Kübler <44084297+jmkuebler@users.noreply.github.com> Co-authored-by: Li, Jiang <jiang1.li@intel.com> Co-authored-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com> Co-authored-by: AlonKejzman <alonkeizman@gmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Tao Hui <taohui3@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Co-authored-by: Nick Hill <nhill@redhat.com> Co-authored-by: Zhuohan Li <zhuohan123@gmail.com> Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com> Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> Co-authored-by: Shu Wang <shuw@nvidia.com> Co-authored-by: Aleksandr Malyshev <164964928+maleksan85@users.noreply.github.com> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Doug Lehr <douglehr@amd.com> Co-authored-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Co-authored-by: yitingdc <59356937+yitingdc@users.noreply.github.com> Co-authored-by: Andrew Sansom <andrew@protopia.ai> Co-authored-by: xaguilar-amd <xavier.aguilarfruto@amd.com> Co-authored-by: Iceber Gu <caiwei95@hotmail.com> Co-authored-by: Tao He <linzhu.ht@alibaba-inc.com> Co-authored-by: Icey <1790571317@qq.com> Co-authored-by: Sage Moore <sage@neuralmagic.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Xu Wenqing <121550081+Xu-Wenqing@users.noreply.github.com> Co-authored-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: RishiAstra <40644327+RishiAstra@users.noreply.github.com> Co-authored-by: Chauncey <chaunceyjiang@gmail.com> Co-authored-by: Seiji Eicher <58963096+eicherseiji@users.noreply.github.com> Co-authored-by: Rui Qiao <161574667+ruisearch42@users.noreply.github.com> Co-authored-by: Jiangyun Zhu <riverclouds.zhu@qq.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: 阿丹(adan) <47373076+LDLINGLINGLING@users.noreply.github.com> Co-authored-by: liudan <adan@minicpm.com> Co-authored-by: liudan <liudan@qq.com> Co-authored-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Co-authored-by: Clouddude <kouss.hd@gmail.com> Co-authored-by: Frank Wang <41319051+frankwang28@users.noreply.github.com> Co-authored-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> Co-authored-by: qizixi <22851944+zixi-qi@users.noreply.github.com> Co-authored-by: Bram Wasti <bwasti@fb.com> Co-authored-by: Naman Lalit <nl2688@nyu.edu> Co-authored-by: Chenheli Hua <huachenheli@outlook.com> Co-authored-by: WeiQing Chen <40507679+david6666666@users.noreply.github.com> Co-authored-by: Junhong <liujunhong11@huawei.com> Co-authored-by: LJH-LBJ <98734602+LJH-LBJ@users.noreply.github.com> Co-authored-by: 22quinn <33176974+22quinn@users.noreply.github.com> Co-authored-by: Xiaohan Zou <renovamenzxh@gmail.com> Co-authored-by: rentianyue-jk <rentianyue-jk@360shuke.com> Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Co-authored-by: Peter Pan <peter.pan@daocloud.io> Co-authored-by: Patrick C. Toulme <135739773+patrick-toulme@users.noreply.github.com> Co-authored-by: Clayton Coleman <smarterclayton@gmail.com> Co-authored-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Co-authored-by: Jialin Ouyang <jialino@meta.com> Co-authored-by: weiliang <weiliangl@nvidia.com> Co-authored-by: Yuxuan Zhang <2448370773@qq.com> Co-authored-by: JJJYmmm <92386084+JJJYmmm@users.noreply.github.com> Co-authored-by: liuye.hj <liuye.hj@alibaba-inc.com> Co-authored-by: Juechen Liu <grinchcoder@gmail.com> Co-authored-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Yingjun Mou <renzomou@gmail.com> Co-authored-by: Zhou Jiahao <me@zhoukz.com> Co-authored-by: Chenxi Yang <cxyang@cs.utexas.edu> Co-authored-by: Chenxi Yang <cxyang@fb.com> Co-authored-by: Rahul Tuli <rtuli@redhat.com> Co-authored-by: Lee Nau <lee.nau@gmail.com> Co-authored-by: Adrian Abeyta <aabeyta@redhat.com> Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Co-authored-by: Aaron Pham <contact@aarnphm.xyz> Co-authored-by: acisseJZhong <40467976+acisseJZhong@users.noreply.github.com> Co-authored-by: Simon Danielsson <70206058+simondanielsson@users.noreply.github.com> Co-authored-by: Yongye Zhu <zyy1102000@gmail.com> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Lucia Fang <fanglu@meta.com> Co-authored-by: Siyuan Fu <siyuanf@nvidia.com> Co-authored-by: Xiaozhu Meng <mxz297@gmail.com> Co-authored-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> Co-authored-by: a120092009 <33205509+a120092009@users.noreply.github.com> Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com> Co-authored-by: CSWYF3634076 <wangyafeng@baidu.com> Co-authored-by: Lehua Ding <lehuading@tencent.com> Co-authored-by: Reza Barazesh <3146276+rzabarazesh@users.noreply.github.com> Co-authored-by: ihb2032 <40718643+ihb2032@users.noreply.github.com> Co-authored-by: Asaf Joseph Gardin <39553475+Josephasafg@users.noreply.github.com> Co-authored-by: Anion <123177548+Anionex@users.noreply.github.com> Co-authored-by: Pavani Majety <pmajety@nvidia.com> Co-authored-by: bnellnm <49004751+bnellnm@users.noreply.github.com> Co-authored-by: Or Ozeri <oro@il.ibm.com> Co-authored-by: cjackal <44624812+cjackal@users.noreply.github.com> Co-authored-by: David Ben-David <sdavidbd@gmail.com> Co-authored-by: David Ben-David <davidb@pliops.com> Co-authored-by: Andrew Xia <axia@mit.edu> Co-authored-by: Andrew Xia <axia@fb.com> Co-authored-by: Salvatore Cena <cena@cenas.it> Co-authored-by: Param <psch@cs.unc.edu> Co-authored-by: Zhewen Li <zhewenli@meta.com> Co-authored-by: nadathurv <work.vnadathur@gmail.com> Co-authored-by: Srreyansh Sethi <107075589+WorldExplored@users.noreply.github.com> Co-authored-by: Wenlong Wang <wangwenlong2755@gmail.com> Co-authored-by: billishyahao <bill.he@amd.com> Co-authored-by: Nathan Scott <natoscott@users.noreply.github.com> Co-authored-by: Kenichi Maehashi <939877+kmaehashi@users.noreply.github.com> Co-authored-by: Johnny <johnnync13@gmail.com> Co-authored-by: Aidyn-A <31858918+Aidyn-A@users.noreply.github.com> Co-authored-by: Huamin Li <3ericli@gmail.com> Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com> Co-authored-by: Hosang <156028780+hyoon1@users.noreply.github.com> Co-authored-by: Jerry Zhang <jerryzh168@gmail.com> Co-authored-by: pwschuurman <psch@google.com> Co-authored-by: Huy Do <huydhn@gmail.com> Co-authored-by: leo-pony <nengjunma@outlook.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: ElizaWszola <ewszola@redhat.com> Co-authored-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: Benjamin Chislett <bchislett@nvidia.com> Co-authored-by: Andrew Xia <axia@meta.com> Co-authored-by: Simon Mo <simon.mo@hey.com> Co-authored-by: TJian <tunjian.tan@embeddedllm.com> Co-authored-by: ahao-anyscale <ahao@anyscale.com> Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Liu-congo <1502632128@qq.com> Co-authored-by: HUIJONG JEONG <64083281+huijjj@users.noreply.github.com> Co-authored-by: Yannick Schnider <Yannick.Schnider1@ibm.com> Co-authored-by: kyt <eluban4532@gmail.com> Co-authored-by: Egor <e.a.krivov@gmail.com> Co-authored-by: Yang Liu <127183760+KKSK-DON@users.noreply.github.com> Co-authored-by: Paul Pak <52512091+paulpak58@users.noreply.github.com> Co-authored-by: whx <56632993+whx-sjtu@users.noreply.github.com> Co-authored-by: Xiang Si <sixiang@google.com> Co-authored-by: Aleksandr Samarin <samarin_ad@mail.ru> Co-authored-by: Jun Jiang <jasl9187@hotmail.com> Co-authored-by: Chendi.Xue <chendi.xue@intel.com> Co-authored-by: Nikhil G <nrghosh@users.noreply.github.com>
215 lines
9.0 KiB
Plaintext
215 lines
9.0 KiB
Plaintext
#include <ATen/cuda/CUDAContext.h>
|
|
|
|
#include "quantization/w8a8/per_token_group_quant_8bit.h"
|
|
|
|
#include <cmath>
|
|
|
|
#include <cuda_fp8.h>
|
|
|
|
#include <torch/all.h>
|
|
|
|
#include "quantization/vectorization.cuh"
|
|
#include "quantization/vectorization_utils.cuh"
|
|
#include "dispatch_utils.h"
|
|
|
|
__device__ __forceinline__ float GroupReduceMax(float val) {
|
|
unsigned mask = threadIdx.x % 32 >= 16 ? 0xffff0000 : 0x0000ffff;
|
|
|
|
val = fmaxf(val, __shfl_xor_sync(mask, val, 8));
|
|
val = fmaxf(val, __shfl_xor_sync(mask, val, 4));
|
|
val = fmaxf(val, __shfl_xor_sync(mask, val, 2));
|
|
val = fmaxf(val, __shfl_xor_sync(mask, val, 1));
|
|
return val;
|
|
}
|
|
|
|
template <typename T, typename DST_DTYPE, bool IS_COLUMN_MAJOR = false,
|
|
bool SCALE_UE8M0 = false, typename scale_packed_t = float>
|
|
__global__ void per_token_group_quant_8bit_kernel(
|
|
const T* __restrict__ input, void* __restrict__ output_q,
|
|
scale_packed_t* __restrict__ output_s, const int group_size,
|
|
const int num_groups, const int groups_per_block, const float eps,
|
|
const float min_8bit, const float max_8bit, const int scale_num_rows = 0,
|
|
const int scale_stride = 0) {
|
|
const int threads_per_group = 16;
|
|
const int64_t local_group_id = threadIdx.x / threads_per_group;
|
|
const int lane_id = threadIdx.x % threads_per_group;
|
|
|
|
const int64_t block_group_id = blockIdx.x * groups_per_block;
|
|
const int64_t global_group_id = block_group_id + local_group_id;
|
|
const int64_t block_group_offset = global_group_id * group_size;
|
|
|
|
float local_absmax = eps;
|
|
|
|
using scale_element_t = float;
|
|
static_assert(sizeof(scale_packed_t) % sizeof(scale_element_t) == 0);
|
|
|
|
const T* group_input = input + block_group_offset;
|
|
DST_DTYPE* group_output =
|
|
static_cast<DST_DTYPE*>(output_q) + block_group_offset;
|
|
scale_element_t* scale_output;
|
|
|
|
if constexpr (IS_COLUMN_MAJOR) {
|
|
const int num_elems_per_pack =
|
|
static_cast<int>(sizeof(scale_packed_t) / sizeof(scale_element_t));
|
|
const int scale_num_rows_element = scale_num_rows * num_elems_per_pack;
|
|
const int row_idx = global_group_id / scale_num_rows_element;
|
|
const int col_idx_raw = global_group_id % scale_num_rows_element;
|
|
const int col_idx = col_idx_raw / num_elems_per_pack;
|
|
const int pack_idx = col_idx_raw % num_elems_per_pack;
|
|
scale_output = reinterpret_cast<scale_element_t*>(output_s) +
|
|
(col_idx * scale_stride * num_elems_per_pack +
|
|
row_idx * num_elems_per_pack + pack_idx);
|
|
} else {
|
|
scale_output = output_s + global_group_id;
|
|
}
|
|
|
|
// shared memory to cache each group's data to avoid double DRAM reads.
|
|
extern __shared__ __align__(16) char smem_raw[];
|
|
T* smem = reinterpret_cast<T*>(smem_raw);
|
|
T* smem_group = smem + local_group_id * group_size;
|
|
|
|
constexpr int vec_size = 16 / sizeof(T);
|
|
using vec_t = vllm::vec_n_t<T, vec_size>;
|
|
|
|
// copy global -> shared & compute absmax
|
|
auto scalar_op_cache = [&] __device__(T & dst, const T& src) {
|
|
float abs_v = fabsf(static_cast<float>(src));
|
|
local_absmax = fmaxf(local_absmax, abs_v);
|
|
dst = src;
|
|
};
|
|
|
|
vllm::vectorize_with_alignment<vec_size>(
|
|
group_input, // in
|
|
smem_group, // out (shared)
|
|
group_size, // elements per group
|
|
lane_id, // thread id
|
|
threads_per_group, // stride in group
|
|
scalar_op_cache); // scalar handler
|
|
|
|
local_absmax = GroupReduceMax(local_absmax);
|
|
|
|
float y_s = local_absmax / max_8bit;
|
|
if constexpr (SCALE_UE8M0) {
|
|
y_s = exp2f(ceilf(log2f(fmaxf(fabsf(y_s), 1e-10f))));
|
|
}
|
|
|
|
scale_element_t y_s_quant = y_s;
|
|
|
|
if (lane_id == 0) {
|
|
*scale_output = y_s_quant;
|
|
}
|
|
|
|
__syncthreads();
|
|
|
|
// quantize shared -> global 8-bit
|
|
auto scalar_op_quant = [&] __device__(DST_DTYPE & dst, const T& src) {
|
|
float q = fminf(fmaxf(static_cast<float>(src) / y_s, min_8bit), max_8bit);
|
|
dst = DST_DTYPE(q);
|
|
};
|
|
|
|
vllm::vectorize_with_alignment<vec_size>(
|
|
smem_group, // in (shared)
|
|
group_output, // out (global quant tensor)
|
|
group_size, // elements
|
|
lane_id, // tid
|
|
threads_per_group, // stride
|
|
scalar_op_quant); // scalar handler
|
|
}
|
|
|
|
void per_token_group_quant_8bit(const torch::Tensor& input,
|
|
torch::Tensor& output_q,
|
|
torch::Tensor& output_s, int64_t group_size,
|
|
double eps, double min_8bit, double max_8bit,
|
|
bool scale_ue8m0) {
|
|
TORCH_CHECK(input.is_contiguous());
|
|
TORCH_CHECK(output_q.is_contiguous());
|
|
|
|
const int num_groups = input.numel() / group_size;
|
|
|
|
TORCH_CHECK(input.numel() % group_size == 0);
|
|
TORCH_CHECK(output_s.dim() == 2);
|
|
|
|
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
|
|
constexpr int THREADS_PER_GROUP = 16;
|
|
|
|
int groups_per_block = 1;
|
|
|
|
if (num_groups % 16 == 0) {
|
|
groups_per_block = 16;
|
|
} else if (num_groups % 8 == 0) {
|
|
groups_per_block = 8;
|
|
} else if (num_groups % 4 == 0) {
|
|
groups_per_block = 4;
|
|
} else if (num_groups % 2 == 0) {
|
|
groups_per_block = 2;
|
|
}
|
|
|
|
auto dst_type = output_q.scalar_type();
|
|
const int num_blocks = num_groups / groups_per_block;
|
|
const int num_threads = groups_per_block * THREADS_PER_GROUP;
|
|
|
|
const bool is_column_major = output_s.stride(0) < output_s.stride(1);
|
|
const int scale_num_rows = output_s.size(1);
|
|
const int scale_stride = output_s.stride(1);
|
|
|
|
#define LAUNCH_KERNEL(T, DST_DTYPE) \
|
|
do { \
|
|
dim3 grid(num_blocks); \
|
|
dim3 block(num_threads); \
|
|
size_t smem_bytes = \
|
|
static_cast<size_t>(groups_per_block) * group_size * sizeof(T); \
|
|
if (is_column_major) { \
|
|
if (scale_ue8m0) { \
|
|
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, true> \
|
|
<<<grid, block, smem_bytes, stream>>>( \
|
|
static_cast<T*>(input.data_ptr()), output_q.data_ptr(), \
|
|
static_cast<float*>(output_s.data_ptr()), group_size, \
|
|
num_groups, groups_per_block, (float)eps, (float)min_8bit, \
|
|
(float)max_8bit, scale_num_rows, scale_stride); \
|
|
} else { \
|
|
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, false> \
|
|
<<<grid, block, smem_bytes, stream>>>( \
|
|
static_cast<T*>(input.data_ptr()), output_q.data_ptr(), \
|
|
static_cast<float*>(output_s.data_ptr()), group_size, \
|
|
num_groups, groups_per_block, (float)eps, (float)min_8bit, \
|
|
(float)max_8bit, scale_num_rows, scale_stride); \
|
|
} \
|
|
} else { \
|
|
if (scale_ue8m0) { \
|
|
per_token_group_quant_8bit_kernel<T, DST_DTYPE, false, true> \
|
|
<<<grid, block, smem_bytes, stream>>>( \
|
|
static_cast<T*>(input.data_ptr()), output_q.data_ptr(), \
|
|
static_cast<float*>(output_s.data_ptr()), group_size, \
|
|
num_groups, groups_per_block, (float)eps, (float)min_8bit, \
|
|
(float)max_8bit); \
|
|
} else { \
|
|
per_token_group_quant_8bit_kernel<T, DST_DTYPE, false, false> \
|
|
<<<grid, block, smem_bytes, stream>>>( \
|
|
static_cast<T*>(input.data_ptr()), output_q.data_ptr(), \
|
|
static_cast<float*>(output_s.data_ptr()), group_size, \
|
|
num_groups, groups_per_block, (float)eps, (float)min_8bit, \
|
|
(float)max_8bit); \
|
|
} \
|
|
} \
|
|
} while (0)
|
|
|
|
VLLM_DISPATCH_FLOATING_TYPES(
|
|
input.scalar_type(), "per_token_group_quant_8bit", ([&] {
|
|
if (dst_type == at::ScalarType::Float8_e4m3fn) {
|
|
LAUNCH_KERNEL(scalar_t, __nv_fp8_e4m3);
|
|
} else if (dst_type == at::ScalarType::Char) {
|
|
LAUNCH_KERNEL(scalar_t, int8_t);
|
|
}
|
|
}));
|
|
|
|
#undef LAUNCH_KERNEL
|
|
}
|
|
|
|
void per_token_group_quant_fp8(const torch::Tensor& input,
|
|
torch::Tensor& output_q, torch::Tensor& output_s,
|
|
int64_t group_size, double eps, double fp8_min,
|
|
double fp8_max, bool scale_ue8m0) {
|
|
per_token_group_quant_8bit(input, output_q, output_s, group_size, eps,
|
|
fp8_min, fp8_max, scale_ue8m0);
|
|
} |