mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:35:01 +08:00
Signed-off-by: nicole-lihui <nicole.li@daocloud.io> Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: courage17340 <courage17340@163.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Jacob Kahn <jacobkahn1@gmail.com> Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com> Signed-off-by: Roger Wang <hey@rogerw.io> Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: zxw <1020938856@qq.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: wang.yuqi <noooop@126.com> Signed-off-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Signed-off-by: Kunshang Ji <kunshang.ji@intel.com> Signed-off-by: chenlang <chen.lang5@zte.com.cn> Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: Jonas Kuebler <kuebj@amazon.com> Signed-off-by: jiang1.li <jiang1.li@intel.com> Signed-off-by: Russell Bryant <rbryant@redhat.com> Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Signed-off-by: AlonKejzman <alonkeizman@gmail.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: taohui <taohui3@gmail.com> Signed-off-by: Tao Hui <taohui3@gmail.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Signed-off-by: Zhuohan Li <zhuohan123@gmail.com> Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Signed-off-by: Shu Wang. <shuw@nvidia.com> Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Signed-off-by: Eugene Khvedchenia <ekhvedchenia@nvidia.com> Signed-off-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Signed-off-by: yiting.jiang <yiting.jiang@daocloud.io> Signed-off-by: Andrew Sansom <andrew@protopia.ai> Signed-off-by: xaguilar <Xavier.AguilarFruto@amd.com> Signed-off-by: Iceber Gu <caiwei95@hotmail.com> Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com> Signed-off-by: Icey <1790571317@qq.com> Signed-off-by: Sage Moore <sage@neuralmagic.com> Signed-off-by: 许文卿 <xwq391974@alibaba-inc.com> Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com> Signed-off-by: Seiji Eicher <seiji@anyscale.com> Signed-off-by: Seiji Eicher <58963096+eicherseiji@users.noreply.github.com> Signed-off-by: zjy0516 <riverclouds.zhu@qq.com> Signed-off-by: Kosseila (CloudThrill) <klouddude@gmail.com> Signed-off-by: frankwang28 <frank.wbb@hotmail.com> Signed-off-by: Frank Wang <41319051+frankwang28@users.noreply.github.com> Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> Signed-off-by: zixi-qi <qizixi@meta.com> Signed-off-by: Bram Wasti <bwasti@meta.com> Signed-off-by: Naman Lalit <nl2688@nyu.edu> Signed-off-by: Chenheli Hua <huachenheli@outlook.com> Signed-off-by: Junhong <liujunhong11@huawei.com> Signed-off-by: Junhong Liu <98734602+LJH-LBJ@users.noreply.github.com> Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com> Signed-off-by: rentianyue-jk <rentianyue-jk@360shuke.com> Signed-off-by: Peter Pan <Peter.Pan@daocloud.io> Signed-off-by: Patrick Toulme <ptoulme@meta.com> Signed-off-by: Patrick Toulme <pctoulme+1@gmail.com> Signed-off-by: Jiangyun Zhu <riverclouds.zhu@qq.com> Signed-off-by: Clayton Coleman <smarterclayton@gmail.com> Signed-off-by: Jialin Ouyang <jialino@meta.com> Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Signed-off-by: Weiliang Liu <weiliangl@nvidia.com> Signed-off-by: zRzRzRzRzRzRzR <2448370773@qq.com> Signed-off-by: liuye.hj <liuye.hj@alibaba-inc.com> Signed-off-by: Juechen Liu <jueliu@meta.com> Signed-off-by: simon-mo <simon.mo@hey.com> Signed-off-by: Robert Shaw <robshaw@redhat.com> Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Signed-off-by: isotr0py <2037008807@qq.com> Signed-off-by: yingjun-mou <renzomou@gmail.com> Signed-off-by: zhoukz <me@zhoukz.com> Signed-off-by: Chenxi Yang <cxyang@fb.com> Signed-off-by: Rahul Tuli <rtuli@redhat.com> Signed-off-by: Lee Nau <lnau@nvidia.com> Signed-off-by: adabeyta <aabeyta@redhat.com> Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: simondanielsson <simon.danielsson99@hotmail.com> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> Signed-off-by: Lucia Fang <fanglu@meta.com> Signed-off-by: a120092009 <zhaoty0121@gmail.com> Signed-off-by: sergiopaniego <sergiopaniegoblanco@gmail.com> Signed-off-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com> Signed-off-by: wangyafeng <wangyafeng@baidu.com> Signed-off-by: Lehua Ding <lehuading@tencent.com> Signed-off-by: lyd1992 <liuyudong@iscas.ac.cn> Signed-off-by: ihb2032 <1355790728@qq.com> Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com> Signed-off-by: anion <1005128408@qq.com> Signed-off-by: Anion <123177548+Anionex@users.noreply.github.com> Signed-off-by: Pavani Majety <pmajety@nvidia.com> Signed-off-by: Bill Nell <bnell@redhat.com> Signed-off-by: bnellnm <49004751+bnellnm@users.noreply.github.com> Signed-off-by: Or Ozeri <oro@il.ibm.com> Signed-off-by: cjackal <44624812+cjackal@users.noreply.github.com> Signed-off-by: David Ben-David <davidb@pliops.com> Signed-off-by: Andrew Xia <axia@meta.com> Signed-off-by: Andrew Xia <axia@fb.com> Signed-off-by: Lu Fang <fanglu@fb.com> Signed-off-by: Salvatore Cena <cena@cenas.it> Signed-off-by: padg9912 <phone.and.desktop@gmail.com> Signed-off-by: nadathurv <work.vnadathur@gmail.com> Signed-off-by: WorldExplored <srreyansh.sethi@gmail.com> Signed-off-by: wwl2755 <wangwenlong2755@gmail.com> Signed-off-by: billishyahao <bill.he@amd.com> Signed-off-by: Nathan Scott <nathans@redhat.com> Signed-off-by: Kenichi Maehashi <maehashi@preferred.jp> Signed-off-by: Johnny <johnnynuca14@gmail.com> Signed-off-by: johnnynunez <johnnynuca14@gmail.com> Signed-off-by: Johnny <johnnync13@gmail.com> Signed-off-by: Huamin Li <3ericli@gmail.com> Signed-off-by: Hosang Yoon <hosang.yoon@amd.com> Signed-off-by: Jerry Zhang <jerryzh168@gmail.com> Signed-off-by: Peter Schuurman <psch@google.com> Signed-off-by: Huy Do <huydhn@gmail.com> Signed-off-by: leo-pony <nengjunma@outlook.com> Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: ElizaWszola <elizaw.9289@gmail.com> Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Signed-off-by: zhewenli <zhewenli@meta.com> Signed-off-by: ahao-anyscale <ahao@anyscale.com> Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Signed-off-by: huijjj <huijong.jeong@squeezebits.com> Signed-off-by: Yannick Schnider <yannick.schnider1@ibm.com> Signed-off-by: kyt <eluban4532@gmail.com> Signed-off-by: Egor <e.a.krivov@gmail.com> Signed-off-by: Yang <lymailforjob@gmail.com> Signed-off-by: Paul Pak <paulpak58@gmail.com> Signed-off-by: whx-sjtu <2952154980@qq.com> Signed-off-by: Xiang Si <sixiang@google.com> Signed-off-by: Aleksandr Samarin <astrlrd@nebius.com> Signed-off-by: Jun Jiang <jasl9187@hotmail.com> Signed-off-by: Chendi Xue <Chendi.Xue@intel.com> Signed-off-by: Chendi.Xue <chendi.xue@intel.com> Signed-off-by: Nikhil Ghosh <nikhil@anyscale.com> Co-authored-by: Nicole LiHui 🥜 <nicolelihui@outlook.com> Co-authored-by: courage17340 <courage17340@users.noreply.github.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Co-authored-by: Jacob Kahn <jacobkahn1@gmail.com> Co-authored-by: Roger Wang <hey@rogerw.io> Co-authored-by: Nicole LiHui 🥜 <nicole.li@daocloud.io> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Fadi Arafeh <115173828+fadara01@users.noreply.github.com> Co-authored-by: Agata Dobrzyniewicz <160237065+adobrzyn@users.noreply.github.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: yyzxw <34639446+yyzxw@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: wang.yuqi <noooop@126.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com> Co-authored-by: chenlang <chen.lang5@zte.com.cn> Co-authored-by: chenlang <10346245@zte.com.cn> Co-authored-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Jonas M. Kübler <44084297+jmkuebler@users.noreply.github.com> Co-authored-by: Li, Jiang <jiang1.li@intel.com> Co-authored-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com> Co-authored-by: AlonKejzman <alonkeizman@gmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Tao Hui <taohui3@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Co-authored-by: Nick Hill <nhill@redhat.com> Co-authored-by: Zhuohan Li <zhuohan123@gmail.com> Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com> Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> Co-authored-by: Shu Wang <shuw@nvidia.com> Co-authored-by: Aleksandr Malyshev <164964928+maleksan85@users.noreply.github.com> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Doug Lehr <douglehr@amd.com> Co-authored-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Co-authored-by: yitingdc <59356937+yitingdc@users.noreply.github.com> Co-authored-by: Andrew Sansom <andrew@protopia.ai> Co-authored-by: xaguilar-amd <xavier.aguilarfruto@amd.com> Co-authored-by: Iceber Gu <caiwei95@hotmail.com> Co-authored-by: Tao He <linzhu.ht@alibaba-inc.com> Co-authored-by: Icey <1790571317@qq.com> Co-authored-by: Sage Moore <sage@neuralmagic.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Xu Wenqing <121550081+Xu-Wenqing@users.noreply.github.com> Co-authored-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: RishiAstra <40644327+RishiAstra@users.noreply.github.com> Co-authored-by: Chauncey <chaunceyjiang@gmail.com> Co-authored-by: Seiji Eicher <58963096+eicherseiji@users.noreply.github.com> Co-authored-by: Rui Qiao <161574667+ruisearch42@users.noreply.github.com> Co-authored-by: Jiangyun Zhu <riverclouds.zhu@qq.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: 阿丹(adan) <47373076+LDLINGLINGLING@users.noreply.github.com> Co-authored-by: liudan <adan@minicpm.com> Co-authored-by: liudan <liudan@qq.com> Co-authored-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Co-authored-by: Clouddude <kouss.hd@gmail.com> Co-authored-by: Frank Wang <41319051+frankwang28@users.noreply.github.com> Co-authored-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> Co-authored-by: qizixi <22851944+zixi-qi@users.noreply.github.com> Co-authored-by: Bram Wasti <bwasti@fb.com> Co-authored-by: Naman Lalit <nl2688@nyu.edu> Co-authored-by: Chenheli Hua <huachenheli@outlook.com> Co-authored-by: WeiQing Chen <40507679+david6666666@users.noreply.github.com> Co-authored-by: Junhong <liujunhong11@huawei.com> Co-authored-by: LJH-LBJ <98734602+LJH-LBJ@users.noreply.github.com> Co-authored-by: 22quinn <33176974+22quinn@users.noreply.github.com> Co-authored-by: Xiaohan Zou <renovamenzxh@gmail.com> Co-authored-by: rentianyue-jk <rentianyue-jk@360shuke.com> Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Co-authored-by: Peter Pan <peter.pan@daocloud.io> Co-authored-by: Patrick C. Toulme <135739773+patrick-toulme@users.noreply.github.com> Co-authored-by: Clayton Coleman <smarterclayton@gmail.com> Co-authored-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Co-authored-by: Jialin Ouyang <jialino@meta.com> Co-authored-by: weiliang <weiliangl@nvidia.com> Co-authored-by: Yuxuan Zhang <2448370773@qq.com> Co-authored-by: JJJYmmm <92386084+JJJYmmm@users.noreply.github.com> Co-authored-by: liuye.hj <liuye.hj@alibaba-inc.com> Co-authored-by: Juechen Liu <grinchcoder@gmail.com> Co-authored-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Yingjun Mou <renzomou@gmail.com> Co-authored-by: Zhou Jiahao <me@zhoukz.com> Co-authored-by: Chenxi Yang <cxyang@cs.utexas.edu> Co-authored-by: Chenxi Yang <cxyang@fb.com> Co-authored-by: Rahul Tuli <rtuli@redhat.com> Co-authored-by: Lee Nau <lee.nau@gmail.com> Co-authored-by: Adrian Abeyta <aabeyta@redhat.com> Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Co-authored-by: Aaron Pham <contact@aarnphm.xyz> Co-authored-by: acisseJZhong <40467976+acisseJZhong@users.noreply.github.com> Co-authored-by: Simon Danielsson <70206058+simondanielsson@users.noreply.github.com> Co-authored-by: Yongye Zhu <zyy1102000@gmail.com> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Lucia Fang <fanglu@meta.com> Co-authored-by: Siyuan Fu <siyuanf@nvidia.com> Co-authored-by: Xiaozhu Meng <mxz297@gmail.com> Co-authored-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> Co-authored-by: a120092009 <33205509+a120092009@users.noreply.github.com> Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com> Co-authored-by: CSWYF3634076 <wangyafeng@baidu.com> Co-authored-by: Lehua Ding <lehuading@tencent.com> Co-authored-by: Reza Barazesh <3146276+rzabarazesh@users.noreply.github.com> Co-authored-by: ihb2032 <40718643+ihb2032@users.noreply.github.com> Co-authored-by: Asaf Joseph Gardin <39553475+Josephasafg@users.noreply.github.com> Co-authored-by: Anion <123177548+Anionex@users.noreply.github.com> Co-authored-by: Pavani Majety <pmajety@nvidia.com> Co-authored-by: bnellnm <49004751+bnellnm@users.noreply.github.com> Co-authored-by: Or Ozeri <oro@il.ibm.com> Co-authored-by: cjackal <44624812+cjackal@users.noreply.github.com> Co-authored-by: David Ben-David <sdavidbd@gmail.com> Co-authored-by: David Ben-David <davidb@pliops.com> Co-authored-by: Andrew Xia <axia@mit.edu> Co-authored-by: Andrew Xia <axia@fb.com> Co-authored-by: Salvatore Cena <cena@cenas.it> Co-authored-by: Param <psch@cs.unc.edu> Co-authored-by: Zhewen Li <zhewenli@meta.com> Co-authored-by: nadathurv <work.vnadathur@gmail.com> Co-authored-by: Srreyansh Sethi <107075589+WorldExplored@users.noreply.github.com> Co-authored-by: Wenlong Wang <wangwenlong2755@gmail.com> Co-authored-by: billishyahao <bill.he@amd.com> Co-authored-by: Nathan Scott <natoscott@users.noreply.github.com> Co-authored-by: Kenichi Maehashi <939877+kmaehashi@users.noreply.github.com> Co-authored-by: Johnny <johnnync13@gmail.com> Co-authored-by: Aidyn-A <31858918+Aidyn-A@users.noreply.github.com> Co-authored-by: Huamin Li <3ericli@gmail.com> Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com> Co-authored-by: Hosang <156028780+hyoon1@users.noreply.github.com> Co-authored-by: Jerry Zhang <jerryzh168@gmail.com> Co-authored-by: pwschuurman <psch@google.com> Co-authored-by: Huy Do <huydhn@gmail.com> Co-authored-by: leo-pony <nengjunma@outlook.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: ElizaWszola <ewszola@redhat.com> Co-authored-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: Benjamin Chislett <bchislett@nvidia.com> Co-authored-by: Andrew Xia <axia@meta.com> Co-authored-by: Simon Mo <simon.mo@hey.com> Co-authored-by: TJian <tunjian.tan@embeddedllm.com> Co-authored-by: ahao-anyscale <ahao@anyscale.com> Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Liu-congo <1502632128@qq.com> Co-authored-by: HUIJONG JEONG <64083281+huijjj@users.noreply.github.com> Co-authored-by: Yannick Schnider <Yannick.Schnider1@ibm.com> Co-authored-by: kyt <eluban4532@gmail.com> Co-authored-by: Egor <e.a.krivov@gmail.com> Co-authored-by: Yang Liu <127183760+KKSK-DON@users.noreply.github.com> Co-authored-by: Paul Pak <52512091+paulpak58@users.noreply.github.com> Co-authored-by: whx <56632993+whx-sjtu@users.noreply.github.com> Co-authored-by: Xiang Si <sixiang@google.com> Co-authored-by: Aleksandr Samarin <samarin_ad@mail.ru> Co-authored-by: Jun Jiang <jasl9187@hotmail.com> Co-authored-by: Chendi.Xue <chendi.xue@intel.com> Co-authored-by: Nikhil G <nrghosh@users.noreply.github.com>
672 lines
20 KiB
Plaintext
672 lines
20 KiB
Plaintext
#pragma once
|
|
#include <hip/hip_fp8.h>
|
|
|
|
#include <hip/hip_fp16.h>
|
|
#include <hip/hip_bf16.h>
|
|
#include <hip/hip_bfloat16.h>
|
|
|
|
#include "../../../../attention/attention_dtypes.h"
|
|
|
|
namespace vllm {
|
|
#ifdef USE_ROCM
|
|
|
|
namespace fp8 {
|
|
#ifdef ENABLE_FP8
|
|
|
|
// Use hardware cvt instruction for fp8 on rocm
|
|
template <typename fp8_type>
|
|
__device__ __forceinline__ fp8_type cvt_c10(float const r) {
|
|
return {};
|
|
}
|
|
|
|
// __hip_fp8_e4m3 only exists starting in ROCm 6.3. The macro
|
|
// HIP_FP8_TYPE_OCP comes from the hip_fp8.h header and also makes
|
|
// its first appearance in ROCm 6.3. Since VLLM_DISPATCH_FP8_TYPES
|
|
// on ROCm instantiates both OCP and FNUZ kernels, we need to replace
|
|
// the new HW cvt with something reasonable that doesn't rely on the
|
|
// ROCm 6.3 feature. This allows compiling on ROCm 6.2 or newer.
|
|
template <>
|
|
__device__ __forceinline__ c10::Float8_e4m3fn cvt_c10(float const r) {
|
|
#if HIP_FP8_TYPE_OCP
|
|
return c10::Float8_e4m3fn(
|
|
__hip_cvt_float_to_fp8(r, __hip_fp8_e4m3::__default_saturation,
|
|
__hip_fp8_e4m3::__default_interpret),
|
|
c10::Float8_e4m3fn::from_bits());
|
|
#else
|
|
// Cast implemented by pytorch. Uses bit manipulation instead of HW cvt.
|
|
// HW cvt above is faster when it is available (ROCm 6.3 or newer).
|
|
return static_cast<c10::Float8_e4m3fn>(r);
|
|
#endif
|
|
}
|
|
|
|
template <>
|
|
__device__ __forceinline__ c10::Float8_e4m3fnuz cvt_c10(float const r) {
|
|
return c10::Float8_e4m3fnuz(
|
|
__hip_cvt_float_to_fp8(r, __hip_fp8_e4m3_fnuz::__default_saturation,
|
|
__hip_fp8_e4m3_fnuz::__default_interpret),
|
|
c10::Float8_e4m3fnuz::from_bits());
|
|
}
|
|
|
|
template <typename Tout, typename Tin>
|
|
__inline__ __device__ Tout vec_conversion(const Tin& x) {
|
|
return x;
|
|
}
|
|
|
|
template <typename Tout, typename Tin>
|
|
__inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
|
|
const float scale) {
|
|
return x;
|
|
}
|
|
|
|
#if HIP_FP8_TYPE_OCP
|
|
using fp8_type = __hip_fp8_e4m3;
|
|
using fp8x2_type = __hip_fp8x2_e4m3;
|
|
#else
|
|
using fp8_type = __hip_fp8_e4m3_fnuz;
|
|
using fp8x2_type = __hip_fp8x2_e4m3_fnuz;
|
|
#endif
|
|
|
|
// fp8 -> half
|
|
template <>
|
|
__inline__ __device__ uint16_t
|
|
vec_conversion<uint16_t, uint8_t>(const uint8_t& a) {
|
|
return __hip_cvt_fp8_to_halfraw(a, fp8_type::__default_interpret).x;
|
|
}
|
|
|
|
// fp8x2 -> half2
|
|
template <>
|
|
__inline__ __device__ uint32_t
|
|
vec_conversion<uint32_t, uint16_t>(const uint16_t& a) {
|
|
union {
|
|
__half2_raw h2r;
|
|
uint32_t ui32;
|
|
} tmp;
|
|
tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
|
|
return tmp.ui32;
|
|
}
|
|
|
|
// fp8x4 -> half2x2
|
|
template <>
|
|
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a) {
|
|
union {
|
|
uint2 u32x2;
|
|
uint32_t u32[2];
|
|
} tmp;
|
|
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
|
|
tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
|
|
return tmp.u32x2;
|
|
}
|
|
|
|
// fp8x8 -> half2x4
|
|
template <>
|
|
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a) {
|
|
union {
|
|
uint4 u64x2;
|
|
uint2 u64[2];
|
|
} tmp;
|
|
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
|
|
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
|
|
return tmp.u64x2;
|
|
}
|
|
|
|
using __nv_bfloat16 = __hip_bfloat16;
|
|
|
|
// fp8 -> __nv_bfloat16
|
|
template <>
|
|
__inline__ __device__ __nv_bfloat16
|
|
vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) {
|
|
fp8_type f8;
|
|
f8.__x = a;
|
|
return __float2bfloat16(static_cast<float>(f8));
|
|
}
|
|
|
|
using __nv_bfloat162 = __hip_bfloat162;
|
|
|
|
// fp8x2 -> __nv_bfloat162
|
|
template <>
|
|
__inline__ __device__ __nv_bfloat162
|
|
vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) {
|
|
__nv_bfloat162 res;
|
|
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
|
|
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
|
|
return res;
|
|
}
|
|
|
|
// fp8x4 -> bf16_4_t
|
|
template <>
|
|
__inline__ __device__ bf16_4_t
|
|
vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a) {
|
|
bf16_4_t res;
|
|
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
|
|
res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
|
|
return res;
|
|
}
|
|
|
|
// fp8x8 -> bf16_8_t
|
|
template <>
|
|
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a) {
|
|
bf16_4_t tmp1, tmp2;
|
|
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
|
|
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
|
|
bf16_8_t res;
|
|
res.x = tmp1.x;
|
|
res.y = tmp1.y;
|
|
res.z = tmp2.x;
|
|
res.w = tmp2.y;
|
|
return res;
|
|
}
|
|
|
|
// fp8 -> float
|
|
template <>
|
|
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a) {
|
|
fp8_type f8;
|
|
f8.__x = a;
|
|
return static_cast<float>(f8);
|
|
}
|
|
|
|
// fp8x2 -> float2
|
|
template <>
|
|
__inline__ __device__ float2
|
|
vec_conversion<float2, uint16_t>(const uint16_t& a) {
|
|
fp8x2_type f8x2;
|
|
f8x2.__x = a;
|
|
return static_cast<float2>(f8x2);
|
|
}
|
|
|
|
// fp8x4 -> float4
|
|
template <>
|
|
__inline__ __device__ Float4_
|
|
vec_conversion<Float4_, uint32_t>(const uint32_t& a) {
|
|
Float4_ res;
|
|
res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
|
|
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
|
|
return res;
|
|
}
|
|
|
|
// fp8x4 -> float4
|
|
template <>
|
|
__inline__ __device__ float4
|
|
vec_conversion<float4, uint32_t>(const uint32_t& a) {
|
|
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
|
|
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
|
return res;
|
|
}
|
|
|
|
// fp8x8 -> float8
|
|
template <>
|
|
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a) {
|
|
Float4_ tmp1, tmp2;
|
|
tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
|
|
tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
|
|
Float8_ res;
|
|
res.x = tmp1.x;
|
|
res.y = tmp1.y;
|
|
res.z = tmp2.x;
|
|
res.w = tmp2.y;
|
|
return res;
|
|
}
|
|
|
|
// half -> fp8
|
|
template <>
|
|
__inline__ __device__ uint8_t
|
|
vec_conversion<uint8_t, uint16_t>(const uint16_t& a) {
|
|
__half_raw tmp;
|
|
tmp.x = a;
|
|
return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation,
|
|
fp8_type::__default_interpret);
|
|
}
|
|
|
|
template <>
|
|
__inline__ __device__ uint16_t
|
|
vec_conversion<uint16_t, uint32_t>(const uint32_t& a) {
|
|
union {
|
|
uint32_t ui32;
|
|
__half2_raw h2r;
|
|
} tmp;
|
|
tmp.ui32 = a;
|
|
return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation,
|
|
fp8_type::__default_interpret);
|
|
}
|
|
|
|
// bf16 -> fp8
|
|
template <>
|
|
__inline__ __device__ uint8_t
|
|
vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a) {
|
|
return __hip_cvt_float_to_fp8(__bfloat162float(a),
|
|
fp8_type::__default_saturation,
|
|
fp8_type::__default_interpret);
|
|
}
|
|
|
|
// float -> fp8
|
|
template <>
|
|
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a) {
|
|
return __hip_cvt_float_to_fp8(a, fp8_type::__default_saturation,
|
|
fp8_type::__default_interpret);
|
|
}
|
|
|
|
// float2 -> half2
|
|
template <>
|
|
__inline__ __device__ uint32_t
|
|
vec_conversion<uint32_t, float2>(const float2& a) {
|
|
union {
|
|
half2 float16;
|
|
uint32_t uint32;
|
|
};
|
|
|
|
float16 = __float22half2_rn(a);
|
|
return uint32;
|
|
}
|
|
|
|
// Float4 -> half2x2
|
|
template <>
|
|
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a) {
|
|
uint2 b;
|
|
float2 val;
|
|
val.x = a.x.x;
|
|
val.y = a.x.y;
|
|
b.x = vec_conversion<uint32_t, float2>(val);
|
|
|
|
val.x = a.y.x;
|
|
val.y = a.y.y;
|
|
b.y = vec_conversion<uint32_t, float2>(val);
|
|
return b;
|
|
}
|
|
|
|
// Float4 -> float4
|
|
template <>
|
|
__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a) {
|
|
float4 b;
|
|
b.x = a.x.x;
|
|
b.y = a.x.y;
|
|
b.z = a.y.x;
|
|
b.w = a.y.y;
|
|
return b;
|
|
}
|
|
|
|
// Float8 -> half2x4
|
|
template <>
|
|
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a) {
|
|
uint4 b;
|
|
b.x = vec_conversion<uint32_t, float2>(a.x);
|
|
b.y = vec_conversion<uint32_t, float2>(a.y);
|
|
b.z = vec_conversion<uint32_t, float2>(a.z);
|
|
b.w = vec_conversion<uint32_t, float2>(a.w);
|
|
return b;
|
|
}
|
|
|
|
// float2 -> bfloat162
|
|
template <>
|
|
__inline__ __device__ __nv_bfloat162
|
|
vec_conversion<__nv_bfloat162, float2>(const float2& a) {
|
|
__nv_bfloat162 b = __float22bfloat162_rn(a);
|
|
return b;
|
|
}
|
|
|
|
// Float4 -> bfloat162x2
|
|
template <>
|
|
__inline__ __device__ bf16_4_t
|
|
vec_conversion<bf16_4_t, Float4_>(const Float4_& a) {
|
|
bf16_4_t b;
|
|
b.x = __float22bfloat162_rn(a.x);
|
|
b.y = __float22bfloat162_rn(a.y);
|
|
return b;
|
|
}
|
|
|
|
// Float8 -> bfloat162x4
|
|
template <>
|
|
__inline__ __device__ bf16_8_t
|
|
vec_conversion<bf16_8_t, Float8_>(const Float8_& a) {
|
|
bf16_8_t b;
|
|
b.x = __float22bfloat162_rn(a.x);
|
|
b.y = __float22bfloat162_rn(a.y);
|
|
b.z = __float22bfloat162_rn(a.z);
|
|
b.w = __float22bfloat162_rn(a.w);
|
|
return b;
|
|
}
|
|
|
|
/* Scaled and vectorized conversions, for data exchange between high and low
|
|
precision domains
|
|
|
|
Convention of the scale in API, e.g: FP8_data = Quantization(
|
|
High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8 Dequant(FP8) *
|
|
scale => HP
|
|
|
|
*/
|
|
|
|
using __nv_bfloat16 = __hip_bfloat16;
|
|
|
|
// fp8 -> __nv_bfloat16
|
|
template <>
|
|
__inline__ __device__ __nv_bfloat16
|
|
scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) {
|
|
fp8_type f8;
|
|
f8.__x = a;
|
|
return __float2bfloat16(static_cast<float>(f8) * scale);
|
|
}
|
|
|
|
// fp8x2 -> __nv_bfloat162
|
|
template <>
|
|
__inline__ __device__ __nv_bfloat162
|
|
scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a,
|
|
float scale) {
|
|
__nv_bfloat162 res;
|
|
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale);
|
|
res.y =
|
|
scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale);
|
|
return res;
|
|
}
|
|
|
|
// fp8x4 -> bf16_4_t
|
|
template <>
|
|
__inline__ __device__ bf16_4_t
|
|
scaled_vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a, float scale) {
|
|
bf16_4_t res;
|
|
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale);
|
|
res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
|
|
scale);
|
|
return res;
|
|
}
|
|
|
|
// fp8x8 -> bf16_8_t
|
|
template <>
|
|
__inline__ __device__ bf16_8_t
|
|
scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, float scale) {
|
|
bf16_4_t tmp1, tmp2;
|
|
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale);
|
|
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale);
|
|
bf16_8_t res;
|
|
res.x = tmp1.x;
|
|
res.y = tmp1.y;
|
|
res.z = tmp2.x;
|
|
res.w = tmp2.y;
|
|
return res;
|
|
}
|
|
|
|
// fp8 -> float
|
|
template <>
|
|
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
|
|
const uint8_t& a, float scale) {
|
|
fp8_type f8;
|
|
f8.__x = a;
|
|
return static_cast<float>(f8) * scale;
|
|
}
|
|
|
|
// fp8x2 -> float2
|
|
template <>
|
|
__inline__ __device__ float2
|
|
scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, float scale) {
|
|
fp8x2_type f8x2;
|
|
f8x2.__x = a;
|
|
return static_cast<float2>(f8x2) * scale;
|
|
}
|
|
|
|
// fp8x4 -> float4
|
|
template <>
|
|
__inline__ __device__ Float4_
|
|
scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale) {
|
|
Float4_ res;
|
|
res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale);
|
|
res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale);
|
|
return res;
|
|
}
|
|
|
|
// fp8x4 -> float4
|
|
template <>
|
|
__inline__ __device__ float4
|
|
scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, float scale) {
|
|
Float4_ res = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
|
|
return {res.x.x, res.x.y, res.y.x, res.y.y};
|
|
}
|
|
|
|
// fp8x8 -> float8
|
|
template <>
|
|
__inline__ __device__ Float8_
|
|
scaled_vec_conversion<Float8_, uint2>(const uint2& a, float scale) {
|
|
Float4_ tmp1, tmp2;
|
|
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale);
|
|
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale);
|
|
Float8_ res;
|
|
res.x = tmp1.x;
|
|
res.y = tmp1.y;
|
|
res.z = tmp2.x;
|
|
res.w = tmp2.y;
|
|
return res;
|
|
}
|
|
|
|
// fp8 -> half
|
|
template <>
|
|
__inline__ __device__ uint16_t
|
|
scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale) {
|
|
__half_raw res;
|
|
res.data = scaled_vec_conversion<float, uint8_t>(a, scale);
|
|
return res.x;
|
|
}
|
|
|
|
// fp8x2 -> half2
|
|
template <>
|
|
__inline__ __device__ uint32_t
|
|
scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale) {
|
|
union {
|
|
__half2_raw h2r;
|
|
uint32_t ui32;
|
|
} tmp;
|
|
tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
|
|
tmp.h2r.x.data *= scale;
|
|
tmp.h2r.y.data *= scale;
|
|
return tmp.ui32;
|
|
}
|
|
|
|
// fp8x4 -> half2x2
|
|
template <>
|
|
__inline__ __device__ uint2
|
|
scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, float scale) {
|
|
union {
|
|
uint2 u32x2;
|
|
uint32_t u32[2];
|
|
} tmp;
|
|
tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
|
|
tmp.u32[1] =
|
|
scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
|
|
return tmp.u32x2;
|
|
}
|
|
|
|
// fp8x8 -> half2x4
|
|
template <>
|
|
__inline__ __device__ uint4 scaled_vec_conversion<uint4, uint2>(const uint2& a,
|
|
float scale) {
|
|
union {
|
|
uint4 u64x2;
|
|
uint2 u64[2];
|
|
} tmp;
|
|
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale);
|
|
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale);
|
|
return tmp.u64x2;
|
|
}
|
|
|
|
// half -> fp8
|
|
template <>
|
|
__inline__ __device__ uint8_t
|
|
scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, float scale) {
|
|
__half_raw tmp;
|
|
tmp.x = a;
|
|
tmp.data /= scale;
|
|
return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation,
|
|
fp8_type::__default_interpret);
|
|
}
|
|
|
|
// halfx2 -> fp8x2
|
|
template <>
|
|
__inline__ __device__ uint16_t
|
|
scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale) {
|
|
union {
|
|
uint32_t ui32;
|
|
__half2_raw h2r;
|
|
} tmp;
|
|
tmp.ui32 = a;
|
|
tmp.h2r.x.data /= scale;
|
|
tmp.h2r.y.data /= scale;
|
|
return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation,
|
|
fp8_type::__default_interpret);
|
|
}
|
|
|
|
// half2x2 -> fp8x4
|
|
template <>
|
|
__inline__ __device__ uint32_t
|
|
scaled_vec_conversion<uint32_t, uint2>(const uint2& a, float scale) {
|
|
union {
|
|
uint16_t ui16[2];
|
|
uint32_t ui32;
|
|
} tmp;
|
|
tmp.ui16[0] = scaled_vec_conversion<uint16_t, uint32_t>(a.x, scale);
|
|
tmp.ui16[1] = scaled_vec_conversion<uint16_t, uint32_t>(a.y, scale);
|
|
return tmp.ui32;
|
|
}
|
|
|
|
// half2x4 -> fp8x8
|
|
template <>
|
|
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint4>(const uint4& a,
|
|
float scale) {
|
|
union {
|
|
uint2 ui2[2];
|
|
uint4 ui4;
|
|
} tmp;
|
|
tmp.ui4 = a;
|
|
uint2 res;
|
|
res.x = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[0], scale);
|
|
res.y = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[1], scale);
|
|
return res;
|
|
}
|
|
|
|
// bf16 -> fp8
|
|
template <>
|
|
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
|
|
const __nv_bfloat16& a, float scale) {
|
|
return __hip_cvt_float_to_fp8(__bfloat162float(a) / scale,
|
|
fp8_type::__default_saturation,
|
|
fp8_type::__default_interpret);
|
|
}
|
|
|
|
// bf16x2 -> fp8x2
|
|
template <>
|
|
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, __nv_bfloat162>(
|
|
const __nv_bfloat162& a, float scale) {
|
|
union {
|
|
uint8_t ui8[2];
|
|
uint16_t ui16;
|
|
} tmp;
|
|
tmp.ui8[0] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.x, scale);
|
|
tmp.ui8[1] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.y, scale);
|
|
return tmp.ui16;
|
|
}
|
|
|
|
// bf16x4 -> fp8x4
|
|
template <>
|
|
__inline__ __device__ uint32_t
|
|
scaled_vec_conversion<uint32_t, bf16_4_t>(const bf16_4_t& a, float scale) {
|
|
union {
|
|
uint16_t ui16[2];
|
|
uint32_t ui32;
|
|
} tmp;
|
|
tmp.ui16[0] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.x, scale);
|
|
tmp.ui16[1] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.y, scale);
|
|
return tmp.ui32;
|
|
}
|
|
|
|
// bf16x8 -> fp8x8
|
|
template <>
|
|
__inline__ __device__ uint2
|
|
scaled_vec_conversion<uint2, bf16_8_t>(const bf16_8_t& a, float scale) {
|
|
uint2 res;
|
|
res.x = scaled_vec_conversion<uint32_t, bf16_4_t>({a.x, a.y}, scale);
|
|
res.y = scaled_vec_conversion<uint32_t, bf16_4_t>({a.z, a.w}, scale);
|
|
return res;
|
|
}
|
|
|
|
// float -> fp8
|
|
template <>
|
|
__inline__ __device__ uint8_t
|
|
scaled_vec_conversion<uint8_t, float>(const float& a, float scale) {
|
|
return __hip_cvt_float_to_fp8(a / scale, fp8_type::__default_saturation,
|
|
fp8_type::__default_interpret);
|
|
}
|
|
|
|
// floatx2 -> fp8x2
|
|
template <>
|
|
__inline__ __device__ uint16_t
|
|
scaled_vec_conversion<uint16_t, float2>(const float2& a, float scale) {
|
|
return __hip_cvt_float2_to_fp8x2(a / scale, fp8_type::__default_saturation,
|
|
fp8_type::__default_interpret);
|
|
}
|
|
|
|
// floatx4 -> fp8x4
|
|
template <>
|
|
__inline__ __device__ uint32_t
|
|
scaled_vec_conversion<uint32_t, float4>(const float4& a, float scale) {
|
|
union {
|
|
uint16_t ui16[2];
|
|
uint32_t ui32;
|
|
} tmp;
|
|
tmp.ui16[0] = scaled_vec_conversion<uint16_t, float2>({a.x, a.y}, scale);
|
|
tmp.ui16[1] = scaled_vec_conversion<uint16_t, float2>({a.z, a.w}, scale);
|
|
return tmp.ui32;
|
|
}
|
|
#endif // ENABLE_FP8
|
|
|
|
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
|
__inline__ __device__ Tout convert(const Tin& x) {
|
|
#ifdef ENABLE_FP8
|
|
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
|
|
return vec_conversion<Tout, Tin>(x);
|
|
}
|
|
#endif
|
|
assert(false);
|
|
return {}; // Squash missing return statement warning
|
|
}
|
|
|
|
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
|
__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
|
|
#ifdef ENABLE_FP8
|
|
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
|
|
return scaled_vec_conversion<Tout, Tin>(x, scale);
|
|
}
|
|
#endif
|
|
assert(false);
|
|
return {}; // Squash missing return statement warning
|
|
}
|
|
|
|
// The following macro is used to dispatch the conversion function based on
|
|
// the data type of the key and value cache. The FN is a macro that calls a
|
|
// function with template<typename scalar_t, typename cache_t,
|
|
// Fp8KVCacheDataType kv_dt>.
|
|
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
|
|
if (KV_DTYPE == "auto") { \
|
|
if (SRC_DTYPE == at::ScalarType::Float) { \
|
|
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
|
|
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
|
FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \
|
|
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
|
FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \
|
|
} else { \
|
|
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
|
|
} \
|
|
} else { \
|
|
if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
|
|
if (SRC_DTYPE == at::ScalarType::Float) { \
|
|
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
|
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
|
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
|
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
|
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
|
} else { \
|
|
TORCH_CHECK(false, \
|
|
"Unsupported input type of kv cache: ", SRC_DTYPE); \
|
|
} \
|
|
} else { \
|
|
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
|
|
} \
|
|
}
|
|
|
|
} // namespace fp8
|
|
#endif // USE_ROCM
|
|
} // namespace vllm
|