mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-08 14:37:02 +08:00
Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: Lucas Kabela <lucaskabela@meta.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: Andrew Sansom <andrew@protopia.ai> Signed-off-by: Boyuan Feng <boyuan@meta.com> Signed-off-by: Boyuan Feng <fby.1994@gmail.com> Signed-off-by: boyuanfeng <boyuan@meta.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: JartX <sagformas@epdcenter.es> Signed-off-by: Chendi Xue <Chendi.Xue@intel.com> Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Signed-off-by: Roger Wang <hey@rogerw.io> Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: wwl2755 <wangwenlong2755@gmail.com> Signed-off-by: Manoel Marques <manoel.marques@ibm.com> Signed-off-by: Manoel Marques <manoelmrqs@gmail.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: pengdrumli <pengdrumli@tencent.com> Signed-off-by: windsonsea <haifeng.yao@daocloud.io> Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai> Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: Huamin Li <3ericli@gmail.com> Signed-off-by: simondanielsson <simon.danielsson99@hotmail.com> Signed-off-by: Rahul Tuli <rtuli@redhat.com> Signed-off-by: Yang <lymailforjob@gmail.com> Signed-off-by: Debolina Roy <debroy@redhat.com> Signed-off-by: David Chen <530634352@qq.com> Signed-off-by: wangzi <3220100013@zju.edu.cn> Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com> Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com> Signed-off-by: Sara Kokkila Schumacher <saraks@ibm.com> Signed-off-by: Csrayz <jover@cmbchina.com> Signed-off-by: ivyilike <pww123@cmbchina.com> Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com> Signed-off-by: Bowen Wang <abmfy@icloud.com> Signed-off-by: qqma <qqma@amazon.com> Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: Lu Fang <fanglu@fb.com> Signed-off-by: Zhuohan Li <zhuohan123@gmail.com> Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: luka <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: Or Ozeri <oro@il.ibm.com> Signed-off-by: Johnny Yang <johnnyyang@google.com> Signed-off-by: Alec Solder <alecs@fb.com> Signed-off-by: Alec S <10566873+alecsolder@users.noreply.github.com> Signed-off-by: Russell Bryant <rbryant@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Alexander Matveev <amatveev@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: liuye.hj <liuye.hj@alibaba-inc.com> Signed-off-by: Kunshang Ji <kunshang.ji@intel.com> Signed-off-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Signed-off-by: Ming Yang <minos.future@gmail.com> Signed-off-by: Zhikaiiii <1658973216@qq.com> Signed-off-by: Andreas Hartel <andreas.hartel@aleph-alpha.com> Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Signed-off-by: wuxibin <wuxibin@bytedance.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: Peter Pan <Peter.Pan@daocloud.io> Signed-off-by: Peter Pan <peter.pan@daocloud.io> Signed-off-by: Nicolò Lucchesi<nicolo.lucchesi@gmail.com> Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Signed-off-by: Sage Moore <sage@neuralmagic.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Signed-off-by: Bill Nell <bnell@redhat.com> Signed-off-by: Shreeasish Kumar <shreeasish@rivosinc.com> Signed-off-by: Weida Hong <wdhongtw@google.com> Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com> Signed-off-by: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com> Signed-off-by: Amir Samani <asamani@nvidia.com> Signed-off-by: ElizaWszola <elizaw.9289@gmail.com> Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Signed-off-by: ilmarkov <markovilya197@gmail.com> Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Signed-off-by: rouchenzi <ruochenwen@gmail.com> Signed-off-by: rouchenzi <40842833+rouchenzi@users.noreply.github.com> Signed-off-by: Andrew Xia <axia@meta.com> Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com> Signed-off-by: Corey Lowman <clowman1993@gmail.com> Signed-off-by: jpvillam <jpvillam@amd.com> Signed-off-by: dougbtv <dosmith@redhat.com> Signed-off-by: Chenxi Yang <cxyang@fb.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: ahao-anyscale <ahao@anyscale.com> Signed-off-by: Yan Lu <luyan@nvidia.com> Signed-off-by: baxingpiaochong <771405853@qq.com> Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> Signed-off-by: Nikhil Gupta <nikhil.gupta2@arm.com> Signed-off-by: Yong Hoon Shin <yhshin@meta.com> Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai> Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Signed-off-by: Ben Browning <bbrownin@redhat.com> Signed-off-by: Chengji Yao <chengjiyao@google.com> Signed-off-by: jiang1.li <jiang1.li@intel.com> Signed-off-by: Jackmin801 <ongjackm@gmail.com> Signed-off-by: Jonas M. Kübler <44084297+jmkuebler@users.noreply.github.com> Signed-off-by: taohui <taohui3@gmail.com> Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io> Signed-off-by: Shu Wang <shuw@nvidia.com> Signed-off-by: Shu Wang. <shuw@nvidia.com> Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Signed-off-by: Duncan Moss <djm.moss@gmail.com> Signed-off-by: Shiyan Deng <dsy842974287@meta.com> Signed-off-by: Wei Wei <wwei6@meta.com> Signed-off-by: Saman Keon <samanamp@outlook.com> Signed-off-by: yangxurui <yangxurui@meituan.com> Signed-off-by: nicole-lihui <nicole.li@daocloud.io> Signed-off-by: courage17340 <courage17340@163.com> Signed-off-by: Jacob Kahn <jacobkahn1@gmail.com> Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com> Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai> Signed-off-by: zxw <1020938856@qq.com> Signed-off-by: wang.yuqi <noooop@126.com> Signed-off-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Signed-off-by: chenlang <chen.lang5@zte.com.cn> Signed-off-by: Jonas Kuebler <kuebj@amazon.com> Signed-off-by: AlonKejzman <alonkeizman@gmail.com> Signed-off-by: Tao Hui <taohui3@gmail.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Signed-off-by: Eugene Khvedchenia <ekhvedchenia@nvidia.com> Signed-off-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Signed-off-by: yiting.jiang <yiting.jiang@daocloud.io> Signed-off-by: xaguilar <Xavier.AguilarFruto@amd.com> Signed-off-by: Iceber Gu <caiwei95@hotmail.com> Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com> Signed-off-by: Icey <1790571317@qq.com> Signed-off-by: 许文卿 <xwq391974@alibaba-inc.com> Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: Nick Hill <nhill@redhat.com> Co-authored-by: Lucas Kabela <lucasakabela@gmail.com> Co-authored-by: Maximilien de Bayser <mbayser@br.ibm.com> Co-authored-by: Andrew Sansom <andrew@protopia.ai> Co-authored-by: Boyuan Feng <boyuan@meta.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: JartX <sagformas@epdcenter.es> Co-authored-by: Chendi.Xue <chendi.xue@intel.com> Co-authored-by: Chauncey <chaunceyjiang@gmail.com> Co-authored-by: xin.li <xin.li@daocloud.io> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Roger Wang <hey@rogerw.io> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Wenlong Wang <wangwenlong2755@gmail.com> Co-authored-by: Manoel Marques <manoelmrqs@gmail.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: lirong <56789630+lirong-lirong@users.noreply.github.com> Co-authored-by: Michael Yao <haifeng.yao@daocloud.io> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: Huamin Li <3ericli@gmail.com> Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com> Co-authored-by: Simon Danielsson <70206058+simondanielsson@users.noreply.github.com> Co-authored-by: Rahul Tuli <rtuli@redhat.com> Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: Yang Liu <127183760+KKSK-DON@users.noreply.github.com> Co-authored-by: Deboleina <debroy@redhat.com> Co-authored-by: yinz-aizip <yinz@aizip.ai> Co-authored-by: WeiQing Chen <40507679+david6666666@users.noreply.github.com> Co-authored-by: wangzi <3220100013@zju.edu.cn> Co-authored-by: Eldar Kurtić <8884008+eldarkurtic@users.noreply.github.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com> Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com> Co-authored-by: Yizhou <136800916+yiz-liu@users.noreply.github.com> Co-authored-by: Sara-KS <50249410+Sara-KS@users.noreply.github.com> Co-authored-by: Csrayz <jover@cmbchina.com> Co-authored-by: ivyilike <pww123@cmbchina.com> Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com> Co-authored-by: Bowen Wang <abmfy@icloud.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Daisy-Ma-coder <daisy.ma.0117@gmail.com> Co-authored-by: qqma <qqma@amazon.com> Co-authored-by: ElizaWszola <ewszola@redhat.com> Co-authored-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Co-authored-by: Zhuohan Li <zhuohan123@gmail.com> Co-authored-by: Simon Mo <simon.mo@hey.com> Co-authored-by: Or Ozeri <oro@il.ibm.com> Co-authored-by: Johnny Yang <24908445+jcyang43@users.noreply.github.com> Co-authored-by: Chengji Yao <chengjiyao@google.com> Co-authored-by: Alec S <10566873+alecsolder@users.noreply.github.com> Co-authored-by: Alec Solder <alecs@fb.com> Co-authored-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Chris Bamford <chrisbam4d@gmail.com> Co-authored-by: Alexander Matveev <59768536+alexm-redhat@users.noreply.github.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: JJJYmmm <92386084+JJJYmmm@users.noreply.github.com> Co-authored-by: liuye.hj <liuye.hj@alibaba-inc.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com> Co-authored-by: Lucia (Lu) Fang <fanglu@meta.com> Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Ming Yang <yming@meta.com> Co-authored-by: Zhikaiiii <55917203+Zhikaiiii@users.noreply.github.com> Co-authored-by: Andreas Hartel <andreas@hartel.me> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: Joel <wuxibin89@163.com> Co-authored-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Mark McLoughlin <markmc@redhat.com> Co-authored-by: Peter Pan <peter.pan@daocloud.io> Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com> Co-authored-by: Fanli Lin <fanli.lin@intel.com> Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Sage Moore <sage@neuralmagic.com> Co-authored-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: bnellnm <49004751+bnellnm@users.noreply.github.com> Co-authored-by: rivos-shreeasish <shreeasish@rivosinc.com> Co-authored-by: Chih-Chieh Yang <chih.chieh.yang@ibm.com> Co-authored-by: Weida Hong <wdhongtw@gmail.com> Co-authored-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Co-authored-by: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com> Co-authored-by: Amir Samani <samani@ualberta.ca> Co-authored-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Co-authored-by: Ilya Markov <markovilya197@gmail.com> Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Co-authored-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Co-authored-by: rouchenzi <40842833+rouchenzi@users.noreply.github.com> Co-authored-by: Andrew Xia <axia@meta.com> Co-authored-by: kourosh hakhamaneshi <31483498+kouroshHakha@users.noreply.github.com> Co-authored-by: Corey Lowman <clowman1993@gmail.com> Co-authored-by: Juan Villamizar <100237675+jpvillam-amd@users.noreply.github.com> Co-authored-by: jpvillam <jpvillam@amd.com> Co-authored-by: Doug Smith <dosmith@redhat.com> Co-authored-by: Chenxi Yang <cxyang@cs.utexas.edu> Co-authored-by: Chenxi Yang <cxyang@fb.com> Co-authored-by: ahao-anyscale <ahao@anyscale.com> Co-authored-by: 0xNullPath <luyanfcp@foxmail.com> Co-authored-by: baxingpiaochong <771405853@qq.com> Co-authored-by: Benjamin Chislett <bchislett@nvidia.com> Co-authored-by: Kyle Sayers <kylesayrs@gmail.com> Co-authored-by: Nikhil Gupta <nikhil.gupta2@arm.com> Co-authored-by: Yong Hoon Shin <48474650+sarckk@users.noreply.github.com> Co-authored-by: lhsjohn <huashuoli@tencent.com> Co-authored-by: Ben Browning <bbrownin@redhat.com> Co-authored-by: Li, Jiang <jiang1.li@intel.com> Co-authored-by: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> Co-authored-by: Jonas M. Kübler <44084297+jmkuebler@users.noreply.github.com> Co-authored-by: Tao Hui <taohui3@gmail.com> Co-authored-by: rongfu.leng <rongfu.leng@daocloud.io> Co-authored-by: Shu Wang <shuw@nvidia.com> Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Co-authored-by: Duncan Moss <djm.moss@gmail.com> Co-authored-by: Shiyan Deng <dsy842974287@meta.com> Co-authored-by: Wei Wei <wwei6@meta.com> Co-authored-by: Saman A. Pour <samanamp@outlook.com> Co-authored-by: XuruiYang <530534756@qq.com> Co-authored-by: yangxurui <yangxurui@meituan.com> Co-authored-by: Nicole LiHui 🥜 <nicolelihui@outlook.com> Co-authored-by: courage17340 <courage17340@users.noreply.github.com> Co-authored-by: Jacob Kahn <jacobkahn1@gmail.com> Co-authored-by: Nicole LiHui 🥜 <nicole.li@daocloud.io> Co-authored-by: Fadi Arafeh <115173828+fadara01@users.noreply.github.com> Co-authored-by: Agata Dobrzyniewicz <160237065+adobrzyn@users.noreply.github.com> Co-authored-by: yyzxw <34639446+yyzxw@users.noreply.github.com> Co-authored-by: wang.yuqi <noooop@126.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: chenlang <chen.lang5@zte.com.cn> Co-authored-by: chenlang <10346245@zte.com.cn> Co-authored-by: AlonKejzman <alonkeizman@gmail.com> Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> Co-authored-by: Aleksandr Malyshev <164964928+maleksan85@users.noreply.github.com> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Doug Lehr <douglehr@amd.com> Co-authored-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Co-authored-by: yitingdc <59356937+yitingdc@users.noreply.github.com> Co-authored-by: xaguilar-amd <xavier.aguilarfruto@amd.com> Co-authored-by: Iceber Gu <caiwei95@hotmail.com> Co-authored-by: Tao He <linzhu.ht@alibaba-inc.com> Co-authored-by: Icey <1790571317@qq.com> Co-authored-by: Xu Wenqing <121550081+Xu-Wenqing@users.noreply.github.com> Co-authored-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: RishiAstra <40644327+RishiAstra@users.noreply.github.com>
249 lines
9.6 KiB
Python
249 lines
9.6 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import tempfile
|
|
from collections.abc import Iterable
|
|
from contextlib import contextmanager
|
|
from functools import partial
|
|
from typing import Any, Union
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import torch.nn as nn
|
|
from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk,
|
|
UserMessage)
|
|
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
|
from PIL import Image
|
|
|
|
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
|
|
from vllm.distributed import (cleanup_dist_env_and_memory,
|
|
init_distributed_environment,
|
|
initialize_model_parallel)
|
|
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
|
from vllm.model_executor.models.interfaces import (SupportsMultiModal,
|
|
supports_multimodal)
|
|
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensorInputs
|
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|
InputProcessingContext)
|
|
from vllm.multimodal.utils import group_mm_kwargs_by_modality
|
|
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
|
from vllm.utils import is_list_of
|
|
|
|
from ...registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS
|
|
from ...utils import dummy_hf_overrides
|
|
|
|
ARCH_TO_SKIP = {
|
|
"MolmoForCausalLM": "incompatible requirements",
|
|
}
|
|
ARCH_NEEDS_EXTRAS = [
|
|
"InternVLChatModel",
|
|
"Idefics3ForConditionalGeneration",
|
|
"LlavaForConditionalGeneration",
|
|
"MiniCPMV",
|
|
"PaliGemmaForConditionalGeneration",
|
|
]
|
|
REPO_ID_TO_SKIP = {
|
|
"nm-testing/pixtral-12b-FP8-dynamic": "duplicated test",
|
|
}
|
|
|
|
ImageInput = list[Image.Image]
|
|
VideoInput = Union[list[Image.Image], list[np.ndarray],
|
|
list[tuple[np.ndarray, dict[str, Any]]]]
|
|
AudioInput = list[tuple[np.ndarray, int]]
|
|
|
|
|
|
def _resize_data(_data: Union[Image.Image, np.ndarray],
|
|
size_factor: float) -> Union[Image.Image, np.ndarray]:
|
|
assert size_factor <= 1, "Size factor must be less than 1"
|
|
# Image input
|
|
if isinstance(_data, Image.Image):
|
|
W, H = _data.width, _data.height
|
|
W, H = map(lambda x: int(x * size_factor), (W, H))
|
|
return _data.resize((W, H))
|
|
# Video input with PIL Images
|
|
elif is_list_of(_data, Image.Image):
|
|
W, H = next(iter(_data)).width, next(iter(_data)).height
|
|
T = len(_data)
|
|
T, W, H = map(lambda x: max(int(x * size_factor), 1), (T, W, H))
|
|
return [d.resize((W, H)) for d in _data[:T]]
|
|
# Video input with numpy arrays
|
|
elif isinstance(_data, np.ndarray) and _data.ndim >= 4:
|
|
T, H, W, C = _data.shape[-4:]
|
|
T, H, W = map(lambda x: max(int(x * size_factor), 1), (T, H, W))
|
|
return _data[..., :T, :H, :W, :C]
|
|
# Audio input
|
|
elif isinstance(_data, np.ndarray) and _data.ndim == 1:
|
|
return _data[:int(len(_data) * size_factor)]
|
|
raise AssertionError("This line should be unreachable.")
|
|
|
|
|
|
def resize_mm_data(
|
|
data: Union[ImageInput, VideoInput, AudioInput],
|
|
size_factors: tuple[float,
|
|
...]) -> Union[ImageInput, VideoInput, AudioInput]:
|
|
size_factors = size_factors[:len(data)]
|
|
if is_list_of(data, (Image.Image, np.ndarray, list)):
|
|
return [_resize_data(d, s) for d, s in zip(data, size_factors)]
|
|
elif is_list_of(data, tuple):
|
|
return [(_resize_data(d, s), meta)
|
|
for (d, meta), s in zip(data, size_factors)]
|
|
raise ValueError("Unsupported multimodal data type.")
|
|
|
|
|
|
def create_batched_mm_kwargs(
|
|
model_cls: type[SupportsMultiModal],
|
|
model_config: ModelConfig,
|
|
processor: BaseMultiModalProcessor,
|
|
size_factors: tuple[float, ...] = (1.0, 0.5, 0.25),
|
|
) -> Iterable[tuple[str, int, BatchedTensorInputs]]:
|
|
processing_info = processor.info
|
|
dummy_inputs = processor.dummy_inputs
|
|
supported_mm_limits = processing_info.get_supported_mm_limits()
|
|
mm_counts = {
|
|
modality: 3 if limit is None else limit
|
|
for modality, limit in supported_mm_limits.items()
|
|
}
|
|
processor_inputs = dummy_inputs.get_dummy_processor_inputs(
|
|
seq_len=model_config.max_model_len,
|
|
mm_counts=mm_counts,
|
|
)
|
|
mm_data = processor_inputs.mm_data
|
|
resized_mm_data = {
|
|
modality: resize_mm_data(data, size_factors)
|
|
for modality, data in mm_data.items()
|
|
}
|
|
# Mistral chat outputs tokens directly, rather than text prompts
|
|
if model_config.tokenizer_mode == "mistral":
|
|
images = resized_mm_data.get("image", [])
|
|
request = ChatCompletionRequest(messages=[
|
|
UserMessage(content=[
|
|
TextChunk(text=""),
|
|
*(ImageChunk(image=image) for image in images),
|
|
]),
|
|
])
|
|
tokenizer = processing_info.get_tokenizer()
|
|
res = tokenizer.mistral.encode_chat_completion(request)
|
|
prompt = res.tokens
|
|
else:
|
|
prompt = processor_inputs.prompt
|
|
mm_kwargs = processor.apply(
|
|
prompt=prompt,
|
|
mm_data=resized_mm_data,
|
|
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
|
|
tokenization_kwargs=processor_inputs.tokenization_kwargs,
|
|
)["mm_kwargs"].require_data()
|
|
items = [
|
|
item for modality in supported_mm_limits
|
|
for item in mm_kwargs[modality]
|
|
]
|
|
return group_mm_kwargs_by_modality(
|
|
items,
|
|
merge_by_field_config=model_cls.merge_by_field_config,
|
|
)
|
|
|
|
|
|
@contextmanager
|
|
def initialize_dummy_model(
|
|
model_cls: type[nn.Module],
|
|
model_config: ModelConfig,
|
|
):
|
|
temp_file = tempfile.mkstemp()[1]
|
|
init_distributed_environment(
|
|
world_size=1,
|
|
rank=0,
|
|
distributed_init_method=f"file://{temp_file}",
|
|
local_rank=0,
|
|
backend="nccl",
|
|
)
|
|
initialize_model_parallel(tensor_model_parallel_size=1)
|
|
vllm_config = VllmConfig(model_config=model_config)
|
|
with set_current_vllm_config(vllm_config=vllm_config):
|
|
with set_default_torch_dtype(model_config.dtype):
|
|
model = model_cls(vllm_config=vllm_config)
|
|
yield model
|
|
|
|
del model
|
|
cleanup_dist_env_and_memory()
|
|
|
|
|
|
def get_model_id_to_test(
|
|
model_arch_list: Iterable[str]) -> list[tuple[str, str]]:
|
|
filtered_results = []
|
|
for model_arch in model_arch_list:
|
|
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
|
|
if model_info.extras and model_arch in ARCH_NEEDS_EXTRAS:
|
|
available_repos = list(
|
|
map(lambda model_id: (model_arch, model_id),
|
|
[model_info.default, *model_info.extras.values()]))
|
|
filtered_results.extend(available_repos)
|
|
else:
|
|
filtered_results.append((model_arch, model_info.default))
|
|
return filtered_results
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"model_arch, model_id",
|
|
get_model_id_to_test(_MULTIMODAL_EXAMPLE_MODELS.keys()))
|
|
def test_model_tensor_schema(model_arch: str, model_id: str):
|
|
if model_arch in ARCH_TO_SKIP:
|
|
pytest.skip(f"Skipping {model_arch} due to {ARCH_TO_SKIP[model_arch]}")
|
|
if model_id in REPO_ID_TO_SKIP:
|
|
pytest.skip(f"Skipping {model_id} due to {REPO_ID_TO_SKIP[model_id]}")
|
|
|
|
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
|
|
model_info.check_available_online(on_fail="skip")
|
|
model_info.check_transformers_version(on_fail="skip",
|
|
check_max_version=False)
|
|
|
|
hf_overrides_fn = partial(dummy_hf_overrides,
|
|
model_arch=model_arch,
|
|
exist_overrides=model_info.hf_overrides)
|
|
|
|
model_config = ModelConfig(
|
|
model_id,
|
|
tokenizer=model_info.tokenizer or model_id,
|
|
tokenizer_mode=model_info.tokenizer_mode,
|
|
revision=model_info.revision,
|
|
trust_remote_code=model_info.trust_remote_code,
|
|
hf_overrides=hf_overrides_fn,
|
|
skip_tokenizer_init=model_info.skip_tokenizer_init,
|
|
enforce_eager=model_info.enforce_eager,
|
|
dtype=model_info.dtype,
|
|
)
|
|
|
|
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
|
|
assert supports_multimodal(model_cls)
|
|
|
|
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
|
|
|
|
inputs_parse_methods = []
|
|
for attr_name in dir(model_cls):
|
|
attr = getattr(model_cls, attr_name)
|
|
if hasattr(attr, "__annotations__"):
|
|
return_type = attr.__annotations__.get("return", None)
|
|
if return_type is not None and "Input" in str(return_type):
|
|
inputs_parse_methods.append(attr_name)
|
|
|
|
if not any(inputs_parse_methods):
|
|
pytest.skip(f"{model_arch} does not support tensor schema validation.")
|
|
|
|
ctx = InputProcessingContext(
|
|
model_config,
|
|
tokenizer=cached_tokenizer_from_config(model_config),
|
|
)
|
|
processing_info = factories.info(ctx)
|
|
supported_mm_limits = processing_info.get_supported_mm_limits()
|
|
limit_mm_per_prompt = {
|
|
modality: 3 if limit is None else limit
|
|
for modality, limit in supported_mm_limits.items()
|
|
}
|
|
model_config.get_multimodal_config().limit_per_prompt = limit_mm_per_prompt
|
|
processor = factories.build_processor(ctx, cache=None)
|
|
|
|
with initialize_dummy_model(model_cls, model_config) as model:
|
|
for modality, _, mm_kwargs in create_batched_mm_kwargs(
|
|
model_cls, model_config, processor):
|
|
for method_name in inputs_parse_methods:
|
|
print(f"Testing `{method_name}` with modality={modality} "
|
|
f"and mm_kwargs{list(mm_kwargs.keys())}")
|
|
getattr(model, method_name)(modality=modality, **mm_kwargs)
|