mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-07 20:07:07 +08:00
Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: Lucas Kabela <lucaskabela@meta.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: Andrew Sansom <andrew@protopia.ai> Signed-off-by: Boyuan Feng <boyuan@meta.com> Signed-off-by: Boyuan Feng <fby.1994@gmail.com> Signed-off-by: boyuanfeng <boyuan@meta.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: JartX <sagformas@epdcenter.es> Signed-off-by: Chendi Xue <Chendi.Xue@intel.com> Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Signed-off-by: Roger Wang <hey@rogerw.io> Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: wwl2755 <wangwenlong2755@gmail.com> Signed-off-by: Manoel Marques <manoel.marques@ibm.com> Signed-off-by: Manoel Marques <manoelmrqs@gmail.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: pengdrumli <pengdrumli@tencent.com> Signed-off-by: windsonsea <haifeng.yao@daocloud.io> Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai> Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: Huamin Li <3ericli@gmail.com> Signed-off-by: simondanielsson <simon.danielsson99@hotmail.com> Signed-off-by: Rahul Tuli <rtuli@redhat.com> Signed-off-by: Yang <lymailforjob@gmail.com> Signed-off-by: Debolina Roy <debroy@redhat.com> Signed-off-by: David Chen <530634352@qq.com> Signed-off-by: wangzi <3220100013@zju.edu.cn> Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com> Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com> Signed-off-by: Sara Kokkila Schumacher <saraks@ibm.com> Signed-off-by: Csrayz <jover@cmbchina.com> Signed-off-by: ivyilike <pww123@cmbchina.com> Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com> Signed-off-by: Bowen Wang <abmfy@icloud.com> Signed-off-by: qqma <qqma@amazon.com> Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: Lu Fang <fanglu@fb.com> Signed-off-by: Zhuohan Li <zhuohan123@gmail.com> Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: luka <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: Or Ozeri <oro@il.ibm.com> Signed-off-by: Johnny Yang <johnnyyang@google.com> Signed-off-by: Alec Solder <alecs@fb.com> Signed-off-by: Alec S <10566873+alecsolder@users.noreply.github.com> Signed-off-by: Russell Bryant <rbryant@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Alexander Matveev <amatveev@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: liuye.hj <liuye.hj@alibaba-inc.com> Signed-off-by: Kunshang Ji <kunshang.ji@intel.com> Signed-off-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Signed-off-by: Ming Yang <minos.future@gmail.com> Signed-off-by: Zhikaiiii <1658973216@qq.com> Signed-off-by: Andreas Hartel <andreas.hartel@aleph-alpha.com> Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Signed-off-by: wuxibin <wuxibin@bytedance.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: Peter Pan <Peter.Pan@daocloud.io> Signed-off-by: Peter Pan <peter.pan@daocloud.io> Signed-off-by: Nicolò Lucchesi<nicolo.lucchesi@gmail.com> Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Signed-off-by: Sage Moore <sage@neuralmagic.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Signed-off-by: Bill Nell <bnell@redhat.com> Signed-off-by: Shreeasish Kumar <shreeasish@rivosinc.com> Signed-off-by: Weida Hong <wdhongtw@google.com> Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com> Signed-off-by: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com> Signed-off-by: Amir Samani <asamani@nvidia.com> Signed-off-by: ElizaWszola <elizaw.9289@gmail.com> Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Signed-off-by: ilmarkov <markovilya197@gmail.com> Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Signed-off-by: rouchenzi <ruochenwen@gmail.com> Signed-off-by: rouchenzi <40842833+rouchenzi@users.noreply.github.com> Signed-off-by: Andrew Xia <axia@meta.com> Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com> Signed-off-by: Corey Lowman <clowman1993@gmail.com> Signed-off-by: jpvillam <jpvillam@amd.com> Signed-off-by: dougbtv <dosmith@redhat.com> Signed-off-by: Chenxi Yang <cxyang@fb.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: ahao-anyscale <ahao@anyscale.com> Signed-off-by: Yan Lu <luyan@nvidia.com> Signed-off-by: baxingpiaochong <771405853@qq.com> Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> Signed-off-by: Nikhil Gupta <nikhil.gupta2@arm.com> Signed-off-by: Yong Hoon Shin <yhshin@meta.com> Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai> Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Signed-off-by: Ben Browning <bbrownin@redhat.com> Signed-off-by: Chengji Yao <chengjiyao@google.com> Signed-off-by: jiang1.li <jiang1.li@intel.com> Signed-off-by: Jackmin801 <ongjackm@gmail.com> Signed-off-by: Jonas M. Kübler <44084297+jmkuebler@users.noreply.github.com> Signed-off-by: taohui <taohui3@gmail.com> Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io> Signed-off-by: Shu Wang <shuw@nvidia.com> Signed-off-by: Shu Wang. <shuw@nvidia.com> Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Signed-off-by: Duncan Moss <djm.moss@gmail.com> Signed-off-by: Shiyan Deng <dsy842974287@meta.com> Signed-off-by: Wei Wei <wwei6@meta.com> Signed-off-by: Saman Keon <samanamp@outlook.com> Signed-off-by: yangxurui <yangxurui@meituan.com> Signed-off-by: nicole-lihui <nicole.li@daocloud.io> Signed-off-by: courage17340 <courage17340@163.com> Signed-off-by: Jacob Kahn <jacobkahn1@gmail.com> Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com> Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai> Signed-off-by: zxw <1020938856@qq.com> Signed-off-by: wang.yuqi <noooop@126.com> Signed-off-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Signed-off-by: chenlang <chen.lang5@zte.com.cn> Signed-off-by: Jonas Kuebler <kuebj@amazon.com> Signed-off-by: AlonKejzman <alonkeizman@gmail.com> Signed-off-by: Tao Hui <taohui3@gmail.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Signed-off-by: Eugene Khvedchenia <ekhvedchenia@nvidia.com> Signed-off-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Signed-off-by: yiting.jiang <yiting.jiang@daocloud.io> Signed-off-by: xaguilar <Xavier.AguilarFruto@amd.com> Signed-off-by: Iceber Gu <caiwei95@hotmail.com> Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com> Signed-off-by: Icey <1790571317@qq.com> Signed-off-by: 许文卿 <xwq391974@alibaba-inc.com> Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: Nick Hill <nhill@redhat.com> Co-authored-by: Lucas Kabela <lucasakabela@gmail.com> Co-authored-by: Maximilien de Bayser <mbayser@br.ibm.com> Co-authored-by: Andrew Sansom <andrew@protopia.ai> Co-authored-by: Boyuan Feng <boyuan@meta.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: JartX <sagformas@epdcenter.es> Co-authored-by: Chendi.Xue <chendi.xue@intel.com> Co-authored-by: Chauncey <chaunceyjiang@gmail.com> Co-authored-by: xin.li <xin.li@daocloud.io> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Roger Wang <hey@rogerw.io> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Wenlong Wang <wangwenlong2755@gmail.com> Co-authored-by: Manoel Marques <manoelmrqs@gmail.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: lirong <56789630+lirong-lirong@users.noreply.github.com> Co-authored-by: Michael Yao <haifeng.yao@daocloud.io> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: Huamin Li <3ericli@gmail.com> Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com> Co-authored-by: Simon Danielsson <70206058+simondanielsson@users.noreply.github.com> Co-authored-by: Rahul Tuli <rtuli@redhat.com> Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: Yang Liu <127183760+KKSK-DON@users.noreply.github.com> Co-authored-by: Deboleina <debroy@redhat.com> Co-authored-by: yinz-aizip <yinz@aizip.ai> Co-authored-by: WeiQing Chen <40507679+david6666666@users.noreply.github.com> Co-authored-by: wangzi <3220100013@zju.edu.cn> Co-authored-by: Eldar Kurtić <8884008+eldarkurtic@users.noreply.github.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com> Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com> Co-authored-by: Yizhou <136800916+yiz-liu@users.noreply.github.com> Co-authored-by: Sara-KS <50249410+Sara-KS@users.noreply.github.com> Co-authored-by: Csrayz <jover@cmbchina.com> Co-authored-by: ivyilike <pww123@cmbchina.com> Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com> Co-authored-by: Bowen Wang <abmfy@icloud.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Daisy-Ma-coder <daisy.ma.0117@gmail.com> Co-authored-by: qqma <qqma@amazon.com> Co-authored-by: ElizaWszola <ewszola@redhat.com> Co-authored-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Co-authored-by: Zhuohan Li <zhuohan123@gmail.com> Co-authored-by: Simon Mo <simon.mo@hey.com> Co-authored-by: Or Ozeri <oro@il.ibm.com> Co-authored-by: Johnny Yang <24908445+jcyang43@users.noreply.github.com> Co-authored-by: Chengji Yao <chengjiyao@google.com> Co-authored-by: Alec S <10566873+alecsolder@users.noreply.github.com> Co-authored-by: Alec Solder <alecs@fb.com> Co-authored-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Chris Bamford <chrisbam4d@gmail.com> Co-authored-by: Alexander Matveev <59768536+alexm-redhat@users.noreply.github.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: JJJYmmm <92386084+JJJYmmm@users.noreply.github.com> Co-authored-by: liuye.hj <liuye.hj@alibaba-inc.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com> Co-authored-by: Lucia (Lu) Fang <fanglu@meta.com> Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Ming Yang <yming@meta.com> Co-authored-by: Zhikaiiii <55917203+Zhikaiiii@users.noreply.github.com> Co-authored-by: Andreas Hartel <andreas@hartel.me> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: Joel <wuxibin89@163.com> Co-authored-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Mark McLoughlin <markmc@redhat.com> Co-authored-by: Peter Pan <peter.pan@daocloud.io> Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com> Co-authored-by: Fanli Lin <fanli.lin@intel.com> Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Sage Moore <sage@neuralmagic.com> Co-authored-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: bnellnm <49004751+bnellnm@users.noreply.github.com> Co-authored-by: rivos-shreeasish <shreeasish@rivosinc.com> Co-authored-by: Chih-Chieh Yang <chih.chieh.yang@ibm.com> Co-authored-by: Weida Hong <wdhongtw@gmail.com> Co-authored-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Co-authored-by: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com> Co-authored-by: Amir Samani <samani@ualberta.ca> Co-authored-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Co-authored-by: Ilya Markov <markovilya197@gmail.com> Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Co-authored-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Co-authored-by: rouchenzi <40842833+rouchenzi@users.noreply.github.com> Co-authored-by: Andrew Xia <axia@meta.com> Co-authored-by: kourosh hakhamaneshi <31483498+kouroshHakha@users.noreply.github.com> Co-authored-by: Corey Lowman <clowman1993@gmail.com> Co-authored-by: Juan Villamizar <100237675+jpvillam-amd@users.noreply.github.com> Co-authored-by: jpvillam <jpvillam@amd.com> Co-authored-by: Doug Smith <dosmith@redhat.com> Co-authored-by: Chenxi Yang <cxyang@cs.utexas.edu> Co-authored-by: Chenxi Yang <cxyang@fb.com> Co-authored-by: ahao-anyscale <ahao@anyscale.com> Co-authored-by: 0xNullPath <luyanfcp@foxmail.com> Co-authored-by: baxingpiaochong <771405853@qq.com> Co-authored-by: Benjamin Chislett <bchislett@nvidia.com> Co-authored-by: Kyle Sayers <kylesayrs@gmail.com> Co-authored-by: Nikhil Gupta <nikhil.gupta2@arm.com> Co-authored-by: Yong Hoon Shin <48474650+sarckk@users.noreply.github.com> Co-authored-by: lhsjohn <huashuoli@tencent.com> Co-authored-by: Ben Browning <bbrownin@redhat.com> Co-authored-by: Li, Jiang <jiang1.li@intel.com> Co-authored-by: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> Co-authored-by: Jonas M. Kübler <44084297+jmkuebler@users.noreply.github.com> Co-authored-by: Tao Hui <taohui3@gmail.com> Co-authored-by: rongfu.leng <rongfu.leng@daocloud.io> Co-authored-by: Shu Wang <shuw@nvidia.com> Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Co-authored-by: Duncan Moss <djm.moss@gmail.com> Co-authored-by: Shiyan Deng <dsy842974287@meta.com> Co-authored-by: Wei Wei <wwei6@meta.com> Co-authored-by: Saman A. Pour <samanamp@outlook.com> Co-authored-by: XuruiYang <530534756@qq.com> Co-authored-by: yangxurui <yangxurui@meituan.com> Co-authored-by: Nicole LiHui 🥜 <nicolelihui@outlook.com> Co-authored-by: courage17340 <courage17340@users.noreply.github.com> Co-authored-by: Jacob Kahn <jacobkahn1@gmail.com> Co-authored-by: Nicole LiHui 🥜 <nicole.li@daocloud.io> Co-authored-by: Fadi Arafeh <115173828+fadara01@users.noreply.github.com> Co-authored-by: Agata Dobrzyniewicz <160237065+adobrzyn@users.noreply.github.com> Co-authored-by: yyzxw <34639446+yyzxw@users.noreply.github.com> Co-authored-by: wang.yuqi <noooop@126.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: chenlang <chen.lang5@zte.com.cn> Co-authored-by: chenlang <10346245@zte.com.cn> Co-authored-by: AlonKejzman <alonkeizman@gmail.com> Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> Co-authored-by: Aleksandr Malyshev <164964928+maleksan85@users.noreply.github.com> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Doug Lehr <douglehr@amd.com> Co-authored-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Co-authored-by: yitingdc <59356937+yitingdc@users.noreply.github.com> Co-authored-by: xaguilar-amd <xavier.aguilarfruto@amd.com> Co-authored-by: Iceber Gu <caiwei95@hotmail.com> Co-authored-by: Tao He <linzhu.ht@alibaba-inc.com> Co-authored-by: Icey <1790571317@qq.com> Co-authored-by: Xu Wenqing <121550081+Xu-Wenqing@users.noreply.github.com> Co-authored-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: RishiAstra <40644327+RishiAstra@users.noreply.github.com>
1226 lines
49 KiB
Python
1226 lines
49 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import random
|
|
from copy import deepcopy
|
|
from dataclasses import dataclass
|
|
from typing import Optional
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from vllm.config.lora import LoRAConfig
|
|
# yapf conflicts with isort for this block
|
|
# yapf: disable
|
|
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
|
|
ColumnParallelLinearWithShardedLoRA,
|
|
LogitsProcessorWithLoRA, LoRAMapping,
|
|
MergedColumnParallelLinearWithLoRA,
|
|
MergedColumnParallelLinearWithShardedLoRA,
|
|
MergedQKVParallelLinearWithLoRA,
|
|
MergedQKVParallelLinearWithShardedLoRA,
|
|
QKVParallelLinearWithLoRA,
|
|
QKVParallelLinearWithShardedLoRA,
|
|
ReplicatedLinearWithLoRA,
|
|
RowParallelLinearWithLoRA,
|
|
RowParallelLinearWithShardedLoRA,
|
|
VocabParallelEmbeddingWithLoRA)
|
|
# yapf: enable
|
|
from vllm.lora.models import LoRALayerWeights, PackedLoRALayerWeights
|
|
from vllm.lora.punica_wrapper import get_punica_wrapper
|
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|
MergedColumnParallelLinear,
|
|
QKVParallelLinear,
|
|
ReplicatedLinear,
|
|
RowParallelLinear)
|
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
ParallelLMHead, VocabParallelEmbedding, get_masked_input_and_mask)
|
|
from vllm.model_executor.utils import set_random_seed
|
|
from vllm.platforms import current_platform
|
|
|
|
from .utils import DummyLoRAManager
|
|
|
|
TOLERANCES = {
|
|
torch.float16: (5e-3, 5e-3),
|
|
torch.float32: (5e-3, 5e-3),
|
|
torch.bfloat16: (3e-2, 2e-2),
|
|
}
|
|
|
|
pytestmark = pytest.mark.skipif(
|
|
not (current_platform.is_cuda_alike() or current_platform.is_cpu()),
|
|
reason="Backend not supported")
|
|
|
|
DEVICES = ([
|
|
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
|
] if current_platform.is_cuda_alike() else ["cpu"])
|
|
|
|
# prefill stage(True) or decode stage(False)
|
|
STAGES = [True, False]
|
|
|
|
NUM_RANDOM_SEEDS = 2
|
|
|
|
VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS = 2
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def clean_cache_reset_device(reset_default_device):
|
|
# Release any memory we might be holding on to. CI runs OOMs otherwise.
|
|
from vllm.lora.ops.triton_ops.utils import (_LORA_A_PTR_DICT,
|
|
_LORA_B_PTR_DICT)
|
|
_LORA_B_PTR_DICT.clear()
|
|
_LORA_A_PTR_DICT.clear()
|
|
|
|
yield
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def skip_cuda_with_stage_false(request):
|
|
"""
|
|
On cuda-like platforms, we use the same kernels for prefill and decode
|
|
stage, and 'stage' is generally ignored, so we only need to test once.
|
|
"""
|
|
if current_platform.is_cuda_alike():
|
|
try:
|
|
if hasattr(request.node, "callspec") and hasattr(
|
|
request.node.callspec, "params"):
|
|
params = request.node.callspec.params
|
|
if "stage" in params and params["stage"] is False:
|
|
pytest.skip("Skip test when stage=False")
|
|
except Exception:
|
|
pass
|
|
yield
|
|
|
|
|
|
def get_random_id_to_index(num_loras: int,
|
|
num_slots: int,
|
|
log: bool = True) -> list[Optional[int]]:
|
|
"""Creates a random lora_id_to_index mapping.
|
|
|
|
Args:
|
|
num_loras: The number of active loras in the mapping.
|
|
num_slots: The number of slots in the mapping. Must be larger
|
|
than num_loras.
|
|
log: Whether to log the output.
|
|
"""
|
|
|
|
if num_loras > num_slots:
|
|
raise ValueError(
|
|
f"num_loras is higher than num_slots: {num_loras} > {num_slots}. "
|
|
"num_loras must be less than or equal to num_slots.")
|
|
|
|
slots: list[Optional[int]] = [None] * num_slots
|
|
random_slot_selections = (torch.randperm(num_slots)[:num_loras]).tolist()
|
|
for lora_id, slot_idx in enumerate(random_slot_selections, start=1):
|
|
slots[slot_idx] = lora_id
|
|
|
|
if log:
|
|
print(f"Created lora_id_to_index mapping: {slots}.")
|
|
|
|
return slots
|
|
|
|
|
|
def populate_loras(
|
|
id_to_index: list[Optional[int]],
|
|
layer: BaseLayerWithLoRA,
|
|
layer_weights: torch.Tensor,
|
|
generate_embeddings_tensor: int = 0,
|
|
repeats: int = 1,
|
|
) -> tuple[dict[int, LoRALayerWeights], dict[int, list[LoRALayerWeights]]]:
|
|
"""This method populates the lora layers with lora weights.
|
|
|
|
Args:
|
|
id_to_index: a list of lora ids. The index of the lora id
|
|
represents which memory slot the lora matrices are
|
|
stored in. A None value indicates a free slot.
|
|
layer: the LoRAlayer to populate.
|
|
layer_weights: the PyTorch tensor containing the layer's
|
|
weights.
|
|
generate_embeddings_tensor: whether to generate an
|
|
embeddings tensor for each LoRA.
|
|
repeats: must only be set for column parallel packed
|
|
layers. Indicates the number of loras to compose
|
|
together to create a single lora layer.
|
|
"""
|
|
|
|
# Dictionary that maps the lora ID to the
|
|
# corresponding lora weights.
|
|
lora_dict: dict[int, LoRALayerWeights] = dict()
|
|
|
|
# Dictionary that maps the lora ID to the
|
|
# corresponding subloras.
|
|
sublora_dict: dict[int, list[LoRALayerWeights]] = dict()
|
|
|
|
for slot_idx, lora_id in enumerate(id_to_index):
|
|
if lora_id is not None:
|
|
subloras: list[LoRALayerWeights] = []
|
|
sublora_len = layer_weights.shape[0] // repeats
|
|
for i in range(repeats):
|
|
sublora = DummyLoRAManager(
|
|
layer_weights.device).init_random_lora(
|
|
module_name=f"fake_{i}",
|
|
weight=layer_weights,
|
|
generate_embeddings_tensor=generate_embeddings_tensor,
|
|
)
|
|
sublora.lora_b = sublora.lora_b[(sublora_len *
|
|
i):(sublora_len * (i + 1)), :]
|
|
sublora.optimize()
|
|
subloras.append(sublora)
|
|
|
|
lora = PackedLoRALayerWeights.pack(
|
|
subloras) if repeats > 1 else subloras[0]
|
|
|
|
layer.set_lora(
|
|
slot_idx,
|
|
lora_a=lora.lora_a,
|
|
lora_b=lora.lora_b,
|
|
embeddings_tensor=lora.embeddings_tensor,
|
|
)
|
|
|
|
lora_dict[lora_id] = lora
|
|
sublora_dict[lora_id] = subloras
|
|
|
|
return lora_dict, sublora_dict
|
|
|
|
|
|
def create_random_inputs(
|
|
active_lora_ids: list[int],
|
|
num_inputs: int,
|
|
input_size: tuple[int, ...],
|
|
input_range: tuple[float, float],
|
|
input_type: torch.dtype = torch.int,
|
|
device: torch.device = "cuda"
|
|
) -> tuple[list[torch.Tensor], list[int], list[int]]:
|
|
"""Creates random inputs.
|
|
|
|
Args:
|
|
active_lora_ids: lora IDs of active lora weights.
|
|
num_inputs: the number of inputs to create.
|
|
input_size: the size of each individual input.
|
|
input_range: the range of values to include in the input.
|
|
input_range[0] <= possible input values < input_range[1]
|
|
input_type: the type of values in the input.
|
|
"""
|
|
|
|
low, high = input_range
|
|
|
|
inputs: list[torch.Tensor] = []
|
|
index_mapping: list[int] = []
|
|
prompt_mapping: list[int] = []
|
|
|
|
for _ in range(num_inputs):
|
|
if input_type == torch.int:
|
|
inputs.append(
|
|
torch.randint(low=int(low),
|
|
high=int(high),
|
|
size=input_size,
|
|
device=device))
|
|
else:
|
|
inputs.append(
|
|
torch.rand(size=input_size, dtype=input_type, device=device) *
|
|
high + low)
|
|
|
|
lora_id = random.choice(active_lora_ids)
|
|
index_mapping += [lora_id] * input_size[0]
|
|
prompt_mapping += [lora_id]
|
|
|
|
return inputs, index_mapping, prompt_mapping
|
|
|
|
|
|
def check_punica_wrapper(punica_wrapper) -> bool:
|
|
if current_platform.is_cuda_alike():
|
|
from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU
|
|
|
|
return type(punica_wrapper) is PunicaWrapperGPU
|
|
elif current_platform.is_cpu():
|
|
from vllm.lora.punica_wrapper.punica_cpu import PunicaWrapperCPU
|
|
|
|
return type(punica_wrapper) is PunicaWrapperCPU
|
|
else:
|
|
return False
|
|
|
|
|
|
@torch.inference_mode()
|
|
@pytest.mark.parametrize("num_loras", [1, 2, 4])
|
|
@pytest.mark.parametrize("device", DEVICES)
|
|
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
|
|
@pytest.mark.parametrize("stage", STAGES)
|
|
def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
|
|
# For multi-GPU testing of Triton kernel, we must explicitly set the CUDA
|
|
# device, see: https://github.com/triton-lang/triton/issues/2925
|
|
# Same below.
|
|
if current_platform.is_cuda_alike():
|
|
torch.cuda.set_device(device)
|
|
|
|
torch.set_default_device(device)
|
|
max_loras = 8
|
|
punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
|
|
assert check_punica_wrapper(punica_wrapper)
|
|
lora_config = LoRAConfig(max_loras=max_loras,
|
|
max_lora_rank=8,
|
|
lora_dtype=torch.float16)
|
|
|
|
def create_random_embedding_layer():
|
|
embedding = VocabParallelEmbedding(vocab_size, 256)
|
|
embedding.weight.data = torch.rand_like(embedding.weight.data)
|
|
embedding.weight.data[vocab_size:, :] = 0
|
|
lora_embedding = VocabParallelEmbeddingWithLoRA(embedding)
|
|
lora_embedding.create_lora_weights(max_loras, lora_config)
|
|
|
|
return embedding, lora_embedding
|
|
|
|
for i in range(NUM_RANDOM_SEEDS):
|
|
set_random_seed(i)
|
|
|
|
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
|
embedding, lora_embedding = create_random_embedding_layer()
|
|
lora_embedding.set_mapping(punica_wrapper)
|
|
lora_dict, _ = populate_loras(
|
|
id_to_index,
|
|
layer=lora_embedding,
|
|
layer_weights=embedding.weight.T,
|
|
)
|
|
|
|
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
|
active_lora_ids=list(lora_dict.keys()),
|
|
num_inputs=num_loras * 3,
|
|
input_size=(200, ),
|
|
input_range=(1, vocab_size),
|
|
device=device)
|
|
lora_mapping = LoRAMapping(index_mapping,
|
|
prompt_mapping,
|
|
is_prefill=stage)
|
|
punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
|
|
vocab_size,
|
|
lora_config.lora_extra_vocab_size)
|
|
|
|
lora_result = lora_embedding(torch.cat(inputs))
|
|
|
|
expected_results: list[torch.Tensor] = []
|
|
for input_, lora_id in zip(inputs, prompt_mapping):
|
|
lora = lora_dict[lora_id]
|
|
result = embedding(input_)
|
|
after_a = F.embedding(
|
|
input_,
|
|
lora.lora_a.T,
|
|
)
|
|
result += (after_a @ lora.lora_b.T)
|
|
expected_results.append(result)
|
|
expected_result = torch.cat(expected_results)
|
|
|
|
rtol, atol = TOLERANCES[lora_result.dtype]
|
|
torch.testing.assert_close(lora_result,
|
|
expected_result,
|
|
rtol=rtol,
|
|
atol=atol)
|
|
|
|
# Check that resetting the lora weights succeeds
|
|
|
|
for slot_idx in range(max_loras):
|
|
lora_embedding.reset_lora(slot_idx)
|
|
|
|
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
|
active_lora_ids=[0],
|
|
num_inputs=num_loras * 3,
|
|
input_size=(200, ),
|
|
input_range=(1, vocab_size),
|
|
device=device)
|
|
lora_mapping = LoRAMapping(index_mapping,
|
|
prompt_mapping,
|
|
is_prefill=stage)
|
|
punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
|
|
vocab_size,
|
|
lora_config.lora_extra_vocab_size)
|
|
|
|
lora_result = lora_embedding(torch.cat(inputs))
|
|
expected_result = embedding(torch.cat(inputs))
|
|
|
|
rtol, atol = TOLERANCES[lora_result.dtype]
|
|
torch.testing.assert_close(lora_result,
|
|
expected_result,
|
|
rtol=rtol,
|
|
atol=atol)
|
|
|
|
|
|
@torch.inference_mode()
|
|
# @pytest.mark.skip(
|
|
# reason="Fails when loras are in any slot other than the first.")
|
|
@pytest.mark.parametrize("num_loras", [1, 2, 4])
|
|
@pytest.mark.parametrize("device", DEVICES)
|
|
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
|
|
@pytest.mark.parametrize("stage", STAGES)
|
|
def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
|
|
vocab_size, stage) -> None:
|
|
|
|
if current_platform.is_cuda_alike():
|
|
torch.cuda.set_device(device)
|
|
|
|
torch.set_default_device(device)
|
|
max_loras = 8
|
|
punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
|
|
assert check_punica_wrapper(punica_wrapper)
|
|
lora_config = LoRAConfig(max_loras=max_loras,
|
|
max_lora_rank=8,
|
|
lora_dtype=torch.float16)
|
|
|
|
def create_random_embedding_layer():
|
|
embedding = VocabParallelEmbedding(vocab_size, 256)
|
|
embedding_data = torch.rand_like(embedding.weight.data)
|
|
embedding.weight.data = embedding_data
|
|
embedding.weight.data[vocab_size:, :] = 0
|
|
expanded_embedding = VocabParallelEmbedding(
|
|
vocab_size + lora_config.lora_extra_vocab_size * max_loras,
|
|
256,
|
|
org_num_embeddings=vocab_size)
|
|
expanded_embedding.weight.data[:vocab_size, :] = embedding_data
|
|
# We need to deepcopy the embedding as it will be modified
|
|
# in place
|
|
lora_embedding = VocabParallelEmbeddingWithLoRA(
|
|
deepcopy(expanded_embedding))
|
|
lora_embedding.create_lora_weights(max_loras, lora_config)
|
|
|
|
return expanded_embedding, lora_embedding
|
|
|
|
for i in range(NUM_RANDOM_SEEDS):
|
|
set_random_seed(i)
|
|
|
|
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
|
expanded_embedding, lora_embedding = create_random_embedding_layer()
|
|
lora_dict, _ = populate_loras(
|
|
id_to_index,
|
|
layer=lora_embedding,
|
|
layer_weights=torch.zeros(
|
|
(256, vocab_size + lora_config.lora_extra_vocab_size)),
|
|
generate_embeddings_tensor=256,
|
|
)
|
|
|
|
lora_embedding.set_mapping(punica_wrapper)
|
|
# All embeddings tensors have the same shape.
|
|
embeddings_tensors = [
|
|
lora_dict[id].embeddings_tensor for id in sorted(lora_dict.keys())
|
|
]
|
|
embeddings_tensor_len = embeddings_tensors[0].shape[0]
|
|
|
|
# Add empty embeddings_tensors for unoccupied lora slots.
|
|
for _ in range(max_loras - len(embeddings_tensors)):
|
|
embeddings_tensors.append(torch.zeros(embeddings_tensors[0].shape))
|
|
|
|
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
|
active_lora_ids=list(lora_dict.keys()),
|
|
num_inputs=num_loras * 3,
|
|
input_size=(200, ),
|
|
input_range=(1, vocab_size),
|
|
device=device)
|
|
lora_mapping = LoRAMapping(index_mapping,
|
|
prompt_mapping,
|
|
is_prefill=stage)
|
|
punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
|
|
vocab_size,
|
|
lora_config.lora_extra_vocab_size)
|
|
original_inputs = deepcopy(inputs)
|
|
|
|
# Force some of the inputs to be in the extended embeddings range
|
|
# to guarantee that their behavior is tested.
|
|
for input_, original_input_, lora_id in zip(inputs, original_inputs,
|
|
prompt_mapping):
|
|
embedding_id = lora_id - 1
|
|
input_[-1] = vocab_size + (embedding_id * embeddings_tensor_len)
|
|
original_input_[-1] = vocab_size
|
|
input_[-2] = vocab_size + (
|
|
(embedding_id + 1) * embeddings_tensor_len - 1)
|
|
original_input_[-2] = vocab_size + embeddings_tensor_len - 1
|
|
|
|
expanded_embedding.weight[vocab_size:vocab_size +
|
|
(embeddings_tensor_len *
|
|
max_loras)] = torch.cat(embeddings_tensors)
|
|
|
|
lora_result = lora_embedding(torch.cat(original_inputs))
|
|
|
|
expected_results: list[torch.Tensor] = []
|
|
for input_, original_input_, lora_id in zip(inputs, original_inputs,
|
|
prompt_mapping):
|
|
lora = lora_dict[lora_id]
|
|
result = expanded_embedding(input_)
|
|
after_a = F.embedding(
|
|
original_input_,
|
|
lora.lora_a.T,
|
|
)
|
|
result += (after_a @ lora.lora_b.T)
|
|
expected_results.append(result)
|
|
expected_result = torch.cat(expected_results)
|
|
|
|
rtol, atol = TOLERANCES[lora_result.dtype]
|
|
torch.testing.assert_close(lora_result,
|
|
expected_result,
|
|
rtol=rtol,
|
|
atol=atol)
|
|
|
|
# Check that resetting the lora weights succeeds
|
|
|
|
for slot_idx in range(max_loras):
|
|
lora_embedding.reset_lora(slot_idx)
|
|
|
|
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
|
active_lora_ids=[0],
|
|
num_inputs=num_loras * 3,
|
|
input_size=(200, ),
|
|
input_range=(1, vocab_size),
|
|
device=device)
|
|
original_inputs = deepcopy(inputs)
|
|
lora_mapping = LoRAMapping(index_mapping,
|
|
prompt_mapping,
|
|
is_prefill=stage)
|
|
punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
|
|
vocab_size,
|
|
lora_config.lora_extra_vocab_size)
|
|
lora_result = lora_embedding(torch.cat(original_inputs))
|
|
expected_result = expanded_embedding(torch.cat(inputs))
|
|
|
|
rtol, atol = TOLERANCES[lora_result.dtype]
|
|
torch.testing.assert_close(lora_result,
|
|
expected_result,
|
|
rtol=rtol,
|
|
atol=atol)
|
|
|
|
|
|
@torch.inference_mode()
|
|
@pytest.mark.parametrize("num_loras", [1, 2, 4])
|
|
@pytest.mark.parametrize("device", DEVICES)
|
|
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 256512])
|
|
@pytest.mark.parametrize("stage", STAGES)
|
|
def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
|
|
stage) -> None:
|
|
|
|
if current_platform.is_cuda_alike():
|
|
torch.cuda.set_device(device)
|
|
|
|
torch.set_default_device(device)
|
|
max_loras = 8
|
|
punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
|
|
assert check_punica_wrapper(punica_wrapper)
|
|
lora_config = LoRAConfig(max_loras=max_loras,
|
|
max_lora_rank=8,
|
|
lora_dtype=torch.float16)
|
|
|
|
def _pretest():
|
|
linear = ParallelLMHead(vocab_size + lora_config.lora_extra_vocab_size,
|
|
1024,
|
|
vocab_size,
|
|
params_dtype=torch.float16)
|
|
linear.weight.data = torch.rand_like(linear.weight.data)
|
|
linear.weight.data[:, vocab_size:] = 0
|
|
logits_processor = LogitsProcessor(
|
|
vocab_size + lora_config.lora_extra_vocab_size, vocab_size)
|
|
lora_logits_processor = LogitsProcessorWithLoRA(
|
|
logits_processor, 1024, linear.weight.dtype, linear.weight.device,
|
|
None)
|
|
lora_logits_processor.create_lora_weights(max_loras, lora_config)
|
|
|
|
return linear, logits_processor, lora_logits_processor
|
|
|
|
for i in range(NUM_RANDOM_SEEDS):
|
|
set_random_seed(i)
|
|
|
|
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
|
linear, logits_processor, lora_logits_processor = _pretest()
|
|
lora_logits_processor.set_mapping(punica_wrapper)
|
|
# NOTE: all the generated loras share the same embeddings tensor.
|
|
lora_dict, _ = populate_loras(
|
|
id_to_index,
|
|
layer=lora_logits_processor,
|
|
layer_weights=linear.weight,
|
|
generate_embeddings_tensor=1024,
|
|
)
|
|
embeddings_tensor = list(lora_dict.values())[0].embeddings_tensor
|
|
embeddings_tensor_len = embeddings_tensor.shape[0]
|
|
|
|
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
|
active_lora_ids=list(lora_dict.keys()),
|
|
num_inputs=8 * num_loras, # * 3,
|
|
input_size=(1, 1024),
|
|
input_range=(0, 1),
|
|
input_type=torch.float16,
|
|
device=device)
|
|
lora_mapping = LoRAMapping(index_mapping,
|
|
prompt_mapping,
|
|
is_prefill=stage)
|
|
punica_wrapper.update_metadata(
|
|
lora_mapping,
|
|
id_to_index,
|
|
max_loras,
|
|
vocab_size,
|
|
lora_config.lora_extra_vocab_size,
|
|
)
|
|
input_ = torch.rand(20, 1024)
|
|
|
|
lora_result = lora_logits_processor._get_logits(
|
|
hidden_states=torch.cat(inputs),
|
|
lm_head=linear,
|
|
embedding_bias=None)
|
|
|
|
original_lm_head = deepcopy(linear)
|
|
|
|
linear.weight[logits_processor.
|
|
org_vocab_size:logits_processor.org_vocab_size +
|
|
embeddings_tensor_len] = embeddings_tensor
|
|
|
|
logits_processor.org_vocab_size = (vocab_size +
|
|
lora_config.lora_extra_vocab_size)
|
|
expected_results: list[torch.Tensor] = []
|
|
for input_, lora_id in zip(inputs, prompt_mapping):
|
|
lora = lora_dict[lora_id]
|
|
result = logits_processor._get_logits(hidden_states=input_,
|
|
lm_head=linear,
|
|
embedding_bias=None)
|
|
result[:, vocab_size + embeddings_tensor_len:] = float("-inf")
|
|
result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
|
|
expected_results.append(result)
|
|
expected_result = torch.cat(expected_results)
|
|
logits_processor.org_vocab_size = vocab_size
|
|
|
|
# Check that resetting the lora weights succeeds
|
|
|
|
for slot_idx in range(max_loras):
|
|
lora_logits_processor.reset_lora(slot_idx)
|
|
|
|
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
|
active_lora_ids=[0],
|
|
num_inputs=8 * num_loras * 3,
|
|
input_size=(1, 1024),
|
|
input_range=(0, 1),
|
|
input_type=torch.float16,
|
|
device=device)
|
|
lora_mapping = LoRAMapping(index_mapping,
|
|
prompt_mapping,
|
|
is_prefill=stage)
|
|
punica_wrapper.update_metadata(
|
|
lora_mapping,
|
|
id_to_index,
|
|
max_loras,
|
|
vocab_size,
|
|
lora_config.lora_extra_vocab_size,
|
|
)
|
|
|
|
lora_result = lora_logits_processor._get_logits(
|
|
hidden_states=torch.cat(inputs),
|
|
lm_head=original_lm_head,
|
|
embedding_bias=None)[:, :vocab_size]
|
|
expected_result = logits_processor._get_logits(
|
|
hidden_states=torch.cat(inputs),
|
|
lm_head=original_lm_head,
|
|
embedding_bias=None)
|
|
|
|
rtol, atol = TOLERANCES[lora_result.dtype]
|
|
torch.testing.assert_close(lora_result,
|
|
expected_result,
|
|
rtol=rtol,
|
|
atol=atol)
|
|
|
|
|
|
@torch.inference_mode()
|
|
@pytest.mark.parametrize("num_loras", [1, 2, 4])
|
|
@pytest.mark.parametrize("device", DEVICES)
|
|
@pytest.mark.parametrize("stage", STAGES)
|
|
def test_linear_replicated(
|
|
dist_init,
|
|
num_loras,
|
|
device,
|
|
stage,
|
|
) -> None:
|
|
|
|
if current_platform.is_cuda_alike():
|
|
torch.cuda.set_device(device)
|
|
|
|
max_loras = 8
|
|
torch.set_default_device(device)
|
|
punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
|
|
assert check_punica_wrapper(punica_wrapper)
|
|
lora_config = LoRAConfig(
|
|
max_loras=max_loras,
|
|
max_lora_rank=8,
|
|
lora_dtype=torch.float16,
|
|
)
|
|
|
|
def create_random_linear_replicated_layer():
|
|
|
|
linear = ReplicatedLinear(4096,
|
|
4096,
|
|
bias=False,
|
|
params_dtype=torch.float16)
|
|
linear.weight.data = torch.rand_like(linear.weight.data)
|
|
lora_linear = ReplicatedLinearWithLoRA(linear)
|
|
|
|
lora_linear.create_lora_weights(max_loras, lora_config)
|
|
assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len(
|
|
lora_linear.lora_b_stacked) == 1)
|
|
return linear, lora_linear
|
|
|
|
for i in range(NUM_RANDOM_SEEDS):
|
|
set_random_seed(i)
|
|
|
|
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
|
linear, lora_linear = create_random_linear_replicated_layer()
|
|
assert torch.equal(linear.weight, lora_linear.weight)
|
|
lora_linear.set_mapping(punica_wrapper)
|
|
lora_dict, _ = populate_loras(
|
|
id_to_index,
|
|
layer=lora_linear,
|
|
layer_weights=linear.weight,
|
|
)
|
|
|
|
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
|
active_lora_ids=list(lora_dict.keys()),
|
|
num_inputs=32 * num_loras,
|
|
input_size=(1, 4096),
|
|
input_range=(0, 1),
|
|
input_type=torch.float16,
|
|
device=device)
|
|
lora_mapping = LoRAMapping(index_mapping,
|
|
prompt_mapping,
|
|
is_prefill=stage)
|
|
punica_wrapper.update_metadata(
|
|
lora_mapping,
|
|
id_to_index,
|
|
max_loras,
|
|
512,
|
|
lora_config.lora_extra_vocab_size,
|
|
)
|
|
|
|
lora_result = lora_linear(torch.cat(inputs))[0]
|
|
|
|
expected_results: list[torch.Tensor] = []
|
|
for input_, lora_id in zip(inputs, prompt_mapping):
|
|
|
|
lora = lora_dict[lora_id]
|
|
result = linear(input_)[0]
|
|
result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
|
|
expected_results.append(result)
|
|
expected_result = torch.cat(expected_results)
|
|
|
|
rtol, atol = TOLERANCES[lora_result.dtype]
|
|
torch.testing.assert_close(lora_result,
|
|
expected_result,
|
|
rtol=rtol,
|
|
atol=atol)
|
|
|
|
# Check that resetting the lora weights succeeds
|
|
|
|
for slot_idx in range(max_loras):
|
|
lora_linear.reset_lora(slot_idx)
|
|
|
|
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
|
active_lora_ids=[0],
|
|
num_inputs=32 * num_loras,
|
|
input_size=(1, 4096),
|
|
input_range=(0, 1),
|
|
input_type=torch.float16,
|
|
device=device)
|
|
lora_mapping = LoRAMapping(index_mapping,
|
|
prompt_mapping,
|
|
is_prefill=stage)
|
|
|
|
punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
|
|
512, lora_config.lora_extra_vocab_size)
|
|
|
|
lora_result = lora_linear(torch.cat(inputs))[0]
|
|
expected_result = linear(torch.cat(inputs))[0]
|
|
|
|
rtol, atol = TOLERANCES[lora_result.dtype]
|
|
torch.testing.assert_close(lora_result,
|
|
expected_result,
|
|
rtol=rtol,
|
|
atol=atol)
|
|
|
|
|
|
@torch.inference_mode()
|
|
@pytest.mark.parametrize("num_loras", [1, 2, 4])
|
|
@pytest.mark.parametrize("orientation", ["row", "column"])
|
|
@pytest.mark.parametrize("fully_shard", [True, False])
|
|
@pytest.mark.parametrize("device", DEVICES)
|
|
@pytest.mark.parametrize("stage", STAGES)
|
|
def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
|
|
device, stage) -> None:
|
|
|
|
if current_platform.is_cuda_alike():
|
|
torch.cuda.set_device(device)
|
|
|
|
max_loras = 8
|
|
torch.set_default_device(device)
|
|
punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
|
|
assert check_punica_wrapper(punica_wrapper)
|
|
lora_config = LoRAConfig(
|
|
max_loras=max_loras,
|
|
max_lora_rank=8,
|
|
fully_sharded_loras=fully_shard,
|
|
lora_dtype=torch.float16,
|
|
)
|
|
|
|
def create_random_linear_parallel_layer():
|
|
if orientation == "row":
|
|
linear = RowParallelLinear(4096,
|
|
4096,
|
|
bias=False,
|
|
params_dtype=torch.float16)
|
|
linear.weight.data = torch.rand_like(linear.weight.data)
|
|
lora_linear = (RowParallelLinearWithLoRA(linear) if not fully_shard
|
|
else RowParallelLinearWithShardedLoRA(linear))
|
|
else:
|
|
linear = ColumnParallelLinear(4096,
|
|
4096,
|
|
bias=False,
|
|
params_dtype=torch.float16)
|
|
linear.weight.data = torch.rand_like(linear.weight.data)
|
|
lora_linear = (ColumnParallelLinearWithLoRA(linear)
|
|
if not fully_shard else
|
|
ColumnParallelLinearWithShardedLoRA(linear))
|
|
lora_linear.create_lora_weights(max_loras, lora_config)
|
|
assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len(
|
|
lora_linear.lora_b_stacked) == 1)
|
|
|
|
return linear, lora_linear
|
|
|
|
for i in range(NUM_RANDOM_SEEDS):
|
|
set_random_seed(i)
|
|
|
|
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
|
linear, lora_linear = create_random_linear_parallel_layer()
|
|
assert torch.equal(linear.weight, lora_linear.weight)
|
|
lora_linear.set_mapping(punica_wrapper)
|
|
lora_dict, _ = populate_loras(
|
|
id_to_index,
|
|
layer=lora_linear,
|
|
layer_weights=linear.weight,
|
|
)
|
|
|
|
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
|
active_lora_ids=list(lora_dict.keys()),
|
|
num_inputs=32 * num_loras,
|
|
input_size=(1, 4096),
|
|
input_range=(0, 1),
|
|
input_type=torch.float16,
|
|
device=device)
|
|
lora_mapping = LoRAMapping(index_mapping,
|
|
prompt_mapping,
|
|
is_prefill=stage)
|
|
punica_wrapper.update_metadata(
|
|
lora_mapping,
|
|
id_to_index,
|
|
max_loras,
|
|
512,
|
|
lora_config.lora_extra_vocab_size,
|
|
)
|
|
|
|
lora_result = lora_linear(torch.cat(inputs))[0]
|
|
|
|
expected_results: list[torch.Tensor] = []
|
|
for input_, lora_id in zip(inputs, prompt_mapping):
|
|
lora = lora_dict[lora_id]
|
|
result = linear(input_)[0]
|
|
result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
|
|
expected_results.append(result)
|
|
expected_result = torch.cat(expected_results)
|
|
|
|
rtol, atol = TOLERANCES[lora_result.dtype]
|
|
torch.testing.assert_close(lora_result,
|
|
expected_result,
|
|
rtol=rtol,
|
|
atol=atol)
|
|
|
|
# Check that resetting the lora weights succeeds
|
|
|
|
for slot_idx in range(max_loras):
|
|
lora_linear.reset_lora(slot_idx)
|
|
|
|
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
|
active_lora_ids=[0],
|
|
num_inputs=32 * num_loras,
|
|
input_size=(1, 4096),
|
|
input_range=(0, 1),
|
|
input_type=torch.float16,
|
|
device=device)
|
|
lora_mapping = LoRAMapping(index_mapping,
|
|
prompt_mapping,
|
|
is_prefill=stage)
|
|
|
|
punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
|
|
512, lora_config.lora_extra_vocab_size)
|
|
|
|
lora_result = lora_linear(torch.cat(inputs))[0]
|
|
expected_result = linear(torch.cat(inputs))[0]
|
|
|
|
rtol, atol = TOLERANCES[lora_result.dtype]
|
|
torch.testing.assert_close(lora_result,
|
|
expected_result,
|
|
rtol=rtol,
|
|
atol=atol)
|
|
|
|
|
|
@torch.inference_mode()
|
|
@pytest.mark.parametrize("num_loras", [1, 2, 4])
|
|
@pytest.mark.parametrize("repeats", [1, 2, 3])
|
|
@pytest.mark.parametrize("fully_shard", [True, False])
|
|
@pytest.mark.parametrize("device", DEVICES)
|
|
@pytest.mark.parametrize("stage", STAGES)
|
|
def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
|
|
device, stage) -> None:
|
|
|
|
if current_platform.is_cuda_alike():
|
|
torch.cuda.set_device(device)
|
|
|
|
max_loras = 8
|
|
torch.set_default_device(device)
|
|
punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
|
|
assert check_punica_wrapper(punica_wrapper)
|
|
lora_config = LoRAConfig(
|
|
max_loras=max_loras,
|
|
max_lora_rank=8,
|
|
fully_sharded_loras=fully_shard,
|
|
lora_dtype=torch.float16,
|
|
)
|
|
|
|
def create_column_parallel_packed_layer():
|
|
if repeats == 2:
|
|
linear = MergedColumnParallelLinear(4096, [4096] * repeats,
|
|
bias=False,
|
|
params_dtype=torch.float16)
|
|
linear.weight.data = torch.rand_like(linear.weight.data)
|
|
lora_linear = (MergedColumnParallelLinearWithLoRA(linear)
|
|
if not fully_shard else
|
|
MergedColumnParallelLinearWithShardedLoRA(linear))
|
|
elif repeats == 3:
|
|
linear = QKVParallelLinear(4096,
|
|
64,
|
|
32,
|
|
bias=False,
|
|
params_dtype=torch.float16)
|
|
linear.weight.data = torch.rand_like(linear.weight.data)
|
|
lora_linear = (MergedQKVParallelLinearWithLoRA(linear)
|
|
if not fully_shard else
|
|
MergedQKVParallelLinearWithShardedLoRA(linear))
|
|
else:
|
|
linear = QKVParallelLinear(4096,
|
|
64,
|
|
32,
|
|
bias=False,
|
|
params_dtype=torch.float16)
|
|
linear.weight.data = torch.rand_like(linear.weight.data)
|
|
lora_linear = QKVParallelLinearWithLoRA(
|
|
linear
|
|
) if not fully_shard else QKVParallelLinearWithShardedLoRA(linear)
|
|
|
|
@dataclass
|
|
class FakeConfig:
|
|
hidden_size = 4096
|
|
num_key_value_heads = 32
|
|
num_attention_heads = 32
|
|
|
|
n_slices = repeats
|
|
lora_linear.create_lora_weights(max_loras,
|
|
lora_config,
|
|
model_config=FakeConfig())
|
|
assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len(
|
|
lora_linear.lora_b_stacked) == n_slices)
|
|
|
|
return linear, lora_linear
|
|
|
|
for i in range(NUM_RANDOM_SEEDS):
|
|
set_random_seed(i)
|
|
|
|
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
|
|
|
linear, lora_linear = create_column_parallel_packed_layer()
|
|
assert torch.equal(linear.weight, lora_linear.weight)
|
|
lora_linear.set_mapping(punica_wrapper)
|
|
lora_dict, sublora_dict = populate_loras(
|
|
id_to_index,
|
|
layer=lora_linear,
|
|
layer_weights=linear.weight,
|
|
repeats=repeats,
|
|
)
|
|
|
|
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
|
active_lora_ids=list(lora_dict.keys()),
|
|
num_inputs=32 * num_loras,
|
|
input_size=(1, 4096),
|
|
input_range=(0, 1),
|
|
input_type=torch.float16,
|
|
device=device)
|
|
lora_mapping = LoRAMapping(index_mapping,
|
|
prompt_mapping,
|
|
is_prefill=stage)
|
|
|
|
punica_wrapper.update_metadata(
|
|
lora_mapping,
|
|
id_to_index,
|
|
max_loras,
|
|
512,
|
|
lora_config.lora_extra_vocab_size,
|
|
)
|
|
|
|
lora_result = lora_linear(torch.cat(inputs))[0]
|
|
|
|
expected_results: list[torch.Tensor] = []
|
|
for input_, lora_id in zip(inputs, prompt_mapping):
|
|
result = linear(input_)[0]
|
|
subloras = sublora_dict[lora_id]
|
|
for i, sublora in enumerate(subloras):
|
|
result[:, sublora.lora_b.shape[0] * i:sublora.lora_b.shape[0] *
|
|
(i + 1)] += (
|
|
input_ @ sublora.lora_a.T @ sublora.lora_b.T *
|
|
sublora.scaling)
|
|
expected_results.append(result)
|
|
expected_result = torch.cat(expected_results)
|
|
|
|
rtol, atol = TOLERANCES[lora_result.dtype]
|
|
torch.testing.assert_close(lora_result,
|
|
expected_result,
|
|
rtol=rtol,
|
|
atol=atol)
|
|
|
|
for slot_idx in range(max_loras):
|
|
lora_linear.reset_lora(slot_idx)
|
|
|
|
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
|
active_lora_ids=[0],
|
|
num_inputs=32 * num_loras,
|
|
input_size=(1, 4096),
|
|
input_range=(0, 1),
|
|
input_type=torch.float16,
|
|
device=device)
|
|
lora_mapping = LoRAMapping(index_mapping,
|
|
prompt_mapping,
|
|
is_prefill=stage)
|
|
|
|
punica_wrapper.update_metadata(
|
|
lora_mapping,
|
|
id_to_index,
|
|
max_loras,
|
|
512,
|
|
lora_config.lora_extra_vocab_size,
|
|
)
|
|
|
|
lora_result = lora_linear(torch.cat(inputs))[0]
|
|
expected_result = linear(torch.cat(inputs))[0]
|
|
|
|
rtol, atol = TOLERANCES[lora_result.dtype]
|
|
torch.testing.assert_close(lora_result,
|
|
expected_result,
|
|
rtol=rtol,
|
|
atol=atol)
|
|
|
|
|
|
@pytest.mark.parametrize("tp_size", [1, 2, 4, 8])
|
|
@pytest.mark.parametrize(
|
|
"seed", list(range(VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS)))
|
|
def test_vocab_parallel_embedding_indices(tp_size, seed):
|
|
random.seed(seed)
|
|
vocab_size = random.randint(4000, 64000)
|
|
added_vocab_size = random.randint(0, 1024)
|
|
org_vocab_size = vocab_size - added_vocab_size
|
|
last_org_vocab_end_index = 0
|
|
last_added_vocab_end_index = org_vocab_size
|
|
computed_vocab_size = 0
|
|
computed_org_vocab_size = 0
|
|
computed_added_vocab_size = 0
|
|
vocab_size_padded = -1
|
|
|
|
all_org_tokens: list[int] = []
|
|
all_added_tokens: list[int] = []
|
|
token_ids: list[int] = []
|
|
|
|
for tp_rank in range(tp_size):
|
|
with patch(
|
|
"vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank",
|
|
return_value=tp_rank
|
|
), patch(
|
|
"vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size",
|
|
return_value=tp_size):
|
|
vocab_embedding = VocabParallelEmbedding(
|
|
vocab_size, 1, org_num_embeddings=org_vocab_size)
|
|
vocab_size_padded = vocab_embedding.num_embeddings_padded
|
|
shard_indices = vocab_embedding.shard_indices
|
|
# Assert that the ranges are contiguous
|
|
assert shard_indices.org_vocab_start_index == last_org_vocab_end_index
|
|
assert (shard_indices.added_vocab_start_index ==
|
|
last_added_vocab_end_index)
|
|
|
|
# Ensure that we are not exceeding the vocab size
|
|
computed_vocab_size += shard_indices.num_elements_padded
|
|
computed_org_vocab_size += shard_indices.num_org_elements
|
|
computed_added_vocab_size += shard_indices.num_added_elements
|
|
|
|
# Ensure that the ranges are not overlapping
|
|
all_org_tokens.extend(
|
|
range(shard_indices.org_vocab_start_index,
|
|
shard_indices.org_vocab_end_index))
|
|
all_added_tokens.extend(
|
|
range(shard_indices.added_vocab_start_index,
|
|
shard_indices.added_vocab_end_index))
|
|
|
|
token_ids.extend(
|
|
range(shard_indices.org_vocab_start_index,
|
|
shard_indices.org_vocab_end_index))
|
|
token_ids.extend([-1] * (shard_indices.num_org_elements_padded -
|
|
shard_indices.num_org_elements))
|
|
token_ids.extend(
|
|
range(shard_indices.added_vocab_start_index,
|
|
shard_indices.added_vocab_end_index))
|
|
token_ids.extend([-1] * (shard_indices.num_added_elements_padded -
|
|
shard_indices.num_added_elements))
|
|
|
|
last_org_vocab_end_index = shard_indices.org_vocab_end_index
|
|
last_added_vocab_end_index = shard_indices.added_vocab_end_index
|
|
|
|
assert computed_vocab_size == vocab_size_padded
|
|
assert computed_org_vocab_size == org_vocab_size
|
|
assert computed_added_vocab_size == added_vocab_size
|
|
|
|
# Ensure that the ranges are not overlapping
|
|
assert len(all_org_tokens) == len(set(all_org_tokens))
|
|
assert len(all_added_tokens) == len(set(all_added_tokens))
|
|
assert not set(all_org_tokens).intersection(set(all_added_tokens))
|
|
|
|
token_ids_tensor = torch.tensor(token_ids, dtype=torch.long)
|
|
reindex_mapping = vocab_embedding.get_sharded_to_full_mapping()
|
|
assert reindex_mapping is not None or tp_size == 1
|
|
if reindex_mapping is not None:
|
|
reindexed_token_ids = token_ids_tensor[reindex_mapping]
|
|
expected = torch.tensor(list(range(0, vocab_size)))
|
|
assert reindexed_token_ids[:vocab_size].equal(expected)
|
|
assert torch.all(reindexed_token_ids[vocab_size:] == -1)
|
|
|
|
|
|
def test_get_masked_input_and_mask():
|
|
x = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
|
|
|
|
# base tp 1 case, no padding
|
|
modified_x, _ = get_masked_input_and_mask(x,
|
|
org_vocab_start_index=0,
|
|
org_vocab_end_index=8,
|
|
added_vocab_start_index=8,
|
|
added_vocab_end_index=12,
|
|
num_org_vocab_padding=0)
|
|
assert torch.equal(x, modified_x)
|
|
|
|
# tp 2 case, no padding
|
|
modified_x_rank_0, _ = get_masked_input_and_mask(x,
|
|
org_vocab_start_index=0,
|
|
org_vocab_end_index=4,
|
|
added_vocab_start_index=8,
|
|
added_vocab_end_index=10,
|
|
num_org_vocab_padding=0)
|
|
modified_x_rank_1, _ = get_masked_input_and_mask(
|
|
x,
|
|
org_vocab_start_index=4,
|
|
org_vocab_end_index=8,
|
|
added_vocab_start_index=10,
|
|
added_vocab_end_index=12,
|
|
num_org_vocab_padding=0)
|
|
assert torch.equal(modified_x_rank_0,
|
|
torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 0, 0]))
|
|
assert torch.equal(modified_x_rank_1,
|
|
torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 4, 5]))
|
|
|
|
# tp 4 case, no padding
|
|
modified_x_rank_0, _ = get_masked_input_and_mask(x,
|
|
org_vocab_start_index=0,
|
|
org_vocab_end_index=2,
|
|
added_vocab_start_index=8,
|
|
added_vocab_end_index=9,
|
|
num_org_vocab_padding=0)
|
|
modified_x_rank_1, _ = get_masked_input_and_mask(x,
|
|
org_vocab_start_index=2,
|
|
org_vocab_end_index=4,
|
|
added_vocab_start_index=9,
|
|
added_vocab_end_index=10,
|
|
num_org_vocab_padding=0)
|
|
modified_x_rank_2, _ = get_masked_input_and_mask(
|
|
x,
|
|
org_vocab_start_index=4,
|
|
org_vocab_end_index=6,
|
|
added_vocab_start_index=10,
|
|
added_vocab_end_index=11,
|
|
num_org_vocab_padding=0)
|
|
modified_x_rank_3, _ = get_masked_input_and_mask(
|
|
x,
|
|
org_vocab_start_index=6,
|
|
org_vocab_end_index=8,
|
|
added_vocab_start_index=11,
|
|
added_vocab_end_index=12,
|
|
num_org_vocab_padding=0)
|
|
assert torch.equal(modified_x_rank_0,
|
|
torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0]))
|
|
assert torch.equal(modified_x_rank_1,
|
|
torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0]))
|
|
assert torch.equal(modified_x_rank_2,
|
|
torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0]))
|
|
assert torch.equal(modified_x_rank_3,
|
|
torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2]))
|
|
|
|
# base tp 1 case, with padding
|
|
modified_x, _ = get_masked_input_and_mask(x,
|
|
org_vocab_start_index=0,
|
|
org_vocab_end_index=8,
|
|
added_vocab_start_index=8,
|
|
added_vocab_end_index=12,
|
|
num_org_vocab_padding=2)
|
|
assert torch.equal(modified_x,
|
|
torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13]))
|
|
|
|
# tp 2 case, with padding
|
|
modified_x_rank_0, _ = get_masked_input_and_mask(x,
|
|
org_vocab_start_index=0,
|
|
org_vocab_end_index=4,
|
|
added_vocab_start_index=8,
|
|
added_vocab_end_index=10,
|
|
num_org_vocab_padding=2)
|
|
modified_x_rank_1, _ = get_masked_input_and_mask(
|
|
x,
|
|
org_vocab_start_index=4,
|
|
org_vocab_end_index=8,
|
|
added_vocab_start_index=10,
|
|
added_vocab_end_index=12,
|
|
num_org_vocab_padding=2)
|
|
assert torch.equal(modified_x_rank_0,
|
|
torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 6, 7, 0, 0]))
|
|
assert torch.equal(modified_x_rank_1,
|
|
torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 6, 7]))
|
|
|
|
# tp 4 case, with padding
|
|
modified_x_rank_0, _ = get_masked_input_and_mask(x,
|
|
org_vocab_start_index=0,
|
|
org_vocab_end_index=2,
|
|
added_vocab_start_index=8,
|
|
added_vocab_end_index=9,
|
|
num_org_vocab_padding=2)
|
|
modified_x_rank_1, _ = get_masked_input_and_mask(x,
|
|
org_vocab_start_index=2,
|
|
org_vocab_end_index=4,
|
|
added_vocab_start_index=9,
|
|
added_vocab_end_index=10,
|
|
num_org_vocab_padding=2)
|
|
modified_x_rank_2, _ = get_masked_input_and_mask(
|
|
x,
|
|
org_vocab_start_index=4,
|
|
org_vocab_end_index=6,
|
|
added_vocab_start_index=10,
|
|
added_vocab_end_index=11,
|
|
num_org_vocab_padding=2)
|
|
modified_x_rank_3, _ = get_masked_input_and_mask(
|
|
x,
|
|
org_vocab_start_index=6,
|
|
org_vocab_end_index=8,
|
|
added_vocab_start_index=11,
|
|
added_vocab_end_index=12,
|
|
num_org_vocab_padding=2)
|
|
assert torch.equal(modified_x_rank_0,
|
|
torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0]))
|
|
assert torch.equal(modified_x_rank_1,
|
|
torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 4, 0, 0]))
|
|
assert torch.equal(modified_x_rank_2,
|
|
torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 4, 0]))
|
|
assert torch.equal(modified_x_rank_3,
|
|
torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4]))
|