mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-08 05:07:02 +08:00
Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: Lucas Kabela <lucaskabela@meta.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: Andrew Sansom <andrew@protopia.ai> Signed-off-by: Boyuan Feng <boyuan@meta.com> Signed-off-by: Boyuan Feng <fby.1994@gmail.com> Signed-off-by: boyuanfeng <boyuan@meta.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: JartX <sagformas@epdcenter.es> Signed-off-by: Chendi Xue <Chendi.Xue@intel.com> Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Signed-off-by: Roger Wang <hey@rogerw.io> Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: wwl2755 <wangwenlong2755@gmail.com> Signed-off-by: Manoel Marques <manoel.marques@ibm.com> Signed-off-by: Manoel Marques <manoelmrqs@gmail.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: pengdrumli <pengdrumli@tencent.com> Signed-off-by: windsonsea <haifeng.yao@daocloud.io> Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai> Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: Huamin Li <3ericli@gmail.com> Signed-off-by: simondanielsson <simon.danielsson99@hotmail.com> Signed-off-by: Rahul Tuli <rtuli@redhat.com> Signed-off-by: Yang <lymailforjob@gmail.com> Signed-off-by: Debolina Roy <debroy@redhat.com> Signed-off-by: David Chen <530634352@qq.com> Signed-off-by: wangzi <3220100013@zju.edu.cn> Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com> Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com> Signed-off-by: Sara Kokkila Schumacher <saraks@ibm.com> Signed-off-by: Csrayz <jover@cmbchina.com> Signed-off-by: ivyilike <pww123@cmbchina.com> Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com> Signed-off-by: Bowen Wang <abmfy@icloud.com> Signed-off-by: qqma <qqma@amazon.com> Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: Lu Fang <fanglu@fb.com> Signed-off-by: Zhuohan Li <zhuohan123@gmail.com> Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: luka <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: Or Ozeri <oro@il.ibm.com> Signed-off-by: Johnny Yang <johnnyyang@google.com> Signed-off-by: Alec Solder <alecs@fb.com> Signed-off-by: Alec S <10566873+alecsolder@users.noreply.github.com> Signed-off-by: Russell Bryant <rbryant@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Alexander Matveev <amatveev@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: liuye.hj <liuye.hj@alibaba-inc.com> Signed-off-by: Kunshang Ji <kunshang.ji@intel.com> Signed-off-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Signed-off-by: Ming Yang <minos.future@gmail.com> Signed-off-by: Zhikaiiii <1658973216@qq.com> Signed-off-by: Andreas Hartel <andreas.hartel@aleph-alpha.com> Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Signed-off-by: wuxibin <wuxibin@bytedance.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: Peter Pan <Peter.Pan@daocloud.io> Signed-off-by: Peter Pan <peter.pan@daocloud.io> Signed-off-by: Nicolò Lucchesi<nicolo.lucchesi@gmail.com> Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Signed-off-by: Sage Moore <sage@neuralmagic.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Signed-off-by: Bill Nell <bnell@redhat.com> Signed-off-by: Shreeasish Kumar <shreeasish@rivosinc.com> Signed-off-by: Weida Hong <wdhongtw@google.com> Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com> Signed-off-by: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com> Signed-off-by: Amir Samani <asamani@nvidia.com> Signed-off-by: ElizaWszola <elizaw.9289@gmail.com> Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Signed-off-by: ilmarkov <markovilya197@gmail.com> Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Signed-off-by: rouchenzi <ruochenwen@gmail.com> Signed-off-by: rouchenzi <40842833+rouchenzi@users.noreply.github.com> Signed-off-by: Andrew Xia <axia@meta.com> Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com> Signed-off-by: Corey Lowman <clowman1993@gmail.com> Signed-off-by: jpvillam <jpvillam@amd.com> Signed-off-by: dougbtv <dosmith@redhat.com> Signed-off-by: Chenxi Yang <cxyang@fb.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: ahao-anyscale <ahao@anyscale.com> Signed-off-by: Yan Lu <luyan@nvidia.com> Signed-off-by: baxingpiaochong <771405853@qq.com> Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> Signed-off-by: Nikhil Gupta <nikhil.gupta2@arm.com> Signed-off-by: Yong Hoon Shin <yhshin@meta.com> Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai> Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Signed-off-by: Ben Browning <bbrownin@redhat.com> Signed-off-by: Chengji Yao <chengjiyao@google.com> Signed-off-by: jiang1.li <jiang1.li@intel.com> Signed-off-by: Jackmin801 <ongjackm@gmail.com> Signed-off-by: Jonas M. Kübler <44084297+jmkuebler@users.noreply.github.com> Signed-off-by: taohui <taohui3@gmail.com> Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io> Signed-off-by: Shu Wang <shuw@nvidia.com> Signed-off-by: Shu Wang. <shuw@nvidia.com> Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Signed-off-by: Duncan Moss <djm.moss@gmail.com> Signed-off-by: Shiyan Deng <dsy842974287@meta.com> Signed-off-by: Wei Wei <wwei6@meta.com> Signed-off-by: Saman Keon <samanamp@outlook.com> Signed-off-by: yangxurui <yangxurui@meituan.com> Signed-off-by: nicole-lihui <nicole.li@daocloud.io> Signed-off-by: courage17340 <courage17340@163.com> Signed-off-by: Jacob Kahn <jacobkahn1@gmail.com> Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com> Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai> Signed-off-by: zxw <1020938856@qq.com> Signed-off-by: wang.yuqi <noooop@126.com> Signed-off-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Signed-off-by: chenlang <chen.lang5@zte.com.cn> Signed-off-by: Jonas Kuebler <kuebj@amazon.com> Signed-off-by: AlonKejzman <alonkeizman@gmail.com> Signed-off-by: Tao Hui <taohui3@gmail.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Signed-off-by: Eugene Khvedchenia <ekhvedchenia@nvidia.com> Signed-off-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Signed-off-by: yiting.jiang <yiting.jiang@daocloud.io> Signed-off-by: xaguilar <Xavier.AguilarFruto@amd.com> Signed-off-by: Iceber Gu <caiwei95@hotmail.com> Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com> Signed-off-by: Icey <1790571317@qq.com> Signed-off-by: 许文卿 <xwq391974@alibaba-inc.com> Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: Nick Hill <nhill@redhat.com> Co-authored-by: Lucas Kabela <lucasakabela@gmail.com> Co-authored-by: Maximilien de Bayser <mbayser@br.ibm.com> Co-authored-by: Andrew Sansom <andrew@protopia.ai> Co-authored-by: Boyuan Feng <boyuan@meta.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: JartX <sagformas@epdcenter.es> Co-authored-by: Chendi.Xue <chendi.xue@intel.com> Co-authored-by: Chauncey <chaunceyjiang@gmail.com> Co-authored-by: xin.li <xin.li@daocloud.io> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Roger Wang <hey@rogerw.io> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Wenlong Wang <wangwenlong2755@gmail.com> Co-authored-by: Manoel Marques <manoelmrqs@gmail.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: lirong <56789630+lirong-lirong@users.noreply.github.com> Co-authored-by: Michael Yao <haifeng.yao@daocloud.io> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: Huamin Li <3ericli@gmail.com> Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com> Co-authored-by: Simon Danielsson <70206058+simondanielsson@users.noreply.github.com> Co-authored-by: Rahul Tuli <rtuli@redhat.com> Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: Yang Liu <127183760+KKSK-DON@users.noreply.github.com> Co-authored-by: Deboleina <debroy@redhat.com> Co-authored-by: yinz-aizip <yinz@aizip.ai> Co-authored-by: WeiQing Chen <40507679+david6666666@users.noreply.github.com> Co-authored-by: wangzi <3220100013@zju.edu.cn> Co-authored-by: Eldar Kurtić <8884008+eldarkurtic@users.noreply.github.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com> Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com> Co-authored-by: Yizhou <136800916+yiz-liu@users.noreply.github.com> Co-authored-by: Sara-KS <50249410+Sara-KS@users.noreply.github.com> Co-authored-by: Csrayz <jover@cmbchina.com> Co-authored-by: ivyilike <pww123@cmbchina.com> Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com> Co-authored-by: Bowen Wang <abmfy@icloud.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Daisy-Ma-coder <daisy.ma.0117@gmail.com> Co-authored-by: qqma <qqma@amazon.com> Co-authored-by: ElizaWszola <ewszola@redhat.com> Co-authored-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Co-authored-by: Zhuohan Li <zhuohan123@gmail.com> Co-authored-by: Simon Mo <simon.mo@hey.com> Co-authored-by: Or Ozeri <oro@il.ibm.com> Co-authored-by: Johnny Yang <24908445+jcyang43@users.noreply.github.com> Co-authored-by: Chengji Yao <chengjiyao@google.com> Co-authored-by: Alec S <10566873+alecsolder@users.noreply.github.com> Co-authored-by: Alec Solder <alecs@fb.com> Co-authored-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Chris Bamford <chrisbam4d@gmail.com> Co-authored-by: Alexander Matveev <59768536+alexm-redhat@users.noreply.github.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: JJJYmmm <92386084+JJJYmmm@users.noreply.github.com> Co-authored-by: liuye.hj <liuye.hj@alibaba-inc.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com> Co-authored-by: Lucia (Lu) Fang <fanglu@meta.com> Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Ming Yang <yming@meta.com> Co-authored-by: Zhikaiiii <55917203+Zhikaiiii@users.noreply.github.com> Co-authored-by: Andreas Hartel <andreas@hartel.me> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: Joel <wuxibin89@163.com> Co-authored-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Mark McLoughlin <markmc@redhat.com> Co-authored-by: Peter Pan <peter.pan@daocloud.io> Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com> Co-authored-by: Fanli Lin <fanli.lin@intel.com> Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Sage Moore <sage@neuralmagic.com> Co-authored-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: bnellnm <49004751+bnellnm@users.noreply.github.com> Co-authored-by: rivos-shreeasish <shreeasish@rivosinc.com> Co-authored-by: Chih-Chieh Yang <chih.chieh.yang@ibm.com> Co-authored-by: Weida Hong <wdhongtw@gmail.com> Co-authored-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Co-authored-by: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com> Co-authored-by: Amir Samani <samani@ualberta.ca> Co-authored-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Co-authored-by: Ilya Markov <markovilya197@gmail.com> Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Co-authored-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Co-authored-by: rouchenzi <40842833+rouchenzi@users.noreply.github.com> Co-authored-by: Andrew Xia <axia@meta.com> Co-authored-by: kourosh hakhamaneshi <31483498+kouroshHakha@users.noreply.github.com> Co-authored-by: Corey Lowman <clowman1993@gmail.com> Co-authored-by: Juan Villamizar <100237675+jpvillam-amd@users.noreply.github.com> Co-authored-by: jpvillam <jpvillam@amd.com> Co-authored-by: Doug Smith <dosmith@redhat.com> Co-authored-by: Chenxi Yang <cxyang@cs.utexas.edu> Co-authored-by: Chenxi Yang <cxyang@fb.com> Co-authored-by: ahao-anyscale <ahao@anyscale.com> Co-authored-by: 0xNullPath <luyanfcp@foxmail.com> Co-authored-by: baxingpiaochong <771405853@qq.com> Co-authored-by: Benjamin Chislett <bchislett@nvidia.com> Co-authored-by: Kyle Sayers <kylesayrs@gmail.com> Co-authored-by: Nikhil Gupta <nikhil.gupta2@arm.com> Co-authored-by: Yong Hoon Shin <48474650+sarckk@users.noreply.github.com> Co-authored-by: lhsjohn <huashuoli@tencent.com> Co-authored-by: Ben Browning <bbrownin@redhat.com> Co-authored-by: Li, Jiang <jiang1.li@intel.com> Co-authored-by: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> Co-authored-by: Jonas M. Kübler <44084297+jmkuebler@users.noreply.github.com> Co-authored-by: Tao Hui <taohui3@gmail.com> Co-authored-by: rongfu.leng <rongfu.leng@daocloud.io> Co-authored-by: Shu Wang <shuw@nvidia.com> Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Co-authored-by: Duncan Moss <djm.moss@gmail.com> Co-authored-by: Shiyan Deng <dsy842974287@meta.com> Co-authored-by: Wei Wei <wwei6@meta.com> Co-authored-by: Saman A. Pour <samanamp@outlook.com> Co-authored-by: XuruiYang <530534756@qq.com> Co-authored-by: yangxurui <yangxurui@meituan.com> Co-authored-by: Nicole LiHui 🥜 <nicolelihui@outlook.com> Co-authored-by: courage17340 <courage17340@users.noreply.github.com> Co-authored-by: Jacob Kahn <jacobkahn1@gmail.com> Co-authored-by: Nicole LiHui 🥜 <nicole.li@daocloud.io> Co-authored-by: Fadi Arafeh <115173828+fadara01@users.noreply.github.com> Co-authored-by: Agata Dobrzyniewicz <160237065+adobrzyn@users.noreply.github.com> Co-authored-by: yyzxw <34639446+yyzxw@users.noreply.github.com> Co-authored-by: wang.yuqi <noooop@126.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: chenlang <chen.lang5@zte.com.cn> Co-authored-by: chenlang <10346245@zte.com.cn> Co-authored-by: AlonKejzman <alonkeizman@gmail.com> Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> Co-authored-by: Aleksandr Malyshev <164964928+maleksan85@users.noreply.github.com> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Doug Lehr <douglehr@amd.com> Co-authored-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Co-authored-by: yitingdc <59356937+yitingdc@users.noreply.github.com> Co-authored-by: xaguilar-amd <xavier.aguilarfruto@amd.com> Co-authored-by: Iceber Gu <caiwei95@hotmail.com> Co-authored-by: Tao He <linzhu.ht@alibaba-inc.com> Co-authored-by: Icey <1790571317@qq.com> Co-authored-by: Xu Wenqing <121550081+Xu-Wenqing@users.noreply.github.com> Co-authored-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: RishiAstra <40644327+RishiAstra@users.noreply.github.com>
262 lines
11 KiB
Markdown
262 lines
11 KiB
Markdown
# Pooling Models
|
|
|
|
vLLM also supports pooling models, such as embedding, classification and reward models.
|
|
|
|
In vLLM, pooling models implement the [VllmModelForPooling][vllm.model_executor.models.VllmModelForPooling] interface.
|
|
These models use a [Pooler][vllm.model_executor.layers.pooler.Pooler] to extract the final hidden states of the input
|
|
before returning them.
|
|
|
|
!!! note
|
|
We currently support pooling models primarily as a matter of convenience. This is not guaranteed to have any performance improvement over using HF Transformers / Sentence Transformers directly.
|
|
|
|
We are now planning to optimize pooling models in vLLM. Please comment on <gh-issue:21796> if you have any suggestions!
|
|
|
|
## Configuration
|
|
|
|
### Model Runner
|
|
|
|
Run a model in pooling mode via the option `--runner pooling`.
|
|
|
|
!!! tip
|
|
There is no need to set this option in the vast majority of cases as vLLM can automatically
|
|
detect the model runner to use via `--runner auto`.
|
|
|
|
### Model Conversion
|
|
|
|
vLLM can adapt models for various pooling tasks via the option `--convert <type>`.
|
|
|
|
If `--runner pooling` has been set (manually or automatically) but the model does not implement the
|
|
[VllmModelForPooling][vllm.model_executor.models.VllmModelForPooling] interface,
|
|
vLLM will attempt to automatically convert the model according to the architecture names
|
|
shown in the table below.
|
|
|
|
| Architecture | `--convert` | Supported pooling tasks |
|
|
|-------------------------------------------------|-------------|-------------------------------|
|
|
| `*ForTextEncoding`, `*EmbeddingModel`, `*Model` | `embed` | `encode`, `embed` |
|
|
| `*For*Classification`, `*ClassificationModel` | `classify` | `encode`, `classify`, `score` |
|
|
| `*ForRewardModeling`, `*RewardModel` | `reward` | `encode` |
|
|
|
|
!!! tip
|
|
You can explicitly set `--convert <type>` to specify how to convert the model.
|
|
|
|
### Pooling Tasks
|
|
|
|
Each pooling model in vLLM supports one or more of these tasks according to
|
|
[Pooler.get_supported_tasks][vllm.model_executor.layers.pooler.Pooler.get_supported_tasks],
|
|
enabling the corresponding APIs:
|
|
|
|
| Task | APIs |
|
|
|------------|--------------------------------------|
|
|
| `encode` | `LLM.reward(...)` |
|
|
| `embed` | `LLM.embed(...)`, `LLM.score(...)`\* |
|
|
| `classify` | `LLM.classify(...)` |
|
|
| `score` | `LLM.score(...)` |
|
|
|
|
\* The `LLM.score(...)` API falls back to `embed` task if the model does not support `score` task.
|
|
|
|
### Pooler Configuration
|
|
|
|
#### Predefined models
|
|
|
|
If the [Pooler][vllm.model_executor.layers.pooler.Pooler] defined by the model accepts `pooler_config`,
|
|
you can override some of its attributes via the `--pooler-config` option.
|
|
|
|
#### Converted models
|
|
|
|
If the model has been converted via `--convert` (see above),
|
|
the pooler assigned to each task has the following attributes by default:
|
|
|
|
| Task | Pooling Type | Normalization | Softmax |
|
|
|------------|--------------|---------------|---------|
|
|
| `reward` | `ALL` | ❌ | ❌ |
|
|
| `embed` | `LAST` | ✅︎ | ❌ |
|
|
| `classify` | `LAST` | ❌ | ✅︎ |
|
|
|
|
When loading [Sentence Transformers](https://huggingface.co/sentence-transformers) models,
|
|
its Sentence Transformers configuration file (`modules.json`) takes priority over the model's defaults.
|
|
|
|
You can further customize this via the `--pooler-config` option,
|
|
which takes priority over both the model's and Sentence Transformers's defaults.
|
|
|
|
## Offline Inference
|
|
|
|
The [LLM][vllm.LLM] class provides various methods for offline inference.
|
|
See [configuration](../api/README.md#configuration) for a list of options when initializing the model.
|
|
|
|
### `LLM.embed`
|
|
|
|
The [embed][vllm.LLM.embed] method outputs an embedding vector for each prompt.
|
|
It is primarily designed for embedding models.
|
|
|
|
```python
|
|
from vllm import LLM
|
|
|
|
llm = LLM(model="intfloat/e5-small", runner="pooling")
|
|
(output,) = llm.embed("Hello, my name is")
|
|
|
|
embeds = output.outputs.embedding
|
|
print(f"Embeddings: {embeds!r} (size={len(embeds)})")
|
|
```
|
|
|
|
A code example can be found here: <gh-file:examples/offline_inference/basic/embed.py>
|
|
|
|
### `LLM.classify`
|
|
|
|
The [classify][vllm.LLM.classify] method outputs a probability vector for each prompt.
|
|
It is primarily designed for classification models.
|
|
|
|
```python
|
|
from vllm import LLM
|
|
|
|
llm = LLM(model="jason9693/Qwen2.5-1.5B-apeach", runner="pooling")
|
|
(output,) = llm.classify("Hello, my name is")
|
|
|
|
probs = output.outputs.probs
|
|
print(f"Class Probabilities: {probs!r} (size={len(probs)})")
|
|
```
|
|
|
|
A code example can be found here: <gh-file:examples/offline_inference/basic/classify.py>
|
|
|
|
### `LLM.score`
|
|
|
|
The [score][vllm.LLM.score] method outputs similarity scores between sentence pairs.
|
|
It is designed for embedding models and cross-encoder models. Embedding models use cosine similarity, and [cross-encoder models](https://www.sbert.net/examples/applications/cross-encoder/README.html) serve as rerankers between candidate query-document pairs in RAG systems.
|
|
|
|
!!! note
|
|
vLLM can only perform the model inference component (e.g. embedding, reranking) of RAG.
|
|
To handle RAG at a higher level, you should use integration frameworks such as [LangChain](https://github.com/langchain-ai/langchain).
|
|
|
|
```python
|
|
from vllm import LLM
|
|
|
|
llm = LLM(model="BAAI/bge-reranker-v2-m3", runner="pooling")
|
|
(output,) = llm.score("What is the capital of France?",
|
|
"The capital of Brazil is Brasilia.")
|
|
|
|
score = output.outputs.score
|
|
print(f"Score: {score}")
|
|
```
|
|
|
|
A code example can be found here: <gh-file:examples/offline_inference/basic/score.py>
|
|
|
|
### `LLM.reward`
|
|
|
|
The [reward][vllm.LLM.reward] method is available to all reward models in vLLM.
|
|
It returns the extracted hidden states directly.
|
|
|
|
```python
|
|
from vllm import LLM
|
|
|
|
llm = LLM(model="internlm/internlm2-1_8b-reward", runner="pooling", trust_remote_code=True)
|
|
(output,) = llm.reward("Hello, my name is")
|
|
|
|
data = output.outputs.data
|
|
print(f"Data: {data!r}")
|
|
```
|
|
|
|
A code example can be found here: <gh-file:examples/offline_inference/basic/reward.py>
|
|
|
|
### `LLM.encode`
|
|
|
|
The [encode][vllm.LLM.encode] method is available to all pooling models in vLLM.
|
|
It returns the extracted hidden states directly.
|
|
|
|
!!! note
|
|
Please use one of the more specific methods or set the task directly when using `LLM.encode`:
|
|
|
|
- For embeddings, use `LLM.embed(...)` or `pooling_task="embed"`.
|
|
- For classification logits, use `LLM.classify(...)` or `pooling_task="classify"`.
|
|
- For rewards, use `LLM.reward(...)` or `pooling_task="reward"`.
|
|
- For similarity scores, use `LLM.score(...)`.
|
|
|
|
```python
|
|
from vllm import LLM
|
|
|
|
llm = LLM(model="intfloat/e5-small", runner="pooling")
|
|
(output,) = llm.encode("Hello, my name is", pooling_task="embed")
|
|
|
|
data = output.outputs.data
|
|
print(f"Data: {data!r}")
|
|
```
|
|
|
|
## Online Serving
|
|
|
|
Our [OpenAI-Compatible Server](../serving/openai_compatible_server.md) provides endpoints that correspond to the offline APIs:
|
|
|
|
- [Pooling API][pooling-api] is similar to `LLM.encode`, being applicable to all types of pooling models.
|
|
- [Embeddings API][embeddings-api] is similar to `LLM.embed`, accepting both text and [multi-modal inputs](../features/multimodal_inputs.md) for embedding models.
|
|
- [Classification API][classification-api] is similar to `LLM.classify` and is applicable to sequence classification models.
|
|
- [Score API][score-api] is similar to `LLM.score` for cross-encoder models.
|
|
|
|
## Matryoshka Embeddings
|
|
|
|
[Matryoshka Embeddings](https://sbert.net/examples/sentence_transformer/training/matryoshka/README.html#matryoshka-embeddings) or [Matryoshka Representation Learning (MRL)](https://arxiv.org/abs/2205.13147) is a technique used in training embedding models. It allows user to trade off between performance and cost.
|
|
|
|
!!! warning
|
|
Not all embedding models are trained using Matryoshka Representation Learning. To avoid misuse of the `dimensions` parameter, vLLM returns an error for requests that attempt to change the output dimension of models that do not support Matryoshka Embeddings.
|
|
|
|
For example, setting `dimensions` parameter while using the `BAAI/bge-m3` model will result in the following error.
|
|
|
|
```json
|
|
{"object":"error","message":"Model \"BAAI/bge-m3\" does not support matryoshka representation, changing output dimensions will lead to poor results.","type":"BadRequestError","param":null,"code":400}
|
|
```
|
|
|
|
### Manually enable Matryoshka Embeddings
|
|
|
|
There is currently no official interface for specifying support for Matryoshka Embeddings. In vLLM, if `is_matryoshka` is `True` in `config.json,` it is allowed to change the output to arbitrary dimensions. Using `matryoshka_dimensions` can control the allowed output dimensions.
|
|
|
|
For models that support Matryoshka Embeddings but not recognized by vLLM, please manually override the config using `hf_overrides={"is_matryoshka": True}`, `hf_overrides={"matryoshka_dimensions": [<allowed output dimensions>]}` (offline) or `--hf-overrides '{"is_matryoshka": true}'`, `--hf-overrides '{"matryoshka_dimensions": [<allowed output dimensions>]}'`(online).
|
|
|
|
Here is an example to serve a model with Matryoshka Embeddings enabled.
|
|
|
|
```text
|
|
vllm serve Snowflake/snowflake-arctic-embed-m-v1.5 --hf-overrides '{"matryoshka_dimensions":[256]}'
|
|
```
|
|
|
|
### Offline Inference
|
|
|
|
You can change the output dimensions of embedding models that support Matryoshka Embeddings by using the dimensions parameter in [PoolingParams][vllm.PoolingParams].
|
|
|
|
```python
|
|
from vllm import LLM, PoolingParams
|
|
|
|
llm = LLM(model="jinaai/jina-embeddings-v3",
|
|
runner="pooling",
|
|
trust_remote_code=True)
|
|
outputs = llm.embed(["Follow the white rabbit."],
|
|
pooling_params=PoolingParams(dimensions=32))
|
|
print(outputs[0].outputs)
|
|
```
|
|
|
|
A code example can be found here: <gh-file:examples/offline_inference/pooling/embed_matryoshka_fy.py>
|
|
|
|
### Online Inference
|
|
|
|
Use the following command to start vllm server.
|
|
|
|
```text
|
|
vllm serve jinaai/jina-embeddings-v3 --trust-remote-code
|
|
```
|
|
|
|
You can change the output dimensions of embedding models that support Matryoshka Embeddings by using the dimensions parameter.
|
|
|
|
```text
|
|
curl http://127.0.0.1:8000/v1/embeddings \
|
|
-H 'accept: application/json' \
|
|
-H 'Content-Type: application/json' \
|
|
-d '{
|
|
"input": "Follow the white rabbit.",
|
|
"model": "jinaai/jina-embeddings-v3",
|
|
"encoding_format": "float",
|
|
"dimensions": 32
|
|
}'
|
|
```
|
|
|
|
Expected output:
|
|
|
|
```json
|
|
{"id":"embd-5c21fc9a5c9d4384a1b021daccaf9f64","object":"list","created":1745476417,"model":"jinaai/jina-embeddings-v3","data":[{"index":0,"object":"embedding","embedding":[-0.3828125,-0.1357421875,0.03759765625,0.125,0.21875,0.09521484375,-0.003662109375,0.1591796875,-0.130859375,-0.0869140625,-0.1982421875,0.1689453125,-0.220703125,0.1728515625,-0.2275390625,-0.0712890625,-0.162109375,-0.283203125,-0.055419921875,-0.0693359375,0.031982421875,-0.04052734375,-0.2734375,0.1826171875,-0.091796875,0.220703125,0.37890625,-0.0888671875,-0.12890625,-0.021484375,-0.0091552734375,0.23046875]}],"usage":{"prompt_tokens":8,"total_tokens":8,"completion_tokens":0,"prompt_tokens_details":null}}
|
|
```
|
|
|
|
An OpenAI client example can be found here: <gh-file:examples/online_serving/pooling/openai_embedding_matryoshka_fy.py>
|