mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-07 22:17:11 +08:00
Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: Lucas Kabela <lucaskabela@meta.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: Andrew Sansom <andrew@protopia.ai> Signed-off-by: Boyuan Feng <boyuan@meta.com> Signed-off-by: Boyuan Feng <fby.1994@gmail.com> Signed-off-by: boyuanfeng <boyuan@meta.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: JartX <sagformas@epdcenter.es> Signed-off-by: Chendi Xue <Chendi.Xue@intel.com> Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Signed-off-by: Roger Wang <hey@rogerw.io> Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: wwl2755 <wangwenlong2755@gmail.com> Signed-off-by: Manoel Marques <manoel.marques@ibm.com> Signed-off-by: Manoel Marques <manoelmrqs@gmail.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: pengdrumli <pengdrumli@tencent.com> Signed-off-by: windsonsea <haifeng.yao@daocloud.io> Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai> Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: Huamin Li <3ericli@gmail.com> Signed-off-by: simondanielsson <simon.danielsson99@hotmail.com> Signed-off-by: Rahul Tuli <rtuli@redhat.com> Signed-off-by: Yang <lymailforjob@gmail.com> Signed-off-by: Debolina Roy <debroy@redhat.com> Signed-off-by: David Chen <530634352@qq.com> Signed-off-by: wangzi <3220100013@zju.edu.cn> Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com> Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com> Signed-off-by: Sara Kokkila Schumacher <saraks@ibm.com> Signed-off-by: Csrayz <jover@cmbchina.com> Signed-off-by: ivyilike <pww123@cmbchina.com> Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com> Signed-off-by: Bowen Wang <abmfy@icloud.com> Signed-off-by: qqma <qqma@amazon.com> Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: Lu Fang <fanglu@fb.com> Signed-off-by: Zhuohan Li <zhuohan123@gmail.com> Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: luka <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: Or Ozeri <oro@il.ibm.com> Signed-off-by: Johnny Yang <johnnyyang@google.com> Signed-off-by: Alec Solder <alecs@fb.com> Signed-off-by: Alec S <10566873+alecsolder@users.noreply.github.com> Signed-off-by: Russell Bryant <rbryant@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Alexander Matveev <amatveev@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: liuye.hj <liuye.hj@alibaba-inc.com> Signed-off-by: Kunshang Ji <kunshang.ji@intel.com> Signed-off-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Signed-off-by: Ming Yang <minos.future@gmail.com> Signed-off-by: Zhikaiiii <1658973216@qq.com> Signed-off-by: Andreas Hartel <andreas.hartel@aleph-alpha.com> Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Signed-off-by: wuxibin <wuxibin@bytedance.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: Peter Pan <Peter.Pan@daocloud.io> Signed-off-by: Peter Pan <peter.pan@daocloud.io> Signed-off-by: Nicolò Lucchesi<nicolo.lucchesi@gmail.com> Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Signed-off-by: Sage Moore <sage@neuralmagic.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Signed-off-by: Bill Nell <bnell@redhat.com> Signed-off-by: Shreeasish Kumar <shreeasish@rivosinc.com> Signed-off-by: Weida Hong <wdhongtw@google.com> Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com> Signed-off-by: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com> Signed-off-by: Amir Samani <asamani@nvidia.com> Signed-off-by: ElizaWszola <elizaw.9289@gmail.com> Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Signed-off-by: ilmarkov <markovilya197@gmail.com> Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Signed-off-by: rouchenzi <ruochenwen@gmail.com> Signed-off-by: rouchenzi <40842833+rouchenzi@users.noreply.github.com> Signed-off-by: Andrew Xia <axia@meta.com> Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com> Signed-off-by: Corey Lowman <clowman1993@gmail.com> Signed-off-by: jpvillam <jpvillam@amd.com> Signed-off-by: dougbtv <dosmith@redhat.com> Signed-off-by: Chenxi Yang <cxyang@fb.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: ahao-anyscale <ahao@anyscale.com> Signed-off-by: Yan Lu <luyan@nvidia.com> Signed-off-by: baxingpiaochong <771405853@qq.com> Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> Signed-off-by: Nikhil Gupta <nikhil.gupta2@arm.com> Signed-off-by: Yong Hoon Shin <yhshin@meta.com> Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai> Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Signed-off-by: Ben Browning <bbrownin@redhat.com> Signed-off-by: Chengji Yao <chengjiyao@google.com> Signed-off-by: jiang1.li <jiang1.li@intel.com> Signed-off-by: Jackmin801 <ongjackm@gmail.com> Signed-off-by: Jonas M. Kübler <44084297+jmkuebler@users.noreply.github.com> Signed-off-by: taohui <taohui3@gmail.com> Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io> Signed-off-by: Shu Wang <shuw@nvidia.com> Signed-off-by: Shu Wang. <shuw@nvidia.com> Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Signed-off-by: Duncan Moss <djm.moss@gmail.com> Signed-off-by: Shiyan Deng <dsy842974287@meta.com> Signed-off-by: Wei Wei <wwei6@meta.com> Signed-off-by: Saman Keon <samanamp@outlook.com> Signed-off-by: yangxurui <yangxurui@meituan.com> Signed-off-by: nicole-lihui <nicole.li@daocloud.io> Signed-off-by: courage17340 <courage17340@163.com> Signed-off-by: Jacob Kahn <jacobkahn1@gmail.com> Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com> Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai> Signed-off-by: zxw <1020938856@qq.com> Signed-off-by: wang.yuqi <noooop@126.com> Signed-off-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Signed-off-by: chenlang <chen.lang5@zte.com.cn> Signed-off-by: Jonas Kuebler <kuebj@amazon.com> Signed-off-by: AlonKejzman <alonkeizman@gmail.com> Signed-off-by: Tao Hui <taohui3@gmail.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Signed-off-by: Eugene Khvedchenia <ekhvedchenia@nvidia.com> Signed-off-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Signed-off-by: yiting.jiang <yiting.jiang@daocloud.io> Signed-off-by: xaguilar <Xavier.AguilarFruto@amd.com> Signed-off-by: Iceber Gu <caiwei95@hotmail.com> Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com> Signed-off-by: Icey <1790571317@qq.com> Signed-off-by: 许文卿 <xwq391974@alibaba-inc.com> Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: Nick Hill <nhill@redhat.com> Co-authored-by: Lucas Kabela <lucasakabela@gmail.com> Co-authored-by: Maximilien de Bayser <mbayser@br.ibm.com> Co-authored-by: Andrew Sansom <andrew@protopia.ai> Co-authored-by: Boyuan Feng <boyuan@meta.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: JartX <sagformas@epdcenter.es> Co-authored-by: Chendi.Xue <chendi.xue@intel.com> Co-authored-by: Chauncey <chaunceyjiang@gmail.com> Co-authored-by: xin.li <xin.li@daocloud.io> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Roger Wang <hey@rogerw.io> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Wenlong Wang <wangwenlong2755@gmail.com> Co-authored-by: Manoel Marques <manoelmrqs@gmail.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: lirong <56789630+lirong-lirong@users.noreply.github.com> Co-authored-by: Michael Yao <haifeng.yao@daocloud.io> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: Huamin Li <3ericli@gmail.com> Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com> Co-authored-by: Simon Danielsson <70206058+simondanielsson@users.noreply.github.com> Co-authored-by: Rahul Tuli <rtuli@redhat.com> Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: Yang Liu <127183760+KKSK-DON@users.noreply.github.com> Co-authored-by: Deboleina <debroy@redhat.com> Co-authored-by: yinz-aizip <yinz@aizip.ai> Co-authored-by: WeiQing Chen <40507679+david6666666@users.noreply.github.com> Co-authored-by: wangzi <3220100013@zju.edu.cn> Co-authored-by: Eldar Kurtić <8884008+eldarkurtic@users.noreply.github.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com> Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com> Co-authored-by: Yizhou <136800916+yiz-liu@users.noreply.github.com> Co-authored-by: Sara-KS <50249410+Sara-KS@users.noreply.github.com> Co-authored-by: Csrayz <jover@cmbchina.com> Co-authored-by: ivyilike <pww123@cmbchina.com> Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com> Co-authored-by: Bowen Wang <abmfy@icloud.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Daisy-Ma-coder <daisy.ma.0117@gmail.com> Co-authored-by: qqma <qqma@amazon.com> Co-authored-by: ElizaWszola <ewszola@redhat.com> Co-authored-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Co-authored-by: Zhuohan Li <zhuohan123@gmail.com> Co-authored-by: Simon Mo <simon.mo@hey.com> Co-authored-by: Or Ozeri <oro@il.ibm.com> Co-authored-by: Johnny Yang <24908445+jcyang43@users.noreply.github.com> Co-authored-by: Chengji Yao <chengjiyao@google.com> Co-authored-by: Alec S <10566873+alecsolder@users.noreply.github.com> Co-authored-by: Alec Solder <alecs@fb.com> Co-authored-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Chris Bamford <chrisbam4d@gmail.com> Co-authored-by: Alexander Matveev <59768536+alexm-redhat@users.noreply.github.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: JJJYmmm <92386084+JJJYmmm@users.noreply.github.com> Co-authored-by: liuye.hj <liuye.hj@alibaba-inc.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com> Co-authored-by: Lucia (Lu) Fang <fanglu@meta.com> Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Ming Yang <yming@meta.com> Co-authored-by: Zhikaiiii <55917203+Zhikaiiii@users.noreply.github.com> Co-authored-by: Andreas Hartel <andreas@hartel.me> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: Joel <wuxibin89@163.com> Co-authored-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Mark McLoughlin <markmc@redhat.com> Co-authored-by: Peter Pan <peter.pan@daocloud.io> Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com> Co-authored-by: Fanli Lin <fanli.lin@intel.com> Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Sage Moore <sage@neuralmagic.com> Co-authored-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: bnellnm <49004751+bnellnm@users.noreply.github.com> Co-authored-by: rivos-shreeasish <shreeasish@rivosinc.com> Co-authored-by: Chih-Chieh Yang <chih.chieh.yang@ibm.com> Co-authored-by: Weida Hong <wdhongtw@gmail.com> Co-authored-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Co-authored-by: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com> Co-authored-by: Amir Samani <samani@ualberta.ca> Co-authored-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Co-authored-by: Ilya Markov <markovilya197@gmail.com> Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Co-authored-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Co-authored-by: rouchenzi <40842833+rouchenzi@users.noreply.github.com> Co-authored-by: Andrew Xia <axia@meta.com> Co-authored-by: kourosh hakhamaneshi <31483498+kouroshHakha@users.noreply.github.com> Co-authored-by: Corey Lowman <clowman1993@gmail.com> Co-authored-by: Juan Villamizar <100237675+jpvillam-amd@users.noreply.github.com> Co-authored-by: jpvillam <jpvillam@amd.com> Co-authored-by: Doug Smith <dosmith@redhat.com> Co-authored-by: Chenxi Yang <cxyang@cs.utexas.edu> Co-authored-by: Chenxi Yang <cxyang@fb.com> Co-authored-by: ahao-anyscale <ahao@anyscale.com> Co-authored-by: 0xNullPath <luyanfcp@foxmail.com> Co-authored-by: baxingpiaochong <771405853@qq.com> Co-authored-by: Benjamin Chislett <bchislett@nvidia.com> Co-authored-by: Kyle Sayers <kylesayrs@gmail.com> Co-authored-by: Nikhil Gupta <nikhil.gupta2@arm.com> Co-authored-by: Yong Hoon Shin <48474650+sarckk@users.noreply.github.com> Co-authored-by: lhsjohn <huashuoli@tencent.com> Co-authored-by: Ben Browning <bbrownin@redhat.com> Co-authored-by: Li, Jiang <jiang1.li@intel.com> Co-authored-by: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> Co-authored-by: Jonas M. Kübler <44084297+jmkuebler@users.noreply.github.com> Co-authored-by: Tao Hui <taohui3@gmail.com> Co-authored-by: rongfu.leng <rongfu.leng@daocloud.io> Co-authored-by: Shu Wang <shuw@nvidia.com> Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Co-authored-by: Duncan Moss <djm.moss@gmail.com> Co-authored-by: Shiyan Deng <dsy842974287@meta.com> Co-authored-by: Wei Wei <wwei6@meta.com> Co-authored-by: Saman A. Pour <samanamp@outlook.com> Co-authored-by: XuruiYang <530534756@qq.com> Co-authored-by: yangxurui <yangxurui@meituan.com> Co-authored-by: Nicole LiHui 🥜 <nicolelihui@outlook.com> Co-authored-by: courage17340 <courage17340@users.noreply.github.com> Co-authored-by: Jacob Kahn <jacobkahn1@gmail.com> Co-authored-by: Nicole LiHui 🥜 <nicole.li@daocloud.io> Co-authored-by: Fadi Arafeh <115173828+fadara01@users.noreply.github.com> Co-authored-by: Agata Dobrzyniewicz <160237065+adobrzyn@users.noreply.github.com> Co-authored-by: yyzxw <34639446+yyzxw@users.noreply.github.com> Co-authored-by: wang.yuqi <noooop@126.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: chenlang <chen.lang5@zte.com.cn> Co-authored-by: chenlang <10346245@zte.com.cn> Co-authored-by: AlonKejzman <alonkeizman@gmail.com> Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> Co-authored-by: Aleksandr Malyshev <164964928+maleksan85@users.noreply.github.com> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Doug Lehr <douglehr@amd.com> Co-authored-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Co-authored-by: yitingdc <59356937+yitingdc@users.noreply.github.com> Co-authored-by: xaguilar-amd <xavier.aguilarfruto@amd.com> Co-authored-by: Iceber Gu <caiwei95@hotmail.com> Co-authored-by: Tao He <linzhu.ht@alibaba-inc.com> Co-authored-by: Icey <1790571317@qq.com> Co-authored-by: Xu Wenqing <121550081+Xu-Wenqing@users.noreply.github.com> Co-authored-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: RishiAstra <40644327+RishiAstra@users.noreply.github.com>
199 lines
7.7 KiB
Python
199 lines
7.7 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import math
|
|
import random
|
|
from typing import Optional
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
import vllm._custom_ops as ops
|
|
from vllm.platforms import current_platform
|
|
from vllm.triton_utils import triton
|
|
|
|
|
|
def cal_diff(x: torch.Tensor,
|
|
y: torch.Tensor,
|
|
name: str,
|
|
use_fp8: bool = False,
|
|
diff_threshold: Optional[float] = None) -> None:
|
|
x, y = x.double(), y.double()
|
|
cos_diff = 1 - 2 * (x * y).sum().item() / max(
|
|
(x * x + y * y).sum().item(), 1e-12)
|
|
if diff_threshold is not None:
|
|
# directly compare the cos_diff with the threshold
|
|
assert cos_diff < diff_threshold
|
|
else:
|
|
# use the default threshold
|
|
if (use_fp8):
|
|
assert cos_diff < 1e-4
|
|
else:
|
|
assert cos_diff < 1e-5
|
|
|
|
|
|
CUTLASS_MLA_UNSUPPORTED_REASON = \
|
|
"Cutlass MLA Requires compute capability of 10 or above." \
|
|
if not current_platform.is_device_capability(100) \
|
|
else "Cutlass MLA is supported"
|
|
|
|
|
|
@pytest.mark.skipif(not current_platform.has_device_capability(100),
|
|
reason=CUTLASS_MLA_UNSUPPORTED_REASON)
|
|
@pytest.mark.parametrize("b", [128])
|
|
@pytest.mark.parametrize("s_q", [1])
|
|
@pytest.mark.parametrize("mean_sk", [4096, 8192, 16384])
|
|
@pytest.mark.parametrize("h_q", [16, 32, 64, 128])
|
|
@pytest.mark.parametrize("h_kv", [1])
|
|
@pytest.mark.parametrize("d", [576])
|
|
@pytest.mark.parametrize("dv", [512])
|
|
@pytest.mark.parametrize("block_size", [64])
|
|
@pytest.mark.parametrize("causal", [True])
|
|
@pytest.mark.parametrize("varlen", [False, True])
|
|
@pytest.mark.parametrize(
|
|
"torch_dtype",
|
|
[
|
|
torch.bfloat16,
|
|
# fp8 can have occasional precision-related failures.
|
|
pytest.param(torch.float8_e4m3fn, marks=pytest.mark.flaky(reruns=2))
|
|
])
|
|
@torch.inference_mode()
|
|
def test_cutlass_mla_decode(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size,
|
|
causal, varlen, torch_dtype):
|
|
device = torch.device("cuda:0")
|
|
if torch_dtype == torch.float8_e4m3fn:
|
|
init_dtype = torch.bfloat16
|
|
else:
|
|
init_dtype = torch_dtype
|
|
torch.set_default_dtype(init_dtype)
|
|
torch.set_default_device(device)
|
|
torch.cuda.set_device(device)
|
|
torch.manual_seed(42)
|
|
random.seed(42)
|
|
|
|
print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, "
|
|
f"{d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}")
|
|
|
|
use_fp8 = torch_dtype == torch.float8_e4m3fn
|
|
scale = math.sqrt(d)**(-1)
|
|
cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32)
|
|
if varlen:
|
|
for i in range(b):
|
|
cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2),
|
|
s_q)
|
|
total_seqlens = cache_seqlens.sum().item()
|
|
max_seqlen = cache_seqlens.max().item()
|
|
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
|
|
|
|
q = torch.randn(b, s_q, h_q, d)
|
|
block_table = torch.arange(b * max_seqlen_pad // block_size,
|
|
dtype=torch.int32).view(
|
|
b, max_seqlen_pad // block_size)
|
|
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
|
|
blocked_v = blocked_k[..., :dv]
|
|
|
|
init_dtype = q.dtype
|
|
if use_fp8:
|
|
fp8_dtype = torch.float8_e4m3fn
|
|
descale_q = torch.ones((1), dtype=torch.float32)
|
|
descale_k = torch.ones((1), dtype=torch.float32)
|
|
|
|
q = q.to(fp8_dtype)
|
|
blocked_k = blocked_k.to(fp8_dtype)
|
|
blocked_v = blocked_v.to(fp8_dtype)
|
|
else:
|
|
descale_q = None
|
|
descale_k = None
|
|
|
|
def cutlass_mla():
|
|
MAX_HEADS = 128
|
|
|
|
q_reshaped = q.squeeze(1)
|
|
q_nope = q_reshaped[:, :, :dv].clone()
|
|
q_pe = q_reshaped[:, :, dv:].clone()
|
|
|
|
if h_q < MAX_HEADS:
|
|
q_nope_padded = q_nope.new_empty((b, MAX_HEADS, dv))
|
|
q_nope_padded[:, :h_q] = q_nope
|
|
q_nope = q_nope_padded
|
|
|
|
q_pe_padded = q_pe.new_empty((b, MAX_HEADS, d - dv))
|
|
q_pe_padded[:, :h_q] = q_pe
|
|
q_pe = q_pe_padded
|
|
|
|
kv_cache_flat = blocked_k.squeeze(2)
|
|
device_properties = torch.cuda.get_device_properties(
|
|
torch.device("cuda:0"))
|
|
sm_count = device_properties.multi_processor_count
|
|
workspace_size = ops.sm100_cutlass_mla_get_workspace_size(
|
|
max_seqlen * block_size, b, sm_count, num_kv_splits=1)
|
|
workspace = torch.empty(workspace_size,
|
|
device="cuda",
|
|
dtype=torch.uint8)
|
|
|
|
out_ans = torch.empty(b, MAX_HEADS, dv, dtype=init_dtype)
|
|
output_lse = torch.empty((b, MAX_HEADS),
|
|
dtype=torch.float32,
|
|
device=q_nope.device)
|
|
ops.sm100_cutlass_mla_decode(out_ans, output_lse, q_nope, q_pe,
|
|
kv_cache_flat, cache_seqlens, block_table,
|
|
workspace, scale, 1)
|
|
return out_ans[:, :h_q].contiguous(), output_lse[:, :h_q].contiguous()
|
|
|
|
def scaled_dot_product_attention(query, key, value, is_causal=False):
|
|
query = query.float()
|
|
key = key.float()
|
|
value = value.float()
|
|
key = key.repeat_interleave(h_q // h_kv, dim=0)
|
|
value = value.repeat_interleave(h_q // h_kv, dim=0)
|
|
attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))
|
|
if is_causal:
|
|
s_q = query.shape[-2]
|
|
s_k = key.shape[-2]
|
|
attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype)
|
|
temp_mask = torch.ones(s_q, s_k,
|
|
dtype=torch.bool).tril(diagonal=s_k - s_q)
|
|
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
|
attn_bias.to(query.dtype)
|
|
attn_weight += attn_bias
|
|
lse = attn_weight.logsumexp(dim=-1)
|
|
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
|
|
return attn_weight @ value, lse
|
|
|
|
def ref_mla():
|
|
q_ = (q.to(torch.float) * descale_q).to(init_dtype) if use_fp8 else q
|
|
blocked_k_ = (blocked_k.to(torch.float) *
|
|
descale_k).to(init_dtype) if use_fp8 else blocked_k
|
|
blocked_v_ = (blocked_v.to(torch.float) *
|
|
descale_k).to(init_dtype) if use_fp8 else blocked_v
|
|
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
|
|
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
|
|
for i in range(b):
|
|
begin = i * max_seqlen_pad
|
|
end = begin + cache_seqlens[i]
|
|
out_i, lse_i = scaled_dot_product_attention(
|
|
q_[i].transpose(0, 1),
|
|
blocked_k_.view(-1, h_kv, d)[begin:end].transpose(0, 1),
|
|
blocked_v_.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
|
|
is_causal=causal,
|
|
)
|
|
out[i] = out_i.transpose(0, 1)
|
|
lse[i] = lse_i
|
|
return out, lse
|
|
|
|
out_cutlass, lse_cutlass = cutlass_mla()
|
|
out_torch, lse_torch = ref_mla()
|
|
# Extract the single token (s_q=1) slice to match cutlass output shape
|
|
out_torch_slice = out_torch[:, 0, :, :] # [b, h_q, dv]
|
|
lse_torch_slice = lse_torch[:, 0, :] # [b, h_q]
|
|
cal_diff(out_cutlass, out_torch_slice, "out", use_fp8)
|
|
# lse has larger numerical error, so use a larger threshold
|
|
cal_diff(lse_cutlass, lse_torch_slice, "lse", use_fp8, diff_threshold=1e-3)
|
|
|
|
t = triton.testing.do_bench(cutlass_mla)
|
|
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
|
|
bytes = (total_seqlens * h_kv * d +
|
|
b * s_q * h_q * d) * (torch.finfo(torch_dtype).bits // 8) + (
|
|
b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8)
|
|
print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS,",
|
|
f"{bytes / 10 ** 6 / t:.0f} GB/s")
|