mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:55:01 +08:00
Signed-off-by: nicole-lihui <nicole.li@daocloud.io> Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: courage17340 <courage17340@163.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Jacob Kahn <jacobkahn1@gmail.com> Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com> Signed-off-by: Roger Wang <hey@rogerw.io> Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: zxw <1020938856@qq.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: wang.yuqi <noooop@126.com> Signed-off-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Signed-off-by: Kunshang Ji <kunshang.ji@intel.com> Signed-off-by: chenlang <chen.lang5@zte.com.cn> Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: Jonas Kuebler <kuebj@amazon.com> Signed-off-by: jiang1.li <jiang1.li@intel.com> Signed-off-by: Russell Bryant <rbryant@redhat.com> Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Signed-off-by: AlonKejzman <alonkeizman@gmail.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: taohui <taohui3@gmail.com> Signed-off-by: Tao Hui <taohui3@gmail.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Signed-off-by: Zhuohan Li <zhuohan123@gmail.com> Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Signed-off-by: Shu Wang. <shuw@nvidia.com> Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Signed-off-by: Eugene Khvedchenia <ekhvedchenia@nvidia.com> Signed-off-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Signed-off-by: yiting.jiang <yiting.jiang@daocloud.io> Signed-off-by: Andrew Sansom <andrew@protopia.ai> Signed-off-by: xaguilar <Xavier.AguilarFruto@amd.com> Signed-off-by: Iceber Gu <caiwei95@hotmail.com> Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com> Signed-off-by: Icey <1790571317@qq.com> Signed-off-by: Sage Moore <sage@neuralmagic.com> Signed-off-by: 许文卿 <xwq391974@alibaba-inc.com> Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com> Signed-off-by: Seiji Eicher <seiji@anyscale.com> Signed-off-by: Seiji Eicher <58963096+eicherseiji@users.noreply.github.com> Signed-off-by: zjy0516 <riverclouds.zhu@qq.com> Signed-off-by: Kosseila (CloudThrill) <klouddude@gmail.com> Signed-off-by: frankwang28 <frank.wbb@hotmail.com> Signed-off-by: Frank Wang <41319051+frankwang28@users.noreply.github.com> Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> Signed-off-by: zixi-qi <qizixi@meta.com> Signed-off-by: Bram Wasti <bwasti@meta.com> Signed-off-by: Naman Lalit <nl2688@nyu.edu> Signed-off-by: Chenheli Hua <huachenheli@outlook.com> Signed-off-by: Junhong <liujunhong11@huawei.com> Signed-off-by: Junhong Liu <98734602+LJH-LBJ@users.noreply.github.com> Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com> Signed-off-by: rentianyue-jk <rentianyue-jk@360shuke.com> Signed-off-by: Peter Pan <Peter.Pan@daocloud.io> Signed-off-by: Patrick Toulme <ptoulme@meta.com> Signed-off-by: Patrick Toulme <pctoulme+1@gmail.com> Signed-off-by: Jiangyun Zhu <riverclouds.zhu@qq.com> Signed-off-by: Clayton Coleman <smarterclayton@gmail.com> Signed-off-by: Jialin Ouyang <jialino@meta.com> Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Signed-off-by: Weiliang Liu <weiliangl@nvidia.com> Signed-off-by: zRzRzRzRzRzRzR <2448370773@qq.com> Signed-off-by: liuye.hj <liuye.hj@alibaba-inc.com> Signed-off-by: Juechen Liu <jueliu@meta.com> Signed-off-by: simon-mo <simon.mo@hey.com> Signed-off-by: Robert Shaw <robshaw@redhat.com> Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Signed-off-by: isotr0py <2037008807@qq.com> Signed-off-by: yingjun-mou <renzomou@gmail.com> Signed-off-by: zhoukz <me@zhoukz.com> Signed-off-by: Chenxi Yang <cxyang@fb.com> Signed-off-by: Rahul Tuli <rtuli@redhat.com> Signed-off-by: Lee Nau <lnau@nvidia.com> Signed-off-by: adabeyta <aabeyta@redhat.com> Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: simondanielsson <simon.danielsson99@hotmail.com> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> Signed-off-by: Lucia Fang <fanglu@meta.com> Signed-off-by: a120092009 <zhaoty0121@gmail.com> Signed-off-by: sergiopaniego <sergiopaniegoblanco@gmail.com> Signed-off-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com> Signed-off-by: wangyafeng <wangyafeng@baidu.com> Signed-off-by: Lehua Ding <lehuading@tencent.com> Signed-off-by: lyd1992 <liuyudong@iscas.ac.cn> Signed-off-by: ihb2032 <1355790728@qq.com> Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com> Signed-off-by: anion <1005128408@qq.com> Signed-off-by: Anion <123177548+Anionex@users.noreply.github.com> Signed-off-by: Pavani Majety <pmajety@nvidia.com> Signed-off-by: Bill Nell <bnell@redhat.com> Signed-off-by: bnellnm <49004751+bnellnm@users.noreply.github.com> Signed-off-by: Or Ozeri <oro@il.ibm.com> Signed-off-by: cjackal <44624812+cjackal@users.noreply.github.com> Signed-off-by: David Ben-David <davidb@pliops.com> Signed-off-by: Andrew Xia <axia@meta.com> Signed-off-by: Andrew Xia <axia@fb.com> Signed-off-by: Lu Fang <fanglu@fb.com> Signed-off-by: Salvatore Cena <cena@cenas.it> Signed-off-by: padg9912 <phone.and.desktop@gmail.com> Signed-off-by: nadathurv <work.vnadathur@gmail.com> Signed-off-by: WorldExplored <srreyansh.sethi@gmail.com> Signed-off-by: wwl2755 <wangwenlong2755@gmail.com> Signed-off-by: billishyahao <bill.he@amd.com> Signed-off-by: Nathan Scott <nathans@redhat.com> Signed-off-by: Kenichi Maehashi <maehashi@preferred.jp> Signed-off-by: Johnny <johnnynuca14@gmail.com> Signed-off-by: johnnynunez <johnnynuca14@gmail.com> Signed-off-by: Johnny <johnnync13@gmail.com> Signed-off-by: Huamin Li <3ericli@gmail.com> Signed-off-by: Hosang Yoon <hosang.yoon@amd.com> Signed-off-by: Jerry Zhang <jerryzh168@gmail.com> Signed-off-by: Peter Schuurman <psch@google.com> Signed-off-by: Huy Do <huydhn@gmail.com> Signed-off-by: leo-pony <nengjunma@outlook.com> Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: ElizaWszola <elizaw.9289@gmail.com> Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Signed-off-by: zhewenli <zhewenli@meta.com> Signed-off-by: ahao-anyscale <ahao@anyscale.com> Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Signed-off-by: huijjj <huijong.jeong@squeezebits.com> Signed-off-by: Yannick Schnider <yannick.schnider1@ibm.com> Signed-off-by: kyt <eluban4532@gmail.com> Signed-off-by: Egor <e.a.krivov@gmail.com> Signed-off-by: Yang <lymailforjob@gmail.com> Signed-off-by: Paul Pak <paulpak58@gmail.com> Signed-off-by: whx-sjtu <2952154980@qq.com> Signed-off-by: Xiang Si <sixiang@google.com> Signed-off-by: Aleksandr Samarin <astrlrd@nebius.com> Signed-off-by: Jun Jiang <jasl9187@hotmail.com> Signed-off-by: Chendi Xue <Chendi.Xue@intel.com> Signed-off-by: Chendi.Xue <chendi.xue@intel.com> Signed-off-by: Nikhil Ghosh <nikhil@anyscale.com> Co-authored-by: Nicole LiHui 🥜 <nicolelihui@outlook.com> Co-authored-by: courage17340 <courage17340@users.noreply.github.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Co-authored-by: Jacob Kahn <jacobkahn1@gmail.com> Co-authored-by: Roger Wang <hey@rogerw.io> Co-authored-by: Nicole LiHui 🥜 <nicole.li@daocloud.io> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Fadi Arafeh <115173828+fadara01@users.noreply.github.com> Co-authored-by: Agata Dobrzyniewicz <160237065+adobrzyn@users.noreply.github.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: yyzxw <34639446+yyzxw@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: wang.yuqi <noooop@126.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com> Co-authored-by: chenlang <chen.lang5@zte.com.cn> Co-authored-by: chenlang <10346245@zte.com.cn> Co-authored-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Jonas M. Kübler <44084297+jmkuebler@users.noreply.github.com> Co-authored-by: Li, Jiang <jiang1.li@intel.com> Co-authored-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com> Co-authored-by: AlonKejzman <alonkeizman@gmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Tao Hui <taohui3@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Co-authored-by: Nick Hill <nhill@redhat.com> Co-authored-by: Zhuohan Li <zhuohan123@gmail.com> Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com> Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> Co-authored-by: Shu Wang <shuw@nvidia.com> Co-authored-by: Aleksandr Malyshev <164964928+maleksan85@users.noreply.github.com> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Doug Lehr <douglehr@amd.com> Co-authored-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Co-authored-by: yitingdc <59356937+yitingdc@users.noreply.github.com> Co-authored-by: Andrew Sansom <andrew@protopia.ai> Co-authored-by: xaguilar-amd <xavier.aguilarfruto@amd.com> Co-authored-by: Iceber Gu <caiwei95@hotmail.com> Co-authored-by: Tao He <linzhu.ht@alibaba-inc.com> Co-authored-by: Icey <1790571317@qq.com> Co-authored-by: Sage Moore <sage@neuralmagic.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Xu Wenqing <121550081+Xu-Wenqing@users.noreply.github.com> Co-authored-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: RishiAstra <40644327+RishiAstra@users.noreply.github.com> Co-authored-by: Chauncey <chaunceyjiang@gmail.com> Co-authored-by: Seiji Eicher <58963096+eicherseiji@users.noreply.github.com> Co-authored-by: Rui Qiao <161574667+ruisearch42@users.noreply.github.com> Co-authored-by: Jiangyun Zhu <riverclouds.zhu@qq.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: 阿丹(adan) <47373076+LDLINGLINGLING@users.noreply.github.com> Co-authored-by: liudan <adan@minicpm.com> Co-authored-by: liudan <liudan@qq.com> Co-authored-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Co-authored-by: Clouddude <kouss.hd@gmail.com> Co-authored-by: Frank Wang <41319051+frankwang28@users.noreply.github.com> Co-authored-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> Co-authored-by: qizixi <22851944+zixi-qi@users.noreply.github.com> Co-authored-by: Bram Wasti <bwasti@fb.com> Co-authored-by: Naman Lalit <nl2688@nyu.edu> Co-authored-by: Chenheli Hua <huachenheli@outlook.com> Co-authored-by: WeiQing Chen <40507679+david6666666@users.noreply.github.com> Co-authored-by: Junhong <liujunhong11@huawei.com> Co-authored-by: LJH-LBJ <98734602+LJH-LBJ@users.noreply.github.com> Co-authored-by: 22quinn <33176974+22quinn@users.noreply.github.com> Co-authored-by: Xiaohan Zou <renovamenzxh@gmail.com> Co-authored-by: rentianyue-jk <rentianyue-jk@360shuke.com> Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Co-authored-by: Peter Pan <peter.pan@daocloud.io> Co-authored-by: Patrick C. Toulme <135739773+patrick-toulme@users.noreply.github.com> Co-authored-by: Clayton Coleman <smarterclayton@gmail.com> Co-authored-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Co-authored-by: Jialin Ouyang <jialino@meta.com> Co-authored-by: weiliang <weiliangl@nvidia.com> Co-authored-by: Yuxuan Zhang <2448370773@qq.com> Co-authored-by: JJJYmmm <92386084+JJJYmmm@users.noreply.github.com> Co-authored-by: liuye.hj <liuye.hj@alibaba-inc.com> Co-authored-by: Juechen Liu <grinchcoder@gmail.com> Co-authored-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Yingjun Mou <renzomou@gmail.com> Co-authored-by: Zhou Jiahao <me@zhoukz.com> Co-authored-by: Chenxi Yang <cxyang@cs.utexas.edu> Co-authored-by: Chenxi Yang <cxyang@fb.com> Co-authored-by: Rahul Tuli <rtuli@redhat.com> Co-authored-by: Lee Nau <lee.nau@gmail.com> Co-authored-by: Adrian Abeyta <aabeyta@redhat.com> Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Co-authored-by: Aaron Pham <contact@aarnphm.xyz> Co-authored-by: acisseJZhong <40467976+acisseJZhong@users.noreply.github.com> Co-authored-by: Simon Danielsson <70206058+simondanielsson@users.noreply.github.com> Co-authored-by: Yongye Zhu <zyy1102000@gmail.com> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Lucia Fang <fanglu@meta.com> Co-authored-by: Siyuan Fu <siyuanf@nvidia.com> Co-authored-by: Xiaozhu Meng <mxz297@gmail.com> Co-authored-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> Co-authored-by: a120092009 <33205509+a120092009@users.noreply.github.com> Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com> Co-authored-by: CSWYF3634076 <wangyafeng@baidu.com> Co-authored-by: Lehua Ding <lehuading@tencent.com> Co-authored-by: Reza Barazesh <3146276+rzabarazesh@users.noreply.github.com> Co-authored-by: ihb2032 <40718643+ihb2032@users.noreply.github.com> Co-authored-by: Asaf Joseph Gardin <39553475+Josephasafg@users.noreply.github.com> Co-authored-by: Anion <123177548+Anionex@users.noreply.github.com> Co-authored-by: Pavani Majety <pmajety@nvidia.com> Co-authored-by: bnellnm <49004751+bnellnm@users.noreply.github.com> Co-authored-by: Or Ozeri <oro@il.ibm.com> Co-authored-by: cjackal <44624812+cjackal@users.noreply.github.com> Co-authored-by: David Ben-David <sdavidbd@gmail.com> Co-authored-by: David Ben-David <davidb@pliops.com> Co-authored-by: Andrew Xia <axia@mit.edu> Co-authored-by: Andrew Xia <axia@fb.com> Co-authored-by: Salvatore Cena <cena@cenas.it> Co-authored-by: Param <psch@cs.unc.edu> Co-authored-by: Zhewen Li <zhewenli@meta.com> Co-authored-by: nadathurv <work.vnadathur@gmail.com> Co-authored-by: Srreyansh Sethi <107075589+WorldExplored@users.noreply.github.com> Co-authored-by: Wenlong Wang <wangwenlong2755@gmail.com> Co-authored-by: billishyahao <bill.he@amd.com> Co-authored-by: Nathan Scott <natoscott@users.noreply.github.com> Co-authored-by: Kenichi Maehashi <939877+kmaehashi@users.noreply.github.com> Co-authored-by: Johnny <johnnync13@gmail.com> Co-authored-by: Aidyn-A <31858918+Aidyn-A@users.noreply.github.com> Co-authored-by: Huamin Li <3ericli@gmail.com> Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com> Co-authored-by: Hosang <156028780+hyoon1@users.noreply.github.com> Co-authored-by: Jerry Zhang <jerryzh168@gmail.com> Co-authored-by: pwschuurman <psch@google.com> Co-authored-by: Huy Do <huydhn@gmail.com> Co-authored-by: leo-pony <nengjunma@outlook.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: ElizaWszola <ewszola@redhat.com> Co-authored-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: Benjamin Chislett <bchislett@nvidia.com> Co-authored-by: Andrew Xia <axia@meta.com> Co-authored-by: Simon Mo <simon.mo@hey.com> Co-authored-by: TJian <tunjian.tan@embeddedllm.com> Co-authored-by: ahao-anyscale <ahao@anyscale.com> Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Liu-congo <1502632128@qq.com> Co-authored-by: HUIJONG JEONG <64083281+huijjj@users.noreply.github.com> Co-authored-by: Yannick Schnider <Yannick.Schnider1@ibm.com> Co-authored-by: kyt <eluban4532@gmail.com> Co-authored-by: Egor <e.a.krivov@gmail.com> Co-authored-by: Yang Liu <127183760+KKSK-DON@users.noreply.github.com> Co-authored-by: Paul Pak <52512091+paulpak58@users.noreply.github.com> Co-authored-by: whx <56632993+whx-sjtu@users.noreply.github.com> Co-authored-by: Xiang Si <sixiang@google.com> Co-authored-by: Aleksandr Samarin <samarin_ad@mail.ru> Co-authored-by: Jun Jiang <jasl9187@hotmail.com> Co-authored-by: Chendi.Xue <chendi.xue@intel.com> Co-authored-by: Nikhil G <nrghosh@users.noreply.github.com>
325 lines
12 KiB
Plaintext
325 lines
12 KiB
Plaintext
#include <ATen/cuda/CUDAContext.h>
|
|
#include <torch/all.h>
|
|
|
|
#include <cmath>
|
|
|
|
#include "dispatch_utils.h"
|
|
#include "quantization/vectorization_utils.cuh"
|
|
#include "cub_helpers.h"
|
|
|
|
static inline __device__ int8_t float_to_int8_rn(float x) {
|
|
#ifdef USE_ROCM
|
|
static constexpr auto i8_min =
|
|
static_cast<float>(std::numeric_limits<int8_t>::min());
|
|
static constexpr auto i8_max =
|
|
static_cast<float>(std::numeric_limits<int8_t>::max());
|
|
|
|
// To match the rounding mode of CUDA, we use nearbyint.
|
|
// It uses the current rounding mode, which is always FE_TONEAREST on HIP.
|
|
// If that changes in the future, we may need to set the rounding mode
|
|
// explicitly, either at runtime or compile time.
|
|
float dst = std::nearbyint(x);
|
|
|
|
// saturate
|
|
// See https://github.com/pytorch/pytorch/issues/127666
|
|
// See https://github.com/llvm/llvm-project/issues/95183
|
|
// hip-clang std::clamp __glibcxx_assert_fail host function when building on
|
|
// Arch/gcc14. The following replaces std::clamp usage with similar logic
|
|
// dst = std::clamp(dst, i8_min, i8_max);
|
|
dst = (dst < i8_min) ? i8_min : (dst > i8_max) ? i8_max : dst;
|
|
return static_cast<int8_t>(dst);
|
|
#else
|
|
// CUDA path
|
|
uint32_t dst;
|
|
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x));
|
|
return reinterpret_cast<const int8_t&>(dst);
|
|
#endif
|
|
}
|
|
|
|
static inline __device__ int32_t float_to_int32_rn(float x) {
|
|
#ifdef USE_ROCM
|
|
// int32_max is not exactly representable as float.
|
|
// Therefore, we need to be careful and manually return int32_max on overflow.
|
|
// For symmetry, we also do the same for int32_min, even though it is exactly
|
|
// representable as float and the conversion should be exact.
|
|
static constexpr auto i32_min = std::numeric_limits<int32_t>::min();
|
|
static constexpr auto i32_min_f = static_cast<float>(i32_min);
|
|
static constexpr auto i32_max = std::numeric_limits<int32_t>::max();
|
|
static constexpr auto i32_max_f = static_cast<float>(i32_max);
|
|
|
|
// To match the rounding mode of CUDA, we use nearbyint.
|
|
// It uses the current rounding mode, which is always FE_TONEAREST on HIP.
|
|
// If that changes in the future, we may need to set the rounding mode
|
|
// explicitly, either at runtime or compile time.
|
|
float dst = std::nearbyint(x);
|
|
|
|
// saturate on the higher end.
|
|
if (dst >= i32_max_f) {
|
|
return i32_max;
|
|
}
|
|
// saturate on the lower end.
|
|
if (dst <= i32_min_f) {
|
|
return i32_min;
|
|
}
|
|
|
|
return static_cast<int32_t>(dst);
|
|
#else
|
|
// CUDA path
|
|
uint32_t dst;
|
|
asm volatile("cvt.rni.sat.s32.f32 %0, %1;" : "=r"(dst) : "f"(x));
|
|
return reinterpret_cast<const int32_t&>(dst);
|
|
#endif
|
|
}
|
|
|
|
static inline __device__ int8_t int32_to_int8(int32_t x) {
|
|
#ifdef USE_ROCM
|
|
static constexpr auto i8_min =
|
|
static_cast<int32_t>(std::numeric_limits<int8_t>::min());
|
|
static constexpr auto i8_max =
|
|
static_cast<int32_t>(std::numeric_limits<int8_t>::max());
|
|
|
|
// saturate
|
|
// See https://github.com/pytorch/pytorch/issues/127666
|
|
// See https://github.com/llvm/llvm-project/issues/95183
|
|
// hip-clang std::clamp __glibcxx_assert_fail host function when building on
|
|
// Arch/gcc14. The following replaces std::clamp usage with similar logic
|
|
// int32_t dst = std::clamp(x, i8_min, i8_max);
|
|
int32_t dst = (x < i8_min) ? i8_min : (x > i8_max) ? i8_max : x;
|
|
return static_cast<int8_t>(dst);
|
|
#else
|
|
// CUDA path
|
|
uint32_t dst;
|
|
asm volatile("cvt.sat.s8.s32 %0, %1;" : "=r"(dst) : "r"(x));
|
|
return reinterpret_cast<const int8_t&>(dst);
|
|
#endif
|
|
}
|
|
|
|
namespace vllm {
|
|
|
|
template <typename scalar_t, typename scale_t>
|
|
__global__ void static_scaled_int8_quant_kernel(
|
|
const scalar_t* __restrict__ input, int8_t* __restrict__ output,
|
|
const scale_t* scale_ptr, const int hidden_size) {
|
|
const int tid = threadIdx.x;
|
|
const int stride = blockDim.x;
|
|
const int64_t token_idx = blockIdx.x;
|
|
const float scale = *scale_ptr;
|
|
|
|
// Must be performed using 64-bit math to avoid integer overflow.
|
|
const scalar_t* row_in = input + token_idx * hidden_size;
|
|
int8_t* row_out = output + token_idx * hidden_size;
|
|
|
|
vectorize_with_alignment<16>(
|
|
row_in, row_out, hidden_size, tid, stride,
|
|
[=] __device__(int8_t& dst, const scalar_t& src) {
|
|
dst = float_to_int8_rn(static_cast<float>(src) / scale);
|
|
});
|
|
}
|
|
|
|
template <typename scalar_t, typename scale_t, typename azp_t>
|
|
__global__ void static_scaled_int8_azp_quant_kernel(
|
|
const scalar_t* __restrict__ input, int8_t* __restrict__ output,
|
|
const scale_t* scale_ptr, const azp_t* azp_ptr, const int hidden_size) {
|
|
const int tid = threadIdx.x;
|
|
const int stride = blockDim.x;
|
|
const int64_t token_idx = blockIdx.x;
|
|
const float scale = *scale_ptr;
|
|
const azp_t azp = *azp_ptr;
|
|
const float inv_s = 1.0f / scale;
|
|
|
|
// Must be performed using 64-bit math to avoid integer overflow.
|
|
const scalar_t* row_in = input + token_idx * hidden_size;
|
|
int8_t* row_out = output + token_idx * hidden_size;
|
|
|
|
vectorize_with_alignment<16>(
|
|
row_in, row_out, hidden_size, tid, stride,
|
|
[=] __device__(int8_t& dst, const scalar_t& src) {
|
|
const auto v = static_cast<float>(src) * inv_s;
|
|
dst = int32_to_int8(float_to_int32_rn(v) + azp);
|
|
});
|
|
}
|
|
|
|
template <typename scalar_t, typename scale_t>
|
|
__global__ void dynamic_scaled_int8_quant_kernel(
|
|
const scalar_t* __restrict__ input, int8_t* __restrict__ output,
|
|
scale_t* scale_out, const int hidden_size) {
|
|
const int tid = threadIdx.x;
|
|
const int stride = blockDim.x;
|
|
const int64_t token_idx = blockIdx.x;
|
|
|
|
// Must be performed using 64-bit math to avoid integer overflow.
|
|
const scalar_t* row_in = input + token_idx * hidden_size;
|
|
int8_t* row_out = output + token_idx * hidden_size;
|
|
|
|
// calculate for absmax
|
|
float thread_max = 0.f;
|
|
vectorize_read_with_alignment<16>(
|
|
row_in, hidden_size, tid, stride, [&] __device__(const scalar_t& src) {
|
|
const float v = fabsf(static_cast<float>(src));
|
|
thread_max = fmaxf(thread_max, v);
|
|
});
|
|
using BlockReduce = cub::BlockReduce<float, 256>;
|
|
__shared__ typename BlockReduce::TempStorage tmp;
|
|
float block_max = BlockReduce(tmp).Reduce(thread_max, CubMaxOp{}, blockDim.x);
|
|
__shared__ float absmax;
|
|
if (tid == 0) {
|
|
absmax = block_max;
|
|
scale_out[blockIdx.x] = absmax / 127.f;
|
|
}
|
|
__syncthreads();
|
|
|
|
float inv_s = (absmax == 0.f) ? 0.f : 127.f / absmax;
|
|
|
|
vectorize_with_alignment<16>(
|
|
row_in, row_out, hidden_size, tid, stride,
|
|
[=] __device__(int8_t& dst, const scalar_t& src) {
|
|
dst = float_to_int8_rn(static_cast<float>(src) * inv_s);
|
|
});
|
|
}
|
|
|
|
// MinMax structure to hold min and max values in one go
|
|
struct MinMax {
|
|
float min, max;
|
|
|
|
__host__ __device__ MinMax()
|
|
: min(std::numeric_limits<float>::max()),
|
|
max(std::numeric_limits<float>::lowest()) {}
|
|
|
|
__host__ __device__ explicit MinMax(float v) : min(v), max(v) {}
|
|
|
|
__host__ __device__ MinMax& operator+=(float v) {
|
|
min = fminf(min, v);
|
|
max = fmaxf(max, v);
|
|
return *this;
|
|
}
|
|
|
|
// merge two MinMax objects
|
|
__host__ __device__ MinMax& operator&=(const MinMax& other) {
|
|
min = fminf(min, other.min);
|
|
max = fmaxf(max, other.max);
|
|
return *this;
|
|
}
|
|
};
|
|
|
|
__host__ __device__ inline MinMax operator+(MinMax a, float v) {
|
|
return a += v;
|
|
}
|
|
__host__ __device__ inline MinMax operator&(MinMax a, const MinMax& b) {
|
|
return a &= b;
|
|
}
|
|
|
|
template <typename scalar_t, typename scale_t, typename azp_t>
|
|
__global__ void dynamic_scaled_int8_azp_quant_kernel(
|
|
const scalar_t* __restrict__ input, int8_t* __restrict__ output,
|
|
scale_t* scale_out, azp_t* azp_out, const int hidden_size) {
|
|
const int tid = threadIdx.x;
|
|
const int stride = blockDim.x;
|
|
const int64_t token_idx = blockIdx.x;
|
|
|
|
// Must be performed using 64-bit math to avoid integer overflow.
|
|
const scalar_t* row_in = input + token_idx * hidden_size;
|
|
int8_t* row_out = output + token_idx * hidden_size;
|
|
|
|
MinMax thread_mm;
|
|
vectorize_read_with_alignment<16>(row_in, hidden_size, tid, stride,
|
|
[&] __device__(const scalar_t& src) {
|
|
thread_mm += static_cast<float>(src);
|
|
});
|
|
|
|
using BlockReduce = cub::BlockReduce<MinMax, 256>;
|
|
__shared__ typename BlockReduce::TempStorage tmp;
|
|
|
|
MinMax mm = BlockReduce(tmp).Reduce(
|
|
thread_mm,
|
|
[] __device__(MinMax a, const MinMax& b) {
|
|
a &= b;
|
|
return a;
|
|
},
|
|
blockDim.x);
|
|
|
|
__shared__ float scale_sh;
|
|
__shared__ azp_t azp_sh;
|
|
if (tid == 0) {
|
|
float s = (mm.max - mm.min) / 255.f;
|
|
float zp = nearbyintf(-128.f - mm.min / s); // round-to-even
|
|
scale_sh = s;
|
|
azp_sh = azp_t(zp);
|
|
scale_out[blockIdx.x] = s;
|
|
azp_out[blockIdx.x] = azp_sh;
|
|
}
|
|
__syncthreads();
|
|
|
|
const float inv_s = 1.f / scale_sh;
|
|
const azp_t azp = azp_sh;
|
|
|
|
vectorize_with_alignment<16>(
|
|
row_in, row_out, hidden_size, tid, stride,
|
|
[=] __device__(int8_t& dst, const scalar_t& src) {
|
|
const auto v = static_cast<float>(src) * inv_s;
|
|
dst = int32_to_int8(float_to_int32_rn(v) + azp);
|
|
});
|
|
}
|
|
|
|
} // namespace vllm
|
|
|
|
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
|
torch::Tensor const& input, // [..., hidden_size]
|
|
torch::Tensor const& scale,
|
|
std::optional<torch::Tensor> const& azp) {
|
|
TORCH_CHECK(input.is_contiguous());
|
|
TORCH_CHECK(out.is_contiguous());
|
|
TORCH_CHECK(scale.numel() == 1);
|
|
TORCH_CHECK(!azp || azp->numel() == 1);
|
|
|
|
int const hidden_size = input.size(-1);
|
|
int const num_tokens = input.numel() / hidden_size;
|
|
dim3 const grid(num_tokens);
|
|
dim3 const block(std::min(hidden_size, 256));
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
VLLM_DISPATCH_FLOATING_TYPES(
|
|
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
|
|
if (!azp) {
|
|
vllm::static_scaled_int8_quant_kernel<scalar_t, float>
|
|
<<<grid, block, 0, stream>>>(
|
|
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
|
scale.data_ptr<float>(), hidden_size);
|
|
} else {
|
|
vllm::static_scaled_int8_azp_quant_kernel<scalar_t, float, int32_t>
|
|
<<<grid, block, 0, stream>>>(
|
|
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
|
scale.data_ptr<float>(), azp->data_ptr<int32_t>(),
|
|
hidden_size);
|
|
}
|
|
});
|
|
}
|
|
|
|
void dynamic_scaled_int8_quant(
|
|
torch::Tensor& out, // [..., hidden_size]
|
|
torch::Tensor const& input, // [..., hidden_size]
|
|
torch::Tensor& scales, std::optional<torch::Tensor> const& azp) {
|
|
TORCH_CHECK(input.is_contiguous());
|
|
TORCH_CHECK(out.is_contiguous());
|
|
TORCH_CHECK(scales.is_contiguous());
|
|
TORCH_CHECK(!azp || azp->is_contiguous());
|
|
|
|
int const hidden_size = input.size(-1);
|
|
int const num_tokens = input.numel() / hidden_size;
|
|
dim3 const grid(num_tokens);
|
|
dim3 const block(std::min(hidden_size, 256));
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
VLLM_DISPATCH_FLOATING_TYPES(
|
|
input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] {
|
|
if (!azp) {
|
|
vllm::dynamic_scaled_int8_quant_kernel<scalar_t, float>
|
|
<<<grid, block, 0, stream>>>(
|
|
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
|
scales.data_ptr<float>(), hidden_size);
|
|
} else {
|
|
vllm::dynamic_scaled_int8_azp_quant_kernel<scalar_t, float, int32_t>
|
|
<<<grid, block, 0, stream>>>(
|
|
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
|
scales.data_ptr<float>(), azp->data_ptr<int32_t>(),
|
|
hidden_size);
|
|
}
|
|
});
|
|
} |