mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:05:02 +08:00
Signed-off-by: nicole-lihui <nicole.li@daocloud.io> Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: courage17340 <courage17340@163.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Jacob Kahn <jacobkahn1@gmail.com> Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com> Signed-off-by: Roger Wang <hey@rogerw.io> Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: zxw <1020938856@qq.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: wang.yuqi <noooop@126.com> Signed-off-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Signed-off-by: Kunshang Ji <kunshang.ji@intel.com> Signed-off-by: chenlang <chen.lang5@zte.com.cn> Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: Jonas Kuebler <kuebj@amazon.com> Signed-off-by: jiang1.li <jiang1.li@intel.com> Signed-off-by: Russell Bryant <rbryant@redhat.com> Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Signed-off-by: AlonKejzman <alonkeizman@gmail.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: taohui <taohui3@gmail.com> Signed-off-by: Tao Hui <taohui3@gmail.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Signed-off-by: Zhuohan Li <zhuohan123@gmail.com> Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Signed-off-by: Shu Wang. <shuw@nvidia.com> Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Signed-off-by: Eugene Khvedchenia <ekhvedchenia@nvidia.com> Signed-off-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Signed-off-by: yiting.jiang <yiting.jiang@daocloud.io> Signed-off-by: Andrew Sansom <andrew@protopia.ai> Signed-off-by: xaguilar <Xavier.AguilarFruto@amd.com> Signed-off-by: Iceber Gu <caiwei95@hotmail.com> Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com> Signed-off-by: Icey <1790571317@qq.com> Signed-off-by: Sage Moore <sage@neuralmagic.com> Signed-off-by: 许文卿 <xwq391974@alibaba-inc.com> Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com> Signed-off-by: Seiji Eicher <seiji@anyscale.com> Signed-off-by: Seiji Eicher <58963096+eicherseiji@users.noreply.github.com> Signed-off-by: zjy0516 <riverclouds.zhu@qq.com> Signed-off-by: Kosseila (CloudThrill) <klouddude@gmail.com> Signed-off-by: frankwang28 <frank.wbb@hotmail.com> Signed-off-by: Frank Wang <41319051+frankwang28@users.noreply.github.com> Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> Signed-off-by: zixi-qi <qizixi@meta.com> Signed-off-by: Bram Wasti <bwasti@meta.com> Signed-off-by: Naman Lalit <nl2688@nyu.edu> Signed-off-by: Chenheli Hua <huachenheli@outlook.com> Signed-off-by: Junhong <liujunhong11@huawei.com> Signed-off-by: Junhong Liu <98734602+LJH-LBJ@users.noreply.github.com> Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com> Signed-off-by: rentianyue-jk <rentianyue-jk@360shuke.com> Signed-off-by: Peter Pan <Peter.Pan@daocloud.io> Signed-off-by: Patrick Toulme <ptoulme@meta.com> Signed-off-by: Patrick Toulme <pctoulme+1@gmail.com> Signed-off-by: Jiangyun Zhu <riverclouds.zhu@qq.com> Signed-off-by: Clayton Coleman <smarterclayton@gmail.com> Signed-off-by: Jialin Ouyang <jialino@meta.com> Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Signed-off-by: Weiliang Liu <weiliangl@nvidia.com> Signed-off-by: zRzRzRzRzRzRzR <2448370773@qq.com> Signed-off-by: liuye.hj <liuye.hj@alibaba-inc.com> Signed-off-by: Juechen Liu <jueliu@meta.com> Signed-off-by: simon-mo <simon.mo@hey.com> Signed-off-by: Robert Shaw <robshaw@redhat.com> Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Signed-off-by: isotr0py <2037008807@qq.com> Signed-off-by: yingjun-mou <renzomou@gmail.com> Signed-off-by: zhoukz <me@zhoukz.com> Signed-off-by: Chenxi Yang <cxyang@fb.com> Signed-off-by: Rahul Tuli <rtuli@redhat.com> Signed-off-by: Lee Nau <lnau@nvidia.com> Signed-off-by: adabeyta <aabeyta@redhat.com> Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: simondanielsson <simon.danielsson99@hotmail.com> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> Signed-off-by: Lucia Fang <fanglu@meta.com> Signed-off-by: a120092009 <zhaoty0121@gmail.com> Signed-off-by: sergiopaniego <sergiopaniegoblanco@gmail.com> Signed-off-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com> Signed-off-by: wangyafeng <wangyafeng@baidu.com> Signed-off-by: Lehua Ding <lehuading@tencent.com> Signed-off-by: lyd1992 <liuyudong@iscas.ac.cn> Signed-off-by: ihb2032 <1355790728@qq.com> Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com> Signed-off-by: anion <1005128408@qq.com> Signed-off-by: Anion <123177548+Anionex@users.noreply.github.com> Signed-off-by: Pavani Majety <pmajety@nvidia.com> Signed-off-by: Bill Nell <bnell@redhat.com> Signed-off-by: bnellnm <49004751+bnellnm@users.noreply.github.com> Signed-off-by: Or Ozeri <oro@il.ibm.com> Signed-off-by: cjackal <44624812+cjackal@users.noreply.github.com> Signed-off-by: David Ben-David <davidb@pliops.com> Signed-off-by: Andrew Xia <axia@meta.com> Signed-off-by: Andrew Xia <axia@fb.com> Signed-off-by: Lu Fang <fanglu@fb.com> Signed-off-by: Salvatore Cena <cena@cenas.it> Signed-off-by: padg9912 <phone.and.desktop@gmail.com> Signed-off-by: nadathurv <work.vnadathur@gmail.com> Signed-off-by: WorldExplored <srreyansh.sethi@gmail.com> Signed-off-by: wwl2755 <wangwenlong2755@gmail.com> Signed-off-by: billishyahao <bill.he@amd.com> Signed-off-by: Nathan Scott <nathans@redhat.com> Signed-off-by: Kenichi Maehashi <maehashi@preferred.jp> Signed-off-by: Johnny <johnnynuca14@gmail.com> Signed-off-by: johnnynunez <johnnynuca14@gmail.com> Signed-off-by: Johnny <johnnync13@gmail.com> Signed-off-by: Huamin Li <3ericli@gmail.com> Signed-off-by: Hosang Yoon <hosang.yoon@amd.com> Signed-off-by: Jerry Zhang <jerryzh168@gmail.com> Signed-off-by: Peter Schuurman <psch@google.com> Signed-off-by: Huy Do <huydhn@gmail.com> Signed-off-by: leo-pony <nengjunma@outlook.com> Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: ElizaWszola <elizaw.9289@gmail.com> Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Signed-off-by: zhewenli <zhewenli@meta.com> Signed-off-by: ahao-anyscale <ahao@anyscale.com> Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Signed-off-by: huijjj <huijong.jeong@squeezebits.com> Signed-off-by: Yannick Schnider <yannick.schnider1@ibm.com> Signed-off-by: kyt <eluban4532@gmail.com> Signed-off-by: Egor <e.a.krivov@gmail.com> Signed-off-by: Yang <lymailforjob@gmail.com> Signed-off-by: Paul Pak <paulpak58@gmail.com> Signed-off-by: whx-sjtu <2952154980@qq.com> Signed-off-by: Xiang Si <sixiang@google.com> Signed-off-by: Aleksandr Samarin <astrlrd@nebius.com> Signed-off-by: Jun Jiang <jasl9187@hotmail.com> Signed-off-by: Chendi Xue <Chendi.Xue@intel.com> Signed-off-by: Chendi.Xue <chendi.xue@intel.com> Signed-off-by: Nikhil Ghosh <nikhil@anyscale.com> Co-authored-by: Nicole LiHui 🥜 <nicolelihui@outlook.com> Co-authored-by: courage17340 <courage17340@users.noreply.github.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Co-authored-by: Jacob Kahn <jacobkahn1@gmail.com> Co-authored-by: Roger Wang <hey@rogerw.io> Co-authored-by: Nicole LiHui 🥜 <nicole.li@daocloud.io> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Fadi Arafeh <115173828+fadara01@users.noreply.github.com> Co-authored-by: Agata Dobrzyniewicz <160237065+adobrzyn@users.noreply.github.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: yyzxw <34639446+yyzxw@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: wang.yuqi <noooop@126.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com> Co-authored-by: chenlang <chen.lang5@zte.com.cn> Co-authored-by: chenlang <10346245@zte.com.cn> Co-authored-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Jonas M. Kübler <44084297+jmkuebler@users.noreply.github.com> Co-authored-by: Li, Jiang <jiang1.li@intel.com> Co-authored-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com> Co-authored-by: AlonKejzman <alonkeizman@gmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Tao Hui <taohui3@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Co-authored-by: Nick Hill <nhill@redhat.com> Co-authored-by: Zhuohan Li <zhuohan123@gmail.com> Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com> Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> Co-authored-by: Shu Wang <shuw@nvidia.com> Co-authored-by: Aleksandr Malyshev <164964928+maleksan85@users.noreply.github.com> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Doug Lehr <douglehr@amd.com> Co-authored-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Co-authored-by: yitingdc <59356937+yitingdc@users.noreply.github.com> Co-authored-by: Andrew Sansom <andrew@protopia.ai> Co-authored-by: xaguilar-amd <xavier.aguilarfruto@amd.com> Co-authored-by: Iceber Gu <caiwei95@hotmail.com> Co-authored-by: Tao He <linzhu.ht@alibaba-inc.com> Co-authored-by: Icey <1790571317@qq.com> Co-authored-by: Sage Moore <sage@neuralmagic.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Xu Wenqing <121550081+Xu-Wenqing@users.noreply.github.com> Co-authored-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: RishiAstra <40644327+RishiAstra@users.noreply.github.com> Co-authored-by: Chauncey <chaunceyjiang@gmail.com> Co-authored-by: Seiji Eicher <58963096+eicherseiji@users.noreply.github.com> Co-authored-by: Rui Qiao <161574667+ruisearch42@users.noreply.github.com> Co-authored-by: Jiangyun Zhu <riverclouds.zhu@qq.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: 阿丹(adan) <47373076+LDLINGLINGLING@users.noreply.github.com> Co-authored-by: liudan <adan@minicpm.com> Co-authored-by: liudan <liudan@qq.com> Co-authored-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Co-authored-by: Clouddude <kouss.hd@gmail.com> Co-authored-by: Frank Wang <41319051+frankwang28@users.noreply.github.com> Co-authored-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> Co-authored-by: qizixi <22851944+zixi-qi@users.noreply.github.com> Co-authored-by: Bram Wasti <bwasti@fb.com> Co-authored-by: Naman Lalit <nl2688@nyu.edu> Co-authored-by: Chenheli Hua <huachenheli@outlook.com> Co-authored-by: WeiQing Chen <40507679+david6666666@users.noreply.github.com> Co-authored-by: Junhong <liujunhong11@huawei.com> Co-authored-by: LJH-LBJ <98734602+LJH-LBJ@users.noreply.github.com> Co-authored-by: 22quinn <33176974+22quinn@users.noreply.github.com> Co-authored-by: Xiaohan Zou <renovamenzxh@gmail.com> Co-authored-by: rentianyue-jk <rentianyue-jk@360shuke.com> Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Co-authored-by: Peter Pan <peter.pan@daocloud.io> Co-authored-by: Patrick C. Toulme <135739773+patrick-toulme@users.noreply.github.com> Co-authored-by: Clayton Coleman <smarterclayton@gmail.com> Co-authored-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Co-authored-by: Jialin Ouyang <jialino@meta.com> Co-authored-by: weiliang <weiliangl@nvidia.com> Co-authored-by: Yuxuan Zhang <2448370773@qq.com> Co-authored-by: JJJYmmm <92386084+JJJYmmm@users.noreply.github.com> Co-authored-by: liuye.hj <liuye.hj@alibaba-inc.com> Co-authored-by: Juechen Liu <grinchcoder@gmail.com> Co-authored-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Yingjun Mou <renzomou@gmail.com> Co-authored-by: Zhou Jiahao <me@zhoukz.com> Co-authored-by: Chenxi Yang <cxyang@cs.utexas.edu> Co-authored-by: Chenxi Yang <cxyang@fb.com> Co-authored-by: Rahul Tuli <rtuli@redhat.com> Co-authored-by: Lee Nau <lee.nau@gmail.com> Co-authored-by: Adrian Abeyta <aabeyta@redhat.com> Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Co-authored-by: Aaron Pham <contact@aarnphm.xyz> Co-authored-by: acisseJZhong <40467976+acisseJZhong@users.noreply.github.com> Co-authored-by: Simon Danielsson <70206058+simondanielsson@users.noreply.github.com> Co-authored-by: Yongye Zhu <zyy1102000@gmail.com> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Lucia Fang <fanglu@meta.com> Co-authored-by: Siyuan Fu <siyuanf@nvidia.com> Co-authored-by: Xiaozhu Meng <mxz297@gmail.com> Co-authored-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> Co-authored-by: a120092009 <33205509+a120092009@users.noreply.github.com> Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com> Co-authored-by: CSWYF3634076 <wangyafeng@baidu.com> Co-authored-by: Lehua Ding <lehuading@tencent.com> Co-authored-by: Reza Barazesh <3146276+rzabarazesh@users.noreply.github.com> Co-authored-by: ihb2032 <40718643+ihb2032@users.noreply.github.com> Co-authored-by: Asaf Joseph Gardin <39553475+Josephasafg@users.noreply.github.com> Co-authored-by: Anion <123177548+Anionex@users.noreply.github.com> Co-authored-by: Pavani Majety <pmajety@nvidia.com> Co-authored-by: bnellnm <49004751+bnellnm@users.noreply.github.com> Co-authored-by: Or Ozeri <oro@il.ibm.com> Co-authored-by: cjackal <44624812+cjackal@users.noreply.github.com> Co-authored-by: David Ben-David <sdavidbd@gmail.com> Co-authored-by: David Ben-David <davidb@pliops.com> Co-authored-by: Andrew Xia <axia@mit.edu> Co-authored-by: Andrew Xia <axia@fb.com> Co-authored-by: Salvatore Cena <cena@cenas.it> Co-authored-by: Param <psch@cs.unc.edu> Co-authored-by: Zhewen Li <zhewenli@meta.com> Co-authored-by: nadathurv <work.vnadathur@gmail.com> Co-authored-by: Srreyansh Sethi <107075589+WorldExplored@users.noreply.github.com> Co-authored-by: Wenlong Wang <wangwenlong2755@gmail.com> Co-authored-by: billishyahao <bill.he@amd.com> Co-authored-by: Nathan Scott <natoscott@users.noreply.github.com> Co-authored-by: Kenichi Maehashi <939877+kmaehashi@users.noreply.github.com> Co-authored-by: Johnny <johnnync13@gmail.com> Co-authored-by: Aidyn-A <31858918+Aidyn-A@users.noreply.github.com> Co-authored-by: Huamin Li <3ericli@gmail.com> Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com> Co-authored-by: Hosang <156028780+hyoon1@users.noreply.github.com> Co-authored-by: Jerry Zhang <jerryzh168@gmail.com> Co-authored-by: pwschuurman <psch@google.com> Co-authored-by: Huy Do <huydhn@gmail.com> Co-authored-by: leo-pony <nengjunma@outlook.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: ElizaWszola <ewszola@redhat.com> Co-authored-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: Benjamin Chislett <bchislett@nvidia.com> Co-authored-by: Andrew Xia <axia@meta.com> Co-authored-by: Simon Mo <simon.mo@hey.com> Co-authored-by: TJian <tunjian.tan@embeddedllm.com> Co-authored-by: ahao-anyscale <ahao@anyscale.com> Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Liu-congo <1502632128@qq.com> Co-authored-by: HUIJONG JEONG <64083281+huijjj@users.noreply.github.com> Co-authored-by: Yannick Schnider <Yannick.Schnider1@ibm.com> Co-authored-by: kyt <eluban4532@gmail.com> Co-authored-by: Egor <e.a.krivov@gmail.com> Co-authored-by: Yang Liu <127183760+KKSK-DON@users.noreply.github.com> Co-authored-by: Paul Pak <52512091+paulpak58@users.noreply.github.com> Co-authored-by: whx <56632993+whx-sjtu@users.noreply.github.com> Co-authored-by: Xiang Si <sixiang@google.com> Co-authored-by: Aleksandr Samarin <samarin_ad@mail.ru> Co-authored-by: Jun Jiang <jasl9187@hotmail.com> Co-authored-by: Chendi.Xue <chendi.xue@intel.com> Co-authored-by: Nikhil G <nrghosh@users.noreply.github.com>
671 lines
28 KiB
Plaintext
671 lines
28 KiB
Plaintext
/*
|
|
* Adapted from
|
|
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
|
* Copyright (c) 2023, The vLLM team.
|
|
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include <torch/all.h>
|
|
#include <ATen/cuda/CUDAContext.h>
|
|
#include <c10/cuda/CUDAGuard.h>
|
|
#include <algorithm>
|
|
|
|
#include "attention_dtypes.h"
|
|
#include "attention_utils.cuh"
|
|
#include "../cuda_compat.h"
|
|
|
|
#ifdef USE_ROCM
|
|
#include <hip/hip_bf16.h>
|
|
#include "../quantization/w8a8/fp8/amd/quant_utils.cuh"
|
|
typedef __hip_bfloat16 __nv_bfloat16;
|
|
#else
|
|
#include "../quantization/w8a8/fp8/nvidia/quant_utils.cuh"
|
|
#endif
|
|
|
|
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
|
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
|
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
|
|
|
|
namespace vllm {
|
|
|
|
// Utility function for attention softmax.
|
|
template <int NUM_WARPS>
|
|
inline __device__ float block_sum(float* red_smem, float sum) {
|
|
// Decompose the thread index into warp / lane.
|
|
int warp = threadIdx.x / WARP_SIZE;
|
|
int lane = threadIdx.x % WARP_SIZE;
|
|
|
|
// Compute the sum per warp.
|
|
#pragma unroll
|
|
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
|
|
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
|
|
}
|
|
|
|
// Warp leaders store the data to shared memory.
|
|
if (lane == 0) {
|
|
red_smem[warp] = sum;
|
|
}
|
|
|
|
// Make sure the data is in shared memory.
|
|
__syncthreads();
|
|
|
|
// The warps compute the final sums.
|
|
if (lane < NUM_WARPS) {
|
|
sum = red_smem[lane];
|
|
}
|
|
|
|
// Parallel reduction inside the warp.
|
|
#pragma unroll
|
|
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
|
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
|
|
}
|
|
|
|
// Broadcast to other threads.
|
|
return VLLM_SHFL_SYNC(sum, 0);
|
|
}
|
|
|
|
// TODO(woosuk): Merge the last two dimensions of the grid.
|
|
// Grid: (num_heads, num_seqs, max_num_partitions).
|
|
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
|
|
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
|
|
bool IS_BLOCK_SPARSE,
|
|
int PARTITION_SIZE = 0> // Zero means no partitioning.
|
|
__device__ void paged_attention_kernel(
|
|
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
|
float* __restrict__ max_logits, // [num_seqs, num_heads,
|
|
// max_num_partitions]
|
|
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
|
|
// head_size]
|
|
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
|
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
|
|
// head_size/x, block_size, x]
|
|
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
|
// head_size, block_size]
|
|
const int num_kv_heads, // [num_heads]
|
|
const float scale,
|
|
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
|
const int* __restrict__ seq_lens, // [num_seqs]
|
|
const int max_num_blocks_per_seq,
|
|
const float* __restrict__ alibi_slopes, // [num_heads]
|
|
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
|
const float* k_scale, const float* v_scale, const int tp_rank,
|
|
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
|
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
|
const int seq_idx = blockIdx.y;
|
|
const int partition_idx = blockIdx.z;
|
|
const int max_num_partitions = gridDim.z;
|
|
constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
|
|
const int seq_len = seq_lens[seq_idx];
|
|
if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) {
|
|
// No work to do. Terminate the thread block.
|
|
return;
|
|
}
|
|
|
|
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
|
|
const int num_blocks_per_partition =
|
|
USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
|
|
|
|
// [start_block_idx, end_block_idx) is the range of blocks to process.
|
|
const int start_block_idx =
|
|
USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
|
|
const int end_block_idx =
|
|
MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks);
|
|
const int num_blocks = end_block_idx - start_block_idx;
|
|
|
|
// [start_token_idx, end_token_idx) is the range of tokens to process.
|
|
const int start_token_idx = start_block_idx * BLOCK_SIZE;
|
|
const int end_token_idx =
|
|
MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len);
|
|
const int num_tokens = end_token_idx - start_token_idx;
|
|
|
|
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
|
constexpr int NUM_THREAD_GROUPS =
|
|
NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE
|
|
// divides NUM_THREADS
|
|
assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
|
|
constexpr int NUM_TOKENS_PER_THREAD_GROUP =
|
|
DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
|
|
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
|
const int thread_idx = threadIdx.x;
|
|
const int warp_idx = thread_idx / WARP_SIZE;
|
|
const int lane = thread_idx % WARP_SIZE;
|
|
|
|
const int head_idx = blockIdx.x;
|
|
const int num_heads = gridDim.x;
|
|
const int num_queries_per_kv = num_heads / num_kv_heads;
|
|
const int kv_head_idx = head_idx / num_queries_per_kv;
|
|
const float alibi_slope =
|
|
alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
|
|
|
|
// A vector type to store a part of a key or a query.
|
|
// The vector size is configured in such a way that the threads in a thread
|
|
// group fetch or compute 16 bytes at a time. For example, if the size of a
|
|
// thread group is 4 and the data type is half, then the vector size is 16 /
|
|
// (4 * sizeof(half)) == 2.
|
|
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
|
|
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
|
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
|
using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
|
|
|
|
constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
|
|
constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
|
|
|
|
const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
|
|
const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
|
|
|
|
// Load the query to registers.
|
|
// Each thread in a thread group has a different part of the query.
|
|
// For example, if the thread group size is 4, then the first thread in
|
|
// the group has 0, 4, 8, ... th vectors of the query, and the second thread
|
|
// has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because
|
|
// q is split from a qkv tensor, it may not be contiguous.
|
|
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
|
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
|
|
#pragma unroll
|
|
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
|
|
i += NUM_THREAD_GROUPS) {
|
|
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
|
|
q_vecs[thread_group_offset][i] =
|
|
*reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
|
|
}
|
|
__syncthreads(); // TODO(naed90): possible speedup if this is replaced with a
|
|
// memory wall right before we use q_vecs
|
|
|
|
// Memory planning.
|
|
extern __shared__ char shared_mem[];
|
|
// NOTE(woosuk): We use FP32 for the softmax logits for better accuracy.
|
|
float* logits = reinterpret_cast<float*>(shared_mem);
|
|
// Workspace for reduction.
|
|
__shared__ float red_smem[2 * NUM_WARPS];
|
|
|
|
// x == THREAD_GROUP_SIZE * VEC_SIZE
|
|
// Each thread group fetches x elements from the key at a time.
|
|
constexpr int x = 16 / sizeof(cache_t);
|
|
float qk_max = -FLT_MAX;
|
|
|
|
// Iterate over the key blocks.
|
|
// Each warp fetches a block of keys for each iteration.
|
|
// Each thread group in a warp fetches a key from the block, and computes
|
|
// dot product with the query.
|
|
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
|
|
|
|
// blocksparse specific vars
|
|
int bs_block_offset;
|
|
int q_bs_block_id;
|
|
if constexpr (IS_BLOCK_SPARSE) {
|
|
// const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
|
|
// blocksparse_block_size);
|
|
q_bs_block_id = (seq_len - 1) / blocksparse_block_size;
|
|
if (blocksparse_head_sliding_step >= 0)
|
|
// sliding on q heads
|
|
bs_block_offset =
|
|
(tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1;
|
|
else
|
|
// sliding on kv heads
|
|
bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) *
|
|
(-blocksparse_head_sliding_step) +
|
|
1;
|
|
}
|
|
|
|
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
|
|
block_idx += NUM_WARPS) {
|
|
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
|
|
// int64 because int32 can lead to overflow when this variable is multiplied
|
|
// by large numbers (e.g., kv_block_stride).
|
|
// For blocksparse attention: skip computation on blocks that are not
|
|
// attended
|
|
if constexpr (IS_BLOCK_SPARSE) {
|
|
const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
|
|
const bool is_remote =
|
|
((k_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0);
|
|
const bool is_local =
|
|
(k_bs_block_id > q_bs_block_id - blocksparse_local_blocks);
|
|
if (!is_remote && !is_local) {
|
|
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
|
|
const int physical_block_offset =
|
|
(thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
|
|
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
|
|
|
|
if (thread_group_offset == 0) {
|
|
// NOTE(linxihui): assign very large number to skipped tokens to
|
|
// avoid contribution to the sumexp softmax normalizer. This will
|
|
// not be used at computing sum(softmax*v) as the blocks will be
|
|
// skipped.
|
|
logits[token_idx - start_token_idx] = -FLT_MAX;
|
|
}
|
|
}
|
|
continue;
|
|
}
|
|
}
|
|
const int64_t physical_block_number =
|
|
static_cast<int64_t>(block_table[block_idx]);
|
|
|
|
// Load a key to registers.
|
|
// Each thread in a thread group has a different part of the key.
|
|
// For example, if the thread group size is 4, then the first thread in
|
|
// the group has 0, 4, 8, ... th vectors of the key, and the second thread
|
|
// has 1, 5, 9, ... th vectors of the key, and so on.
|
|
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
|
|
const int physical_block_offset =
|
|
(thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
|
|
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
|
|
K_vec k_vecs[NUM_VECS_PER_THREAD];
|
|
|
|
#pragma unroll
|
|
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
|
|
const cache_t* k_ptr =
|
|
k_cache + physical_block_number * kv_block_stride +
|
|
kv_head_idx * kv_head_stride + physical_block_offset * x;
|
|
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
|
|
const int offset1 = (vec_idx * VEC_SIZE) / x;
|
|
const int offset2 = (vec_idx * VEC_SIZE) % x;
|
|
|
|
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
|
|
k_vecs[j] = *reinterpret_cast<const K_vec*>(
|
|
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
|
} else {
|
|
// Vector conversion from Quant_vec to K_vec.
|
|
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
|
|
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
|
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
|
|
k_vec_quant, *k_scale);
|
|
}
|
|
}
|
|
|
|
// Compute dot product.
|
|
// This includes a reduction across the threads in the same thread group.
|
|
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(
|
|
q_vecs[thread_group_offset], k_vecs);
|
|
// Add the ALiBi bias if slopes are given.
|
|
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
|
|
|
|
if (thread_group_offset == 0) {
|
|
// Store the partial reductions to shared memory.
|
|
// NOTE(woosuk): It is required to zero out the masked logits.
|
|
const bool mask = token_idx >= seq_len;
|
|
logits[token_idx - start_token_idx] = mask ? 0.f : qk;
|
|
// Update the max value.
|
|
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Perform reduction across the threads in the same warp to get the
|
|
// max qk value for each "warp" (not across the thread block yet).
|
|
// The 0-th thread of each thread group already has its max qk value.
|
|
#pragma unroll
|
|
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
|
|
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
|
|
}
|
|
if (lane == 0) {
|
|
red_smem[warp_idx] = qk_max;
|
|
}
|
|
__syncthreads();
|
|
|
|
// TODO(woosuk): Refactor this part.
|
|
// Get the max qk value for the sequence.
|
|
qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
|
|
#pragma unroll
|
|
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
|
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
|
|
}
|
|
// Broadcast the max qk value to all threads.
|
|
qk_max = VLLM_SHFL_SYNC(qk_max, 0);
|
|
|
|
// Get the sum of the exp values.
|
|
float exp_sum = 0.f;
|
|
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
|
|
float val = __expf(logits[i] - qk_max);
|
|
logits[i] = val;
|
|
exp_sum += val;
|
|
}
|
|
exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
|
|
|
|
// Compute softmax.
|
|
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
|
|
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
|
|
logits[i] *= inv_sum;
|
|
}
|
|
__syncthreads();
|
|
|
|
// If partitioning is enabled, store the max logit and exp_sum.
|
|
if (USE_PARTITIONING && thread_idx == 0) {
|
|
float* max_logits_ptr = max_logits +
|
|
seq_idx * num_heads * max_num_partitions +
|
|
head_idx * max_num_partitions + partition_idx;
|
|
*max_logits_ptr = qk_max;
|
|
float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions +
|
|
head_idx * max_num_partitions + partition_idx;
|
|
*exp_sums_ptr = exp_sum;
|
|
}
|
|
|
|
// Each thread will fetch 16 bytes from the value cache at a time.
|
|
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
|
|
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
|
using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
|
using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
|
|
using Float_L_vec = typename FloatVec<L_vec>::Type;
|
|
|
|
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
|
|
constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
|
|
constexpr int NUM_ROWS_PER_THREAD =
|
|
DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
|
|
|
|
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
|
|
float accs[NUM_ROWS_PER_THREAD];
|
|
#pragma unroll
|
|
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
|
accs[i] = 0.f;
|
|
}
|
|
|
|
scalar_t zero_value;
|
|
zero(zero_value);
|
|
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
|
|
block_idx += NUM_WARPS) {
|
|
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
|
|
// int64 because int32 can lead to overflow when this variable is multiplied
|
|
// by large numbers (e.g., kv_block_stride).
|
|
// For blocksparse attention: skip computation on blocks that are not
|
|
// attended
|
|
if constexpr (IS_BLOCK_SPARSE) {
|
|
int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
|
|
if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) &&
|
|
!((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) {
|
|
continue;
|
|
}
|
|
}
|
|
const int64_t physical_block_number =
|
|
static_cast<int64_t>(block_table[block_idx]);
|
|
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
|
|
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
|
|
L_vec logits_vec;
|
|
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx -
|
|
start_token_idx));
|
|
|
|
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
|
|
kv_head_idx * kv_head_stride;
|
|
#pragma unroll
|
|
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
|
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
|
if (row_idx < HEAD_SIZE) {
|
|
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
|
|
V_vec v_vec;
|
|
|
|
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
|
|
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
|
|
} else {
|
|
V_quant_vec v_quant_vec =
|
|
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
|
|
// Vector conversion from V_quant_vec to V_vec.
|
|
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
|
|
*v_scale);
|
|
}
|
|
if (block_idx == num_seq_blocks - 1) {
|
|
// NOTE(woosuk): When v_vec contains the tokens that are out of the
|
|
// context, we should explicitly zero out the values since they may
|
|
// contain NaNs. See
|
|
// https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
|
|
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
|
|
#pragma unroll
|
|
for (int j = 0; j < V_VEC_SIZE; j++) {
|
|
v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
|
|
}
|
|
}
|
|
accs[i] += dot(logits_vec, v_vec);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Perform reduction within each warp.
|
|
#pragma unroll
|
|
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
|
float acc = accs[i];
|
|
#pragma unroll
|
|
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
|
|
acc += VLLM_SHFL_XOR_SYNC(acc, mask);
|
|
}
|
|
accs[i] = acc;
|
|
}
|
|
|
|
// NOTE(woosuk): A barrier is required because the shared memory space for
|
|
// logits is reused for the output.
|
|
__syncthreads();
|
|
|
|
// Perform reduction across warps.
|
|
float* out_smem = reinterpret_cast<float*>(shared_mem);
|
|
#pragma unroll
|
|
for (int i = NUM_WARPS; i > 1; i /= 2) {
|
|
int mid = i / 2;
|
|
// Upper warps write to shared memory.
|
|
if (warp_idx >= mid && warp_idx < i) {
|
|
float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
|
|
#pragma unroll
|
|
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
|
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
|
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
|
|
dst[row_idx] = accs[i];
|
|
}
|
|
}
|
|
}
|
|
__syncthreads();
|
|
|
|
// Lower warps update the output.
|
|
if (warp_idx < mid) {
|
|
const float* src = &out_smem[warp_idx * HEAD_SIZE];
|
|
#pragma unroll
|
|
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
|
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
|
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
|
|
accs[i] += src[row_idx];
|
|
}
|
|
}
|
|
}
|
|
__syncthreads();
|
|
}
|
|
|
|
// Write the final output.
|
|
if (warp_idx == 0) {
|
|
scalar_t* out_ptr =
|
|
out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
|
|
head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE;
|
|
#pragma unroll
|
|
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
|
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
|
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
|
|
from_float(*(out_ptr + row_idx), accs[i]);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Grid: (num_heads, num_seqs, 1).
|
|
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
|
|
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
|
|
bool IS_BLOCK_SPARSE>
|
|
__global__ void paged_attention_v1_kernel(
|
|
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
|
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
|
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
|
|
// head_size/x, block_size, x]
|
|
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
|
// head_size, block_size]
|
|
const int num_kv_heads, // [num_heads]
|
|
const float scale,
|
|
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
|
const int* __restrict__ seq_lens, // [num_seqs]
|
|
const int max_num_blocks_per_seq,
|
|
const float* __restrict__ alibi_slopes, // [num_heads]
|
|
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
|
const float* k_scale, const float* v_scale, const int tp_rank,
|
|
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
|
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
|
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
|
|
KV_DTYPE, IS_BLOCK_SPARSE>(
|
|
/* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
|
|
v_cache, num_kv_heads, scale, block_tables, seq_lens,
|
|
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
|
|
kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks,
|
|
blocksparse_vert_stride, blocksparse_block_size,
|
|
blocksparse_head_sliding_step);
|
|
}
|
|
|
|
// Grid: (num_heads, num_seqs, max_num_partitions).
|
|
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
|
|
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
|
|
bool IS_BLOCK_SPARSE,
|
|
int PARTITION_SIZE>
|
|
__global__ void paged_attention_v2_kernel(
|
|
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
|
float* __restrict__ max_logits, // [num_seqs, num_heads,
|
|
// max_num_partitions]
|
|
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
|
|
// max_num_partitions, head_size]
|
|
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
|
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
|
|
// head_size/x, block_size, x]
|
|
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
|
// head_size, block_size]
|
|
const int num_kv_heads, // [num_heads]
|
|
const float scale,
|
|
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
|
const int* __restrict__ seq_lens, // [num_seqs]
|
|
const int max_num_blocks_per_seq,
|
|
const float* __restrict__ alibi_slopes, // [num_heads]
|
|
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
|
const float* k_scale, const float* v_scale, const int tp_rank,
|
|
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
|
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
|
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
|
|
KV_DTYPE, IS_BLOCK_SPARSE, PARTITION_SIZE>(
|
|
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
|
|
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
|
|
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank,
|
|
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
|
|
blocksparse_head_sliding_step);
|
|
}
|
|
|
|
// Grid: (num_heads, num_seqs).
|
|
template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
|
|
int PARTITION_SIZE>
|
|
__global__ void paged_attention_v2_reduce_kernel(
|
|
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
|
const float* __restrict__ exp_sums, // [num_seqs, num_heads,
|
|
// max_num_partitions]
|
|
const float* __restrict__ max_logits, // [num_seqs, num_heads,
|
|
// max_num_partitions]
|
|
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
|
|
// max_num_partitions, head_size]
|
|
const int* __restrict__ seq_lens, // [num_seqs]
|
|
const int max_num_partitions) {
|
|
const int num_heads = gridDim.x;
|
|
const int head_idx = blockIdx.x;
|
|
const int seq_idx = blockIdx.y;
|
|
const int seq_len = seq_lens[seq_idx];
|
|
const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
|
|
if (num_partitions == 1) {
|
|
// No need to reduce. Only copy tmp_out to out.
|
|
scalar_t* out_ptr =
|
|
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
|
const scalar_t* tmp_out_ptr =
|
|
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
|
|
head_idx * max_num_partitions * HEAD_SIZE;
|
|
for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
|
|
out_ptr[i] = tmp_out_ptr[i];
|
|
}
|
|
// Terminate the thread block.
|
|
return;
|
|
}
|
|
|
|
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
|
const int warp_idx = threadIdx.x / WARP_SIZE;
|
|
const int lane = threadIdx.x % WARP_SIZE;
|
|
|
|
// Size: 2 * num_partitions.
|
|
extern __shared__ char shared_mem[];
|
|
// Workspace for reduction.
|
|
__shared__ float red_smem[2 * NUM_WARPS];
|
|
|
|
// Load max logits to shared memory.
|
|
float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
|
|
const float* max_logits_ptr = max_logits +
|
|
seq_idx * num_heads * max_num_partitions +
|
|
head_idx * max_num_partitions;
|
|
float max_logit = -FLT_MAX;
|
|
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
|
|
const float l = max_logits_ptr[i];
|
|
shared_max_logits[i] = l;
|
|
max_logit = fmaxf(max_logit, l);
|
|
}
|
|
__syncthreads();
|
|
|
|
// Get the global max logit.
|
|
// Reduce within the warp.
|
|
#pragma unroll
|
|
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
|
|
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
|
|
}
|
|
if (lane == 0) {
|
|
red_smem[warp_idx] = max_logit;
|
|
}
|
|
__syncthreads();
|
|
// Reduce across warps.
|
|
max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
|
|
#pragma unroll
|
|
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
|
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
|
|
}
|
|
// Broadcast the max value to all threads.
|
|
max_logit = VLLM_SHFL_SYNC(max_logit, 0);
|
|
|
|
// Load rescaled exp sums to shared memory.
|
|
float* shared_exp_sums =
|
|
reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
|
|
const float* exp_sums_ptr = exp_sums +
|
|
seq_idx * num_heads * max_num_partitions +
|
|
head_idx * max_num_partitions;
|
|
float global_exp_sum = 0.0f;
|
|
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
|
|
float l = shared_max_logits[i];
|
|
float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
|
|
global_exp_sum += rescaled_exp_sum;
|
|
shared_exp_sums[i] = rescaled_exp_sum;
|
|
}
|
|
__syncthreads();
|
|
global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
|
|
const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
|
|
|
|
// Aggregate tmp_out to out.
|
|
const scalar_t* tmp_out_ptr =
|
|
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
|
|
head_idx * max_num_partitions * HEAD_SIZE;
|
|
scalar_t* out_ptr =
|
|
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
|
#pragma unroll
|
|
for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
|
|
float acc = 0.0f;
|
|
for (int j = 0; j < num_partitions; ++j) {
|
|
acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] *
|
|
inv_global_exp_sum;
|
|
}
|
|
from_float(out_ptr[i], acc);
|
|
}
|
|
}
|
|
|
|
} // namespace vllm
|
|
|
|
#undef MAX
|
|
#undef MIN
|
|
#undef DIVIDE_ROUND_UP
|