mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-14 07:07:02 +08:00
Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: Lucas Kabela <lucaskabela@meta.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: Andrew Sansom <andrew@protopia.ai> Signed-off-by: Boyuan Feng <boyuan@meta.com> Signed-off-by: Boyuan Feng <fby.1994@gmail.com> Signed-off-by: boyuanfeng <boyuan@meta.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: JartX <sagformas@epdcenter.es> Signed-off-by: Chendi Xue <Chendi.Xue@intel.com> Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Signed-off-by: Roger Wang <hey@rogerw.io> Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: wwl2755 <wangwenlong2755@gmail.com> Signed-off-by: Manoel Marques <manoel.marques@ibm.com> Signed-off-by: Manoel Marques <manoelmrqs@gmail.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: pengdrumli <pengdrumli@tencent.com> Signed-off-by: windsonsea <haifeng.yao@daocloud.io> Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai> Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: Huamin Li <3ericli@gmail.com> Signed-off-by: simondanielsson <simon.danielsson99@hotmail.com> Signed-off-by: Rahul Tuli <rtuli@redhat.com> Signed-off-by: Yang <lymailforjob@gmail.com> Signed-off-by: Debolina Roy <debroy@redhat.com> Signed-off-by: David Chen <530634352@qq.com> Signed-off-by: wangzi <3220100013@zju.edu.cn> Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com> Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com> Signed-off-by: Sara Kokkila Schumacher <saraks@ibm.com> Signed-off-by: Csrayz <jover@cmbchina.com> Signed-off-by: ivyilike <pww123@cmbchina.com> Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com> Signed-off-by: Bowen Wang <abmfy@icloud.com> Signed-off-by: qqma <qqma@amazon.com> Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: Lu Fang <fanglu@fb.com> Signed-off-by: Zhuohan Li <zhuohan123@gmail.com> Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: luka <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: Or Ozeri <oro@il.ibm.com> Signed-off-by: Johnny Yang <johnnyyang@google.com> Signed-off-by: Alec Solder <alecs@fb.com> Signed-off-by: Alec S <10566873+alecsolder@users.noreply.github.com> Signed-off-by: Russell Bryant <rbryant@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Alexander Matveev <amatveev@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: liuye.hj <liuye.hj@alibaba-inc.com> Signed-off-by: Kunshang Ji <kunshang.ji@intel.com> Signed-off-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Signed-off-by: Ming Yang <minos.future@gmail.com> Signed-off-by: Zhikaiiii <1658973216@qq.com> Signed-off-by: Andreas Hartel <andreas.hartel@aleph-alpha.com> Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Signed-off-by: wuxibin <wuxibin@bytedance.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: Peter Pan <Peter.Pan@daocloud.io> Signed-off-by: Peter Pan <peter.pan@daocloud.io> Signed-off-by: Nicolò Lucchesi<nicolo.lucchesi@gmail.com> Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Signed-off-by: Sage Moore <sage@neuralmagic.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Signed-off-by: Bill Nell <bnell@redhat.com> Signed-off-by: Shreeasish Kumar <shreeasish@rivosinc.com> Signed-off-by: Weida Hong <wdhongtw@google.com> Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com> Signed-off-by: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com> Signed-off-by: Amir Samani <asamani@nvidia.com> Signed-off-by: ElizaWszola <elizaw.9289@gmail.com> Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Signed-off-by: ilmarkov <markovilya197@gmail.com> Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Signed-off-by: rouchenzi <ruochenwen@gmail.com> Signed-off-by: rouchenzi <40842833+rouchenzi@users.noreply.github.com> Signed-off-by: Andrew Xia <axia@meta.com> Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com> Signed-off-by: Corey Lowman <clowman1993@gmail.com> Signed-off-by: jpvillam <jpvillam@amd.com> Signed-off-by: dougbtv <dosmith@redhat.com> Signed-off-by: Chenxi Yang <cxyang@fb.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: ahao-anyscale <ahao@anyscale.com> Signed-off-by: Yan Lu <luyan@nvidia.com> Signed-off-by: baxingpiaochong <771405853@qq.com> Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> Signed-off-by: Nikhil Gupta <nikhil.gupta2@arm.com> Signed-off-by: Yong Hoon Shin <yhshin@meta.com> Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai> Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Signed-off-by: Ben Browning <bbrownin@redhat.com> Signed-off-by: Chengji Yao <chengjiyao@google.com> Signed-off-by: jiang1.li <jiang1.li@intel.com> Signed-off-by: Jackmin801 <ongjackm@gmail.com> Signed-off-by: Jonas M. Kübler <44084297+jmkuebler@users.noreply.github.com> Signed-off-by: taohui <taohui3@gmail.com> Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io> Signed-off-by: Shu Wang <shuw@nvidia.com> Signed-off-by: Shu Wang. <shuw@nvidia.com> Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Signed-off-by: Duncan Moss <djm.moss@gmail.com> Signed-off-by: Shiyan Deng <dsy842974287@meta.com> Signed-off-by: Wei Wei <wwei6@meta.com> Signed-off-by: Saman Keon <samanamp@outlook.com> Signed-off-by: yangxurui <yangxurui@meituan.com> Signed-off-by: nicole-lihui <nicole.li@daocloud.io> Signed-off-by: courage17340 <courage17340@163.com> Signed-off-by: Jacob Kahn <jacobkahn1@gmail.com> Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com> Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai> Signed-off-by: zxw <1020938856@qq.com> Signed-off-by: wang.yuqi <noooop@126.com> Signed-off-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Signed-off-by: chenlang <chen.lang5@zte.com.cn> Signed-off-by: Jonas Kuebler <kuebj@amazon.com> Signed-off-by: AlonKejzman <alonkeizman@gmail.com> Signed-off-by: Tao Hui <taohui3@gmail.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Signed-off-by: Eugene Khvedchenia <ekhvedchenia@nvidia.com> Signed-off-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Signed-off-by: yiting.jiang <yiting.jiang@daocloud.io> Signed-off-by: xaguilar <Xavier.AguilarFruto@amd.com> Signed-off-by: Iceber Gu <caiwei95@hotmail.com> Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com> Signed-off-by: Icey <1790571317@qq.com> Signed-off-by: 许文卿 <xwq391974@alibaba-inc.com> Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: Nick Hill <nhill@redhat.com> Co-authored-by: Lucas Kabela <lucasakabela@gmail.com> Co-authored-by: Maximilien de Bayser <mbayser@br.ibm.com> Co-authored-by: Andrew Sansom <andrew@protopia.ai> Co-authored-by: Boyuan Feng <boyuan@meta.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: JartX <sagformas@epdcenter.es> Co-authored-by: Chendi.Xue <chendi.xue@intel.com> Co-authored-by: Chauncey <chaunceyjiang@gmail.com> Co-authored-by: xin.li <xin.li@daocloud.io> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Roger Wang <hey@rogerw.io> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Wenlong Wang <wangwenlong2755@gmail.com> Co-authored-by: Manoel Marques <manoelmrqs@gmail.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: lirong <56789630+lirong-lirong@users.noreply.github.com> Co-authored-by: Michael Yao <haifeng.yao@daocloud.io> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: Huamin Li <3ericli@gmail.com> Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com> Co-authored-by: Simon Danielsson <70206058+simondanielsson@users.noreply.github.com> Co-authored-by: Rahul Tuli <rtuli@redhat.com> Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: Yang Liu <127183760+KKSK-DON@users.noreply.github.com> Co-authored-by: Deboleina <debroy@redhat.com> Co-authored-by: yinz-aizip <yinz@aizip.ai> Co-authored-by: WeiQing Chen <40507679+david6666666@users.noreply.github.com> Co-authored-by: wangzi <3220100013@zju.edu.cn> Co-authored-by: Eldar Kurtić <8884008+eldarkurtic@users.noreply.github.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com> Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com> Co-authored-by: Yizhou <136800916+yiz-liu@users.noreply.github.com> Co-authored-by: Sara-KS <50249410+Sara-KS@users.noreply.github.com> Co-authored-by: Csrayz <jover@cmbchina.com> Co-authored-by: ivyilike <pww123@cmbchina.com> Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com> Co-authored-by: Bowen Wang <abmfy@icloud.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Daisy-Ma-coder <daisy.ma.0117@gmail.com> Co-authored-by: qqma <qqma@amazon.com> Co-authored-by: ElizaWszola <ewszola@redhat.com> Co-authored-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Co-authored-by: Zhuohan Li <zhuohan123@gmail.com> Co-authored-by: Simon Mo <simon.mo@hey.com> Co-authored-by: Or Ozeri <oro@il.ibm.com> Co-authored-by: Johnny Yang <24908445+jcyang43@users.noreply.github.com> Co-authored-by: Chengji Yao <chengjiyao@google.com> Co-authored-by: Alec S <10566873+alecsolder@users.noreply.github.com> Co-authored-by: Alec Solder <alecs@fb.com> Co-authored-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Chris Bamford <chrisbam4d@gmail.com> Co-authored-by: Alexander Matveev <59768536+alexm-redhat@users.noreply.github.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: JJJYmmm <92386084+JJJYmmm@users.noreply.github.com> Co-authored-by: liuye.hj <liuye.hj@alibaba-inc.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com> Co-authored-by: Lucia (Lu) Fang <fanglu@meta.com> Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Ming Yang <yming@meta.com> Co-authored-by: Zhikaiiii <55917203+Zhikaiiii@users.noreply.github.com> Co-authored-by: Andreas Hartel <andreas@hartel.me> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: Joel <wuxibin89@163.com> Co-authored-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Mark McLoughlin <markmc@redhat.com> Co-authored-by: Peter Pan <peter.pan@daocloud.io> Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com> Co-authored-by: Fanli Lin <fanli.lin@intel.com> Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Sage Moore <sage@neuralmagic.com> Co-authored-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: bnellnm <49004751+bnellnm@users.noreply.github.com> Co-authored-by: rivos-shreeasish <shreeasish@rivosinc.com> Co-authored-by: Chih-Chieh Yang <chih.chieh.yang@ibm.com> Co-authored-by: Weida Hong <wdhongtw@gmail.com> Co-authored-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Co-authored-by: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com> Co-authored-by: Amir Samani <samani@ualberta.ca> Co-authored-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Co-authored-by: Ilya Markov <markovilya197@gmail.com> Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Co-authored-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Co-authored-by: rouchenzi <40842833+rouchenzi@users.noreply.github.com> Co-authored-by: Andrew Xia <axia@meta.com> Co-authored-by: kourosh hakhamaneshi <31483498+kouroshHakha@users.noreply.github.com> Co-authored-by: Corey Lowman <clowman1993@gmail.com> Co-authored-by: Juan Villamizar <100237675+jpvillam-amd@users.noreply.github.com> Co-authored-by: jpvillam <jpvillam@amd.com> Co-authored-by: Doug Smith <dosmith@redhat.com> Co-authored-by: Chenxi Yang <cxyang@cs.utexas.edu> Co-authored-by: Chenxi Yang <cxyang@fb.com> Co-authored-by: ahao-anyscale <ahao@anyscale.com> Co-authored-by: 0xNullPath <luyanfcp@foxmail.com> Co-authored-by: baxingpiaochong <771405853@qq.com> Co-authored-by: Benjamin Chislett <bchislett@nvidia.com> Co-authored-by: Kyle Sayers <kylesayrs@gmail.com> Co-authored-by: Nikhil Gupta <nikhil.gupta2@arm.com> Co-authored-by: Yong Hoon Shin <48474650+sarckk@users.noreply.github.com> Co-authored-by: lhsjohn <huashuoli@tencent.com> Co-authored-by: Ben Browning <bbrownin@redhat.com> Co-authored-by: Li, Jiang <jiang1.li@intel.com> Co-authored-by: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> Co-authored-by: Jonas M. Kübler <44084297+jmkuebler@users.noreply.github.com> Co-authored-by: Tao Hui <taohui3@gmail.com> Co-authored-by: rongfu.leng <rongfu.leng@daocloud.io> Co-authored-by: Shu Wang <shuw@nvidia.com> Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Co-authored-by: Duncan Moss <djm.moss@gmail.com> Co-authored-by: Shiyan Deng <dsy842974287@meta.com> Co-authored-by: Wei Wei <wwei6@meta.com> Co-authored-by: Saman A. Pour <samanamp@outlook.com> Co-authored-by: XuruiYang <530534756@qq.com> Co-authored-by: yangxurui <yangxurui@meituan.com> Co-authored-by: Nicole LiHui 🥜 <nicolelihui@outlook.com> Co-authored-by: courage17340 <courage17340@users.noreply.github.com> Co-authored-by: Jacob Kahn <jacobkahn1@gmail.com> Co-authored-by: Nicole LiHui 🥜 <nicole.li@daocloud.io> Co-authored-by: Fadi Arafeh <115173828+fadara01@users.noreply.github.com> Co-authored-by: Agata Dobrzyniewicz <160237065+adobrzyn@users.noreply.github.com> Co-authored-by: yyzxw <34639446+yyzxw@users.noreply.github.com> Co-authored-by: wang.yuqi <noooop@126.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: chenlang <chen.lang5@zte.com.cn> Co-authored-by: chenlang <10346245@zte.com.cn> Co-authored-by: AlonKejzman <alonkeizman@gmail.com> Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> Co-authored-by: Aleksandr Malyshev <164964928+maleksan85@users.noreply.github.com> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Doug Lehr <douglehr@amd.com> Co-authored-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Co-authored-by: yitingdc <59356937+yitingdc@users.noreply.github.com> Co-authored-by: xaguilar-amd <xavier.aguilarfruto@amd.com> Co-authored-by: Iceber Gu <caiwei95@hotmail.com> Co-authored-by: Tao He <linzhu.ht@alibaba-inc.com> Co-authored-by: Icey <1790571317@qq.com> Co-authored-by: Xu Wenqing <121550081+Xu-Wenqing@users.noreply.github.com> Co-authored-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: RishiAstra <40644327+RishiAstra@users.noreply.github.com>
302 lines
11 KiB
Python
302 lines
11 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import math
|
|
from typing import Optional
|
|
|
|
import torch
|
|
|
|
from tests.v1.attention.utils import (_Backend, create_standard_kv_cache_spec,
|
|
create_vllm_config,
|
|
get_attention_backend)
|
|
from vllm.config import ParallelConfig, SpeculativeConfig
|
|
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
|
|
|
|
|
class MockAttentionLayer(torch.nn.Module):
|
|
_q_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
|
|
_k_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
|
|
_v_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return x
|
|
|
|
|
|
def forward_attention(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
kv_cache: torch.Tensor,
|
|
block_table: torch.Tensor,
|
|
slot_mapping: torch.Tensor,
|
|
seqlen_k: int,
|
|
backend: _Backend,
|
|
spec_token_tree: Optional[str] = None,
|
|
num_spec_tokens: int = 0,
|
|
) -> torch.Tensor:
|
|
batch_size, q_len, num_heads, dim_per_head = q.shape
|
|
num_kv_heads = k.shape[-2]
|
|
# Initialize the query and KV sequence lengths.
|
|
query_start_loc = q_len * torch.arange(
|
|
batch_size + 1, device=q.device, dtype=torch.int32)
|
|
query_lens = torch.diff(query_start_loc)
|
|
seq_lens = torch.full(
|
|
(batch_size, ),
|
|
seqlen_k,
|
|
device=q.device,
|
|
dtype=torch.int32,
|
|
)
|
|
context_lens = seq_lens - query_lens
|
|
max_seq_len = int(seq_lens.max())
|
|
max_query_len = q_len
|
|
num_actual_tokens = query_start_loc[-1]
|
|
|
|
softmax_scale = q.shape[-1]**(-0.5)
|
|
layer = MockAttentionLayer()
|
|
|
|
# Build common metadata.
|
|
model_name = "meta-llama/Meta-Llama-3-8B"
|
|
builder_cls, impl_cls = get_attention_backend(backend)
|
|
vllm_config = create_vllm_config(model_name=model_name,
|
|
max_model_len=max(seq_lens))
|
|
if spec_token_tree is not None:
|
|
# Create speculative config if token tree is specified.
|
|
vllm_config.speculative_config = SpeculativeConfig(
|
|
target_model_config=vllm_config.model_config,
|
|
target_parallel_config=ParallelConfig(),
|
|
model=model_name,
|
|
method="eagle",
|
|
num_speculative_tokens=num_spec_tokens,
|
|
speculative_token_tree=spec_token_tree)
|
|
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
|
|
builder = builder_cls(kv_cache_spec, [], vllm_config, q.device)
|
|
common_attn_metadata = CommonAttentionMetadata(
|
|
query_start_loc=query_start_loc,
|
|
query_start_loc_cpu=query_start_loc.cpu(),
|
|
seq_lens=seq_lens,
|
|
seq_lens_cpu=seq_lens.cpu(),
|
|
num_computed_tokens_cpu=context_lens.cpu(),
|
|
num_reqs=batch_size,
|
|
num_actual_tokens=num_actual_tokens,
|
|
max_query_len=max_query_len,
|
|
max_seq_len=max_seq_len,
|
|
block_table_tensor=block_table,
|
|
slot_mapping=slot_mapping,
|
|
)
|
|
|
|
# Build attention metadata.
|
|
attn_metadata = builder.build(
|
|
common_prefix_len=0,
|
|
common_attn_metadata=common_attn_metadata,
|
|
)
|
|
|
|
# Initialize the backend implementation.
|
|
instance = impl_cls(
|
|
num_heads=num_heads,
|
|
head_size=dim_per_head,
|
|
scale=softmax_scale,
|
|
num_kv_heads=num_kv_heads,
|
|
alibi_slopes=None,
|
|
sliding_window=None,
|
|
kv_cache_dtype="auto",
|
|
)
|
|
|
|
# Run forward pass and return output.
|
|
query = q.view(-1, num_heads, dim_per_head)
|
|
key = k.view(-1, num_kv_heads, dim_per_head)
|
|
value = v.view(-1, num_kv_heads, dim_per_head)
|
|
output = torch.empty_like(query)
|
|
return instance.forward(
|
|
layer=layer,
|
|
query=query,
|
|
key=key,
|
|
value=value,
|
|
kv_cache=kv_cache.clone(),
|
|
attn_metadata=attn_metadata,
|
|
output=output,
|
|
)
|
|
|
|
|
|
def test_tree_attn_correctness() -> None:
|
|
torch.manual_seed(42)
|
|
torch.cuda.manual_seed_all(42)
|
|
|
|
device = "cuda"
|
|
tree_attn_masks = {
|
|
# Chain.
|
|
"[(0,), (0, 0), (0, 0, 0)]":
|
|
torch.tensor(
|
|
[
|
|
[1, 0, 0, 0],
|
|
[1, 1, 0, 0],
|
|
[1, 1, 1, 0],
|
|
[1, 1, 1, 1],
|
|
],
|
|
device=device,
|
|
dtype=torch.int32,
|
|
),
|
|
# Tree.
|
|
"[(0,), (1,), (0, 0), (0, 1), (1, 0), (1, 1)]":
|
|
torch.tensor(
|
|
[
|
|
[1, 0, 0, 0, 0, 0, 0],
|
|
[1, 1, 0, 0, 0, 0, 0],
|
|
[1, 0, 1, 0, 0, 0, 0],
|
|
[1, 1, 0, 1, 0, 0, 0],
|
|
[1, 1, 0, 0, 1, 0, 0],
|
|
[1, 0, 1, 0, 0, 1, 0],
|
|
[1, 0, 1, 0, 0, 0, 1],
|
|
],
|
|
device=device,
|
|
dtype=torch.int32,
|
|
),
|
|
}
|
|
|
|
dim_per_head = 128
|
|
num_kv_heads = 2
|
|
block_size = 32
|
|
max_sequence_length = 8192
|
|
randomize_blocks = True
|
|
for batch_size in [1, 16, 32]:
|
|
for num_heads in [2, 4]:
|
|
for sequence_position in [16, 1024, 2048]:
|
|
for spec_token_tree, tree_attn_mask in tree_attn_masks.items():
|
|
# Assert that the number of heads is divisible
|
|
# by the number of KV heads.
|
|
assert num_heads % num_kv_heads == 0
|
|
|
|
# Initialize q, k, and v.
|
|
tree_size_q = tree_attn_mask.shape[0]
|
|
seqlen_k = sequence_position + tree_size_q
|
|
q = torch.randn(
|
|
(batch_size, tree_size_q, num_heads, dim_per_head),
|
|
device=device,
|
|
dtype=torch.bfloat16,
|
|
)
|
|
k = torch.randn(
|
|
(batch_size, tree_size_q, num_kv_heads, dim_per_head),
|
|
device=device,
|
|
dtype=torch.bfloat16,
|
|
)
|
|
v = torch.randn(
|
|
(batch_size, tree_size_q, num_kv_heads, dim_per_head),
|
|
device=device,
|
|
dtype=torch.bfloat16,
|
|
)
|
|
|
|
# Set up the block table and KV cache for paged KV.
|
|
assert max_sequence_length % block_size == 0
|
|
max_blocks_per_batch = max_sequence_length // block_size
|
|
kv_cache = torch.randn(
|
|
(
|
|
2,
|
|
batch_size * max_blocks_per_batch,
|
|
block_size,
|
|
num_kv_heads,
|
|
dim_per_head,
|
|
),
|
|
device=q.device,
|
|
dtype=torch.bfloat16,
|
|
)
|
|
num_alloc_blocks_per_batch = math.ceil(seqlen_k /
|
|
block_size)
|
|
block_table = torch.zeros(
|
|
(batch_size, max_blocks_per_batch),
|
|
device=q.device,
|
|
dtype=torch.int32,
|
|
)
|
|
block_ids = torch.arange(
|
|
0,
|
|
batch_size * num_alloc_blocks_per_batch,
|
|
device=q.device,
|
|
dtype=torch.int32,
|
|
)
|
|
if randomize_blocks:
|
|
# Randomize the block ids.
|
|
block_ids = block_ids[torch.randperm(
|
|
block_ids.numel())]
|
|
block_table[:, :
|
|
num_alloc_blocks_per_batch] = block_ids.view(
|
|
-1, num_alloc_blocks_per_batch)
|
|
|
|
# Set up the slot mapping for the input KVs.
|
|
tree_positions = sequence_position + torch.arange(
|
|
0,
|
|
tree_size_q,
|
|
device=q.device,
|
|
dtype=torch.int64,
|
|
).repeat(batch_size, 1)
|
|
tree_slot_mapping = _gen_slot_mapping(
|
|
tree_positions, block_table, block_size)
|
|
|
|
# Compute attention for the tree.
|
|
tree_attn_output = forward_attention(
|
|
q=q,
|
|
k=k,
|
|
v=v,
|
|
kv_cache=kv_cache,
|
|
block_table=block_table,
|
|
slot_mapping=tree_slot_mapping,
|
|
seqlen_k=seqlen_k,
|
|
backend=_Backend.TREE_ATTN,
|
|
spec_token_tree=spec_token_tree,
|
|
num_spec_tokens=tree_size_q - 1,
|
|
).view(batch_size, -1, num_heads, dim_per_head)
|
|
|
|
# Verify that the chain attention output for each
|
|
# branch of the tree (computed using FA3) matches
|
|
# the tree attention output.
|
|
for q_index in range(tree_size_q):
|
|
# Get the q, k, and v for the branch.
|
|
branch_mask = tree_attn_mask[q_index, :]
|
|
branch_indices = torch.nonzero(branch_mask,
|
|
as_tuple=True)[0]
|
|
q_len = branch_indices.shape[0]
|
|
q_branch = q[:, branch_indices]
|
|
k_branch = k[:, branch_indices]
|
|
v_branch = v[:, branch_indices]
|
|
|
|
# Setup slot mapping for the branch.
|
|
branch_positions = sequence_position + torch.arange(
|
|
0,
|
|
q_len,
|
|
device=q.device,
|
|
dtype=torch.int64,
|
|
).repeat(batch_size, 1)
|
|
branch_slot_mapping = _gen_slot_mapping(
|
|
branch_positions, block_table, block_size)
|
|
|
|
# Compute flash attention for the branch.
|
|
flash_attn_output = forward_attention(
|
|
q=q_branch,
|
|
k=k_branch,
|
|
v=v_branch,
|
|
kv_cache=kv_cache,
|
|
block_table=block_table,
|
|
slot_mapping=branch_slot_mapping,
|
|
seqlen_k=sequence_position + q_len,
|
|
backend=_Backend.FLASH_ATTN,
|
|
).view(batch_size, -1, num_heads, dim_per_head)
|
|
|
|
# Compare the outputs.
|
|
assert torch.allclose(
|
|
tree_attn_output[:, branch_indices],
|
|
flash_attn_output,
|
|
atol=7.81e-3,
|
|
), (f"outputs are not close for "
|
|
f"batch_size: {batch_size}, "
|
|
f"num_heads: {num_heads}, "
|
|
f"sequence_position: {sequence_position}, "
|
|
f"tree_attn_mask: {tree_attn_mask}, "
|
|
f"q_index: {q_index}.")
|
|
|
|
|
|
def _gen_slot_mapping(positions: torch.Tensor, block_table: torch.Tensor,
|
|
block_size: int):
|
|
block_indices = positions // block_size
|
|
blocks = block_table.gather(dim=1, index=block_indices)
|
|
return (blocks * block_size + positions % block_size).view(-1)
|