mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-12 10:50:08 +08:00
Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: Lucas Kabela <lucaskabela@meta.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: Andrew Sansom <andrew@protopia.ai> Signed-off-by: Boyuan Feng <boyuan@meta.com> Signed-off-by: Boyuan Feng <fby.1994@gmail.com> Signed-off-by: boyuanfeng <boyuan@meta.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: JartX <sagformas@epdcenter.es> Signed-off-by: Chendi Xue <Chendi.Xue@intel.com> Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Signed-off-by: Roger Wang <hey@rogerw.io> Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: wwl2755 <wangwenlong2755@gmail.com> Signed-off-by: Manoel Marques <manoel.marques@ibm.com> Signed-off-by: Manoel Marques <manoelmrqs@gmail.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: pengdrumli <pengdrumli@tencent.com> Signed-off-by: windsonsea <haifeng.yao@daocloud.io> Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai> Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: Huamin Li <3ericli@gmail.com> Signed-off-by: simondanielsson <simon.danielsson99@hotmail.com> Signed-off-by: Rahul Tuli <rtuli@redhat.com> Signed-off-by: Yang <lymailforjob@gmail.com> Signed-off-by: Debolina Roy <debroy@redhat.com> Signed-off-by: David Chen <530634352@qq.com> Signed-off-by: wangzi <3220100013@zju.edu.cn> Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com> Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com> Signed-off-by: Sara Kokkila Schumacher <saraks@ibm.com> Signed-off-by: Csrayz <jover@cmbchina.com> Signed-off-by: ivyilike <pww123@cmbchina.com> Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com> Signed-off-by: Bowen Wang <abmfy@icloud.com> Signed-off-by: qqma <qqma@amazon.com> Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: Lu Fang <fanglu@fb.com> Signed-off-by: Zhuohan Li <zhuohan123@gmail.com> Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: luka <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: Or Ozeri <oro@il.ibm.com> Signed-off-by: Johnny Yang <johnnyyang@google.com> Signed-off-by: Alec Solder <alecs@fb.com> Signed-off-by: Alec S <10566873+alecsolder@users.noreply.github.com> Signed-off-by: Russell Bryant <rbryant@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Alexander Matveev <amatveev@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: liuye.hj <liuye.hj@alibaba-inc.com> Signed-off-by: Kunshang Ji <kunshang.ji@intel.com> Signed-off-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Signed-off-by: Ming Yang <minos.future@gmail.com> Signed-off-by: Zhikaiiii <1658973216@qq.com> Signed-off-by: Andreas Hartel <andreas.hartel@aleph-alpha.com> Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Signed-off-by: wuxibin <wuxibin@bytedance.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: Peter Pan <Peter.Pan@daocloud.io> Signed-off-by: Peter Pan <peter.pan@daocloud.io> Signed-off-by: Nicolò Lucchesi<nicolo.lucchesi@gmail.com> Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Signed-off-by: Sage Moore <sage@neuralmagic.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Signed-off-by: Bill Nell <bnell@redhat.com> Signed-off-by: Shreeasish Kumar <shreeasish@rivosinc.com> Signed-off-by: Weida Hong <wdhongtw@google.com> Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com> Signed-off-by: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com> Signed-off-by: Amir Samani <asamani@nvidia.com> Signed-off-by: ElizaWszola <elizaw.9289@gmail.com> Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Signed-off-by: ilmarkov <markovilya197@gmail.com> Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Signed-off-by: rouchenzi <ruochenwen@gmail.com> Signed-off-by: rouchenzi <40842833+rouchenzi@users.noreply.github.com> Signed-off-by: Andrew Xia <axia@meta.com> Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com> Signed-off-by: Corey Lowman <clowman1993@gmail.com> Signed-off-by: jpvillam <jpvillam@amd.com> Signed-off-by: dougbtv <dosmith@redhat.com> Signed-off-by: Chenxi Yang <cxyang@fb.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: ahao-anyscale <ahao@anyscale.com> Signed-off-by: Yan Lu <luyan@nvidia.com> Signed-off-by: baxingpiaochong <771405853@qq.com> Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> Signed-off-by: Nikhil Gupta <nikhil.gupta2@arm.com> Signed-off-by: Yong Hoon Shin <yhshin@meta.com> Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai> Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Signed-off-by: Ben Browning <bbrownin@redhat.com> Signed-off-by: Chengji Yao <chengjiyao@google.com> Signed-off-by: jiang1.li <jiang1.li@intel.com> Signed-off-by: Jackmin801 <ongjackm@gmail.com> Signed-off-by: Jonas M. Kübler <44084297+jmkuebler@users.noreply.github.com> Signed-off-by: taohui <taohui3@gmail.com> Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io> Signed-off-by: Shu Wang <shuw@nvidia.com> Signed-off-by: Shu Wang. <shuw@nvidia.com> Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Signed-off-by: Duncan Moss <djm.moss@gmail.com> Signed-off-by: Shiyan Deng <dsy842974287@meta.com> Signed-off-by: Wei Wei <wwei6@meta.com> Signed-off-by: Saman Keon <samanamp@outlook.com> Signed-off-by: yangxurui <yangxurui@meituan.com> Signed-off-by: nicole-lihui <nicole.li@daocloud.io> Signed-off-by: courage17340 <courage17340@163.com> Signed-off-by: Jacob Kahn <jacobkahn1@gmail.com> Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com> Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai> Signed-off-by: zxw <1020938856@qq.com> Signed-off-by: wang.yuqi <noooop@126.com> Signed-off-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Signed-off-by: chenlang <chen.lang5@zte.com.cn> Signed-off-by: Jonas Kuebler <kuebj@amazon.com> Signed-off-by: AlonKejzman <alonkeizman@gmail.com> Signed-off-by: Tao Hui <taohui3@gmail.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Signed-off-by: Eugene Khvedchenia <ekhvedchenia@nvidia.com> Signed-off-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Signed-off-by: yiting.jiang <yiting.jiang@daocloud.io> Signed-off-by: xaguilar <Xavier.AguilarFruto@amd.com> Signed-off-by: Iceber Gu <caiwei95@hotmail.com> Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com> Signed-off-by: Icey <1790571317@qq.com> Signed-off-by: 许文卿 <xwq391974@alibaba-inc.com> Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: Nick Hill <nhill@redhat.com> Co-authored-by: Lucas Kabela <lucasakabela@gmail.com> Co-authored-by: Maximilien de Bayser <mbayser@br.ibm.com> Co-authored-by: Andrew Sansom <andrew@protopia.ai> Co-authored-by: Boyuan Feng <boyuan@meta.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: JartX <sagformas@epdcenter.es> Co-authored-by: Chendi.Xue <chendi.xue@intel.com> Co-authored-by: Chauncey <chaunceyjiang@gmail.com> Co-authored-by: xin.li <xin.li@daocloud.io> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Roger Wang <hey@rogerw.io> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Wenlong Wang <wangwenlong2755@gmail.com> Co-authored-by: Manoel Marques <manoelmrqs@gmail.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: lirong <56789630+lirong-lirong@users.noreply.github.com> Co-authored-by: Michael Yao <haifeng.yao@daocloud.io> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: Huamin Li <3ericli@gmail.com> Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com> Co-authored-by: Simon Danielsson <70206058+simondanielsson@users.noreply.github.com> Co-authored-by: Rahul Tuli <rtuli@redhat.com> Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: Yang Liu <127183760+KKSK-DON@users.noreply.github.com> Co-authored-by: Deboleina <debroy@redhat.com> Co-authored-by: yinz-aizip <yinz@aizip.ai> Co-authored-by: WeiQing Chen <40507679+david6666666@users.noreply.github.com> Co-authored-by: wangzi <3220100013@zju.edu.cn> Co-authored-by: Eldar Kurtić <8884008+eldarkurtic@users.noreply.github.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com> Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com> Co-authored-by: Yizhou <136800916+yiz-liu@users.noreply.github.com> Co-authored-by: Sara-KS <50249410+Sara-KS@users.noreply.github.com> Co-authored-by: Csrayz <jover@cmbchina.com> Co-authored-by: ivyilike <pww123@cmbchina.com> Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com> Co-authored-by: Bowen Wang <abmfy@icloud.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Daisy-Ma-coder <daisy.ma.0117@gmail.com> Co-authored-by: qqma <qqma@amazon.com> Co-authored-by: ElizaWszola <ewszola@redhat.com> Co-authored-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Co-authored-by: Zhuohan Li <zhuohan123@gmail.com> Co-authored-by: Simon Mo <simon.mo@hey.com> Co-authored-by: Or Ozeri <oro@il.ibm.com> Co-authored-by: Johnny Yang <24908445+jcyang43@users.noreply.github.com> Co-authored-by: Chengji Yao <chengjiyao@google.com> Co-authored-by: Alec S <10566873+alecsolder@users.noreply.github.com> Co-authored-by: Alec Solder <alecs@fb.com> Co-authored-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Chris Bamford <chrisbam4d@gmail.com> Co-authored-by: Alexander Matveev <59768536+alexm-redhat@users.noreply.github.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: JJJYmmm <92386084+JJJYmmm@users.noreply.github.com> Co-authored-by: liuye.hj <liuye.hj@alibaba-inc.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com> Co-authored-by: Lucia (Lu) Fang <fanglu@meta.com> Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Ming Yang <yming@meta.com> Co-authored-by: Zhikaiiii <55917203+Zhikaiiii@users.noreply.github.com> Co-authored-by: Andreas Hartel <andreas@hartel.me> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: Joel <wuxibin89@163.com> Co-authored-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Mark McLoughlin <markmc@redhat.com> Co-authored-by: Peter Pan <peter.pan@daocloud.io> Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com> Co-authored-by: Fanli Lin <fanli.lin@intel.com> Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Sage Moore <sage@neuralmagic.com> Co-authored-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: bnellnm <49004751+bnellnm@users.noreply.github.com> Co-authored-by: rivos-shreeasish <shreeasish@rivosinc.com> Co-authored-by: Chih-Chieh Yang <chih.chieh.yang@ibm.com> Co-authored-by: Weida Hong <wdhongtw@gmail.com> Co-authored-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Co-authored-by: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com> Co-authored-by: Amir Samani <samani@ualberta.ca> Co-authored-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Co-authored-by: Ilya Markov <markovilya197@gmail.com> Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Co-authored-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Co-authored-by: rouchenzi <40842833+rouchenzi@users.noreply.github.com> Co-authored-by: Andrew Xia <axia@meta.com> Co-authored-by: kourosh hakhamaneshi <31483498+kouroshHakha@users.noreply.github.com> Co-authored-by: Corey Lowman <clowman1993@gmail.com> Co-authored-by: Juan Villamizar <100237675+jpvillam-amd@users.noreply.github.com> Co-authored-by: jpvillam <jpvillam@amd.com> Co-authored-by: Doug Smith <dosmith@redhat.com> Co-authored-by: Chenxi Yang <cxyang@cs.utexas.edu> Co-authored-by: Chenxi Yang <cxyang@fb.com> Co-authored-by: ahao-anyscale <ahao@anyscale.com> Co-authored-by: 0xNullPath <luyanfcp@foxmail.com> Co-authored-by: baxingpiaochong <771405853@qq.com> Co-authored-by: Benjamin Chislett <bchislett@nvidia.com> Co-authored-by: Kyle Sayers <kylesayrs@gmail.com> Co-authored-by: Nikhil Gupta <nikhil.gupta2@arm.com> Co-authored-by: Yong Hoon Shin <48474650+sarckk@users.noreply.github.com> Co-authored-by: lhsjohn <huashuoli@tencent.com> Co-authored-by: Ben Browning <bbrownin@redhat.com> Co-authored-by: Li, Jiang <jiang1.li@intel.com> Co-authored-by: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> Co-authored-by: Jonas M. Kübler <44084297+jmkuebler@users.noreply.github.com> Co-authored-by: Tao Hui <taohui3@gmail.com> Co-authored-by: rongfu.leng <rongfu.leng@daocloud.io> Co-authored-by: Shu Wang <shuw@nvidia.com> Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Co-authored-by: Duncan Moss <djm.moss@gmail.com> Co-authored-by: Shiyan Deng <dsy842974287@meta.com> Co-authored-by: Wei Wei <wwei6@meta.com> Co-authored-by: Saman A. Pour <samanamp@outlook.com> Co-authored-by: XuruiYang <530534756@qq.com> Co-authored-by: yangxurui <yangxurui@meituan.com> Co-authored-by: Nicole LiHui 🥜 <nicolelihui@outlook.com> Co-authored-by: courage17340 <courage17340@users.noreply.github.com> Co-authored-by: Jacob Kahn <jacobkahn1@gmail.com> Co-authored-by: Nicole LiHui 🥜 <nicole.li@daocloud.io> Co-authored-by: Fadi Arafeh <115173828+fadara01@users.noreply.github.com> Co-authored-by: Agata Dobrzyniewicz <160237065+adobrzyn@users.noreply.github.com> Co-authored-by: yyzxw <34639446+yyzxw@users.noreply.github.com> Co-authored-by: wang.yuqi <noooop@126.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: chenlang <chen.lang5@zte.com.cn> Co-authored-by: chenlang <10346245@zte.com.cn> Co-authored-by: AlonKejzman <alonkeizman@gmail.com> Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> Co-authored-by: Aleksandr Malyshev <164964928+maleksan85@users.noreply.github.com> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Doug Lehr <douglehr@amd.com> Co-authored-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Co-authored-by: yitingdc <59356937+yitingdc@users.noreply.github.com> Co-authored-by: xaguilar-amd <xavier.aguilarfruto@amd.com> Co-authored-by: Iceber Gu <caiwei95@hotmail.com> Co-authored-by: Tao He <linzhu.ht@alibaba-inc.com> Co-authored-by: Icey <1790571317@qq.com> Co-authored-by: Xu Wenqing <121550081+Xu-Wenqing@users.noreply.github.com> Co-authored-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: RishiAstra <40644327+RishiAstra@users.noreply.github.com>
292 lines
11 KiB
Plaintext
292 lines
11 KiB
Plaintext
/*
|
|
Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
|
Copyright 2025 SGLang Team. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
/*
|
|
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
|
|
* by Alcanderian JieXin Liang
|
|
*/
|
|
#include "core/registration.h"
|
|
|
|
#include <ATen/cuda/CUDAContext.h>
|
|
#include <c10/cuda/CUDAGuard.h>
|
|
#include <cutlass/cutlass.h>
|
|
#include <cutlass/kernel_hardware_info.h>
|
|
#include <torch/all.h>
|
|
|
|
#include <cute/tensor.hpp>
|
|
#include <iostream>
|
|
|
|
#include "cutlass_sm100_mla/device/sm100_mla.hpp"
|
|
#include "cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp"
|
|
|
|
// clang-format off
|
|
#if !defined(CUDA_VERSION) || CUDA_VERSION < 12040
|
|
void sm100_cutlass_mla_decode(
|
|
torch::Tensor const& out,
|
|
torch::Tensor const& lse,
|
|
torch::Tensor const& q_nope,
|
|
torch::Tensor const& q_pe,
|
|
torch::Tensor const& kv_c_and_k_pe_cache,
|
|
torch::Tensor const& seq_lens,
|
|
torch::Tensor const& page_table,
|
|
torch::Tensor const& workspace,
|
|
double sm_scale,
|
|
int64_t num_kv_splits) {
|
|
TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_decode");
|
|
}
|
|
int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count, int64_t num_kv_splits) {
|
|
TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_get_workspace_size");
|
|
}
|
|
#else
|
|
|
|
#define CUTLASS_CHECK(status) \
|
|
{ \
|
|
cutlass::Status error = status; \
|
|
TORCH_CHECK(error == cutlass::Status::kSuccess, cutlassGetStatusString(error)); \
|
|
}
|
|
|
|
using namespace cute;
|
|
using namespace cutlass::fmha::kernel;
|
|
|
|
template <bool v>
|
|
struct IsPersistent {
|
|
static const bool value = v;
|
|
};
|
|
|
|
template <typename T, typename TOut, bool IsPaged128, typename PersistenceOption = IsPersistent<true>>
|
|
struct MlaSm100 {
|
|
using Element = T;
|
|
using ElementAcc = float;
|
|
using ElementOut = TOut;
|
|
|
|
using TileShape = Shape<_128, _128, Shape<_512, _64>>;
|
|
using TileShapeH = cute::tuple_element_t<0, TileShape>;
|
|
using TileShapeD = cute::tuple_element_t<2, TileShape>;
|
|
|
|
// H K (D_latent D_rope) B
|
|
using ProblemShape = cute::tuple<TileShapeH, int, TileShapeD, int>;
|
|
|
|
using StrideQ = cute::tuple<int64_t, _1, int64_t>; // H D B
|
|
using StrideK = cute::tuple<int64_t, _1, int64_t>; // K D B
|
|
using StrideO = StrideK; // H D B
|
|
using StrideLSE = cute::tuple<_1, int>; // H B
|
|
|
|
using TileScheduler =
|
|
std::conditional_t<PersistenceOption::value, Sm100MlaPersistentTileScheduler, Sm100MlaIndividualTileScheduler>;
|
|
|
|
using FmhaKernel = cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized<
|
|
TileShape,
|
|
Element,
|
|
ElementAcc,
|
|
ElementOut,
|
|
ElementAcc,
|
|
TileScheduler,
|
|
/*kIsCpAsync=*/!IsPaged128>;
|
|
using Fmha = cutlass::fmha::device::MLA<FmhaKernel>;
|
|
};
|
|
|
|
template <typename T>
|
|
typename T::Fmha::Arguments args_from_options(
|
|
at::Tensor const& out,
|
|
at::Tensor const& lse,
|
|
at::Tensor const& q_nope,
|
|
at::Tensor const& q_pe,
|
|
at::Tensor const& kv_c_and_k_pe_cache,
|
|
at::Tensor const& seq_lens,
|
|
at::Tensor const& page_table,
|
|
double sm_scale,
|
|
int64_t num_kv_splits) {
|
|
cutlass::KernelHardwareInfo hw_info;
|
|
hw_info.device_id = q_nope.device().index();
|
|
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
|
|
|
int batches = q_nope.sizes()[0];
|
|
int page_count_per_seq = page_table.sizes()[1];
|
|
int page_count_total = kv_c_and_k_pe_cache.sizes()[0];
|
|
int page_size = kv_c_and_k_pe_cache.sizes()[1];
|
|
int max_seq_len = page_size * page_count_per_seq;
|
|
using TileShapeH = typename T::TileShapeH;
|
|
using TileShapeD = typename T::TileShapeD;
|
|
auto problem_shape = cute::make_tuple(TileShapeH{}, max_seq_len, TileShapeD{}, batches);
|
|
|
|
auto [H, K, D, B] = problem_shape;
|
|
auto [D_latent, D_rope] = D;
|
|
|
|
float scale = float(sm_scale);
|
|
|
|
using StrideQ = typename T::StrideQ;
|
|
using StrideK = typename T::StrideK;
|
|
using StrideO = typename T::StrideO;
|
|
using StrideLSE = typename T::StrideLSE;
|
|
|
|
StrideQ stride_Q_nope = cute::make_tuple(
|
|
static_cast<int64_t>(q_nope.stride(1)), _1{}, static_cast<int64_t>(q_nope.stride(0)));
|
|
StrideQ stride_Q_pe = cute::make_tuple(
|
|
static_cast<int64_t>(q_pe.stride(1)), _1{}, static_cast<int64_t>(q_pe.stride(0)));
|
|
|
|
StrideK stride_C = cute::make_tuple(
|
|
static_cast<int64_t>(0 + D_latent + D_rope), _1{}, static_cast<int64_t>(page_size * (D_latent + D_rope)));
|
|
StrideLSE stride_PT = cute::make_stride(_1{}, page_count_per_seq);
|
|
StrideLSE stride_LSE = cute::make_tuple(_1{}, 0 + H);
|
|
StrideO stride_O = cute::make_tuple(static_cast<int64_t>(0 + D_latent), _1{}, static_cast<int64_t>(0 + H * D_latent));
|
|
|
|
using Element = typename T::Element;
|
|
using ElementOut = typename T::ElementOut;
|
|
using ElementAcc = typename T::ElementAcc;
|
|
auto Q_nope_ptr = static_cast<Element*>(q_nope.data_ptr());
|
|
auto Q_pe_ptr = static_cast<Element*>(q_pe.data_ptr());
|
|
auto C_ptr = static_cast<Element*>(kv_c_and_k_pe_cache.data_ptr());
|
|
typename T::Fmha::Arguments arguments{
|
|
problem_shape,
|
|
{scale,
|
|
Q_nope_ptr,
|
|
stride_Q_nope,
|
|
Q_pe_ptr,
|
|
stride_Q_pe,
|
|
C_ptr,
|
|
stride_C,
|
|
C_ptr + D_latent,
|
|
stride_C,
|
|
static_cast<int*>(seq_lens.data_ptr()),
|
|
static_cast<int*>(page_table.data_ptr()),
|
|
stride_PT,
|
|
page_count_total,
|
|
page_size},
|
|
{static_cast<ElementOut*>(out.data_ptr()),
|
|
stride_O,
|
|
static_cast<ElementAcc*>(lse.defined() ? lse.data_ptr() : nullptr),
|
|
stride_LSE},
|
|
hw_info,
|
|
// TODO(trevor-m): Change split_kv back to -1 when
|
|
// https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will
|
|
// perform worse with larger context length and smaller batch sizes.
|
|
static_cast<int>(num_kv_splits), // split_kv
|
|
nullptr, // is_var_split_kv
|
|
};
|
|
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
|
|
// split_kv automatically based on batch size and sequence length to balance
|
|
// workload across available SMs. Consider using var_split_kv for manual
|
|
// control if needed.
|
|
T::Fmha::set_split_kv(arguments);
|
|
return arguments;
|
|
}
|
|
|
|
template <typename Element, typename ElementOut, bool IsPaged128, typename PersistenceOption>
|
|
void runMla(
|
|
at::Tensor const& out,
|
|
at::Tensor const& lse,
|
|
at::Tensor const& q_nope,
|
|
at::Tensor const& q_pe,
|
|
at::Tensor const& kv_c_and_k_pe_cache,
|
|
at::Tensor const& seq_lens,
|
|
at::Tensor const& page_table,
|
|
at::Tensor const& workspace,
|
|
double sm_scale,
|
|
int64_t num_kv_splits,
|
|
cudaStream_t stream) {
|
|
using MlaSm100Type = MlaSm100<Element, ElementOut, IsPaged128, PersistenceOption>;
|
|
typename MlaSm100Type::Fmha fmha;
|
|
auto arguments = args_from_options<MlaSm100Type>(out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, sm_scale, num_kv_splits);
|
|
|
|
CUTLASS_CHECK(fmha.can_implement(arguments));
|
|
|
|
CUTLASS_CHECK(fmha.initialize(arguments, workspace.data_ptr(), stream));
|
|
|
|
CUTLASS_CHECK(fmha.run(arguments, workspace.data_ptr(), stream));
|
|
}
|
|
|
|
#define DISPATCH_BOOL(expr, const_expr, ...) \
|
|
[&]() -> bool { \
|
|
if (expr) { \
|
|
constexpr bool const_expr = true; \
|
|
return __VA_ARGS__(); \
|
|
} else { \
|
|
constexpr bool const_expr = false; \
|
|
return __VA_ARGS__(); \
|
|
} \
|
|
}()
|
|
|
|
void sm100_cutlass_mla_decode(
|
|
torch::Tensor const& out,
|
|
torch::Tensor const& lse,
|
|
torch::Tensor const& q_nope,
|
|
torch::Tensor const& q_pe,
|
|
torch::Tensor const& kv_c_and_k_pe_cache,
|
|
torch::Tensor const& seq_lens,
|
|
torch::Tensor const& page_table,
|
|
torch::Tensor const& workspace,
|
|
double sm_scale,
|
|
int64_t num_kv_splits) {
|
|
auto in_dtype = q_nope.dtype();
|
|
at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()};
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(q_nope.get_device());
|
|
const int page_size = kv_c_and_k_pe_cache.sizes()[1];
|
|
|
|
// NOTE(alcanderian): IsPersistent has bug with manual split_kv.
|
|
// Kernel will hang if batch is too large with large num_kv_splits. (for example bs=8, num_kv_splits=8)
|
|
// Maybe per batch split kv will fix this.
|
|
DISPATCH_BOOL(page_size == 128, IsPaged128, [&] {
|
|
DISPATCH_BOOL(num_kv_splits <= 1, NotManualSplitKV, [&] {
|
|
if (in_dtype == at::ScalarType::Half) {
|
|
runMla<cutlass::half_t, cutlass::half_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
|
|
out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
|
|
} else if (in_dtype == at::ScalarType::BFloat16) {
|
|
runMla<cutlass::bfloat16_t, cutlass::bfloat16_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
|
|
out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
|
|
} else if (in_dtype == at::ScalarType::Float8_e4m3fn) {
|
|
runMla<cutlass::float_e4m3_t, cutlass::bfloat16_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
|
|
out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
|
|
} else {
|
|
TORCH_CHECK(false, "Unsupported input data type of MLA");
|
|
}
|
|
return true;
|
|
});
|
|
return true;
|
|
});
|
|
}
|
|
|
|
int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count, int64_t num_kv_splits) {
|
|
// Workspace size depends on ElementAcc and ElementLSE (same as ElementAcc)
|
|
// which are float, so Element type here doesn't matter.
|
|
using MlaSm100Type = MlaSm100<cutlass::half_t, cutlass::half_t, true>;
|
|
|
|
// Get split kv. Requires problem shape and sm_count only.
|
|
typename MlaSm100Type::Fmha::Arguments arguments;
|
|
using TileShapeH = typename MlaSm100Type::TileShapeH;
|
|
using TileShapeD = typename MlaSm100Type::TileShapeD;
|
|
arguments.problem_shape =
|
|
cute::make_tuple(TileShapeH{}, static_cast<int>(max_seq_len), TileShapeD{}, static_cast<int>(num_batches));
|
|
// Assumes device 0 when getting sm_count.
|
|
arguments.hw_info.sm_count =
|
|
sm_count <= 0 ? cutlass::KernelHardwareInfo::query_device_multiprocessor_count(/*device_id=*/0) : sm_count;
|
|
arguments.split_kv = static_cast<int>(num_kv_splits);
|
|
MlaSm100Type::Fmha::set_split_kv(arguments);
|
|
|
|
return MlaSm100Type::Fmha::get_workspace_size(arguments);
|
|
}
|
|
|
|
#endif
|
|
|
|
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
|
m.impl("sm100_cutlass_mla_decode", &sm100_cutlass_mla_decode);
|
|
}
|
|
|
|
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CatchAll, m) {
|
|
m.impl("sm100_cutlass_mla_get_workspace_size", &sm100_cutlass_mla_get_workspace_size);
|
|
}
|
|
|
|
// clang-format on
|