mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-07 21:37:11 +08:00
Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: Lucas Kabela <lucaskabela@meta.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: Andrew Sansom <andrew@protopia.ai> Signed-off-by: Boyuan Feng <boyuan@meta.com> Signed-off-by: Boyuan Feng <fby.1994@gmail.com> Signed-off-by: boyuanfeng <boyuan@meta.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: JartX <sagformas@epdcenter.es> Signed-off-by: Chendi Xue <Chendi.Xue@intel.com> Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Signed-off-by: Roger Wang <hey@rogerw.io> Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: wwl2755 <wangwenlong2755@gmail.com> Signed-off-by: Manoel Marques <manoel.marques@ibm.com> Signed-off-by: Manoel Marques <manoelmrqs@gmail.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: pengdrumli <pengdrumli@tencent.com> Signed-off-by: windsonsea <haifeng.yao@daocloud.io> Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai> Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: Huamin Li <3ericli@gmail.com> Signed-off-by: simondanielsson <simon.danielsson99@hotmail.com> Signed-off-by: Rahul Tuli <rtuli@redhat.com> Signed-off-by: Yang <lymailforjob@gmail.com> Signed-off-by: Debolina Roy <debroy@redhat.com> Signed-off-by: David Chen <530634352@qq.com> Signed-off-by: wangzi <3220100013@zju.edu.cn> Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com> Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com> Signed-off-by: Sara Kokkila Schumacher <saraks@ibm.com> Signed-off-by: Csrayz <jover@cmbchina.com> Signed-off-by: ivyilike <pww123@cmbchina.com> Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com> Signed-off-by: Bowen Wang <abmfy@icloud.com> Signed-off-by: qqma <qqma@amazon.com> Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: Lu Fang <fanglu@fb.com> Signed-off-by: Zhuohan Li <zhuohan123@gmail.com> Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: luka <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: Or Ozeri <oro@il.ibm.com> Signed-off-by: Johnny Yang <johnnyyang@google.com> Signed-off-by: Alec Solder <alecs@fb.com> Signed-off-by: Alec S <10566873+alecsolder@users.noreply.github.com> Signed-off-by: Russell Bryant <rbryant@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Alexander Matveev <amatveev@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: liuye.hj <liuye.hj@alibaba-inc.com> Signed-off-by: Kunshang Ji <kunshang.ji@intel.com> Signed-off-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Signed-off-by: Ming Yang <minos.future@gmail.com> Signed-off-by: Zhikaiiii <1658973216@qq.com> Signed-off-by: Andreas Hartel <andreas.hartel@aleph-alpha.com> Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Signed-off-by: wuxibin <wuxibin@bytedance.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: Peter Pan <Peter.Pan@daocloud.io> Signed-off-by: Peter Pan <peter.pan@daocloud.io> Signed-off-by: Nicolò Lucchesi<nicolo.lucchesi@gmail.com> Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Signed-off-by: Sage Moore <sage@neuralmagic.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Signed-off-by: Bill Nell <bnell@redhat.com> Signed-off-by: Shreeasish Kumar <shreeasish@rivosinc.com> Signed-off-by: Weida Hong <wdhongtw@google.com> Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com> Signed-off-by: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com> Signed-off-by: Amir Samani <asamani@nvidia.com> Signed-off-by: ElizaWszola <elizaw.9289@gmail.com> Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Signed-off-by: ilmarkov <markovilya197@gmail.com> Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Signed-off-by: rouchenzi <ruochenwen@gmail.com> Signed-off-by: rouchenzi <40842833+rouchenzi@users.noreply.github.com> Signed-off-by: Andrew Xia <axia@meta.com> Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com> Signed-off-by: Corey Lowman <clowman1993@gmail.com> Signed-off-by: jpvillam <jpvillam@amd.com> Signed-off-by: dougbtv <dosmith@redhat.com> Signed-off-by: Chenxi Yang <cxyang@fb.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: ahao-anyscale <ahao@anyscale.com> Signed-off-by: Yan Lu <luyan@nvidia.com> Signed-off-by: baxingpiaochong <771405853@qq.com> Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> Signed-off-by: Nikhil Gupta <nikhil.gupta2@arm.com> Signed-off-by: Yong Hoon Shin <yhshin@meta.com> Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai> Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Signed-off-by: Ben Browning <bbrownin@redhat.com> Signed-off-by: Chengji Yao <chengjiyao@google.com> Signed-off-by: jiang1.li <jiang1.li@intel.com> Signed-off-by: Jackmin801 <ongjackm@gmail.com> Signed-off-by: Jonas M. Kübler <44084297+jmkuebler@users.noreply.github.com> Signed-off-by: taohui <taohui3@gmail.com> Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io> Signed-off-by: Shu Wang <shuw@nvidia.com> Signed-off-by: Shu Wang. <shuw@nvidia.com> Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Signed-off-by: Duncan Moss <djm.moss@gmail.com> Signed-off-by: Shiyan Deng <dsy842974287@meta.com> Signed-off-by: Wei Wei <wwei6@meta.com> Signed-off-by: Saman Keon <samanamp@outlook.com> Signed-off-by: yangxurui <yangxurui@meituan.com> Signed-off-by: nicole-lihui <nicole.li@daocloud.io> Signed-off-by: courage17340 <courage17340@163.com> Signed-off-by: Jacob Kahn <jacobkahn1@gmail.com> Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com> Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai> Signed-off-by: zxw <1020938856@qq.com> Signed-off-by: wang.yuqi <noooop@126.com> Signed-off-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Signed-off-by: chenlang <chen.lang5@zte.com.cn> Signed-off-by: Jonas Kuebler <kuebj@amazon.com> Signed-off-by: AlonKejzman <alonkeizman@gmail.com> Signed-off-by: Tao Hui <taohui3@gmail.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Signed-off-by: Eugene Khvedchenia <ekhvedchenia@nvidia.com> Signed-off-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Signed-off-by: yiting.jiang <yiting.jiang@daocloud.io> Signed-off-by: xaguilar <Xavier.AguilarFruto@amd.com> Signed-off-by: Iceber Gu <caiwei95@hotmail.com> Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com> Signed-off-by: Icey <1790571317@qq.com> Signed-off-by: 许文卿 <xwq391974@alibaba-inc.com> Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: Nick Hill <nhill@redhat.com> Co-authored-by: Lucas Kabela <lucasakabela@gmail.com> Co-authored-by: Maximilien de Bayser <mbayser@br.ibm.com> Co-authored-by: Andrew Sansom <andrew@protopia.ai> Co-authored-by: Boyuan Feng <boyuan@meta.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: JartX <sagformas@epdcenter.es> Co-authored-by: Chendi.Xue <chendi.xue@intel.com> Co-authored-by: Chauncey <chaunceyjiang@gmail.com> Co-authored-by: xin.li <xin.li@daocloud.io> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Roger Wang <hey@rogerw.io> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Wenlong Wang <wangwenlong2755@gmail.com> Co-authored-by: Manoel Marques <manoelmrqs@gmail.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: lirong <56789630+lirong-lirong@users.noreply.github.com> Co-authored-by: Michael Yao <haifeng.yao@daocloud.io> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: Huamin Li <3ericli@gmail.com> Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com> Co-authored-by: Simon Danielsson <70206058+simondanielsson@users.noreply.github.com> Co-authored-by: Rahul Tuli <rtuli@redhat.com> Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: Yang Liu <127183760+KKSK-DON@users.noreply.github.com> Co-authored-by: Deboleina <debroy@redhat.com> Co-authored-by: yinz-aizip <yinz@aizip.ai> Co-authored-by: WeiQing Chen <40507679+david6666666@users.noreply.github.com> Co-authored-by: wangzi <3220100013@zju.edu.cn> Co-authored-by: Eldar Kurtić <8884008+eldarkurtic@users.noreply.github.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com> Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com> Co-authored-by: Yizhou <136800916+yiz-liu@users.noreply.github.com> Co-authored-by: Sara-KS <50249410+Sara-KS@users.noreply.github.com> Co-authored-by: Csrayz <jover@cmbchina.com> Co-authored-by: ivyilike <pww123@cmbchina.com> Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com> Co-authored-by: Bowen Wang <abmfy@icloud.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Daisy-Ma-coder <daisy.ma.0117@gmail.com> Co-authored-by: qqma <qqma@amazon.com> Co-authored-by: ElizaWszola <ewszola@redhat.com> Co-authored-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Co-authored-by: Zhuohan Li <zhuohan123@gmail.com> Co-authored-by: Simon Mo <simon.mo@hey.com> Co-authored-by: Or Ozeri <oro@il.ibm.com> Co-authored-by: Johnny Yang <24908445+jcyang43@users.noreply.github.com> Co-authored-by: Chengji Yao <chengjiyao@google.com> Co-authored-by: Alec S <10566873+alecsolder@users.noreply.github.com> Co-authored-by: Alec Solder <alecs@fb.com> Co-authored-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Chris Bamford <chrisbam4d@gmail.com> Co-authored-by: Alexander Matveev <59768536+alexm-redhat@users.noreply.github.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: JJJYmmm <92386084+JJJYmmm@users.noreply.github.com> Co-authored-by: liuye.hj <liuye.hj@alibaba-inc.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com> Co-authored-by: Lucia (Lu) Fang <fanglu@meta.com> Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Ming Yang <yming@meta.com> Co-authored-by: Zhikaiiii <55917203+Zhikaiiii@users.noreply.github.com> Co-authored-by: Andreas Hartel <andreas@hartel.me> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: Joel <wuxibin89@163.com> Co-authored-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Mark McLoughlin <markmc@redhat.com> Co-authored-by: Peter Pan <peter.pan@daocloud.io> Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com> Co-authored-by: Fanli Lin <fanli.lin@intel.com> Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Sage Moore <sage@neuralmagic.com> Co-authored-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: bnellnm <49004751+bnellnm@users.noreply.github.com> Co-authored-by: rivos-shreeasish <shreeasish@rivosinc.com> Co-authored-by: Chih-Chieh Yang <chih.chieh.yang@ibm.com> Co-authored-by: Weida Hong <wdhongtw@gmail.com> Co-authored-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Co-authored-by: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com> Co-authored-by: Amir Samani <samani@ualberta.ca> Co-authored-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Co-authored-by: Ilya Markov <markovilya197@gmail.com> Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Co-authored-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Co-authored-by: rouchenzi <40842833+rouchenzi@users.noreply.github.com> Co-authored-by: Andrew Xia <axia@meta.com> Co-authored-by: kourosh hakhamaneshi <31483498+kouroshHakha@users.noreply.github.com> Co-authored-by: Corey Lowman <clowman1993@gmail.com> Co-authored-by: Juan Villamizar <100237675+jpvillam-amd@users.noreply.github.com> Co-authored-by: jpvillam <jpvillam@amd.com> Co-authored-by: Doug Smith <dosmith@redhat.com> Co-authored-by: Chenxi Yang <cxyang@cs.utexas.edu> Co-authored-by: Chenxi Yang <cxyang@fb.com> Co-authored-by: ahao-anyscale <ahao@anyscale.com> Co-authored-by: 0xNullPath <luyanfcp@foxmail.com> Co-authored-by: baxingpiaochong <771405853@qq.com> Co-authored-by: Benjamin Chislett <bchislett@nvidia.com> Co-authored-by: Kyle Sayers <kylesayrs@gmail.com> Co-authored-by: Nikhil Gupta <nikhil.gupta2@arm.com> Co-authored-by: Yong Hoon Shin <48474650+sarckk@users.noreply.github.com> Co-authored-by: lhsjohn <huashuoli@tencent.com> Co-authored-by: Ben Browning <bbrownin@redhat.com> Co-authored-by: Li, Jiang <jiang1.li@intel.com> Co-authored-by: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> Co-authored-by: Jonas M. Kübler <44084297+jmkuebler@users.noreply.github.com> Co-authored-by: Tao Hui <taohui3@gmail.com> Co-authored-by: rongfu.leng <rongfu.leng@daocloud.io> Co-authored-by: Shu Wang <shuw@nvidia.com> Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Co-authored-by: Duncan Moss <djm.moss@gmail.com> Co-authored-by: Shiyan Deng <dsy842974287@meta.com> Co-authored-by: Wei Wei <wwei6@meta.com> Co-authored-by: Saman A. Pour <samanamp@outlook.com> Co-authored-by: XuruiYang <530534756@qq.com> Co-authored-by: yangxurui <yangxurui@meituan.com> Co-authored-by: Nicole LiHui 🥜 <nicolelihui@outlook.com> Co-authored-by: courage17340 <courage17340@users.noreply.github.com> Co-authored-by: Jacob Kahn <jacobkahn1@gmail.com> Co-authored-by: Nicole LiHui 🥜 <nicole.li@daocloud.io> Co-authored-by: Fadi Arafeh <115173828+fadara01@users.noreply.github.com> Co-authored-by: Agata Dobrzyniewicz <160237065+adobrzyn@users.noreply.github.com> Co-authored-by: yyzxw <34639446+yyzxw@users.noreply.github.com> Co-authored-by: wang.yuqi <noooop@126.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: chenlang <chen.lang5@zte.com.cn> Co-authored-by: chenlang <10346245@zte.com.cn> Co-authored-by: AlonKejzman <alonkeizman@gmail.com> Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> Co-authored-by: Aleksandr Malyshev <164964928+maleksan85@users.noreply.github.com> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Doug Lehr <douglehr@amd.com> Co-authored-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Co-authored-by: yitingdc <59356937+yitingdc@users.noreply.github.com> Co-authored-by: xaguilar-amd <xavier.aguilarfruto@amd.com> Co-authored-by: Iceber Gu <caiwei95@hotmail.com> Co-authored-by: Tao He <linzhu.ht@alibaba-inc.com> Co-authored-by: Icey <1790571317@qq.com> Co-authored-by: Xu Wenqing <121550081+Xu-Wenqing@users.noreply.github.com> Co-authored-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: RishiAstra <40644327+RishiAstra@users.noreply.github.com>
368 lines
14 KiB
Python
368 lines
14 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import inspect
|
|
from collections.abc import Sequence
|
|
from typing import Optional
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
|
|
from vllm.platforms import current_platform
|
|
from vllm.sampling_params import SamplingParams
|
|
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
|
from vllm.v1.pool.metadata import PoolingMetadata
|
|
from vllm.v1.sample.logits_processor import LogitsProcessors
|
|
from vllm.v1.sample.metadata import SamplingMetadata
|
|
from vllm.v1.utils import CpuGpuBuffer
|
|
from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable
|
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
|
|
|
VOCAB_SIZE = 1024
|
|
NUM_OUTPUT_TOKENS = 20
|
|
MAX_PROMPT_SIZE = 100
|
|
CUDA_DEVICES = [
|
|
f"{current_platform.device_type}:{i}"
|
|
for i in range(min(current_platform.device_count(), 2))
|
|
]
|
|
MAX_NUM_PROMPT_TOKENS = 64
|
|
|
|
|
|
def _compare_objs(obj1,
|
|
obj2,
|
|
skip: Sequence = ("logitsprocs", "batch_update_builder")):
|
|
attrs = inspect.getmembers(obj1, lambda a: not (inspect.isroutine(a)))
|
|
attr_names = set([
|
|
a[0] for a in attrs
|
|
if not (a[0].startswith('__') and a[0].endswith('__'))
|
|
])
|
|
for attr_name in attr_names:
|
|
if attr_name in skip:
|
|
continue
|
|
|
|
a = getattr(obj1, attr_name)
|
|
b = getattr(obj2, attr_name)
|
|
|
|
is_same = False
|
|
if isinstance(a, torch.Tensor):
|
|
if a.numel() == 0 or b.numel() == 0:
|
|
is_same = (a.numel() == 0 and b.numel() == 0)
|
|
elif torch.allclose(a, b):
|
|
is_same = True
|
|
elif isinstance(a, np.ndarray):
|
|
if np.allclose(a, b):
|
|
is_same = True
|
|
elif isinstance(a, MultiGroupBlockTable):
|
|
for a_i, b_i in zip(a.block_tables, b.block_tables):
|
|
_compare_objs(a_i, b_i)
|
|
is_same = True
|
|
elif isinstance(a, (BlockTable, SamplingMetadata, PoolingMetadata)):
|
|
_compare_objs(a, b)
|
|
is_same = True # if we make it here must be same
|
|
elif a == b:
|
|
is_same = True
|
|
elif isinstance(a, CpuGpuBuffer):
|
|
is_same = np.allclose(a.np, b.np) and torch.allclose(a.gpu, b.gpu)
|
|
assert is_same, f"Attribute {attr_name} is different"\
|
|
f" in {obj1} and {obj2}: {a} != {b}"
|
|
|
|
|
|
def _remove_requests(input_batch: InputBatch, batch_size: int,
|
|
reqs: list[CachedRequestState]) -> set[str]:
|
|
"""
|
|
Remove some requests randomly from the batch and returns
|
|
set of request removed
|
|
"""
|
|
|
|
num_reqs_to_remove = np.random.randint(0, batch_size)
|
|
req_indices_to_remove: set[int] = set()
|
|
for _ in range(num_reqs_to_remove):
|
|
req_index_to_remove = np.random.randint(0, batch_size)
|
|
req_indices_to_remove.add(req_index_to_remove)
|
|
|
|
req_ids_to_remove: set[str] = set()
|
|
for index in req_indices_to_remove:
|
|
input_batch.remove_request(reqs[index].req_id)
|
|
req_ids_to_remove.add(reqs[index].req_id)
|
|
return req_ids_to_remove
|
|
|
|
|
|
def _construct_expected_sampling_metadata(
|
|
reqs: list[CachedRequestState],
|
|
req_ids_retained: set[int],
|
|
req_id_index_in_input_batch: dict[str, int],
|
|
device: torch.device,
|
|
) -> SamplingMetadata:
|
|
"""
|
|
Constructs and returns the expected SamplingMetadata for this
|
|
batch.
|
|
"""
|
|
num_reqs = len(req_ids_retained)
|
|
output_token_ids: list[list[int]] = [list() for _ in range(num_reqs)]
|
|
prompt_token_ids: list[list[int]] = [list() for _ in range(num_reqs)]
|
|
presence_penalties = [0.0 for _ in range(num_reqs)]
|
|
frequency_penalties = [0.0 for _ in range(num_reqs)]
|
|
repetition_penalties = [1.0 for _ in range(num_reqs)]
|
|
top_k = [0 for _ in range(num_reqs)]
|
|
top_p = [0.0 for _ in range(num_reqs)]
|
|
temperature = [0.0 for _ in range(num_reqs)]
|
|
min_tokens = {}
|
|
logit_bias = [None] * num_reqs
|
|
allowed_token_ids_mask = torch.zeros(num_reqs,
|
|
VOCAB_SIZE,
|
|
dtype=torch.bool,
|
|
device=device)
|
|
bad_words_token_ids = {}
|
|
for req in reqs:
|
|
if req.req_id not in req_ids_retained:
|
|
continue
|
|
index_in_input_batch = req_id_index_in_input_batch[req.req_id]
|
|
output_token_ids[index_in_input_batch] = req.output_token_ids
|
|
prompt_token_ids[index_in_input_batch] = req.prompt_token_ids
|
|
presence_penalties[
|
|
index_in_input_batch] = req.sampling_params.presence_penalty
|
|
frequency_penalties[index_in_input_batch] = (
|
|
req.sampling_params.frequency_penalty)
|
|
repetition_penalties[index_in_input_batch] = (
|
|
req.sampling_params.repetition_penalty)
|
|
top_k[index_in_input_batch] = req.sampling_params.top_k
|
|
top_p[index_in_input_batch] = req.sampling_params.top_p
|
|
temperature[index_in_input_batch] = req.sampling_params.temperature
|
|
min_tokens[index_in_input_batch] = (
|
|
req.sampling_params.min_tokens,
|
|
req.sampling_params.all_stop_token_ids)
|
|
logit_bias[index_in_input_batch] = req.sampling_params.logit_bias
|
|
if req.sampling_params.allowed_token_ids:
|
|
allowed_token_ids_mask[index_in_input_batch][
|
|
req.sampling_params.allowed_token_ids] = True
|
|
if req.sampling_params.bad_words_token_ids:
|
|
bad_words_token_ids[
|
|
index_in_input_batch] = req.sampling_params.bad_words_token_ids
|
|
|
|
return SamplingMetadata(
|
|
temperature=torch.tensor(temperature, dtype=torch.float,
|
|
device=device),
|
|
all_greedy=False,
|
|
all_random=True,
|
|
top_p=None if all(x == 1.0 for x in top_p) else torch.tensor(
|
|
top_p, dtype=torch.float, device=device),
|
|
top_k=None if all(x == 0 for x in top_k) else torch.tensor(
|
|
top_k, dtype=torch.int, device=device),
|
|
generators={},
|
|
max_num_logprobs=0,
|
|
prompt_token_ids=make_tensor_with_pad(
|
|
prompt_token_ids,
|
|
pad=VOCAB_SIZE,
|
|
device=torch.device(device),
|
|
dtype=torch.int64,
|
|
),
|
|
frequency_penalties=torch.tensor(frequency_penalties,
|
|
dtype=torch.float,
|
|
device=device),
|
|
presence_penalties=torch.tensor(presence_penalties,
|
|
dtype=torch.float,
|
|
device=device),
|
|
repetition_penalties=torch.tensor(repetition_penalties,
|
|
dtype=torch.float,
|
|
device=device),
|
|
output_token_ids=output_token_ids,
|
|
no_penalties=(all(x == 0 for x in presence_penalties)
|
|
and all(x == 0 for x in frequency_penalties)
|
|
and all(x == 1 for x in repetition_penalties)),
|
|
allowed_token_ids_mask=allowed_token_ids_mask,
|
|
bad_words_token_ids=bad_words_token_ids,
|
|
logitsprocs=LogitsProcessors(),
|
|
)
|
|
|
|
|
|
def _create_sampling_params():
|
|
return SamplingParams(
|
|
top_k=np.random.randint(1, 10),
|
|
top_p=np.random.uniform(0.0, 1.0),
|
|
presence_penalty=np.random.uniform(-2.0, 2.0),
|
|
repetition_penalty=np.random.uniform(0.0, 2.0),
|
|
frequency_penalty=np.random.uniform(-2.0, 2.0),
|
|
min_tokens=np.random.randint(1, 10),
|
|
stop_token_ids=[
|
|
np.random.randint(0, VOCAB_SIZE)
|
|
for _ in range(np.random.randint(10))
|
|
],
|
|
logit_bias={0: np.random.uniform(-3.0, 3.0)},
|
|
)
|
|
|
|
|
|
def _construct_cached_request_state(req_id_suffix: int):
|
|
prompt_token_ids = [
|
|
np.random.randint(0, VOCAB_SIZE)
|
|
for _ in range(np.random.randint(0, MAX_PROMPT_SIZE))
|
|
]
|
|
output_token_ids = [
|
|
np.random.randint(0, VOCAB_SIZE)
|
|
for _ in range(np.random.randint(0, NUM_OUTPUT_TOKENS))
|
|
]
|
|
return CachedRequestState(
|
|
req_id=f"req_id_{req_id_suffix}",
|
|
prompt_token_ids=prompt_token_ids,
|
|
sampling_params=_create_sampling_params(),
|
|
pooling_params=None,
|
|
mm_features=[],
|
|
block_ids=([], ),
|
|
generator=None,
|
|
num_computed_tokens=len(output_token_ids),
|
|
output_token_ids=output_token_ids,
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
|
@pytest.mark.parametrize("batch_size", [1, 2, 32, 64])
|
|
def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
|
|
"""
|
|
Tests the logic for managing sampling metadata in the InputBatch.
|
|
|
|
This test involves adding a set of requests to the InputBatch,
|
|
followed by removing a subset of them. Afterward, the batch is compacted,
|
|
and the `make_sampling_metadata` method is invoked on the batch. The
|
|
output of `make_sampling_metadata` is then compared against the expected
|
|
results to ensure correctness.
|
|
|
|
Note: Ignore logits processor logic, which is tested separately
|
|
"""
|
|
input_batch: InputBatch = InputBatch(
|
|
max_num_reqs=batch_size,
|
|
max_model_len=1024,
|
|
max_num_batched_tokens=1024,
|
|
device=torch.device(device),
|
|
pin_memory=is_pin_memory_available(),
|
|
vocab_size=1024,
|
|
block_sizes=[1],
|
|
)
|
|
reqs: list[CachedRequestState] = []
|
|
req_id_reqs = {}
|
|
req_id_output_token_ids = {}
|
|
|
|
# Add requests
|
|
for req_index in range(batch_size):
|
|
req: CachedRequestState = _construct_cached_request_state(req_index)
|
|
assigned_req_index = input_batch.add_request(req)
|
|
assert req_index == assigned_req_index
|
|
reqs.append(req)
|
|
req_id_reqs[req.req_id] = req
|
|
req_id_output_token_ids[req.req_id] = req.output_token_ids
|
|
|
|
# Remove some requests
|
|
req_ids_to_remove = _remove_requests(input_batch, batch_size, reqs)
|
|
req_ids_retained = set(req_id_reqs.keys()) - req_ids_to_remove
|
|
|
|
# Compact the input batch
|
|
input_batch.condense()
|
|
|
|
# Generate the sampling metadata
|
|
sampling_metadata = input_batch._make_sampling_metadata()
|
|
|
|
# Create expected output.
|
|
expected_sampling_metadata = _construct_expected_sampling_metadata(
|
|
reqs,
|
|
req_ids_retained,
|
|
input_batch.req_id_to_index,
|
|
device=torch.device(device))
|
|
|
|
def same(t1: Optional[torch.Tensor], t2: Optional[torch.Tensor]) -> bool:
|
|
return (t1 is None
|
|
and t2 is None) or (t1 is not None and t2 is not None
|
|
and torch.allclose(t1, t2))
|
|
|
|
# Assert the actual and expected output.
|
|
assert torch.allclose(expected_sampling_metadata.temperature,
|
|
sampling_metadata.temperature)
|
|
assert same(expected_sampling_metadata.top_p, sampling_metadata.top_p)
|
|
assert same(expected_sampling_metadata.top_k, sampling_metadata.top_k)
|
|
assert torch.allclose(
|
|
expected_sampling_metadata.frequency_penalties,
|
|
sampling_metadata.frequency_penalties,
|
|
)
|
|
assert torch.allclose(
|
|
expected_sampling_metadata.presence_penalties,
|
|
sampling_metadata.presence_penalties,
|
|
)
|
|
assert torch.allclose(
|
|
expected_sampling_metadata.repetition_penalties,
|
|
sampling_metadata.repetition_penalties,
|
|
)
|
|
assert torch.allclose(expected_sampling_metadata.prompt_token_ids,
|
|
sampling_metadata.prompt_token_ids)
|
|
assert (expected_sampling_metadata.output_token_ids ==
|
|
sampling_metadata.output_token_ids)
|
|
assert expected_sampling_metadata.no_penalties == \
|
|
sampling_metadata.no_penalties
|
|
if sampling_metadata.allowed_token_ids_mask:
|
|
assert torch.allclose(
|
|
expected_sampling_metadata.allowed_token_ids_mask,
|
|
sampling_metadata.allowed_token_ids_mask)
|
|
assert expected_sampling_metadata.bad_words_token_ids == \
|
|
sampling_metadata.bad_words_token_ids
|
|
|
|
|
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
|
@pytest.mark.parametrize("batch_size", [32])
|
|
@pytest.mark.parametrize("swap_list", [((0, 1), )])
|
|
def test_swap_states_in_input_batch(device: str, batch_size: int,
|
|
swap_list: list):
|
|
"""
|
|
Tests the logic for managing sampling metadata in the InputBatch.
|
|
|
|
This test involves adding a set of requests to the InputBatch,
|
|
followed by removing a subset of them. Afterward, the batch is compacted,
|
|
and the `make_sampling_metadata` method is invoked on the batch. The
|
|
output of `make_sampling_metadata` is then compared against the expected
|
|
results to ensure correctness.
|
|
|
|
Note: Ignore logits processor logic, which is tested separately
|
|
"""
|
|
input_batch: InputBatch = InputBatch(
|
|
max_num_reqs=batch_size,
|
|
max_model_len=1024,
|
|
max_num_batched_tokens=1024,
|
|
device=torch.device(device),
|
|
pin_memory=is_pin_memory_available(),
|
|
vocab_size=1024,
|
|
block_sizes=[1],
|
|
)
|
|
ref_input_batch: InputBatch = InputBatch(
|
|
max_num_reqs=batch_size,
|
|
max_model_len=1024,
|
|
max_num_batched_tokens=1024,
|
|
device=torch.device(device),
|
|
pin_memory=is_pin_memory_available(),
|
|
vocab_size=1024,
|
|
block_sizes=[1],
|
|
)
|
|
|
|
reqs: list[CachedRequestState] = []
|
|
req_id_reqs = {}
|
|
req_id_output_token_ids = {}
|
|
# Add requests
|
|
for req_index in range(batch_size):
|
|
req: CachedRequestState = _construct_cached_request_state(req_index)
|
|
assigned_req_index = input_batch.add_request(req)
|
|
assert assigned_req_index == req_index
|
|
reqs.append(req)
|
|
req_id_reqs[req.req_id] = req
|
|
req_id_output_token_ids[req.req_id] = req.output_token_ids
|
|
|
|
reordered_reqs = reqs.copy()
|
|
for swap_pair in swap_list:
|
|
reordered_reqs[swap_pair[0]], reordered_reqs[swap_pair[1]] = \
|
|
reordered_reqs[swap_pair[1]], reordered_reqs[swap_pair[0]]
|
|
input_batch.swap_states(swap_pair[0], swap_pair[1])
|
|
|
|
for req_index in range(batch_size):
|
|
req = reordered_reqs[req_index]
|
|
assigned_req_index = ref_input_batch.add_request(req)
|
|
assert assigned_req_index == req_index
|
|
|
|
input_batch.refresh_metadata()
|
|
ref_input_batch.refresh_metadata()
|
|
|
|
_compare_objs(input_batch, ref_input_batch)
|