mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-04 04:31:23 +08:00
Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: Lucas Kabela <lucaskabela@meta.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: Andrew Sansom <andrew@protopia.ai> Signed-off-by: Boyuan Feng <boyuan@meta.com> Signed-off-by: Boyuan Feng <fby.1994@gmail.com> Signed-off-by: boyuanfeng <boyuan@meta.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: JartX <sagformas@epdcenter.es> Signed-off-by: Chendi Xue <Chendi.Xue@intel.com> Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Signed-off-by: Roger Wang <hey@rogerw.io> Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: wwl2755 <wangwenlong2755@gmail.com> Signed-off-by: Manoel Marques <manoel.marques@ibm.com> Signed-off-by: Manoel Marques <manoelmrqs@gmail.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: pengdrumli <pengdrumli@tencent.com> Signed-off-by: windsonsea <haifeng.yao@daocloud.io> Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai> Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: Huamin Li <3ericli@gmail.com> Signed-off-by: simondanielsson <simon.danielsson99@hotmail.com> Signed-off-by: Rahul Tuli <rtuli@redhat.com> Signed-off-by: Yang <lymailforjob@gmail.com> Signed-off-by: Debolina Roy <debroy@redhat.com> Signed-off-by: David Chen <530634352@qq.com> Signed-off-by: wangzi <3220100013@zju.edu.cn> Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com> Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com> Signed-off-by: Sara Kokkila Schumacher <saraks@ibm.com> Signed-off-by: Csrayz <jover@cmbchina.com> Signed-off-by: ivyilike <pww123@cmbchina.com> Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com> Signed-off-by: Bowen Wang <abmfy@icloud.com> Signed-off-by: qqma <qqma@amazon.com> Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: Lu Fang <fanglu@fb.com> Signed-off-by: Zhuohan Li <zhuohan123@gmail.com> Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: luka <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: Or Ozeri <oro@il.ibm.com> Signed-off-by: Johnny Yang <johnnyyang@google.com> Signed-off-by: Alec Solder <alecs@fb.com> Signed-off-by: Alec S <10566873+alecsolder@users.noreply.github.com> Signed-off-by: Russell Bryant <rbryant@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Alexander Matveev <amatveev@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: liuye.hj <liuye.hj@alibaba-inc.com> Signed-off-by: Kunshang Ji <kunshang.ji@intel.com> Signed-off-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Signed-off-by: Ming Yang <minos.future@gmail.com> Signed-off-by: Zhikaiiii <1658973216@qq.com> Signed-off-by: Andreas Hartel <andreas.hartel@aleph-alpha.com> Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Signed-off-by: wuxibin <wuxibin@bytedance.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: Peter Pan <Peter.Pan@daocloud.io> Signed-off-by: Peter Pan <peter.pan@daocloud.io> Signed-off-by: Nicolò Lucchesi<nicolo.lucchesi@gmail.com> Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Signed-off-by: Sage Moore <sage@neuralmagic.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Signed-off-by: Bill Nell <bnell@redhat.com> Signed-off-by: Shreeasish Kumar <shreeasish@rivosinc.com> Signed-off-by: Weida Hong <wdhongtw@google.com> Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com> Signed-off-by: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com> Signed-off-by: Amir Samani <asamani@nvidia.com> Signed-off-by: ElizaWszola <elizaw.9289@gmail.com> Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Signed-off-by: ilmarkov <markovilya197@gmail.com> Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Signed-off-by: rouchenzi <ruochenwen@gmail.com> Signed-off-by: rouchenzi <40842833+rouchenzi@users.noreply.github.com> Signed-off-by: Andrew Xia <axia@meta.com> Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com> Signed-off-by: Corey Lowman <clowman1993@gmail.com> Signed-off-by: jpvillam <jpvillam@amd.com> Signed-off-by: dougbtv <dosmith@redhat.com> Signed-off-by: Chenxi Yang <cxyang@fb.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: ahao-anyscale <ahao@anyscale.com> Signed-off-by: Yan Lu <luyan@nvidia.com> Signed-off-by: baxingpiaochong <771405853@qq.com> Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> Signed-off-by: Nikhil Gupta <nikhil.gupta2@arm.com> Signed-off-by: Yong Hoon Shin <yhshin@meta.com> Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai> Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Signed-off-by: Ben Browning <bbrownin@redhat.com> Signed-off-by: Chengji Yao <chengjiyao@google.com> Signed-off-by: jiang1.li <jiang1.li@intel.com> Signed-off-by: Jackmin801 <ongjackm@gmail.com> Signed-off-by: Jonas M. Kübler <44084297+jmkuebler@users.noreply.github.com> Signed-off-by: taohui <taohui3@gmail.com> Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io> Signed-off-by: Shu Wang <shuw@nvidia.com> Signed-off-by: Shu Wang. <shuw@nvidia.com> Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Signed-off-by: Duncan Moss <djm.moss@gmail.com> Signed-off-by: Shiyan Deng <dsy842974287@meta.com> Signed-off-by: Wei Wei <wwei6@meta.com> Signed-off-by: Saman Keon <samanamp@outlook.com> Signed-off-by: yangxurui <yangxurui@meituan.com> Signed-off-by: nicole-lihui <nicole.li@daocloud.io> Signed-off-by: courage17340 <courage17340@163.com> Signed-off-by: Jacob Kahn <jacobkahn1@gmail.com> Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com> Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai> Signed-off-by: zxw <1020938856@qq.com> Signed-off-by: wang.yuqi <noooop@126.com> Signed-off-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Signed-off-by: chenlang <chen.lang5@zte.com.cn> Signed-off-by: Jonas Kuebler <kuebj@amazon.com> Signed-off-by: AlonKejzman <alonkeizman@gmail.com> Signed-off-by: Tao Hui <taohui3@gmail.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Signed-off-by: Eugene Khvedchenia <ekhvedchenia@nvidia.com> Signed-off-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Signed-off-by: yiting.jiang <yiting.jiang@daocloud.io> Signed-off-by: xaguilar <Xavier.AguilarFruto@amd.com> Signed-off-by: Iceber Gu <caiwei95@hotmail.com> Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com> Signed-off-by: Icey <1790571317@qq.com> Signed-off-by: 许文卿 <xwq391974@alibaba-inc.com> Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: Nick Hill <nhill@redhat.com> Co-authored-by: Lucas Kabela <lucasakabela@gmail.com> Co-authored-by: Maximilien de Bayser <mbayser@br.ibm.com> Co-authored-by: Andrew Sansom <andrew@protopia.ai> Co-authored-by: Boyuan Feng <boyuan@meta.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: JartX <sagformas@epdcenter.es> Co-authored-by: Chendi.Xue <chendi.xue@intel.com> Co-authored-by: Chauncey <chaunceyjiang@gmail.com> Co-authored-by: xin.li <xin.li@daocloud.io> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Roger Wang <hey@rogerw.io> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Wenlong Wang <wangwenlong2755@gmail.com> Co-authored-by: Manoel Marques <manoelmrqs@gmail.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: lirong <56789630+lirong-lirong@users.noreply.github.com> Co-authored-by: Michael Yao <haifeng.yao@daocloud.io> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: Huamin Li <3ericli@gmail.com> Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com> Co-authored-by: Simon Danielsson <70206058+simondanielsson@users.noreply.github.com> Co-authored-by: Rahul Tuli <rtuli@redhat.com> Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: Yang Liu <127183760+KKSK-DON@users.noreply.github.com> Co-authored-by: Deboleina <debroy@redhat.com> Co-authored-by: yinz-aizip <yinz@aizip.ai> Co-authored-by: WeiQing Chen <40507679+david6666666@users.noreply.github.com> Co-authored-by: wangzi <3220100013@zju.edu.cn> Co-authored-by: Eldar Kurtić <8884008+eldarkurtic@users.noreply.github.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com> Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com> Co-authored-by: Yizhou <136800916+yiz-liu@users.noreply.github.com> Co-authored-by: Sara-KS <50249410+Sara-KS@users.noreply.github.com> Co-authored-by: Csrayz <jover@cmbchina.com> Co-authored-by: ivyilike <pww123@cmbchina.com> Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com> Co-authored-by: Bowen Wang <abmfy@icloud.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Daisy-Ma-coder <daisy.ma.0117@gmail.com> Co-authored-by: qqma <qqma@amazon.com> Co-authored-by: ElizaWszola <ewszola@redhat.com> Co-authored-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Co-authored-by: Zhuohan Li <zhuohan123@gmail.com> Co-authored-by: Simon Mo <simon.mo@hey.com> Co-authored-by: Or Ozeri <oro@il.ibm.com> Co-authored-by: Johnny Yang <24908445+jcyang43@users.noreply.github.com> Co-authored-by: Chengji Yao <chengjiyao@google.com> Co-authored-by: Alec S <10566873+alecsolder@users.noreply.github.com> Co-authored-by: Alec Solder <alecs@fb.com> Co-authored-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Chris Bamford <chrisbam4d@gmail.com> Co-authored-by: Alexander Matveev <59768536+alexm-redhat@users.noreply.github.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: JJJYmmm <92386084+JJJYmmm@users.noreply.github.com> Co-authored-by: liuye.hj <liuye.hj@alibaba-inc.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com> Co-authored-by: Lucia (Lu) Fang <fanglu@meta.com> Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Ming Yang <yming@meta.com> Co-authored-by: Zhikaiiii <55917203+Zhikaiiii@users.noreply.github.com> Co-authored-by: Andreas Hartel <andreas@hartel.me> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: Joel <wuxibin89@163.com> Co-authored-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Mark McLoughlin <markmc@redhat.com> Co-authored-by: Peter Pan <peter.pan@daocloud.io> Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com> Co-authored-by: Fanli Lin <fanli.lin@intel.com> Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Sage Moore <sage@neuralmagic.com> Co-authored-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: bnellnm <49004751+bnellnm@users.noreply.github.com> Co-authored-by: rivos-shreeasish <shreeasish@rivosinc.com> Co-authored-by: Chih-Chieh Yang <chih.chieh.yang@ibm.com> Co-authored-by: Weida Hong <wdhongtw@gmail.com> Co-authored-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Co-authored-by: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com> Co-authored-by: Amir Samani <samani@ualberta.ca> Co-authored-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Co-authored-by: Ilya Markov <markovilya197@gmail.com> Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Co-authored-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Co-authored-by: rouchenzi <40842833+rouchenzi@users.noreply.github.com> Co-authored-by: Andrew Xia <axia@meta.com> Co-authored-by: kourosh hakhamaneshi <31483498+kouroshHakha@users.noreply.github.com> Co-authored-by: Corey Lowman <clowman1993@gmail.com> Co-authored-by: Juan Villamizar <100237675+jpvillam-amd@users.noreply.github.com> Co-authored-by: jpvillam <jpvillam@amd.com> Co-authored-by: Doug Smith <dosmith@redhat.com> Co-authored-by: Chenxi Yang <cxyang@cs.utexas.edu> Co-authored-by: Chenxi Yang <cxyang@fb.com> Co-authored-by: ahao-anyscale <ahao@anyscale.com> Co-authored-by: 0xNullPath <luyanfcp@foxmail.com> Co-authored-by: baxingpiaochong <771405853@qq.com> Co-authored-by: Benjamin Chislett <bchislett@nvidia.com> Co-authored-by: Kyle Sayers <kylesayrs@gmail.com> Co-authored-by: Nikhil Gupta <nikhil.gupta2@arm.com> Co-authored-by: Yong Hoon Shin <48474650+sarckk@users.noreply.github.com> Co-authored-by: lhsjohn <huashuoli@tencent.com> Co-authored-by: Ben Browning <bbrownin@redhat.com> Co-authored-by: Li, Jiang <jiang1.li@intel.com> Co-authored-by: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> Co-authored-by: Jonas M. Kübler <44084297+jmkuebler@users.noreply.github.com> Co-authored-by: Tao Hui <taohui3@gmail.com> Co-authored-by: rongfu.leng <rongfu.leng@daocloud.io> Co-authored-by: Shu Wang <shuw@nvidia.com> Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Co-authored-by: Duncan Moss <djm.moss@gmail.com> Co-authored-by: Shiyan Deng <dsy842974287@meta.com> Co-authored-by: Wei Wei <wwei6@meta.com> Co-authored-by: Saman A. Pour <samanamp@outlook.com> Co-authored-by: XuruiYang <530534756@qq.com> Co-authored-by: yangxurui <yangxurui@meituan.com> Co-authored-by: Nicole LiHui 🥜 <nicolelihui@outlook.com> Co-authored-by: courage17340 <courage17340@users.noreply.github.com> Co-authored-by: Jacob Kahn <jacobkahn1@gmail.com> Co-authored-by: Nicole LiHui 🥜 <nicole.li@daocloud.io> Co-authored-by: Fadi Arafeh <115173828+fadara01@users.noreply.github.com> Co-authored-by: Agata Dobrzyniewicz <160237065+adobrzyn@users.noreply.github.com> Co-authored-by: yyzxw <34639446+yyzxw@users.noreply.github.com> Co-authored-by: wang.yuqi <noooop@126.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: chenlang <chen.lang5@zte.com.cn> Co-authored-by: chenlang <10346245@zte.com.cn> Co-authored-by: AlonKejzman <alonkeizman@gmail.com> Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> Co-authored-by: Aleksandr Malyshev <164964928+maleksan85@users.noreply.github.com> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Doug Lehr <douglehr@amd.com> Co-authored-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Co-authored-by: yitingdc <59356937+yitingdc@users.noreply.github.com> Co-authored-by: xaguilar-amd <xavier.aguilarfruto@amd.com> Co-authored-by: Iceber Gu <caiwei95@hotmail.com> Co-authored-by: Tao He <linzhu.ht@alibaba-inc.com> Co-authored-by: Icey <1790571317@qq.com> Co-authored-by: Xu Wenqing <121550081+Xu-Wenqing@users.noreply.github.com> Co-authored-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: RishiAstra <40644327+RishiAstra@users.noreply.github.com>
598 lines
21 KiB
Python
598 lines
21 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||
|
||
import pytest
|
||
|
||
from vllm.attention.layer import Attention
|
||
from vllm.config import (CacheConfig, ModelConfig, SchedulerConfig, VllmConfig,
|
||
set_current_vllm_config)
|
||
from vllm.pooling_params import PoolingParams
|
||
from vllm.sampling_params import SamplingParams
|
||
from vllm.utils import GiB_bytes
|
||
from vllm.v1.core.kv_cache_utils import (estimate_max_model_len,
|
||
get_kv_cache_configs)
|
||
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
|
||
SchedulerOutput)
|
||
from vllm.v1.worker.tpu_model_runner import (
|
||
TPUModelRunner, _get_padded_num_reqs_with_upper_limit,
|
||
_get_padded_token_len, _get_req_paddings, _get_token_paddings)
|
||
|
||
|
||
def get_vllm_config():
|
||
scheduler_config = SchedulerConfig(
|
||
max_num_seqs=10,
|
||
max_num_batched_tokens=512,
|
||
max_model_len=512,
|
||
)
|
||
model_config = ModelConfig(
|
||
model="facebook/opt-125m",
|
||
dtype="bfloat16", # TPUs typically use bfloat16
|
||
seed=42,
|
||
)
|
||
cache_config = CacheConfig(
|
||
block_size=16,
|
||
gpu_memory_utilization=0.9,
|
||
swap_space=0,
|
||
cache_dtype="auto",
|
||
)
|
||
vllm_config = VllmConfig(
|
||
model_config=model_config,
|
||
cache_config=cache_config,
|
||
scheduler_config=scheduler_config,
|
||
)
|
||
return vllm_config
|
||
|
||
|
||
def get_model_runner(vllm_config):
|
||
device = "xla:0" # Mocking TPU device
|
||
return TPUModelRunner(vllm_config, device)
|
||
|
||
|
||
@pytest.fixture
|
||
def model_runner():
|
||
# Patchers have already been started at module level.
|
||
vllm_config = get_vllm_config()
|
||
return get_model_runner(vllm_config)
|
||
|
||
|
||
def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
|
||
new_reqs = []
|
||
num_scheduled_tokens = {}
|
||
total_num_scheduled_tokens = 0
|
||
for req_id in req_ids:
|
||
new_reqs.append(
|
||
NewRequestData(
|
||
req_id=req_id,
|
||
prompt_token_ids=[1, 2, 3],
|
||
mm_features=[],
|
||
sampling_params=SamplingParams(),
|
||
pooling_params=PoolingParams(),
|
||
block_ids=([0], ), # block_ids should be tuple[list[int]]
|
||
num_computed_tokens=0,
|
||
lora_request=None,
|
||
))
|
||
num_scheduled_tokens[req_id] = 3
|
||
total_num_scheduled_tokens += num_scheduled_tokens[req_id]
|
||
|
||
return SchedulerOutput(
|
||
scheduled_new_reqs=new_reqs,
|
||
scheduled_cached_reqs=CachedRequestData.make_empty(),
|
||
num_scheduled_tokens=num_scheduled_tokens,
|
||
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
||
scheduled_spec_decode_tokens={},
|
||
scheduled_encoder_inputs={},
|
||
num_common_prefix_blocks=0,
|
||
finished_req_ids=set(),
|
||
free_encoder_mm_hashes=[],
|
||
structured_output_request_ids={},
|
||
grammar_bitmask=None,
|
||
)
|
||
|
||
|
||
def _is_req_scheduled(model_runner, req_id: str) -> bool:
|
||
return req_id in model_runner.input_batch.req_id_to_index
|
||
|
||
|
||
def _is_req_added(model_runner, req_id: str) -> bool:
|
||
return req_id in model_runner.requests
|
||
|
||
|
||
def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
|
||
"""Check if the request state block IDs match the block table.
|
||
|
||
This function handles both legacy BlockTable and new MultiGroupBlockTable
|
||
structures for backward compatibility.
|
||
"""
|
||
|
||
req_index = model_runner.input_batch.req_id_to_index[req_id]
|
||
multi_group_block_table = model_runner.input_batch.block_table
|
||
req_state = model_runner.requests[req_id]
|
||
|
||
# Access the first block table from MultiGroupBlockTable
|
||
# This is safe since we currently only use single KV cache groups
|
||
block_table = multi_group_block_table[0]
|
||
|
||
# req_state.block_ids is now tuple[list[int], ...] for MultiGroupBlockTable
|
||
# Extract the first group's block IDs
|
||
if isinstance(req_state.block_ids[0], list):
|
||
# New format: tuple[list[int], ...] - extract first group
|
||
req_block_ids = req_state.block_ids[0]
|
||
else:
|
||
# Legacy format: list[int] - use directly
|
||
req_block_ids = req_state.block_ids
|
||
|
||
if block_table.num_blocks_per_row[req_index] != len(req_block_ids):
|
||
return False
|
||
|
||
num_blocks = block_table.num_blocks_per_row[req_index]
|
||
block_table_values = block_table.block_table.np[req_index, :num_blocks]
|
||
return (block_table_values == req_block_ids).all()
|
||
|
||
|
||
def test_update_states_new_request(model_runner):
|
||
req_id = "req_0"
|
||
|
||
# new req
|
||
scheduler_output = _schedule_new_request(req_id)
|
||
|
||
model_runner._update_states(scheduler_output)
|
||
|
||
assert _is_req_added(model_runner, req_id)
|
||
assert _is_req_scheduled(model_runner, req_id)
|
||
assert _is_req_state_block_table_match(model_runner, req_id)
|
||
|
||
|
||
def test_update_states_request_finished(model_runner):
|
||
req_id = "req_0"
|
||
|
||
# new req
|
||
scheduler_output = _schedule_new_request(req_id)
|
||
|
||
model_runner._update_states(scheduler_output)
|
||
assert _is_req_added(model_runner, req_id)
|
||
assert _is_req_scheduled(model_runner, req_id)
|
||
|
||
# finish req
|
||
scheduler_output = SchedulerOutput(
|
||
scheduled_new_reqs=[],
|
||
scheduled_cached_reqs=CachedRequestData.make_empty(),
|
||
num_scheduled_tokens={},
|
||
total_num_scheduled_tokens=0,
|
||
scheduled_spec_decode_tokens={},
|
||
scheduled_encoder_inputs={},
|
||
num_common_prefix_blocks=0,
|
||
finished_req_ids={req_id},
|
||
free_encoder_mm_hashes=[],
|
||
structured_output_request_ids={},
|
||
grammar_bitmask=None,
|
||
)
|
||
|
||
model_runner._update_states(scheduler_output)
|
||
assert not _is_req_added(model_runner, req_id)
|
||
assert not _is_req_scheduled(model_runner, req_id)
|
||
|
||
|
||
def test_update_states_request_resumed(model_runner):
|
||
req_id = "req_0"
|
||
|
||
# new req
|
||
scheduler_output = _schedule_new_request(req_id)
|
||
|
||
model_runner._update_states(scheduler_output)
|
||
assert _is_req_added(model_runner, req_id)
|
||
assert _is_req_scheduled(model_runner, req_id)
|
||
|
||
# unschedule req
|
||
scheduler_output = SchedulerOutput(
|
||
scheduled_new_reqs=[],
|
||
scheduled_cached_reqs=CachedRequestData.make_empty(),
|
||
num_scheduled_tokens={},
|
||
total_num_scheduled_tokens=0,
|
||
scheduled_spec_decode_tokens={},
|
||
scheduled_encoder_inputs={},
|
||
num_common_prefix_blocks=0,
|
||
finished_req_ids=set(),
|
||
free_encoder_mm_hashes=[],
|
||
structured_output_request_ids={},
|
||
grammar_bitmask=None,
|
||
)
|
||
|
||
model_runner._update_states(scheduler_output)
|
||
assert _is_req_added(model_runner, req_id)
|
||
assert not _is_req_scheduled(model_runner, req_id)
|
||
|
||
# resume req
|
||
cached_req_data = CachedRequestData(
|
||
req_ids=[req_id],
|
||
resumed_from_preemption=[False],
|
||
new_token_ids=[[]],
|
||
new_block_ids=[([], )],
|
||
num_computed_tokens=[0],
|
||
)
|
||
|
||
scheduler_output = SchedulerOutput(
|
||
scheduled_new_reqs=[],
|
||
scheduled_cached_reqs=cached_req_data,
|
||
num_scheduled_tokens={req_id: 1},
|
||
total_num_scheduled_tokens=1,
|
||
scheduled_spec_decode_tokens={},
|
||
scheduled_encoder_inputs={},
|
||
num_common_prefix_blocks=0,
|
||
finished_req_ids=set(),
|
||
free_encoder_mm_hashes=[],
|
||
structured_output_request_ids={},
|
||
grammar_bitmask=None,
|
||
)
|
||
|
||
model_runner._update_states(scheduler_output)
|
||
assert _is_req_added(model_runner, req_id)
|
||
assert _is_req_scheduled(model_runner, req_id)
|
||
assert _is_req_state_block_table_match(model_runner, req_id)
|
||
|
||
|
||
def test_update_states_no_changes(model_runner):
|
||
req_id = "req_0"
|
||
|
||
# new req
|
||
scheduler_output = _schedule_new_request(req_id)
|
||
|
||
model_runner._update_states(scheduler_output)
|
||
assert _is_req_added(model_runner, req_id)
|
||
assert _is_req_scheduled(model_runner, req_id)
|
||
|
||
# schedule req
|
||
scheduler_output = SchedulerOutput(
|
||
scheduled_new_reqs=[],
|
||
scheduled_cached_reqs=CachedRequestData.make_empty(),
|
||
num_scheduled_tokens={req_id: 1},
|
||
total_num_scheduled_tokens=1,
|
||
scheduled_spec_decode_tokens={},
|
||
scheduled_encoder_inputs={},
|
||
num_common_prefix_blocks=0,
|
||
finished_req_ids=set(),
|
||
free_encoder_mm_hashes=[],
|
||
structured_output_request_ids={},
|
||
grammar_bitmask=None,
|
||
)
|
||
|
||
model_runner._update_states(scheduler_output)
|
||
assert _is_req_added(model_runner, req_id)
|
||
assert _is_req_scheduled(model_runner, req_id)
|
||
assert _is_req_state_block_table_match(model_runner, req_id)
|
||
|
||
|
||
def test_update_states_request_unscheduled(model_runner):
|
||
req_ids = ("req_0", "req_1")
|
||
|
||
# new reqs
|
||
scheduler_output = _schedule_new_request(*req_ids)
|
||
|
||
model_runner._update_states(scheduler_output)
|
||
|
||
assert _is_req_added(model_runner, req_ids[0])
|
||
assert _is_req_scheduled(model_runner, req_ids[0])
|
||
|
||
assert _is_req_added(model_runner, req_ids[1])
|
||
assert _is_req_scheduled(model_runner, req_ids[1])
|
||
|
||
# unschedule req_1
|
||
scheduler_output = SchedulerOutput(
|
||
scheduled_new_reqs=[],
|
||
scheduled_cached_reqs=CachedRequestData.make_empty(),
|
||
num_scheduled_tokens={req_ids[0]: 1},
|
||
total_num_scheduled_tokens=1,
|
||
scheduled_spec_decode_tokens={},
|
||
scheduled_encoder_inputs={},
|
||
num_common_prefix_blocks=0,
|
||
finished_req_ids=set(),
|
||
free_encoder_mm_hashes=[],
|
||
structured_output_request_ids={},
|
||
grammar_bitmask=None,
|
||
)
|
||
|
||
model_runner._update_states(scheduler_output)
|
||
|
||
assert _is_req_added(model_runner, req_ids[0])
|
||
assert _is_req_scheduled(model_runner, req_ids[0])
|
||
|
||
assert _is_req_added(model_runner, req_ids[1])
|
||
assert not _is_req_scheduled(model_runner, req_ids[1])
|
||
|
||
|
||
def test_get_paddings():
|
||
# Bucketed padding
|
||
min_token_size, max_token_size, padding_gap = 16, 512, 64
|
||
expected_paddings = [16, 32, 64, 128, 192, 256, 320, 384, 448, 512]
|
||
actual_paddings = _get_token_paddings(min_token_size, max_token_size,
|
||
padding_gap)
|
||
|
||
# Bucketed padding with max_token_size not a power of two.
|
||
max_token_size = 317
|
||
expected_paddings = [16, 32, 64, 128, 192, 256, 320]
|
||
actual_paddings = _get_token_paddings(min_token_size, max_token_size,
|
||
padding_gap)
|
||
assert actual_paddings == expected_paddings
|
||
|
||
# Exponential padding.
|
||
max_token_size, padding_gap = 1024, 0
|
||
expected_paddings = [16, 32, 64, 128, 256, 512, 1024]
|
||
actual_paddings = _get_token_paddings(min_token_size, max_token_size,
|
||
padding_gap)
|
||
assert actual_paddings == expected_paddings
|
||
# Exponential padding with max_token_size not a power of two.
|
||
max_token_size = 317
|
||
expected_paddings = [16, 32, 64, 128, 256, 512]
|
||
actual_paddings = _get_token_paddings(min_token_size, max_token_size,
|
||
padding_gap)
|
||
assert actual_paddings == expected_paddings
|
||
|
||
|
||
def test_get_padded_token_len():
|
||
min_token_size, max_token_size, padding_gap = 16, 512, 64
|
||
paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap)
|
||
assert _get_padded_token_len(paddings, 1) == 16
|
||
assert _get_padded_token_len(paddings, 16) == 16
|
||
assert _get_padded_token_len(paddings, 20) == 32
|
||
assert _get_padded_token_len(paddings, 300) == 320
|
||
assert _get_padded_token_len(paddings, 512) == 512
|
||
|
||
|
||
def test_get_padded_num_reqs_with_upper_limit():
|
||
assert _get_padded_num_reqs_with_upper_limit(3, 32) == 8
|
||
assert _get_padded_num_reqs_with_upper_limit(9, 32) == 16
|
||
assert _get_padded_num_reqs_with_upper_limit(19, 32) == 32
|
||
assert _get_padded_num_reqs_with_upper_limit(17, 28) == 28
|
||
|
||
|
||
def test_get_req_paddings():
|
||
assert _get_req_paddings(1, 32) == [8, 16, 32]
|
||
assert _get_req_paddings(8, 32) == [8, 16, 32]
|
||
assert _get_req_paddings(8, 36) == [8, 16, 32, 36]
|
||
|
||
|
||
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(
|
||
model_runner):
|
||
layer_0 = "model.layers.0.self_attn.attn"
|
||
layer_1 = "model.layers.1.self_attn.attn"
|
||
error_msg = f"{layer_1} must come before the current layer"
|
||
vllm_config = model_runner.vllm_config
|
||
with pytest.raises(ValueError, match=error_msg), \
|
||
set_current_vllm_config(vllm_config):
|
||
fwd_context = {
|
||
# initialization below will fail because target layer is invalid;
|
||
# the target layer needs to come before layer 1
|
||
layer_0:
|
||
Attention(
|
||
num_heads=8,
|
||
head_size=128,
|
||
scale=1.0,
|
||
prefix=layer_0,
|
||
kv_sharing_target_layer_name=layer_1,
|
||
),
|
||
layer_1:
|
||
Attention(
|
||
num_heads=8,
|
||
head_size=128,
|
||
scale=1.0,
|
||
prefix=layer_1,
|
||
)
|
||
}
|
||
# suppress var not used error
|
||
assert fwd_context is not None
|
||
|
||
|
||
def test_init_kv_cache_with_kv_sharing_target_layer_not_exist(model_runner):
|
||
layer_0 = "model.layers.0.self_attn.attn"
|
||
layer_1 = "model.layers.1.self_attn.attn"
|
||
invalid_layer = "model.layers.0.cross_attn.attn"
|
||
error_msg = f"{invalid_layer} is not a valid Attention layer in the model"
|
||
vllm_config = model_runner.vllm_config
|
||
with pytest.raises(ValueError, match=error_msg), \
|
||
set_current_vllm_config(vllm_config):
|
||
fwd_context = {
|
||
layer_0:
|
||
Attention(
|
||
num_heads=8,
|
||
head_size=128,
|
||
scale=1.0,
|
||
prefix=layer_0,
|
||
),
|
||
layer_1:
|
||
Attention(
|
||
num_heads=8,
|
||
head_size=128,
|
||
scale=1.0,
|
||
prefix=layer_1,
|
||
# invalid layer: cross_attn.atn doesn't exist!
|
||
kv_sharing_target_layer_name=invalid_layer,
|
||
)
|
||
}
|
||
# suppress var not used error
|
||
assert fwd_context is not None
|
||
|
||
|
||
def test_init_kv_cache_with_kv_sharing_target_same_as_current(model_runner):
|
||
layer_0 = "model.layers.0.self_attn.attn"
|
||
layer_1 = "model.layers.1.self_attn.attn"
|
||
error_msg = f"{layer_1} cannot be the same as the current layer"
|
||
vllm_config = model_runner.vllm_config
|
||
with pytest.raises(ValueError, match=error_msg), \
|
||
set_current_vllm_config(vllm_config):
|
||
fwd_context = {
|
||
# initialization below will fail because target layer is invalid;
|
||
# the target layer needs to come before layer 1
|
||
layer_0:
|
||
Attention(
|
||
num_heads=8,
|
||
head_size=128,
|
||
scale=1.0,
|
||
prefix=layer_0,
|
||
),
|
||
layer_1:
|
||
Attention(
|
||
num_heads=8,
|
||
head_size=128,
|
||
scale=1.0,
|
||
prefix=layer_1,
|
||
kv_sharing_target_layer_name=layer_1,
|
||
)
|
||
}
|
||
# suppress var not used error
|
||
assert fwd_context is not None
|
||
|
||
|
||
def test_init_kv_cache_without_kv_sharing():
|
||
layer_0 = "model.layers.0.self_attn.attn"
|
||
layer_1 = "model.layers.1.self_attn.attn"
|
||
vllm_config = get_vllm_config()
|
||
with set_current_vllm_config(vllm_config):
|
||
fwd_context = {
|
||
layer_0:
|
||
Attention(
|
||
num_heads=8,
|
||
head_size=128,
|
||
scale=1.0,
|
||
prefix=layer_0,
|
||
),
|
||
layer_1:
|
||
Attention(
|
||
num_heads=8,
|
||
head_size=128,
|
||
scale=1.0,
|
||
prefix=layer_1,
|
||
)
|
||
}
|
||
# suppress var not used error
|
||
assert fwd_context is not None
|
||
# Set high context length to test max context length estimation
|
||
vllm_config.model_config.max_model_len = 1_000_000
|
||
vllm_ctx = vllm_config.compilation_config.static_forward_context
|
||
model_runner = get_model_runner(vllm_config)
|
||
kv_cache_spec = model_runner.get_kv_cache_spec()
|
||
assert len(kv_cache_spec) == 2
|
||
assert len(model_runner.shared_kv_cache_layers) == 0
|
||
|
||
available_memory = 20 * GiB_bytes
|
||
# page size for each layer KV can be calculated as
|
||
# 2 (non-MLA) * 8 (num_heads) * 128 (head_dim)
|
||
# * 2 (bfloat16, kv_cache dtype) * 128 (block_size) = 512KB
|
||
num_expected_blocks = 20480 # 20GB / 512KB / 2 (num layers)
|
||
kv_cache_config = get_kv_cache_configs(vllm_config, [kv_cache_spec],
|
||
[available_memory])[0]
|
||
assert kv_cache_config.num_blocks == num_expected_blocks
|
||
assert len(kv_cache_config.kv_cache_tensors) == 2
|
||
assert kv_cache_config.kv_cache_tensors[0].size == available_memory // 2
|
||
assert kv_cache_config.kv_cache_tensors[1].size == available_memory // 2
|
||
|
||
max_context_len =\
|
||
estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
|
||
# max context len with KV sharing should be 2x as large as without
|
||
# max_context_len = available_memory / (page_size / block_size) / num_caches
|
||
# max_context_len = 5GB / (512KB / 128) / 2 = 655360
|
||
assert max_context_len == 655360
|
||
|
||
# important: override tensor size to prevent large mem alloc during test
|
||
# this will only allocate 2 block worth of memory (2 * 512kb)
|
||
kv_cache_config.num_blocks = 1
|
||
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
||
kv_cache_tensor.size = (
|
||
kv_cache_spec[kv_cache_tensor.shared_by[0]].page_size_bytes)
|
||
|
||
model_runner.initialize_kv_cache(kv_cache_config)
|
||
|
||
layer_0_kv = vllm_ctx[layer_0].kv_cache[0]
|
||
layer_1_kv = vllm_ctx[layer_1].kv_cache[0]
|
||
# check layer 1 kv cache does NOT share memory with layer 0
|
||
assert id(layer_1_kv) != id(layer_0_kv)
|
||
|
||
# check layer 1 added to kv cache group's layer names
|
||
assert len(kv_cache_config.kv_cache_groups) == 1
|
||
assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2
|
||
assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0
|
||
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
|
||
|
||
|
||
def test_init_kv_cache_with_kv_sharing_valid():
|
||
layer_0 = "model.layers.0.self_attn.attn"
|
||
layer_1 = "model.layers.1.self_attn.attn"
|
||
vllm_config = get_vllm_config()
|
||
with set_current_vllm_config(vllm_config):
|
||
fwd_context = {
|
||
layer_0:
|
||
Attention(
|
||
num_heads=8,
|
||
head_size=128,
|
||
scale=1.0,
|
||
prefix=layer_0,
|
||
),
|
||
layer_1:
|
||
Attention(
|
||
num_heads=8,
|
||
head_size=128,
|
||
scale=1.0,
|
||
prefix=layer_1,
|
||
kv_sharing_target_layer_name="model.layers.0.self_attn.attn",
|
||
)
|
||
}
|
||
# suppress var not used error
|
||
assert fwd_context is not None
|
||
# Set high context length to test max context length estimation
|
||
vllm_config.model_config.max_model_len = 3_000_000
|
||
vllm_ctx = vllm_config.compilation_config.static_forward_context
|
||
model_runner = get_model_runner(vllm_config)
|
||
kv_cache_spec = model_runner.get_kv_cache_spec()
|
||
assert len(kv_cache_spec) == 1
|
||
assert layer_0 in kv_cache_spec
|
||
assert model_runner.shared_kv_cache_layers[layer_1] == layer_0
|
||
|
||
available_memory = 20 * GiB_bytes
|
||
# page size for layer 0's kv_cache_spec is 512KB
|
||
# with KV sharing, we can allocate (available_mem//page_size//1) blocks
|
||
# which is twice as many as without KV sharing
|
||
num_expected_blocks = 2 * 20480 # 20GB / 512KB
|
||
kv_cache_config = get_kv_cache_configs(vllm_config, [kv_cache_spec],
|
||
[available_memory])[0]
|
||
assert kv_cache_config.num_blocks == num_expected_blocks
|
||
assert len(kv_cache_config.kv_cache_tensors) == 1
|
||
# Each layer now has twice the available memory for KV cache
|
||
# compared to no KV sharing
|
||
assert kv_cache_config.kv_cache_tensors[0].size == available_memory
|
||
|
||
max_context_len =\
|
||
estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
|
||
# max context len with KV sharing should be 2x as large as without
|
||
assert max_context_len == (2 * 655360)
|
||
|
||
# important: override tensor size to prevent large mem alloc during test
|
||
# this will only allocate 1 block worth of memory (512kb)
|
||
kv_cache_config.num_blocks = 1
|
||
kv_cache_config.kv_cache_tensors[0].size =\
|
||
kv_cache_spec[layer_0].page_size_bytes
|
||
|
||
model_runner.initialize_kv_cache(kv_cache_config)
|
||
|
||
layer_0_kv = vllm_ctx[layer_0].kv_cache[0]
|
||
layer_1_kv = vllm_ctx[layer_1].kv_cache[0]
|
||
# check layer 1 kv cache shares memory with layer 0
|
||
assert id(layer_1_kv) == id(layer_0_kv)
|
||
|
||
# check layer 1 added to kv cache group's layer names
|
||
assert len(kv_cache_config.kv_cache_groups) == 1
|
||
assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2
|
||
assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0
|
||
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
|
||
|
||
|
||
def test_most_model_len(monkeypatch: pytest.MonkeyPatch):
|
||
monkeypatch.setenv("VLLM_TPU_MOST_MODEL_LEN", "2048")
|
||
vllm_config = get_vllm_config()
|
||
vllm_config.model_config.max_model_len = 32000
|
||
vllm_config.scheduler_config.max_num_seqs = 1200
|
||
model_runner = get_model_runner(vllm_config)
|
||
|
||
# verify model runner will adjust num_reqs to avoid SMEM OOM.
|
||
assert model_runner.num_reqs_most_model_len == 1200
|
||
# num_page_per_req = 32k // 128
|
||
# num_reqs = 1024 ** 2 // 2 // num_page_per_req // 4 = 524
|
||
assert model_runner.num_reqs_max_model_len == 524
|