vllm/tests/kernels/mamba/test_mamba_ssm_ssd.py
HAIAI aee76334d9
[amd_dev] branch rebase (#25753)
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Signed-off-by: Andrew Sansom <andrew@protopia.ai>
Signed-off-by: Boyuan Feng <boyuan@meta.com>
Signed-off-by: Boyuan Feng <fby.1994@gmail.com>
Signed-off-by: boyuanfeng <boyuan@meta.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: JartX <sagformas@epdcenter.es>
Signed-off-by: Chendi Xue <Chendi.Xue@intel.com>
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Roger Wang <hey@rogerw.io>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
Signed-off-by: Manoel Marques <manoel.marques@ibm.com>
Signed-off-by: Manoel Marques <manoelmrqs@gmail.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: pengdrumli <pengdrumli@tencent.com>
Signed-off-by: windsonsea <haifeng.yao@daocloud.io>
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: Huamin Li <3ericli@gmail.com>
Signed-off-by: simondanielsson <simon.danielsson99@hotmail.com>
Signed-off-by: Rahul Tuli <rtuli@redhat.com>
Signed-off-by: Yang <lymailforjob@gmail.com>
Signed-off-by: Debolina Roy <debroy@redhat.com>
Signed-off-by: David Chen <530634352@qq.com>
Signed-off-by: wangzi <3220100013@zju.edu.cn>
Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Signed-off-by: Sara Kokkila Schumacher <saraks@ibm.com>
Signed-off-by: Csrayz <jover@cmbchina.com>
Signed-off-by: ivyilike <pww123@cmbchina.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Bowen Wang <abmfy@icloud.com>
Signed-off-by: qqma <qqma@amazon.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: Lu Fang <fanglu@fb.com>
Signed-off-by: Zhuohan Li <zhuohan123@gmail.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: luka <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Signed-off-by: Or Ozeri <oro@il.ibm.com>
Signed-off-by: Johnny Yang <johnnyyang@google.com>
Signed-off-by: Alec Solder <alecs@fb.com>
Signed-off-by: Alec S <10566873+alecsolder@users.noreply.github.com>
Signed-off-by: Russell Bryant <rbryant@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: liuye.hj <liuye.hj@alibaba-inc.com>
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Signed-off-by: Lucia Fang <116399278+luccafong@users.noreply.github.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Signed-off-by: Ming Yang <minos.future@gmail.com>
Signed-off-by: Zhikaiiii <1658973216@qq.com>
Signed-off-by: Andreas Hartel <andreas.hartel@aleph-alpha.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: wuxibin <wuxibin@bytedance.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: Peter Pan <Peter.Pan@daocloud.io>
Signed-off-by: Peter Pan <peter.pan@daocloud.io>
Signed-off-by: Nicolò Lucchesi<nicolo.lucchesi@gmail.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Signed-off-by: Bill Nell <bnell@redhat.com>
Signed-off-by: Shreeasish Kumar <shreeasish@rivosinc.com>
Signed-off-by: Weida Hong <wdhongtw@google.com>
Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
Signed-off-by: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com>
Signed-off-by: Amir Samani <asamani@nvidia.com>
Signed-off-by: ElizaWszola <elizaw.9289@gmail.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: ilmarkov <markovilya197@gmail.com>
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com>
Signed-off-by: rouchenzi <ruochenwen@gmail.com>
Signed-off-by: rouchenzi <40842833+rouchenzi@users.noreply.github.com>
Signed-off-by: Andrew Xia <axia@meta.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Corey Lowman <clowman1993@gmail.com>
Signed-off-by: jpvillam <jpvillam@amd.com>
Signed-off-by: dougbtv <dosmith@redhat.com>
Signed-off-by: Chenxi Yang <cxyang@fb.com>
Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Signed-off-by: ahao-anyscale <ahao@anyscale.com>
Signed-off-by: Yan Lu <luyan@nvidia.com>
Signed-off-by: baxingpiaochong <771405853@qq.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Nikhil Gupta <nikhil.gupta2@arm.com>
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Ben Browning <bbrownin@redhat.com>
Signed-off-by: Chengji Yao <chengjiyao@google.com>
Signed-off-by: jiang1.li <jiang1.li@intel.com>
Signed-off-by: Jackmin801 <ongjackm@gmail.com>
Signed-off-by: Jonas M. Kübler <44084297+jmkuebler@users.noreply.github.com>
Signed-off-by: taohui <taohui3@gmail.com>
Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io>
Signed-off-by: Shu Wang <shuw@nvidia.com>
Signed-off-by: Shu Wang. <shuw@nvidia.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
Signed-off-by: Shiyan Deng <dsy842974287@meta.com>
Signed-off-by: Wei Wei <wwei6@meta.com>
Signed-off-by: Saman Keon <samanamp@outlook.com>
Signed-off-by: yangxurui <yangxurui@meituan.com>
Signed-off-by: nicole-lihui <nicole.li@daocloud.io>
Signed-off-by: courage17340 <courage17340@163.com>
Signed-off-by: Jacob Kahn <jacobkahn1@gmail.com>
Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com>
Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai>
Signed-off-by: zxw <1020938856@qq.com>
Signed-off-by: wang.yuqi <noooop@126.com>
Signed-off-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Signed-off-by: chenlang <chen.lang5@zte.com.cn>
Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
Signed-off-by: AlonKejzman <alonkeizman@gmail.com>
Signed-off-by: Tao Hui <taohui3@gmail.com>
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
Signed-off-by: Eugene Khvedchenia <ekhvedchenia@nvidia.com>
Signed-off-by: Eugene Khvedchenya <ekhvedchenya@gmail.com>
Signed-off-by: yiting.jiang <yiting.jiang@daocloud.io>
Signed-off-by: xaguilar <Xavier.AguilarFruto@amd.com>
Signed-off-by: Iceber Gu <caiwei95@hotmail.com>
Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
Signed-off-by: Icey <1790571317@qq.com>
Signed-off-by: 许文卿 <xwq391974@alibaba-inc.com>
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Lucas Kabela <lucasakabela@gmail.com>
Co-authored-by: Maximilien de Bayser <mbayser@br.ibm.com>
Co-authored-by: Andrew Sansom <andrew@protopia.ai>
Co-authored-by: Boyuan Feng <boyuan@meta.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: JartX <sagformas@epdcenter.es>
Co-authored-by: Chendi.Xue <chendi.xue@intel.com>
Co-authored-by: Chauncey <chaunceyjiang@gmail.com>
Co-authored-by: xin.li <xin.li@daocloud.io>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Wenlong Wang <wangwenlong2755@gmail.com>
Co-authored-by: Manoel Marques <manoelmrqs@gmail.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: lirong <56789630+lirong-lirong@users.noreply.github.com>
Co-authored-by: Michael Yao <haifeng.yao@daocloud.io>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: Huamin Li <3ericli@gmail.com>
Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com>
Co-authored-by: Simon Danielsson <70206058+simondanielsson@users.noreply.github.com>
Co-authored-by: Rahul Tuli <rtuli@redhat.com>
Co-authored-by: Claude <noreply@anthropic.com>
Co-authored-by: Yang Liu <127183760+KKSK-DON@users.noreply.github.com>
Co-authored-by: Deboleina <debroy@redhat.com>
Co-authored-by: yinz-aizip <yinz@aizip.ai>
Co-authored-by: WeiQing Chen <40507679+david6666666@users.noreply.github.com>
Co-authored-by: wangzi <3220100013@zju.edu.cn>
Co-authored-by: Eldar Kurtić <8884008+eldarkurtic@users.noreply.github.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com>
Co-authored-by: Yizhou <136800916+yiz-liu@users.noreply.github.com>
Co-authored-by: Sara-KS <50249410+Sara-KS@users.noreply.github.com>
Co-authored-by: Csrayz <jover@cmbchina.com>
Co-authored-by: ivyilike <pww123@cmbchina.com>
Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Co-authored-by: Bowen Wang <abmfy@icloud.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Daisy-Ma-coder <daisy.ma.0117@gmail.com>
Co-authored-by: qqma <qqma@amazon.com>
Co-authored-by: ElizaWszola <ewszola@redhat.com>
Co-authored-by: Lucia Fang <116399278+luccafong@users.noreply.github.com>
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>
Co-authored-by: Or Ozeri <oro@il.ibm.com>
Co-authored-by: Johnny Yang <24908445+jcyang43@users.noreply.github.com>
Co-authored-by: Chengji Yao <chengjiyao@google.com>
Co-authored-by: Alec S <10566873+alecsolder@users.noreply.github.com>
Co-authored-by: Alec Solder <alecs@fb.com>
Co-authored-by: Russell Bryant <rbryant@redhat.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: Chris Bamford <chrisbam4d@gmail.com>
Co-authored-by: Alexander Matveev <59768536+alexm-redhat@users.noreply.github.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: JJJYmmm <92386084+JJJYmmm@users.noreply.github.com>
Co-authored-by: liuye.hj <liuye.hj@alibaba-inc.com>
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
Co-authored-by: Lucia (Lu) Fang <fanglu@meta.com>
Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Ming Yang <yming@meta.com>
Co-authored-by: Zhikaiiii <55917203+Zhikaiiii@users.noreply.github.com>
Co-authored-by: Andreas Hartel <andreas@hartel.me>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com>
Co-authored-by: Joel <wuxibin89@163.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: Mark McLoughlin <markmc@redhat.com>
Co-authored-by: Peter Pan <peter.pan@daocloud.io>
Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com>
Co-authored-by: Fanli Lin <fanli.lin@intel.com>
Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: Sage Moore <sage@neuralmagic.com>
Co-authored-by: yewentao256 <zhyanwentao@126.com>
Co-authored-by: bnellnm <49004751+bnellnm@users.noreply.github.com>
Co-authored-by: rivos-shreeasish <shreeasish@rivosinc.com>
Co-authored-by: Chih-Chieh Yang <chih.chieh.yang@ibm.com>
Co-authored-by: Weida Hong <wdhongtw@gmail.com>
Co-authored-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
Co-authored-by: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com>
Co-authored-by: Amir Samani <samani@ualberta.ca>
Co-authored-by: Luka Govedič <lgovedic@redhat.com>
Co-authored-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Co-authored-by: Ilya Markov <markovilya197@gmail.com>
Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com>
Co-authored-by: Jialin Ouyang <Jialin.Ouyang@gmail.com>
Co-authored-by: rouchenzi <40842833+rouchenzi@users.noreply.github.com>
Co-authored-by: Andrew Xia <axia@meta.com>
Co-authored-by: kourosh hakhamaneshi <31483498+kouroshHakha@users.noreply.github.com>
Co-authored-by: Corey Lowman <clowman1993@gmail.com>
Co-authored-by: Juan Villamizar <100237675+jpvillam-amd@users.noreply.github.com>
Co-authored-by: jpvillam <jpvillam@amd.com>
Co-authored-by: Doug Smith <dosmith@redhat.com>
Co-authored-by: Chenxi Yang <cxyang@cs.utexas.edu>
Co-authored-by: Chenxi Yang <cxyang@fb.com>
Co-authored-by: ahao-anyscale <ahao@anyscale.com>
Co-authored-by: 0xNullPath <luyanfcp@foxmail.com>
Co-authored-by: baxingpiaochong <771405853@qq.com>
Co-authored-by: Benjamin Chislett <bchislett@nvidia.com>
Co-authored-by: Kyle Sayers <kylesayrs@gmail.com>
Co-authored-by: Nikhil Gupta <nikhil.gupta2@arm.com>
Co-authored-by: Yong Hoon Shin <48474650+sarckk@users.noreply.github.com>
Co-authored-by: lhsjohn <huashuoli@tencent.com>
Co-authored-by: Ben Browning <bbrownin@redhat.com>
Co-authored-by: Li, Jiang <jiang1.li@intel.com>
Co-authored-by: Jackmin801 <56836461+Jackmin801@users.noreply.github.com>
Co-authored-by: Jonas M. Kübler <44084297+jmkuebler@users.noreply.github.com>
Co-authored-by: Tao Hui <taohui3@gmail.com>
Co-authored-by: rongfu.leng <rongfu.leng@daocloud.io>
Co-authored-by: Shu Wang <shuw@nvidia.com>
Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: Duncan Moss <djm.moss@gmail.com>
Co-authored-by: Shiyan Deng <dsy842974287@meta.com>
Co-authored-by: Wei Wei <wwei6@meta.com>
Co-authored-by: Saman A. Pour <samanamp@outlook.com>
Co-authored-by: XuruiYang <530534756@qq.com>
Co-authored-by: yangxurui <yangxurui@meituan.com>
Co-authored-by: Nicole LiHui 🥜 <nicolelihui@outlook.com>
Co-authored-by: courage17340 <courage17340@users.noreply.github.com>
Co-authored-by: Jacob Kahn <jacobkahn1@gmail.com>
Co-authored-by: Nicole LiHui 🥜 <nicole.li@daocloud.io>
Co-authored-by: Fadi Arafeh <115173828+fadara01@users.noreply.github.com>
Co-authored-by: Agata Dobrzyniewicz <160237065+adobrzyn@users.noreply.github.com>
Co-authored-by: yyzxw <34639446+yyzxw@users.noreply.github.com>
Co-authored-by: wang.yuqi <noooop@126.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: chenlang <chen.lang5@zte.com.cn>
Co-authored-by: chenlang <10346245@zte.com.cn>
Co-authored-by: AlonKejzman <alonkeizman@gmail.com>
Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com>
Co-authored-by: Aleksandr Malyshev <164964928+maleksan85@users.noreply.github.com>
Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>
Co-authored-by: Doug Lehr <douglehr@amd.com>
Co-authored-by: Eugene Khvedchenya <ekhvedchenya@gmail.com>
Co-authored-by: yitingdc <59356937+yitingdc@users.noreply.github.com>
Co-authored-by: xaguilar-amd <xavier.aguilarfruto@amd.com>
Co-authored-by: Iceber Gu <caiwei95@hotmail.com>
Co-authored-by: Tao He <linzhu.ht@alibaba-inc.com>
Co-authored-by: Icey <1790571317@qq.com>
Co-authored-by: Xu Wenqing <121550081+Xu-Wenqing@users.noreply.github.com>
Co-authored-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com>
Co-authored-by: RishiAstra <40644327+RishiAstra@users.noreply.github.com>
2025-09-26 17:14:31 +01:00

561 lines
22 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
mamba_chunk_scan_combined_varlen)
from vllm.platforms import current_platform
from vllm.v1.attention.backends.mamba2_attn import (
_query_start_loc_to_chunk_indices_offsets)
# Added by the IBM Team, 2024
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/modules/ssd_minimal.py
# this is the segsum implementation taken from above
def segsum(x):
"""Calculates segment sum."""
T = x.size(-1)
x = repeat(x, "... d -> ... d e", e=T)
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool),
diagonal=-1)
x = x.masked_fill(~mask, 0)
x_segsum = torch.cumsum(x, dim=-2)
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool),
diagonal=0)
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
return x_segsum
def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None):
"""
Arguments:
X: (batch, length, n_heads, d_head)
A: (batch, length, n_heads)
B: (batch, length, n_heads, d_state)
C: (batch, length, n_heads, d_state)
Return:
Y: (batch, length, n_heads, d_head)
"""
assert X.dtype == A.dtype == B.dtype == C.dtype
assert X.shape[1] % block_len == 0
# Rearrange into blocks/chunks
X, A, B, C = (rearrange(x, "b (c l) ... -> b c l ...", l=block_len)
for x in (X, A, B, C))
A = rearrange(A, "b c l h -> b h c l")
A_cumsum = torch.cumsum(A, dim=-1)
# 1. Compute the output for each intra-chunk (diagonal blocks)
L = torch.exp(segsum(A))
Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X)
# 2. Compute the state for each intra-chunk
# (right term of low-rank factorization of off-diagonal blocks; B terms)
decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum)
states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X)
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at
# chunk boundaries
# (middle term of factorization of off-diag blocks; A terms)
if initial_states is None:
initial_states = torch.zeros_like(states[:, :1])
states = torch.cat([initial_states, states], dim=1)
decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
states, final_state = new_states[:, :-1], new_states[:, -1]
# 4. Compute state -> output conversion per chunk
# (left term of low-rank factorization of off-diagonal blocks; C terms)
state_decay_out = torch.exp(A_cumsum)
Y_off = torch.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out)
# Add output of intra-chunk and inter-chunk terms
# (diagonal and off-diagonal blocks)
Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p")
return Y, final_state
def generate_random_inputs(batch_size,
seqlen,
n_heads,
d_head,
itype,
device='cuda'):
current_platform.seed_everything(0)
A = (-torch.exp(torch.rand(n_heads, dtype=itype, device=device)))
dt = F.softplus(
torch.randn(batch_size, seqlen, n_heads, dtype=itype, device=device) -
4)
X = torch.randn((batch_size, seqlen, n_heads, d_head),
dtype=itype,
device=device)
B = torch.randn((batch_size, seqlen, n_heads, d_head),
dtype=itype,
device=device)
C = torch.randn((batch_size, seqlen, n_heads, d_head),
dtype=itype,
device=device)
return A, dt, X, B, C
def generate_continuous_batched_examples(example_lens_by_batch,
num_examples,
full_length,
last_taken,
exhausted,
n_heads,
d_head,
itype,
device='cuda',
return_naive_ref=True):
# this function generates a random examples of certain length
# and then cut according to "example_lens_by_batch" and feed
# them in continuous batches to the kernels.
# If if return_naive_ref=True, the naive torch implementation
# ssd_minimal_discrete will be used to compute and return
# reference output.
# generate the full-length example
A, dt, X, B, C = generate_random_inputs(num_examples, full_length, n_heads,
d_head, itype)
if return_naive_ref:
Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1),
A * dt,
B,
C,
block_len=full_length //
4)
# internal function that outputs a cont batch of examples
# given a tuple of lengths for each example in the batch
# e.g., example_lens=(8, 4) means take 8 samples from first eg,
# 4 examples from second eg, etc
def get_continuous_batch(example_lens: tuple[int, ...]):
indices = []
for i, x in enumerate(example_lens):
c = last_taken.get(i, 0)
indices.append((c, c + x))
last_taken[i] = (c + x) % full_length
exhausted[i] = last_taken[i] == 0
return (torch.concat([x[i, s:e] for i, (s, e) in enumerate(indices)
]).unsqueeze(0) for x in (dt, X, B, C))
# internal function that maps "n" to the appropriate right boundary
# value when forming continuous batches from examples of length given
# by "full_length".
# - e.g., when n > full_length, returns n % full_length
# when n == full_length, returns full_length
def end_boundary(n: int):
return n - ((n - 1) // full_length) * full_length
IND_E = None
for spec in example_lens_by_batch:
# get the (maybe partial) example seen in this cont batch
dt2, X2, B2, C2 = get_continuous_batch(spec)
# get the metadata
cu_seqlens = torch.tensor((0, ) + spec, device=device).cumsum(dim=0)
seq_idx = torch.zeros(cu_seqlens[-1],
dtype=torch.int32,
device=cu_seqlens.device)
for i, (srt, end) in enumerate(zip(
cu_seqlens,
cu_seqlens[1:],
)):
seq_idx[srt:end] = i
# for cont batch
if IND_E is None:
IND_S = [0 for _ in range(len(spec))]
else:
IND_S = [x % full_length for x in IND_E]
IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)]
# varlen has implicit batch=1
dt2 = dt2.squeeze(0)
X2 = X2.squeeze(0)
B2 = B2.squeeze(0)
C2 = C2.squeeze(0)
yield ([Y_min[s, IND_S[s]:IND_E[s]]
for s in range(num_examples)] if return_naive_ref else None,
cu_seqlens, seq_idx, (A, dt2, X2, B2, C2))
@pytest.mark.parametrize("itype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("n_heads", [3, 4, 11, 16, 32])
@pytest.mark.parametrize("d_head", [5, 8, 19, 32, 128])
@pytest.mark.parametrize("seq_len_chunk_size", [(112, 16), (128, 32)])
def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
itype):
# this tests the kernels on a single example (bs=1)
# TODO: the bfloat16 case requires higher thresholds. To be investigated
if itype == torch.bfloat16:
atol, rtol = 5e-2, 5e-2
else:
atol, rtol = 8e-3, 5e-3
# set seed
batch_size = 1 # batch_size
# ssd_minimal_discrete requires chunk_size divide seqlen
# - this is only required for generating the reference seqs,
# it is not an operational limitation.
seqlen, chunk_size = seq_len_chunk_size
A, dt, X, B, C = generate_random_inputs(batch_size, seqlen, n_heads,
d_head, itype)
Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), A * dt,
B, C, chunk_size)
cu_seqlens = torch.tensor((0, seqlen), device='cuda').cumsum(dim=0)
seq_idx = torch.zeros(seqlen, dtype=torch.int32, device=cu_seqlens.device)
chunk_indices, chunk_offsets = \
_query_start_loc_to_chunk_indices_offsets(
cu_seqlens, chunk_size, cu_seqlens[-1])
# varlen has implicit batch=1
X = X.squeeze(0)
dt = dt.squeeze(0)
A = A.squeeze(0)
B = B.squeeze(0)
C = C.squeeze(0)
Y = torch.empty_like(X)
final_state = mamba_chunk_scan_combined_varlen(X,
dt,
A,
B,
C,
chunk_size,
D=None,
cu_seqlens=cu_seqlens,
seq_idx=seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
out=Y)
# just test the last in sequence
torch.testing.assert_close(Y[-1], Y_min[0, -1], atol=atol, rtol=rtol)
# just test the last head
# NOTE, in the kernel we always cast states to fp32
torch.testing.assert_close(final_state[:, -1].to(torch.float32),
final_state_min[:, -1].to(torch.float32),
atol=atol,
rtol=rtol)
@pytest.mark.parametrize("itype", [torch.float32, torch.float16])
@pytest.mark.parametrize("n_heads", [4, 8, 13])
@pytest.mark.parametrize("d_head", [5, 16, 21, 32])
@pytest.mark.parametrize(
"seq_len_chunk_size_cases",
[
# small-ish chunk_size (8)
(64, 8, 2, [(64, 32), (64, 32)]),
(64, 8, 2, [(32, 32), (32, 32), (32, 32)]),
(64, 8, 2, [(8, 8), (8, 8), (8, 8)]), # chunk size boundary
(64, 8, 2, [(4, 4), (4, 4), (4, 4),
(4, 4)]), # chunk_size larger than cont batches
(64, 8, 5, [
(64, 32, 16, 8, 8),
(8, 16, 32, 16, 8),
(8, 8, 16, 32, 16),
]), # mode examples with varied lengths
# large-ish chunk_size (256)
(64, 256, 1, [(5, ), (1, ), (1, ),
(1, )]), # irregular sizes with small sequences
(64, 256, 2, [(5, 30), (1, 2), (1, 2),
(1, 2)]), # irregular sizes with small sequences
# we also need to test some large seqlen
# to catch errors with init states decay
(768, 128, 2, [(138, 225), (138, 225)]),
])
def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
itype):
# this test with multiple examples in a continuous batch
# (i.e. chunked prefill)
seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases
# This test can have larger error for longer sequences
if seqlen > 256:
atol, rtol = 1e-2, 5e-3
else:
atol, rtol = 5e-3, 5e-3
# hold state during the cutting process so we know if an
# example has been exhausted and needs to cycle
last_taken: dict = {} # map: eg -> pointer to last taken sample
exhausted: dict = {} # map: eg -> boolean indicating example is exhausted
states = None
for Y_min, cu_seqlens, seq_idx, (
A, dt, X, B, C) in generate_continuous_batched_examples(
cases, num_examples, seqlen, last_taken, exhausted, n_heads,
d_head, itype):
chunk_indices, chunk_offsets = \
_query_start_loc_to_chunk_indices_offsets(
cu_seqlens, chunk_size, cu_seqlens[-1])
Y = torch.empty_like(X)
new_states = mamba_chunk_scan_combined_varlen(
X,
dt,
A,
B,
C,
chunk_size,
D=None,
cu_seqlens=cu_seqlens,
seq_idx=seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
initial_states=states,
out=Y,
)
# just test the last in sequence
for i in range(num_examples):
# just test one dim and dstate
Y_eg = Y[cu_seqlens[i]:cu_seqlens[i + 1], 0, 0]
Y_min_eg = Y_min[i][:, 0, 0]
torch.testing.assert_close(Y_eg, Y_min_eg, atol=atol, rtol=rtol)
# update states
states = new_states
for i, clear in exhausted.items():
if clear:
states[i].fill_(0.)
exhausted[i] = False
@pytest.mark.parametrize("chunk_size", [8, 256])
@pytest.mark.parametrize("seqlens", [
(16, 2, 8, 13),
(270, 88, 212, 203),
(16, 20),
])
def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
# This test verifies the correctness of the chunked prefill implementation
# in the mamba2 ssd kernels, by comparing concatenation (in the sequence
# dimension) of chunked results with the full sequence result.
# It is different from test_mamba_chunk_scan_cont_batch by:
# 1. Not using the naive torch implementation (ssd_minimal_discrete) to get
# reference outputs. Instead, it compares chunked kernel outputs to full
# sequence kernel outputs. This is the most straightforward way to
# assert chunked prefill correctness.
# 2. It focuses on cases where sequences change in the middle of mamba
# chunks, and not necessarily on chunk boundaries.
max_seqlen = max(seqlens)
# This test can have larger error for longer sequences
if max_seqlen > 256:
atol, rtol = 1e-2, 5e-3
else:
atol, rtol = 5e-3, 5e-3
num_sequences = len(seqlens)
n_heads = 16
d_head = 64
itype = torch.float32
# hold state during the cutting process so we know if an
# example has been exhausted and needs to cycle
last_taken: dict = {} # map: eg -> pointer to last taken sample
exhausted: dict = {} # map: eg -> boolean indicating example is exhausted
_, cu_seqlens, seq_idx, (A, dt, X, B, C) = next(
generate_continuous_batched_examples([seqlens],
num_sequences,
max_seqlen,
last_taken,
exhausted,
n_heads,
d_head,
itype,
return_naive_ref=False))
seqlens = torch.tensor(seqlens, dtype=torch.int32, device=X.device)
device = X.device
## full seqlen computation
chunk_indices, chunk_offsets = \
_query_start_loc_to_chunk_indices_offsets(
cu_seqlens, chunk_size, cu_seqlens[-1])
Y_ref = torch.empty_like(X)
state_ref = mamba_chunk_scan_combined_varlen(
X,
dt,
A,
B,
C,
chunk_size,
D=None,
cu_seqlens=cu_seqlens,
seq_idx=seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
initial_states=None,
out=Y_ref,
)
## chunked seqlen computation
# first chunk
chunked_seqlens = seqlens // 2
chunked_cu_seqlens = torch.cat([
torch.tensor([0], device=device),
torch.cumsum(chunked_seqlens, dim=0)
],
dim=0)
chunked_seq_idx = torch.repeat_interleave(
torch.arange(len(chunked_seqlens), device=device),
chunked_seqlens,
output_size=chunked_cu_seqlens[-1]).to(torch.int32)
chunked_input_seq_len = chunked_cu_seqlens[-1]
X_chunked = torch.zeros_like(X)[:chunked_input_seq_len, ...]
dt_chunked = torch.zeros_like(dt)[:chunked_input_seq_len, ...]
B_chunked = torch.zeros_like(B)[:chunked_input_seq_len, ...]
C_chunked = torch.zeros_like(C)[:chunked_input_seq_len, ...]
for i in range(num_sequences):
# fmt: off
chunk_f = lambda x, i: x[cu_seqlens[i]:cu_seqlens[i] + chunked_seqlens[i], ...] # noqa: E501
X_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(X, i) # noqa: E501
dt_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(dt, i) # noqa: E501
B_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(B, i) # noqa: E501
C_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(C, i) # noqa: E501
# fmt: on
chunk_indices, chunk_offsets = \
_query_start_loc_to_chunk_indices_offsets(
chunked_cu_seqlens, chunk_size, chunked_cu_seqlens[-1])
Y_partial = torch.empty_like(X_chunked)
partial_state = mamba_chunk_scan_combined_varlen(
X_chunked,
dt_chunked,
A,
B_chunked,
C_chunked,
chunk_size,
D=None,
cu_seqlens=chunked_cu_seqlens,
seq_idx=chunked_seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
initial_states=None,
out=Y_partial,
)
# remaining chunk
remaining_chunked_seqlens = seqlens - chunked_seqlens
remaining_chunked_cu_seqlens = torch.cat([
torch.tensor([0], device=device),
torch.cumsum(remaining_chunked_seqlens, dim=0)
],
dim=0)
remaining_chunked_seq_idx = torch.repeat_interleave(
torch.arange(len(remaining_chunked_seqlens), device=device),
remaining_chunked_seqlens,
output_size=remaining_chunked_cu_seqlens[-1]).to(torch.int32)
remaining_chunked_input_seq_len = remaining_chunked_cu_seqlens[-1]
# fmt: off
remaining_X_chunked = torch.zeros_like(X)[:remaining_chunked_input_seq_len, ...] # noqa: E501
remaining_dt_chunked = torch.zeros_like(dt)[:remaining_chunked_input_seq_len, ...] # noqa: E501
remaining_B_chunked = torch.zeros_like(B)[:remaining_chunked_input_seq_len, ...] # noqa: E501
remaining_C_chunked = torch.zeros_like(C)[:remaining_chunked_input_seq_len, ...] # noqa: E501
for i in range(num_sequences):
remaining_chunk_f = lambda x, i: x[cu_seqlens[i] + chunked_seqlens[i]:cu_seqlens[i+1], ...] # noqa: E501
remaining_X_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(X, i) # noqa: E501
remaining_dt_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(dt, i) # noqa: E501
remaining_B_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(B, i) # noqa: E501
remaining_C_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(C, i) # noqa: E501
# assert input chunking is correct
concat_chunk_f = lambda pt1, pt2, i: torch.cat([
pt1[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1],...],
pt2[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1],...],
],
dim=0)
concat_batch_f = lambda pt1, pt2: torch.cat([concat_chunk_f(pt1, pt2, i) for i in range(num_sequences)], dim=0) # noqa: E501
# fmt: on
assert concat_batch_f(X_chunked, remaining_X_chunked).equal(X)
assert concat_batch_f(dt_chunked, remaining_dt_chunked).equal(dt)
assert concat_batch_f(B_chunked, remaining_B_chunked).equal(B)
assert concat_batch_f(C_chunked, remaining_C_chunked).equal(C)
chunk_indices, chunk_offsets = \
_query_start_loc_to_chunk_indices_offsets(
remaining_chunked_cu_seqlens,
chunk_size,
remaining_chunked_cu_seqlens[-1])
Y_chunked = torch.empty_like(remaining_X_chunked)
state_chunked = mamba_chunk_scan_combined_varlen(
remaining_X_chunked,
remaining_dt_chunked,
A,
remaining_B_chunked,
remaining_C_chunked,
chunk_size,
D=None,
cu_seqlens=remaining_chunked_cu_seqlens,
seq_idx=remaining_chunked_seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
initial_states=partial_state,
out=Y_chunked,
)
Y = concat_batch_f(Y_partial, Y_chunked)
# kernel chunked is same as kernel overall
for i in range(num_sequences):
Y_seq = Y[cu_seqlens[i]:cu_seqlens[i + 1], ...]
Y_ref_seq = Y_ref[cu_seqlens[i]:cu_seqlens[i + 1], ...]
torch.testing.assert_close(
Y_seq[:chunked_seqlens[i], ...],
Y_ref_seq[:chunked_seqlens[i], ...],
atol=atol,
rtol=rtol,
msg=lambda x: f"seq{i} output part1 " + x) # noqa: B023
torch.testing.assert_close(
Y_seq[chunked_seqlens[i]:, ...],
Y_ref_seq[chunked_seqlens[i]:, ...],
atol=atol,
rtol=rtol,
msg=lambda x: f"seq{i} output part2 " + x) # noqa: B023
state_seq = state_chunked[i]
state_seq_ref = state_ref[i]
torch.testing.assert_close(
state_seq,
state_seq_ref,
atol=atol,
rtol=rtol,
msg=lambda x: f"seq{i} state " + x) # noqa: B023