mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-07 19:57:08 +08:00
Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: Lucas Kabela <lucaskabela@meta.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: Andrew Sansom <andrew@protopia.ai> Signed-off-by: Boyuan Feng <boyuan@meta.com> Signed-off-by: Boyuan Feng <fby.1994@gmail.com> Signed-off-by: boyuanfeng <boyuan@meta.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: JartX <sagformas@epdcenter.es> Signed-off-by: Chendi Xue <Chendi.Xue@intel.com> Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Signed-off-by: Roger Wang <hey@rogerw.io> Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: wwl2755 <wangwenlong2755@gmail.com> Signed-off-by: Manoel Marques <manoel.marques@ibm.com> Signed-off-by: Manoel Marques <manoelmrqs@gmail.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: pengdrumli <pengdrumli@tencent.com> Signed-off-by: windsonsea <haifeng.yao@daocloud.io> Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai> Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: Huamin Li <3ericli@gmail.com> Signed-off-by: simondanielsson <simon.danielsson99@hotmail.com> Signed-off-by: Rahul Tuli <rtuli@redhat.com> Signed-off-by: Yang <lymailforjob@gmail.com> Signed-off-by: Debolina Roy <debroy@redhat.com> Signed-off-by: David Chen <530634352@qq.com> Signed-off-by: wangzi <3220100013@zju.edu.cn> Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com> Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com> Signed-off-by: Sara Kokkila Schumacher <saraks@ibm.com> Signed-off-by: Csrayz <jover@cmbchina.com> Signed-off-by: ivyilike <pww123@cmbchina.com> Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com> Signed-off-by: Bowen Wang <abmfy@icloud.com> Signed-off-by: qqma <qqma@amazon.com> Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: Lu Fang <fanglu@fb.com> Signed-off-by: Zhuohan Li <zhuohan123@gmail.com> Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: luka <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: Or Ozeri <oro@il.ibm.com> Signed-off-by: Johnny Yang <johnnyyang@google.com> Signed-off-by: Alec Solder <alecs@fb.com> Signed-off-by: Alec S <10566873+alecsolder@users.noreply.github.com> Signed-off-by: Russell Bryant <rbryant@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Alexander Matveev <amatveev@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: liuye.hj <liuye.hj@alibaba-inc.com> Signed-off-by: Kunshang Ji <kunshang.ji@intel.com> Signed-off-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Signed-off-by: Ming Yang <minos.future@gmail.com> Signed-off-by: Zhikaiiii <1658973216@qq.com> Signed-off-by: Andreas Hartel <andreas.hartel@aleph-alpha.com> Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Signed-off-by: wuxibin <wuxibin@bytedance.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: Peter Pan <Peter.Pan@daocloud.io> Signed-off-by: Peter Pan <peter.pan@daocloud.io> Signed-off-by: Nicolò Lucchesi<nicolo.lucchesi@gmail.com> Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Signed-off-by: Sage Moore <sage@neuralmagic.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Signed-off-by: Bill Nell <bnell@redhat.com> Signed-off-by: Shreeasish Kumar <shreeasish@rivosinc.com> Signed-off-by: Weida Hong <wdhongtw@google.com> Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com> Signed-off-by: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com> Signed-off-by: Amir Samani <asamani@nvidia.com> Signed-off-by: ElizaWszola <elizaw.9289@gmail.com> Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Signed-off-by: ilmarkov <markovilya197@gmail.com> Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Signed-off-by: rouchenzi <ruochenwen@gmail.com> Signed-off-by: rouchenzi <40842833+rouchenzi@users.noreply.github.com> Signed-off-by: Andrew Xia <axia@meta.com> Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com> Signed-off-by: Corey Lowman <clowman1993@gmail.com> Signed-off-by: jpvillam <jpvillam@amd.com> Signed-off-by: dougbtv <dosmith@redhat.com> Signed-off-by: Chenxi Yang <cxyang@fb.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: ahao-anyscale <ahao@anyscale.com> Signed-off-by: Yan Lu <luyan@nvidia.com> Signed-off-by: baxingpiaochong <771405853@qq.com> Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> Signed-off-by: Nikhil Gupta <nikhil.gupta2@arm.com> Signed-off-by: Yong Hoon Shin <yhshin@meta.com> Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai> Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Signed-off-by: Ben Browning <bbrownin@redhat.com> Signed-off-by: Chengji Yao <chengjiyao@google.com> Signed-off-by: jiang1.li <jiang1.li@intel.com> Signed-off-by: Jackmin801 <ongjackm@gmail.com> Signed-off-by: Jonas M. Kübler <44084297+jmkuebler@users.noreply.github.com> Signed-off-by: taohui <taohui3@gmail.com> Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io> Signed-off-by: Shu Wang <shuw@nvidia.com> Signed-off-by: Shu Wang. <shuw@nvidia.com> Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Signed-off-by: Duncan Moss <djm.moss@gmail.com> Signed-off-by: Shiyan Deng <dsy842974287@meta.com> Signed-off-by: Wei Wei <wwei6@meta.com> Signed-off-by: Saman Keon <samanamp@outlook.com> Signed-off-by: yangxurui <yangxurui@meituan.com> Signed-off-by: nicole-lihui <nicole.li@daocloud.io> Signed-off-by: courage17340 <courage17340@163.com> Signed-off-by: Jacob Kahn <jacobkahn1@gmail.com> Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com> Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai> Signed-off-by: zxw <1020938856@qq.com> Signed-off-by: wang.yuqi <noooop@126.com> Signed-off-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Signed-off-by: chenlang <chen.lang5@zte.com.cn> Signed-off-by: Jonas Kuebler <kuebj@amazon.com> Signed-off-by: AlonKejzman <alonkeizman@gmail.com> Signed-off-by: Tao Hui <taohui3@gmail.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Signed-off-by: Eugene Khvedchenia <ekhvedchenia@nvidia.com> Signed-off-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Signed-off-by: yiting.jiang <yiting.jiang@daocloud.io> Signed-off-by: xaguilar <Xavier.AguilarFruto@amd.com> Signed-off-by: Iceber Gu <caiwei95@hotmail.com> Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com> Signed-off-by: Icey <1790571317@qq.com> Signed-off-by: 许文卿 <xwq391974@alibaba-inc.com> Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: Nick Hill <nhill@redhat.com> Co-authored-by: Lucas Kabela <lucasakabela@gmail.com> Co-authored-by: Maximilien de Bayser <mbayser@br.ibm.com> Co-authored-by: Andrew Sansom <andrew@protopia.ai> Co-authored-by: Boyuan Feng <boyuan@meta.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: JartX <sagformas@epdcenter.es> Co-authored-by: Chendi.Xue <chendi.xue@intel.com> Co-authored-by: Chauncey <chaunceyjiang@gmail.com> Co-authored-by: xin.li <xin.li@daocloud.io> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Roger Wang <hey@rogerw.io> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Wenlong Wang <wangwenlong2755@gmail.com> Co-authored-by: Manoel Marques <manoelmrqs@gmail.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: lirong <56789630+lirong-lirong@users.noreply.github.com> Co-authored-by: Michael Yao <haifeng.yao@daocloud.io> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: Huamin Li <3ericli@gmail.com> Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com> Co-authored-by: Simon Danielsson <70206058+simondanielsson@users.noreply.github.com> Co-authored-by: Rahul Tuli <rtuli@redhat.com> Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: Yang Liu <127183760+KKSK-DON@users.noreply.github.com> Co-authored-by: Deboleina <debroy@redhat.com> Co-authored-by: yinz-aizip <yinz@aizip.ai> Co-authored-by: WeiQing Chen <40507679+david6666666@users.noreply.github.com> Co-authored-by: wangzi <3220100013@zju.edu.cn> Co-authored-by: Eldar Kurtić <8884008+eldarkurtic@users.noreply.github.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com> Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com> Co-authored-by: Yizhou <136800916+yiz-liu@users.noreply.github.com> Co-authored-by: Sara-KS <50249410+Sara-KS@users.noreply.github.com> Co-authored-by: Csrayz <jover@cmbchina.com> Co-authored-by: ivyilike <pww123@cmbchina.com> Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com> Co-authored-by: Bowen Wang <abmfy@icloud.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Daisy-Ma-coder <daisy.ma.0117@gmail.com> Co-authored-by: qqma <qqma@amazon.com> Co-authored-by: ElizaWszola <ewszola@redhat.com> Co-authored-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Co-authored-by: Zhuohan Li <zhuohan123@gmail.com> Co-authored-by: Simon Mo <simon.mo@hey.com> Co-authored-by: Or Ozeri <oro@il.ibm.com> Co-authored-by: Johnny Yang <24908445+jcyang43@users.noreply.github.com> Co-authored-by: Chengji Yao <chengjiyao@google.com> Co-authored-by: Alec S <10566873+alecsolder@users.noreply.github.com> Co-authored-by: Alec Solder <alecs@fb.com> Co-authored-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Chris Bamford <chrisbam4d@gmail.com> Co-authored-by: Alexander Matveev <59768536+alexm-redhat@users.noreply.github.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: JJJYmmm <92386084+JJJYmmm@users.noreply.github.com> Co-authored-by: liuye.hj <liuye.hj@alibaba-inc.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com> Co-authored-by: Lucia (Lu) Fang <fanglu@meta.com> Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Ming Yang <yming@meta.com> Co-authored-by: Zhikaiiii <55917203+Zhikaiiii@users.noreply.github.com> Co-authored-by: Andreas Hartel <andreas@hartel.me> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: Joel <wuxibin89@163.com> Co-authored-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Mark McLoughlin <markmc@redhat.com> Co-authored-by: Peter Pan <peter.pan@daocloud.io> Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com> Co-authored-by: Fanli Lin <fanli.lin@intel.com> Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Sage Moore <sage@neuralmagic.com> Co-authored-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: bnellnm <49004751+bnellnm@users.noreply.github.com> Co-authored-by: rivos-shreeasish <shreeasish@rivosinc.com> Co-authored-by: Chih-Chieh Yang <chih.chieh.yang@ibm.com> Co-authored-by: Weida Hong <wdhongtw@gmail.com> Co-authored-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Co-authored-by: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com> Co-authored-by: Amir Samani <samani@ualberta.ca> Co-authored-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Co-authored-by: Ilya Markov <markovilya197@gmail.com> Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Co-authored-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Co-authored-by: rouchenzi <40842833+rouchenzi@users.noreply.github.com> Co-authored-by: Andrew Xia <axia@meta.com> Co-authored-by: kourosh hakhamaneshi <31483498+kouroshHakha@users.noreply.github.com> Co-authored-by: Corey Lowman <clowman1993@gmail.com> Co-authored-by: Juan Villamizar <100237675+jpvillam-amd@users.noreply.github.com> Co-authored-by: jpvillam <jpvillam@amd.com> Co-authored-by: Doug Smith <dosmith@redhat.com> Co-authored-by: Chenxi Yang <cxyang@cs.utexas.edu> Co-authored-by: Chenxi Yang <cxyang@fb.com> Co-authored-by: ahao-anyscale <ahao@anyscale.com> Co-authored-by: 0xNullPath <luyanfcp@foxmail.com> Co-authored-by: baxingpiaochong <771405853@qq.com> Co-authored-by: Benjamin Chislett <bchislett@nvidia.com> Co-authored-by: Kyle Sayers <kylesayrs@gmail.com> Co-authored-by: Nikhil Gupta <nikhil.gupta2@arm.com> Co-authored-by: Yong Hoon Shin <48474650+sarckk@users.noreply.github.com> Co-authored-by: lhsjohn <huashuoli@tencent.com> Co-authored-by: Ben Browning <bbrownin@redhat.com> Co-authored-by: Li, Jiang <jiang1.li@intel.com> Co-authored-by: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> Co-authored-by: Jonas M. Kübler <44084297+jmkuebler@users.noreply.github.com> Co-authored-by: Tao Hui <taohui3@gmail.com> Co-authored-by: rongfu.leng <rongfu.leng@daocloud.io> Co-authored-by: Shu Wang <shuw@nvidia.com> Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Co-authored-by: Duncan Moss <djm.moss@gmail.com> Co-authored-by: Shiyan Deng <dsy842974287@meta.com> Co-authored-by: Wei Wei <wwei6@meta.com> Co-authored-by: Saman A. Pour <samanamp@outlook.com> Co-authored-by: XuruiYang <530534756@qq.com> Co-authored-by: yangxurui <yangxurui@meituan.com> Co-authored-by: Nicole LiHui 🥜 <nicolelihui@outlook.com> Co-authored-by: courage17340 <courage17340@users.noreply.github.com> Co-authored-by: Jacob Kahn <jacobkahn1@gmail.com> Co-authored-by: Nicole LiHui 🥜 <nicole.li@daocloud.io> Co-authored-by: Fadi Arafeh <115173828+fadara01@users.noreply.github.com> Co-authored-by: Agata Dobrzyniewicz <160237065+adobrzyn@users.noreply.github.com> Co-authored-by: yyzxw <34639446+yyzxw@users.noreply.github.com> Co-authored-by: wang.yuqi <noooop@126.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: chenlang <chen.lang5@zte.com.cn> Co-authored-by: chenlang <10346245@zte.com.cn> Co-authored-by: AlonKejzman <alonkeizman@gmail.com> Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> Co-authored-by: Aleksandr Malyshev <164964928+maleksan85@users.noreply.github.com> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Doug Lehr <douglehr@amd.com> Co-authored-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Co-authored-by: yitingdc <59356937+yitingdc@users.noreply.github.com> Co-authored-by: xaguilar-amd <xavier.aguilarfruto@amd.com> Co-authored-by: Iceber Gu <caiwei95@hotmail.com> Co-authored-by: Tao He <linzhu.ht@alibaba-inc.com> Co-authored-by: Icey <1790571317@qq.com> Co-authored-by: Xu Wenqing <121550081+Xu-Wenqing@users.noreply.github.com> Co-authored-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: RishiAstra <40644327+RishiAstra@users.noreply.github.com>
457 lines
18 KiB
Python
457 lines
18 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import copy
|
|
import time
|
|
import uuid
|
|
from concurrent.futures import Future, ThreadPoolExecutor
|
|
|
|
import pytest
|
|
from transformers import AutoTokenizer
|
|
|
|
from vllm import SamplingParams
|
|
from vllm.engine.arg_utils import EngineArgs
|
|
from vllm.platforms import current_platform
|
|
from vllm.utils import set_default_torch_num_threads
|
|
from vllm.v1.engine import EngineCoreRequest
|
|
from vllm.v1.engine.core import EngineCore
|
|
from vllm.v1.executor.abstract import Executor, UniProcExecutor
|
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
|
from vllm.v1.outputs import ModelRunnerOutput
|
|
|
|
from ...utils import create_new_process_for_each_test, multi_gpu_test
|
|
|
|
if not current_platform.is_cuda():
|
|
pytest.skip(reason="V1 currently only supported on CUDA.",
|
|
allow_module_level=True)
|
|
|
|
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
|
TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME)
|
|
PROMPT = "Hello my name is Robert and I love quantization kernels"
|
|
PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids
|
|
|
|
|
|
def make_request() -> EngineCoreRequest:
|
|
return EngineCoreRequest(
|
|
request_id=str(uuid.uuid4()),
|
|
prompt_token_ids=PROMPT_TOKENS,
|
|
mm_features=None,
|
|
sampling_params=SamplingParams(),
|
|
pooling_params=None,
|
|
eos_token_id=None,
|
|
arrival_time=time.time(),
|
|
lora_request=None,
|
|
cache_salt=None,
|
|
data_parallel_rank=None,
|
|
)
|
|
|
|
|
|
@create_new_process_for_each_test()
|
|
def test_engine_core(monkeypatch: pytest.MonkeyPatch):
|
|
|
|
with monkeypatch.context() as m:
|
|
m.setenv("VLLM_USE_V1", "1")
|
|
"""Setup the EngineCore."""
|
|
engine_args = EngineArgs(model=MODEL_NAME)
|
|
vllm_config = engine_args.create_engine_config()
|
|
executor_class = Executor.get_class(vllm_config)
|
|
|
|
with set_default_torch_num_threads(1):
|
|
engine_core = EngineCore(vllm_config=vllm_config,
|
|
executor_class=executor_class,
|
|
log_stats=True)
|
|
"""Test basic request lifecycle."""
|
|
|
|
# First request.
|
|
engine_core.add_request(
|
|
*engine_core.preprocess_add_request(make_request()))
|
|
assert len(engine_core.scheduler.waiting) == 1
|
|
assert len(engine_core.scheduler.running) == 0
|
|
|
|
_ = engine_core.step()
|
|
assert len(engine_core.scheduler.waiting) == 0
|
|
assert len(engine_core.scheduler.running) == 1
|
|
|
|
# Second request.
|
|
engine_core.add_request(
|
|
*engine_core.preprocess_add_request(make_request()))
|
|
assert len(engine_core.scheduler.waiting) == 1
|
|
assert len(engine_core.scheduler.running) == 1
|
|
|
|
_ = engine_core.step()
|
|
assert len(engine_core.scheduler.waiting) == 0
|
|
assert len(engine_core.scheduler.running) == 2
|
|
|
|
# Add two requests in a row.
|
|
engine_core.add_request(
|
|
*engine_core.preprocess_add_request(make_request()))
|
|
engine_core.add_request(
|
|
*engine_core.preprocess_add_request(make_request()))
|
|
assert len(engine_core.scheduler.waiting) == 2
|
|
assert len(engine_core.scheduler.running) == 2
|
|
|
|
_ = engine_core.step()
|
|
assert len(engine_core.scheduler.waiting) == 0
|
|
assert len(engine_core.scheduler.running) == 4
|
|
|
|
# Loop through until they are all done.
|
|
while (outs := engine_core.step()[0].get(0)) and outs.outputs:
|
|
pass
|
|
|
|
assert len(engine_core.scheduler.waiting) == 0
|
|
assert len(engine_core.scheduler.running) == 0
|
|
"""Test abort cycle."""
|
|
|
|
# Basic abort.
|
|
req = make_request()
|
|
request_id = req.request_id
|
|
|
|
engine_core.add_request(*engine_core.preprocess_add_request(req))
|
|
assert len(engine_core.scheduler.waiting) == 1
|
|
assert len(engine_core.scheduler.running) == 0
|
|
assert engine_core.scheduler.has_unfinished_requests()
|
|
assert not engine_core.scheduler.has_finished_requests()
|
|
|
|
_ = engine_core.step()
|
|
assert len(engine_core.scheduler.waiting) == 0
|
|
assert len(engine_core.scheduler.running) == 1
|
|
assert engine_core.scheduler.has_unfinished_requests()
|
|
assert not engine_core.scheduler.has_finished_requests()
|
|
|
|
engine_core.abort_requests([request_id])
|
|
assert len(engine_core.scheduler.waiting) == 0
|
|
assert len(engine_core.scheduler.running) == 0
|
|
assert not engine_core.scheduler.has_unfinished_requests()
|
|
assert engine_core.scheduler.has_finished_requests()
|
|
|
|
_ = engine_core.step()
|
|
assert not engine_core.scheduler.has_unfinished_requests()
|
|
assert not engine_core.scheduler.has_finished_requests()
|
|
|
|
# Add, step, abort 1 of the 3.
|
|
req0 = make_request()
|
|
req1 = make_request()
|
|
req2 = make_request()
|
|
|
|
engine_core.add_request(*engine_core.preprocess_add_request(req0))
|
|
engine_core.add_request(*engine_core.preprocess_add_request(req1))
|
|
assert len(engine_core.scheduler.waiting) == 2
|
|
assert len(engine_core.scheduler.running) == 0
|
|
|
|
_ = engine_core.step()
|
|
assert len(engine_core.scheduler.waiting) == 0
|
|
assert len(engine_core.scheduler.running) == 2
|
|
|
|
engine_core.add_request(*engine_core.preprocess_add_request(req2))
|
|
assert len(engine_core.scheduler.waiting) == 1
|
|
assert len(engine_core.scheduler.running) == 2
|
|
|
|
_ = engine_core.step()
|
|
assert len(engine_core.scheduler.waiting) == 0
|
|
assert len(engine_core.scheduler.running) == 3
|
|
|
|
# Abort just one.
|
|
engine_core.abort_requests([req1.request_id])
|
|
assert len(engine_core.scheduler.waiting) == 0
|
|
assert len(engine_core.scheduler.running) == 2
|
|
|
|
_ = engine_core.step()
|
|
assert len(engine_core.scheduler.waiting) == 0
|
|
assert len(engine_core.scheduler.running) == 2
|
|
|
|
# Abort the other requests at the same time.
|
|
engine_core.abort_requests([req2.request_id, req0.request_id])
|
|
assert len(engine_core.scheduler.waiting) == 0
|
|
assert len(engine_core.scheduler.running) == 0
|
|
|
|
# Sending duplicate requests with same request_id
|
|
req0 = make_request()
|
|
req1 = make_request()
|
|
req0.request_id = req1.request_id = "test"
|
|
engine_core.add_request(*engine_core.preprocess_add_request(req0))
|
|
|
|
while (outs := engine_core.step()[0].get(0)) and outs.outputs:
|
|
pass
|
|
|
|
engine_core.add_request(*engine_core.preprocess_add_request(req1))
|
|
while (outs := engine_core.step()[0].get(0)) and outs.outputs:
|
|
pass
|
|
|
|
assert len(engine_core.scheduler.waiting) == 0
|
|
assert len(engine_core.scheduler.running) == 0
|
|
|
|
|
|
@create_new_process_for_each_test()
|
|
def test_engine_core_advanced_sampling(monkeypatch: pytest.MonkeyPatch):
|
|
"""
|
|
A basic end-to-end test to verify that the engine functions correctly
|
|
when additional sampling parameters, such as top_p, min_tokens, and
|
|
presence_penalty, are set.
|
|
"""
|
|
with monkeypatch.context() as m:
|
|
m.setenv("VLLM_USE_V1", "1")
|
|
"""Setup the EngineCore."""
|
|
engine_args = EngineArgs(model=MODEL_NAME)
|
|
vllm_config = engine_args.create_engine_config()
|
|
executor_class = Executor.get_class(vllm_config)
|
|
|
|
with set_default_torch_num_threads(1):
|
|
engine_core = EngineCore(vllm_config=vllm_config,
|
|
executor_class=executor_class,
|
|
log_stats=True)
|
|
"""Test basic request lifecycle."""
|
|
# First request.
|
|
request: EngineCoreRequest = make_request()
|
|
request.sampling_params = SamplingParams(
|
|
min_tokens=4,
|
|
presence_penalty=1.0,
|
|
frequency_penalty=1.0,
|
|
repetition_penalty=0.1,
|
|
stop_token_ids=[1001, 1002],
|
|
)
|
|
engine_core.add_request(*engine_core.preprocess_add_request(request))
|
|
|
|
def _check_engine_state():
|
|
assert len(engine_core.scheduler.waiting) == 1
|
|
assert len(engine_core.scheduler.running) == 0
|
|
# Loop through until they are all done.
|
|
while (outs := engine_core.step()[0].get(0)) and outs.outputs:
|
|
pass
|
|
assert len(engine_core.scheduler.waiting) == 0
|
|
assert len(engine_core.scheduler.running) == 0
|
|
|
|
_check_engine_state()
|
|
|
|
# Second request.
|
|
request2 = make_request()
|
|
request2.sampling_params = SamplingParams(
|
|
top_p=0.99,
|
|
top_k=50,
|
|
)
|
|
engine_core.add_request(*engine_core.preprocess_add_request(request2))
|
|
_check_engine_state()
|
|
|
|
|
|
@create_new_process_for_each_test()
|
|
def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
|
|
"""
|
|
Test that the engine can handle multiple concurrent batches.
|
|
"""
|
|
|
|
def make_request_with_max_tokens(req_id: str,
|
|
max_tokens: int) -> EngineCoreRequest:
|
|
request = make_request()
|
|
request.request_id = req_id
|
|
request.sampling_params.max_tokens = max_tokens
|
|
return request
|
|
|
|
class DummyExecutor(UniProcExecutor):
|
|
|
|
def initialize_from_config(
|
|
self, kv_cache_configs: list[KVCacheConfig]) -> None:
|
|
super().initialize_from_config(kv_cache_configs)
|
|
|
|
# Create a thread pool with a single worker
|
|
self.thread_pool = ThreadPoolExecutor(max_workers=1)
|
|
|
|
def execute_model(
|
|
self,
|
|
scheduler_output,
|
|
non_block=False,
|
|
) -> Future[ModelRunnerOutput]:
|
|
"""Make execute_model non-blocking."""
|
|
|
|
# DummyExecutor used only for testing async case.
|
|
assert non_block
|
|
|
|
def _execute():
|
|
output = self.collective_rpc("execute_model",
|
|
args=(scheduler_output, ))
|
|
# Make a copy because output[0] may be reused
|
|
# by the next batch.
|
|
return copy.deepcopy(output[0])
|
|
|
|
# Use the thread pool instead of creating a new thread
|
|
return self.thread_pool.submit(_execute)
|
|
|
|
@property
|
|
def max_concurrent_batches(self) -> int:
|
|
return 2
|
|
|
|
def shutdown(self):
|
|
if hasattr(self, 'thread_pool'):
|
|
self.thread_pool.shutdown(wait=False)
|
|
|
|
with monkeypatch.context() as m:
|
|
m.setenv("VLLM_USE_V1", "1")
|
|
|
|
engine_args = EngineArgs(
|
|
model=MODEL_NAME,
|
|
# To test concurrent batches.
|
|
max_num_seqs=2,
|
|
# Avoid all requests being scheduled once.
|
|
enable_prefix_caching=False,
|
|
max_num_batched_tokens=10,
|
|
# Reduce startup time.
|
|
enforce_eager=True,
|
|
)
|
|
vllm_config = engine_args.create_engine_config()
|
|
with set_default_torch_num_threads(1):
|
|
engine_core = EngineCore(vllm_config=vllm_config,
|
|
log_stats=False,
|
|
executor_class=DummyExecutor)
|
|
assert engine_core.batch_queue is not None
|
|
|
|
# Add two requests in a row. Each request have 12 prompt tokens.
|
|
req0 = make_request_with_max_tokens("0", 5)
|
|
engine_core.add_request(*engine_core.preprocess_add_request(req0))
|
|
req1 = make_request_with_max_tokens("1", 5)
|
|
engine_core.add_request(*engine_core.preprocess_add_request(req1))
|
|
|
|
# Schedule Batch 1: (10, req0)
|
|
assert engine_core.step_with_batch_queue()[0] is None
|
|
assert len(engine_core.batch_queue) == 1
|
|
scheduler_output = engine_core.batch_queue[-1][1]
|
|
assert scheduler_output.num_scheduled_tokens["0"] == 10
|
|
# num_computed_tokens should have been updated immediately.
|
|
assert engine_core.scheduler.requests[
|
|
req0.request_id].num_computed_tokens == 10
|
|
|
|
# Schedule Batch 2: (2, req0), (8, req1)
|
|
assert engine_core.step_with_batch_queue()[0] == {}
|
|
assert len(engine_core.batch_queue) == 1
|
|
scheduler_output = engine_core.batch_queue[-1][1]
|
|
assert scheduler_output.num_scheduled_tokens["0"] == 2
|
|
assert scheduler_output.num_scheduled_tokens["1"] == 8
|
|
# num_computed_tokens should have been updated immediately.
|
|
assert engine_core.scheduler.requests["0"].num_computed_tokens == 12
|
|
assert engine_core.scheduler.requests["1"].num_computed_tokens == 8
|
|
|
|
assert engine_core.scheduler.get_num_unfinished_requests() == 2
|
|
|
|
# Finish Batch 1 and schedule Batch 3: (4, req1).
|
|
# Note that req0 cannot be scheduled
|
|
# because it is in the decoding stage now.
|
|
engine_core.step_with_batch_queue()
|
|
assert len(engine_core.batch_queue) == 1
|
|
scheduler_output = engine_core.batch_queue[-1][1]
|
|
assert scheduler_output.num_scheduled_tokens["1"] == 4
|
|
|
|
# Finish Batch 2. Get first token of req0.
|
|
# Schedule Batch 4: (1, req0).
|
|
output = engine_core.step_with_batch_queue()[0].get(0)
|
|
assert output is not None
|
|
assert len(output.outputs) == 1
|
|
assert engine_core.scheduler.requests[req0.request_id].num_tokens == 13
|
|
scheduler_output = engine_core.batch_queue[-1][1]
|
|
assert scheduler_output.num_scheduled_tokens["0"] == 1
|
|
|
|
# Finish Batch 3. Get first token of req1. Schedule Batch 5: (1, req1).
|
|
output = engine_core.step_with_batch_queue()[0].get(0)
|
|
assert output is not None
|
|
assert len(output.outputs) == 1
|
|
assert engine_core.scheduler.requests[req1.request_id].num_tokens == 13
|
|
scheduler_output = engine_core.batch_queue[-1][1]
|
|
assert scheduler_output.num_scheduled_tokens["1"] == 1
|
|
|
|
# Loop until req0 is finished.
|
|
req_id = 0
|
|
expected_num_tokens = [
|
|
engine_core.scheduler.requests["0"].num_tokens + 1,
|
|
engine_core.scheduler.requests["1"].num_tokens + 1,
|
|
]
|
|
while engine_core.scheduler.get_num_unfinished_requests() == 2:
|
|
output = engine_core.step_with_batch_queue()[0]
|
|
# Every step consumes an output.
|
|
assert output is not None
|
|
assert len(output[0].outputs) == 1
|
|
if req_id in engine_core.scheduler.requests:
|
|
assert engine_core.scheduler.requests[
|
|
req_id].num_tokens == expected_num_tokens[req_id]
|
|
expected_num_tokens[req_id] += 1
|
|
req_id = (req_id + 1) % 2
|
|
|
|
|
|
@multi_gpu_test(num_gpus=2)
|
|
def test_engine_core_tp(monkeypatch: pytest.MonkeyPatch):
|
|
"""
|
|
Test engine can initialize worker in tp properly
|
|
"""
|
|
|
|
with monkeypatch.context() as m:
|
|
m.setenv("VLLM_USE_V1", "1")
|
|
"""Setup the EngineCore."""
|
|
engine_args = EngineArgs(
|
|
model=MODEL_NAME,
|
|
tensor_parallel_size=2,
|
|
# Reduce startup time.
|
|
enforce_eager=True,
|
|
)
|
|
vllm_config = engine_args.create_engine_config()
|
|
executor_class = Executor.get_class(vllm_config)
|
|
|
|
with set_default_torch_num_threads(1):
|
|
engine_core = EngineCore(vllm_config=vllm_config,
|
|
executor_class=executor_class,
|
|
log_stats=True)
|
|
|
|
def get_worker_cache_config_field(worker, key: str):
|
|
return getattr(worker.cache_config, key)
|
|
|
|
num_gpu_blocks = engine_core.collective_rpc(
|
|
get_worker_cache_config_field, args=("num_gpu_blocks", ))
|
|
num_cpu_blocks = engine_core.collective_rpc(
|
|
get_worker_cache_config_field, args=("num_cpu_blocks", ))
|
|
assert all(x is not None for x in num_gpu_blocks)
|
|
assert all(x is not None for x in num_cpu_blocks)
|
|
|
|
|
|
@create_new_process_for_each_test()
|
|
def test_engine_core_invalid_request_id_type(monkeypatch: pytest.MonkeyPatch):
|
|
"""Test that engine raises TypeError for non-string request_id."""
|
|
with monkeypatch.context() as m:
|
|
m.setenv("VLLM_USE_V1", "1")
|
|
|
|
engine_args = EngineArgs(model=MODEL_NAME)
|
|
vllm_config = engine_args.create_engine_config()
|
|
executor_class = Executor.get_class(vllm_config)
|
|
|
|
with set_default_torch_num_threads(1):
|
|
engine_core = EngineCore(vllm_config=vllm_config,
|
|
executor_class=executor_class,
|
|
log_stats=True)
|
|
|
|
# Test with UUID object (common mistake)
|
|
uuid_request = make_request()
|
|
uuid_request.request_id = uuid.uuid4() # UUID object instead of string
|
|
|
|
with pytest.raises(TypeError,
|
|
match="request_id must be a string, got.*UUID"):
|
|
engine_core.add_request(
|
|
*engine_core.preprocess_add_request(uuid_request))
|
|
|
|
# Test with integer
|
|
int_request = make_request()
|
|
int_request.request_id = 12345
|
|
|
|
with pytest.raises(TypeError,
|
|
match="request_id must be a string, got.*int"):
|
|
engine_core.add_request(
|
|
*engine_core.preprocess_add_request(int_request))
|
|
|
|
# Test with None
|
|
none_request = make_request()
|
|
none_request.request_id = None
|
|
|
|
with pytest.raises(TypeError,
|
|
match="request_id must be a string, got.*NoneType"):
|
|
engine_core.add_request(
|
|
*engine_core.preprocess_add_request(none_request))
|
|
|
|
# Verify engine is still functional after errors
|
|
valid_request = make_request()
|
|
engine_core.add_request(
|
|
*engine_core.preprocess_add_request(valid_request))
|
|
assert len(engine_core.scheduler.waiting) == 1
|
|
assert len(engine_core.scheduler.running) == 0
|