mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-25 20:04:26 +08:00
[Model] Pooling models default to using chunked prefill & prefix caching if supported. (#20930)
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
parent
16fb668b61
commit
84cf78acee
@ -65,3 +65,9 @@ def test_pooling_params(llm: LLM):
|
|||||||
assert torch.allclose(
|
assert torch.allclose(
|
||||||
softmax(wo_activation), w_activation, atol=1e-2
|
softmax(wo_activation), w_activation, atol=1e-2
|
||||||
), "w_activation should be close to activation(wo_activation)."
|
), "w_activation should be close to activation(wo_activation)."
|
||||||
|
|
||||||
|
|
||||||
|
def test_encode_api(llm: LLM):
|
||||||
|
err_msg = "pooling_task must be one of.+"
|
||||||
|
with pytest.raises(ValueError, match=err_msg):
|
||||||
|
llm.encode(prompts, use_tqdm=False)
|
||||||
|
|||||||
@ -211,3 +211,18 @@ async def test_activation(server: RemoteOpenAIServer, model_name: str):
|
|||||||
assert torch.allclose(
|
assert torch.allclose(
|
||||||
F.softmax(wo_activation, dim=-1), w_activation, atol=1e-2
|
F.softmax(wo_activation, dim=-1), w_activation, atol=1e-2
|
||||||
), "w_activation should be close to activation(wo_activation)."
|
), "w_activation should be close to activation(wo_activation)."
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
|
def test_pooling(server: RemoteOpenAIServer, model_name: str):
|
||||||
|
# pooling api uses ALL pooling, which does not support chunked prefill.
|
||||||
|
response = requests.post(
|
||||||
|
server.url_for("pooling"),
|
||||||
|
json={
|
||||||
|
"model": model_name,
|
||||||
|
"input": "test",
|
||||||
|
"encoding_format": "float"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.json()["error"]["type"] == "BadRequestError"
|
||||||
|
|||||||
@ -177,9 +177,12 @@ def mteb_test_embed_models(hf_runner,
|
|||||||
max_model_len=None,
|
max_model_len=None,
|
||||||
**vllm_extra_kwargs) as vllm_model:
|
**vllm_extra_kwargs) as vllm_model:
|
||||||
|
|
||||||
|
model_config = vllm_model.llm.llm_engine.model_config
|
||||||
|
|
||||||
if model_info.architecture:
|
if model_info.architecture:
|
||||||
assert (model_info.architecture
|
assert model_info.architecture in model_config.architectures
|
||||||
in vllm_model.llm.llm_engine.model_config.architectures)
|
assert (model_config._model_info.default_pooling_type ==
|
||||||
|
model_info.default_pooling_type)
|
||||||
|
|
||||||
vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model),
|
vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model),
|
||||||
MTEB_EMBED_TASKS)
|
MTEB_EMBED_TASKS)
|
||||||
@ -286,7 +289,12 @@ def mteb_test_rerank_models(hf_runner,
|
|||||||
**vllm_extra_kwargs) as vllm_model:
|
**vllm_extra_kwargs) as vllm_model:
|
||||||
|
|
||||||
model_config = vllm_model.llm.llm_engine.model_config
|
model_config = vllm_model.llm.llm_engine.model_config
|
||||||
|
|
||||||
|
if model_info.architecture:
|
||||||
|
assert (model_info.architecture in model_config.architectures)
|
||||||
assert model_config.hf_config.num_labels == 1
|
assert model_config.hf_config.num_labels == 1
|
||||||
|
assert (model_config._model_info.default_pooling_type ==
|
||||||
|
model_info.default_pooling_type)
|
||||||
|
|
||||||
vllm_main_score = run_mteb_rerank(vllm_mteb_encoder(vllm_model),
|
vllm_main_score = run_mteb_rerank(vllm_mteb_encoder(vllm_model),
|
||||||
tasks=MTEB_RERANK_TASKS,
|
tasks=MTEB_RERANK_TASKS,
|
||||||
|
|||||||
@ -0,0 +1,93 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForSequenceClassification
|
||||||
|
|
||||||
|
from tests.models.language.pooling.embed_utils import (
|
||||||
|
run_embedding_correctness_test)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model",
|
||||||
|
["jason9693/Qwen2.5-1.5B-apeach"],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
|
def test_classify_models(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
example_prompts,
|
||||||
|
model: str,
|
||||||
|
dtype: str,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
example_prompts = example_prompts * 2
|
||||||
|
|
||||||
|
with vllm_runner(model,
|
||||||
|
max_model_len=512,
|
||||||
|
dtype=dtype,
|
||||||
|
enable_prefix_caching=True) as vllm_model:
|
||||||
|
cache_config = vllm_model.llm.llm_engine.cache_config
|
||||||
|
assert cache_config.enable_prefix_caching
|
||||||
|
vllm_outputs = vllm_model.classify(example_prompts)
|
||||||
|
|
||||||
|
with hf_runner(model,
|
||||||
|
dtype=dtype,
|
||||||
|
auto_cls=AutoModelForSequenceClassification) as hf_model:
|
||||||
|
hf_outputs = hf_model.classify(example_prompts)
|
||||||
|
|
||||||
|
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
|
||||||
|
hf_output = torch.tensor(hf_output)
|
||||||
|
vllm_output = torch.tensor(vllm_output)
|
||||||
|
|
||||||
|
assert torch.allclose(hf_output, vllm_output,
|
||||||
|
1e-3 if dtype == "float" else 1e-2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model",
|
||||||
|
["Qwen/Qwen3-Embedding-0.6B"],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
|
def test_embed_models(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
example_prompts,
|
||||||
|
model: str,
|
||||||
|
dtype: str,
|
||||||
|
):
|
||||||
|
example_prompts = [str(s).strip() for s in example_prompts] * 2
|
||||||
|
|
||||||
|
with vllm_runner(
|
||||||
|
model,
|
||||||
|
runner="pooling",
|
||||||
|
max_model_len=None,
|
||||||
|
enable_prefix_caching=True,
|
||||||
|
) as vllm_model:
|
||||||
|
cache_config = vllm_model.llm.llm_engine.cache_config
|
||||||
|
assert cache_config.enable_prefix_caching
|
||||||
|
vllm_outputs = vllm_model.embed(example_prompts)
|
||||||
|
|
||||||
|
with hf_runner(
|
||||||
|
model,
|
||||||
|
is_sentence_transformer=True,
|
||||||
|
) as hf_model:
|
||||||
|
run_embedding_correctness_test(hf_model, example_prompts, vllm_outputs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model",
|
||||||
|
[
|
||||||
|
"intfloat/e5-small",
|
||||||
|
"Alibaba-NLP/gte-Qwen2-1.5B-instruct", # is_causal == False
|
||||||
|
"papluca/xlm-roberta-base-language-detection",
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
|
def test_non_causal_models(hf_runner, vllm_runner, example_prompts, model: str,
|
||||||
|
dtype: str) -> None:
|
||||||
|
with vllm_runner(model,
|
||||||
|
max_model_len=512,
|
||||||
|
dtype=dtype,
|
||||||
|
enable_prefix_caching=True) as vllm_model:
|
||||||
|
cache_config = vllm_model.llm.llm_engine.cache_config
|
||||||
|
assert not cache_config.enable_prefix_caching
|
||||||
@ -2,73 +2,78 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from ...utils import EmbedModelInfo, RerankModelInfo
|
from ...utils import (CLSPoolingEmbedModelInfo, CLSPoolingRerankModelInfo,
|
||||||
|
EmbedModelInfo, LASTPoolingEmbedModelInfo,
|
||||||
|
RerankModelInfo)
|
||||||
from .embed_utils import correctness_test_embed_models
|
from .embed_utils import correctness_test_embed_models
|
||||||
from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models
|
from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models
|
||||||
|
|
||||||
MODELS = [
|
MODELS = [
|
||||||
########## BertModel
|
########## BertModel
|
||||||
EmbedModelInfo("BAAI/bge-base-en",
|
CLSPoolingEmbedModelInfo("BAAI/bge-base-en",
|
||||||
architecture="BertModel",
|
architecture="BertModel",
|
||||||
enable_test=True),
|
enable_test=True),
|
||||||
EmbedModelInfo("BAAI/bge-base-zh",
|
CLSPoolingEmbedModelInfo("BAAI/bge-base-zh",
|
||||||
architecture="BertModel",
|
architecture="BertModel",
|
||||||
enable_test=False),
|
enable_test=False),
|
||||||
EmbedModelInfo("BAAI/bge-small-en",
|
CLSPoolingEmbedModelInfo("BAAI/bge-small-en",
|
||||||
architecture="BertModel",
|
architecture="BertModel",
|
||||||
enable_test=False),
|
enable_test=False),
|
||||||
EmbedModelInfo("BAAI/bge-small-zh",
|
CLSPoolingEmbedModelInfo("BAAI/bge-small-zh",
|
||||||
architecture="BertModel",
|
architecture="BertModel",
|
||||||
enable_test=False),
|
enable_test=False),
|
||||||
EmbedModelInfo("BAAI/bge-large-en",
|
CLSPoolingEmbedModelInfo("BAAI/bge-large-en",
|
||||||
architecture="BertModel",
|
architecture="BertModel",
|
||||||
enable_test=False),
|
enable_test=False),
|
||||||
EmbedModelInfo("BAAI/bge-large-zh",
|
CLSPoolingEmbedModelInfo("BAAI/bge-large-zh",
|
||||||
architecture="BertModel",
|
architecture="BertModel",
|
||||||
enable_test=False),
|
enable_test=False),
|
||||||
EmbedModelInfo("BAAI/bge-large-zh-noinstruct",
|
CLSPoolingEmbedModelInfo("BAAI/bge-large-zh-noinstruct",
|
||||||
architecture="BertModel",
|
architecture="BertModel",
|
||||||
enable_test=False),
|
enable_test=False),
|
||||||
EmbedModelInfo("BAAI/bge-base-en-v1.5",
|
CLSPoolingEmbedModelInfo("BAAI/bge-base-en-v1.5",
|
||||||
architecture="BertModel",
|
architecture="BertModel",
|
||||||
enable_test=False),
|
enable_test=False),
|
||||||
EmbedModelInfo("BAAI/bge-base-zh-v1.5",
|
CLSPoolingEmbedModelInfo("BAAI/bge-base-zh-v1.5",
|
||||||
architecture="BertModel",
|
architecture="BertModel",
|
||||||
enable_test=False),
|
enable_test=False),
|
||||||
EmbedModelInfo("BAAI/bge-small-en-v1.5",
|
CLSPoolingEmbedModelInfo("BAAI/bge-small-en-v1.5",
|
||||||
architecture="BertModel",
|
architecture="BertModel",
|
||||||
enable_test=False),
|
enable_test=False),
|
||||||
EmbedModelInfo("BAAI/bge-small-zh-v1.5",
|
CLSPoolingEmbedModelInfo("BAAI/bge-small-zh-v1.5",
|
||||||
architecture="BertModel",
|
architecture="BertModel",
|
||||||
enable_test=False),
|
enable_test=False),
|
||||||
EmbedModelInfo("BAAI/bge-large-en-v1.5",
|
CLSPoolingEmbedModelInfo("BAAI/bge-large-en-v1.5",
|
||||||
architecture="BertModel",
|
architecture="BertModel",
|
||||||
enable_test=False),
|
enable_test=False),
|
||||||
EmbedModelInfo("BAAI/bge-large-zh-v1.5",
|
CLSPoolingEmbedModelInfo("BAAI/bge-large-zh-v1.5",
|
||||||
architecture="BertModel",
|
architecture="BertModel",
|
||||||
enable_test=False),
|
enable_test=False),
|
||||||
########## XLMRobertaModel
|
########## XLMRobertaModel
|
||||||
EmbedModelInfo("BAAI/bge-m3",
|
CLSPoolingEmbedModelInfo("BAAI/bge-m3",
|
||||||
architecture="XLMRobertaModel",
|
architecture="XLMRobertaModel",
|
||||||
enable_test=True),
|
enable_test=True),
|
||||||
########## Qwen2Model
|
########## Qwen2Model
|
||||||
EmbedModelInfo("BAAI/bge-code-v1",
|
LASTPoolingEmbedModelInfo("BAAI/bge-code-v1",
|
||||||
architecture="Qwen2Model",
|
architecture="Qwen2Model",
|
||||||
dtype="float32",
|
dtype="float32",
|
||||||
enable_test=True),
|
enable_test=True),
|
||||||
]
|
]
|
||||||
|
|
||||||
RERANK_MODELS = [
|
RERANK_MODELS = [
|
||||||
########## XLMRobertaForSequenceClassification
|
########## XLMRobertaForSequenceClassification
|
||||||
RerankModelInfo("BAAI/bge-reranker-base",
|
CLSPoolingRerankModelInfo(
|
||||||
architecture="XLMRobertaForSequenceClassification",
|
"BAAI/bge-reranker-base",
|
||||||
enable_test=True),
|
architecture="XLMRobertaForSequenceClassification",
|
||||||
RerankModelInfo("BAAI/bge-reranker-large",
|
enable_test=True),
|
||||||
architecture="XLMRobertaForSequenceClassification",
|
CLSPoolingRerankModelInfo(
|
||||||
enable_test=False),
|
"BAAI/bge-reranker-large",
|
||||||
RerankModelInfo("BAAI/bge-reranker-v2-m3",
|
architecture="XLMRobertaForSequenceClassification",
|
||||||
architecture="XLMRobertaForSequenceClassification",
|
enable_test=False),
|
||||||
enable_test=False)
|
CLSPoolingRerankModelInfo(
|
||||||
|
"BAAI/bge-reranker-v2-m3",
|
||||||
|
architecture="XLMRobertaForSequenceClassification",
|
||||||
|
enable_test=False)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -8,12 +8,12 @@ import torch
|
|||||||
|
|
||||||
from tests.conftest import HfRunner
|
from tests.conftest import HfRunner
|
||||||
|
|
||||||
from .mteb_utils import (RerankModelInfo, VllmMtebEncoder,
|
from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo
|
||||||
mteb_test_rerank_models)
|
from .mteb_utils import VllmMtebEncoder, mteb_test_rerank_models
|
||||||
|
|
||||||
RERANK_MODELS = [
|
RERANK_MODELS = [
|
||||||
RerankModelInfo("BAAI/bge-reranker-v2-gemma",
|
LASTPoolingRerankModelInfo("BAAI/bge-reranker-v2-gemma",
|
||||||
architecture="GemmaForSequenceClassification"),
|
architecture="GemmaForSequenceClassification"),
|
||||||
]
|
]
|
||||||
|
|
||||||
PROMPT = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'." # noqa: E501
|
PROMPT = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'." # noqa: E501
|
||||||
|
|||||||
@ -2,13 +2,15 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from .mteb_utils import RerankModelInfo, mteb_test_rerank_models
|
from ...utils import (CLSPoolingRerankModelInfo, LASTPoolingRerankModelInfo,
|
||||||
|
RerankModelInfo)
|
||||||
|
from .mteb_utils import mteb_test_rerank_models
|
||||||
|
|
||||||
RERANK_MODELS = [
|
RERANK_MODELS = [
|
||||||
RerankModelInfo("cross-encoder/ms-marco-TinyBERT-L-2-v2",
|
CLSPoolingRerankModelInfo("cross-encoder/ms-marco-TinyBERT-L-2-v2",
|
||||||
architecture="BertForSequenceClassification"),
|
architecture="BertForSequenceClassification"),
|
||||||
RerankModelInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls",
|
LASTPoolingRerankModelInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls",
|
||||||
architecture="Qwen3ForSequenceClassification")
|
architecture="Qwen3ForSequenceClassification")
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -4,57 +4,58 @@ from typing import Any
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from ...utils import check_transformers_version
|
from ...utils import (CLSPoolingEmbedModelInfo, EmbedModelInfo,
|
||||||
from .embed_utils import EmbedModelInfo, correctness_test_embed_models
|
LASTPoolingEmbedModelInfo, check_transformers_version)
|
||||||
|
from .embed_utils import correctness_test_embed_models
|
||||||
from .mteb_utils import mteb_test_embed_models
|
from .mteb_utils import mteb_test_embed_models
|
||||||
|
|
||||||
MODELS = [
|
MODELS = [
|
||||||
########## BertModel
|
########## BertModel
|
||||||
EmbedModelInfo("thenlper/gte-large",
|
CLSPoolingEmbedModelInfo("thenlper/gte-large",
|
||||||
architecture="BertModel",
|
architecture="BertModel",
|
||||||
enable_test=True),
|
enable_test=True),
|
||||||
EmbedModelInfo("thenlper/gte-base",
|
CLSPoolingEmbedModelInfo("thenlper/gte-base",
|
||||||
architecture="BertModel",
|
architecture="BertModel",
|
||||||
enable_test=False),
|
enable_test=False),
|
||||||
EmbedModelInfo("thenlper/gte-small",
|
CLSPoolingEmbedModelInfo("thenlper/gte-small",
|
||||||
architecture="BertModel",
|
architecture="BertModel",
|
||||||
enable_test=False),
|
enable_test=False),
|
||||||
EmbedModelInfo("thenlper/gte-large-zh",
|
CLSPoolingEmbedModelInfo("thenlper/gte-large-zh",
|
||||||
architecture="BertModel",
|
architecture="BertModel",
|
||||||
enable_test=False),
|
enable_test=False),
|
||||||
EmbedModelInfo("thenlper/gte-base-zh",
|
CLSPoolingEmbedModelInfo("thenlper/gte-base-zh",
|
||||||
architecture="BertModel",
|
architecture="BertModel",
|
||||||
enable_test=False),
|
enable_test=False),
|
||||||
EmbedModelInfo("thenlper/gte-small-zh",
|
CLSPoolingEmbedModelInfo("thenlper/gte-small-zh",
|
||||||
architecture="BertModel",
|
architecture="BertModel",
|
||||||
enable_test=False),
|
enable_test=False),
|
||||||
########### NewModel
|
########### NewModel
|
||||||
EmbedModelInfo("Alibaba-NLP/gte-multilingual-base",
|
CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-multilingual-base",
|
||||||
architecture="GteNewModel",
|
architecture="GteNewModel",
|
||||||
enable_test=True),
|
enable_test=True),
|
||||||
EmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5",
|
CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5",
|
||||||
architecture="GteNewModel",
|
architecture="GteNewModel",
|
||||||
enable_test=True),
|
enable_test=True),
|
||||||
EmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5",
|
CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5",
|
||||||
architecture="GteNewModel",
|
architecture="GteNewModel",
|
||||||
enable_test=True),
|
enable_test=True),
|
||||||
########### Qwen2ForCausalLM
|
########### Qwen2ForCausalLM
|
||||||
EmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct",
|
LASTPoolingEmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct",
|
||||||
architecture="Qwen2ForCausalLM",
|
architecture="Qwen2ForCausalLM",
|
||||||
enable_test=True),
|
enable_test=True),
|
||||||
########## ModernBertModel
|
########## ModernBertModel
|
||||||
EmbedModelInfo("Alibaba-NLP/gte-modernbert-base",
|
CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-modernbert-base",
|
||||||
architecture="ModernBertModel",
|
architecture="ModernBertModel",
|
||||||
enable_test=True),
|
enable_test=True),
|
||||||
########## Qwen3ForCausalLM
|
########## Qwen3ForCausalLM
|
||||||
EmbedModelInfo("Qwen/Qwen3-Embedding-0.6B",
|
LASTPoolingEmbedModelInfo("Qwen/Qwen3-Embedding-0.6B",
|
||||||
architecture="Qwen3ForCausalLM",
|
architecture="Qwen3ForCausalLM",
|
||||||
dtype="float32",
|
dtype="float32",
|
||||||
enable_test=True),
|
enable_test=True),
|
||||||
EmbedModelInfo("Qwen/Qwen3-Embedding-4B",
|
LASTPoolingEmbedModelInfo("Qwen/Qwen3-Embedding-4B",
|
||||||
architecture="Qwen3ForCausalLM",
|
architecture="Qwen3ForCausalLM",
|
||||||
dtype="float32",
|
dtype="float32",
|
||||||
enable_test=False),
|
enable_test=False),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -2,34 +2,34 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from ...utils import EmbedModelInfo
|
from ...utils import CLSPoolingEmbedModelInfo, EmbedModelInfo
|
||||||
from .embed_utils import correctness_test_embed_models
|
from .embed_utils import correctness_test_embed_models
|
||||||
from .mteb_utils import mteb_test_embed_models
|
from .mteb_utils import mteb_test_embed_models
|
||||||
|
|
||||||
MODELS = [
|
MODELS = [
|
||||||
########## BertModel
|
########## BertModel
|
||||||
EmbedModelInfo("intfloat/e5-small",
|
CLSPoolingEmbedModelInfo("intfloat/e5-small",
|
||||||
architecture="BertModel",
|
architecture="BertModel",
|
||||||
enable_test=True),
|
enable_test=True),
|
||||||
EmbedModelInfo("intfloat/e5-base",
|
CLSPoolingEmbedModelInfo("intfloat/e5-base",
|
||||||
architecture="BertModel",
|
architecture="BertModel",
|
||||||
enable_test=False),
|
enable_test=False),
|
||||||
EmbedModelInfo("intfloat/e5-large",
|
CLSPoolingEmbedModelInfo("intfloat/e5-large",
|
||||||
architecture="BertModel",
|
architecture="BertModel",
|
||||||
enable_test=False),
|
enable_test=False),
|
||||||
EmbedModelInfo("intfloat/multilingual-e5-small",
|
CLSPoolingEmbedModelInfo("intfloat/multilingual-e5-small",
|
||||||
architecture="BertModel",
|
architecture="BertModel",
|
||||||
enable_test=False),
|
enable_test=False),
|
||||||
########## XLMRobertaModel
|
########## XLMRobertaModel
|
||||||
EmbedModelInfo("intfloat/multilingual-e5-base",
|
CLSPoolingEmbedModelInfo("intfloat/multilingual-e5-base",
|
||||||
architecture="XLMRobertaModel",
|
architecture="XLMRobertaModel",
|
||||||
enable_test=True),
|
enable_test=True),
|
||||||
EmbedModelInfo("intfloat/multilingual-e5-large",
|
CLSPoolingEmbedModelInfo("intfloat/multilingual-e5-large",
|
||||||
architecture="XLMRobertaModel",
|
architecture="XLMRobertaModel",
|
||||||
enable_test=False),
|
enable_test=False),
|
||||||
EmbedModelInfo("intfloat/multilingual-e5-large-instruct",
|
CLSPoolingEmbedModelInfo("intfloat/multilingual-e5-large-instruct",
|
||||||
architecture="XLMRobertaModel",
|
architecture="XLMRobertaModel",
|
||||||
enable_test=False),
|
enable_test=False),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -6,20 +6,22 @@ import pytest
|
|||||||
|
|
||||||
from vllm import PoolingParams
|
from vllm import PoolingParams
|
||||||
|
|
||||||
from ...utils import EmbedModelInfo, RerankModelInfo
|
from ...utils import (CLSPoolingEmbedModelInfo, CLSPoolingRerankModelInfo,
|
||||||
|
EmbedModelInfo, RerankModelInfo)
|
||||||
from .embed_utils import (check_embeddings_close,
|
from .embed_utils import (check_embeddings_close,
|
||||||
correctness_test_embed_models, matryoshka_fy)
|
correctness_test_embed_models, matryoshka_fy)
|
||||||
from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models
|
from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models
|
||||||
|
|
||||||
EMBEDDING_MODELS = [
|
EMBEDDING_MODELS = [
|
||||||
EmbedModelInfo("jinaai/jina-embeddings-v3",
|
CLSPoolingEmbedModelInfo("jinaai/jina-embeddings-v3",
|
||||||
architecture="XLMRobertaModel",
|
architecture="XLMRobertaModel",
|
||||||
is_matryoshka=True)
|
is_matryoshka=True)
|
||||||
]
|
]
|
||||||
|
|
||||||
RERANK_MODELS = [
|
RERANK_MODELS = [
|
||||||
RerankModelInfo("jinaai/jina-reranker-v2-base-multilingual",
|
CLSPoolingRerankModelInfo(
|
||||||
architecture="XLMRobertaForSequenceClassification")
|
"jinaai/jina-reranker-v2-base-multilingual",
|
||||||
|
architecture="XLMRobertaForSequenceClassification")
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -7,15 +7,16 @@ import torch
|
|||||||
|
|
||||||
from tests.conftest import HfRunner
|
from tests.conftest import HfRunner
|
||||||
|
|
||||||
from .mteb_utils import RerankModelInfo, mteb_test_rerank_models
|
from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo
|
||||||
|
from .mteb_utils import mteb_test_rerank_models
|
||||||
|
|
||||||
RERANK_MODELS = [
|
RERANK_MODELS = [
|
||||||
RerankModelInfo("mixedbread-ai/mxbai-rerank-base-v2",
|
LASTPoolingRerankModelInfo("mixedbread-ai/mxbai-rerank-base-v2",
|
||||||
architecture="Qwen2ForSequenceClassification",
|
architecture="Qwen2ForSequenceClassification",
|
||||||
enable_test=True),
|
enable_test=True),
|
||||||
RerankModelInfo("mixedbread-ai/mxbai-rerank-large-v2",
|
LASTPoolingRerankModelInfo("mixedbread-ai/mxbai-rerank-large-v2",
|
||||||
architecture="Qwen2ForSequenceClassification",
|
architecture="Qwen2ForSequenceClassification",
|
||||||
enable_test=False)
|
enable_test=False)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -3,22 +3,23 @@
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from .embed_utils import EmbedModelInfo, correctness_test_embed_models
|
from ...utils import CLSPoolingEmbedModelInfo, EmbedModelInfo
|
||||||
|
from .embed_utils import correctness_test_embed_models
|
||||||
from .mteb_utils import mteb_test_embed_models
|
from .mteb_utils import mteb_test_embed_models
|
||||||
|
|
||||||
MODELS = [
|
MODELS = [
|
||||||
EmbedModelInfo("nomic-ai/nomic-embed-text-v1",
|
CLSPoolingEmbedModelInfo("nomic-ai/nomic-embed-text-v1",
|
||||||
architecture="NomicBertModel",
|
architecture="NomicBertModel",
|
||||||
enable_test=True),
|
enable_test=True),
|
||||||
EmbedModelInfo("nomic-ai/nomic-embed-text-v1.5",
|
CLSPoolingEmbedModelInfo("nomic-ai/nomic-embed-text-v1.5",
|
||||||
architecture="NomicBertModel",
|
architecture="NomicBertModel",
|
||||||
enable_test=False),
|
enable_test=False),
|
||||||
EmbedModelInfo("nomic-ai/CodeRankEmbed",
|
CLSPoolingEmbedModelInfo("nomic-ai/CodeRankEmbed",
|
||||||
architecture="NomicBertModel",
|
architecture="NomicBertModel",
|
||||||
enable_test=False),
|
enable_test=False),
|
||||||
EmbedModelInfo("nomic-ai/nomic-embed-text-v2-moe",
|
CLSPoolingEmbedModelInfo("nomic-ai/nomic-embed-text-v2-moe",
|
||||||
architecture="NomicBertModel",
|
architecture="NomicBertModel",
|
||||||
enable_test=True)
|
enable_test=True)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -8,15 +8,16 @@ import torch
|
|||||||
from tests.conftest import HfRunner
|
from tests.conftest import HfRunner
|
||||||
from tests.utils import multi_gpu_test
|
from tests.utils import multi_gpu_test
|
||||||
|
|
||||||
from .mteb_utils import RerankModelInfo, mteb_test_rerank_models
|
from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo
|
||||||
|
from .mteb_utils import mteb_test_rerank_models
|
||||||
|
|
||||||
RERANK_MODELS = [
|
RERANK_MODELS = [
|
||||||
RerankModelInfo("Qwen/Qwen3-Reranker-0.6B",
|
LASTPoolingRerankModelInfo("Qwen/Qwen3-Reranker-0.6B",
|
||||||
architecture="Qwen3ForSequenceClassification",
|
architecture="Qwen3ForSequenceClassification",
|
||||||
enable_test=True),
|
enable_test=True),
|
||||||
RerankModelInfo("Qwen/Qwen3-Reranker-4B",
|
LASTPoolingRerankModelInfo("Qwen/Qwen3-Reranker-4B",
|
||||||
architecture="Qwen3ForSequenceClassification",
|
architecture="Qwen3ForSequenceClassification",
|
||||||
enable_test=False)
|
enable_test=False)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -3,42 +3,43 @@
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from .embed_utils import EmbedModelInfo, correctness_test_embed_models
|
from ...utils import CLSPoolingEmbedModelInfo, EmbedModelInfo
|
||||||
|
from .embed_utils import correctness_test_embed_models
|
||||||
from .mteb_utils import mteb_test_embed_models
|
from .mteb_utils import mteb_test_embed_models
|
||||||
|
|
||||||
MODELS = [
|
MODELS = [
|
||||||
EmbedModelInfo("Snowflake/snowflake-arctic-embed-xs",
|
CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-xs",
|
||||||
is_matryoshka=False,
|
is_matryoshka=False,
|
||||||
architecture="BertModel",
|
architecture="BertModel",
|
||||||
enable_test=True),
|
enable_test=True),
|
||||||
EmbedModelInfo("Snowflake/snowflake-arctic-embed-s",
|
CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-s",
|
||||||
is_matryoshka=False,
|
is_matryoshka=False,
|
||||||
architecture="BertModel",
|
architecture="BertModel",
|
||||||
enable_test=False),
|
enable_test=False),
|
||||||
EmbedModelInfo("Snowflake/snowflake-arctic-embed-m",
|
CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-m",
|
||||||
is_matryoshka=False,
|
is_matryoshka=False,
|
||||||
architecture="BertModel",
|
architecture="BertModel",
|
||||||
enable_test=False),
|
enable_test=False),
|
||||||
EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-long",
|
CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-m-long",
|
||||||
is_matryoshka=False,
|
is_matryoshka=False,
|
||||||
architecture="NomicBertModel",
|
architecture="NomicBertModel",
|
||||||
enable_test=True),
|
enable_test=True),
|
||||||
EmbedModelInfo("Snowflake/snowflake-arctic-embed-l",
|
CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-l",
|
||||||
is_matryoshka=False,
|
is_matryoshka=False,
|
||||||
architecture="BertModel",
|
architecture="BertModel",
|
||||||
enable_test=False),
|
enable_test=False),
|
||||||
EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v1.5",
|
CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v1.5",
|
||||||
is_matryoshka=True,
|
is_matryoshka=True,
|
||||||
architecture="BertModel",
|
architecture="BertModel",
|
||||||
enable_test=True),
|
enable_test=True),
|
||||||
EmbedModelInfo("Snowflake/snowflake-arctic-embed-l-v2.0",
|
CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-l-v2.0",
|
||||||
is_matryoshka=True,
|
is_matryoshka=True,
|
||||||
architecture="XLMRobertaModel",
|
architecture="XLMRobertaModel",
|
||||||
enable_test=True),
|
enable_test=True),
|
||||||
EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v2.0",
|
CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v2.0",
|
||||||
is_matryoshka=True,
|
is_matryoshka=True,
|
||||||
architecture="GteModel",
|
architecture="GteModel",
|
||||||
enable_test=True),
|
enable_test=True),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -345,16 +345,34 @@ class EmbedModelInfo(NamedTuple):
|
|||||||
matryoshka_dimensions: Optional[list[int]] = None
|
matryoshka_dimensions: Optional[list[int]] = None
|
||||||
architecture: str = ""
|
architecture: str = ""
|
||||||
dtype: str = "auto"
|
dtype: str = "auto"
|
||||||
|
default_pooling_type: str = ""
|
||||||
enable_test: bool = True
|
enable_test: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class CLSPoolingEmbedModelInfo(EmbedModelInfo):
|
||||||
|
default_pooling_type: str = "CLS"
|
||||||
|
|
||||||
|
|
||||||
|
class LASTPoolingEmbedModelInfo(EmbedModelInfo):
|
||||||
|
default_pooling_type: str = "LAST"
|
||||||
|
|
||||||
|
|
||||||
class RerankModelInfo(NamedTuple):
|
class RerankModelInfo(NamedTuple):
|
||||||
name: str
|
name: str
|
||||||
architecture: str = ""
|
architecture: str = ""
|
||||||
dtype: str = "auto"
|
dtype: str = "auto"
|
||||||
|
default_pooling_type: str = ""
|
||||||
enable_test: bool = True
|
enable_test: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class CLSPoolingRerankModelInfo(RerankModelInfo):
|
||||||
|
default_pooling_type: str = "CLS"
|
||||||
|
|
||||||
|
|
||||||
|
class LASTPoolingRerankModelInfo(RerankModelInfo):
|
||||||
|
default_pooling_type: str = "LAST"
|
||||||
|
|
||||||
|
|
||||||
def dummy_hf_overrides(
|
def dummy_hf_overrides(
|
||||||
hf_config: PretrainedConfig,
|
hf_config: PretrainedConfig,
|
||||||
*,
|
*,
|
||||||
|
|||||||
@ -227,6 +227,20 @@ def test_get_pooling_config_from_args():
|
|||||||
assert asdict(pooling_config) == asdict(override_pooler_config)
|
assert asdict(pooling_config) == asdict(override_pooler_config)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("model_id", "default_pooling_type", "pooling_type"),
|
||||||
|
[
|
||||||
|
("tomaarsen/Qwen3-Reranker-0.6B-seq-cls", "LAST", "LAST"), # LLM
|
||||||
|
("intfloat/e5-small", "CLS", "MEAN"), # BertModel
|
||||||
|
("Qwen/Qwen2.5-Math-RM-72B", "ALL", "ALL"), # reward
|
||||||
|
("Qwen/Qwen2.5-Math-PRM-7B", "STEP", "STEP") # step reward
|
||||||
|
])
|
||||||
|
def test_default_pooling_type(model_id, default_pooling_type, pooling_type):
|
||||||
|
model_config = ModelConfig(model_id)
|
||||||
|
assert model_config._model_info.default_pooling_type == default_pooling_type
|
||||||
|
assert model_config.pooler_config.pooling_type == pooling_type
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(current_platform.is_rocm(),
|
@pytest.mark.skipif(current_platform.is_rocm(),
|
||||||
reason="Xformers backend is not supported on ROCm.")
|
reason="Xformers backend is not supported on ROCm.")
|
||||||
def test_get_bert_tokenization_sentence_transformer_config():
|
def test_get_bert_tokenization_sentence_transformer_config():
|
||||||
|
|||||||
@ -871,6 +871,10 @@ class ModelConfig:
|
|||||||
if getattr(pooler_config, k) is None:
|
if getattr(pooler_config, k) is None:
|
||||||
setattr(pooler_config, k, v)
|
setattr(pooler_config, k, v)
|
||||||
|
|
||||||
|
default_pooling_type = self._model_info.default_pooling_type
|
||||||
|
if pooler_config.pooling_type is None:
|
||||||
|
pooler_config.pooling_type = default_pooling_type
|
||||||
|
|
||||||
return pooler_config
|
return pooler_config
|
||||||
|
|
||||||
return None
|
return None
|
||||||
@ -3844,6 +3848,10 @@ class VllmConfig:
|
|||||||
disable_chunked_prefill_reasons.append(
|
disable_chunked_prefill_reasons.append(
|
||||||
"Only \"last\" pooling supports chunked "
|
"Only \"last\" pooling supports chunked "
|
||||||
"prefill and prefix caching; disabling both.")
|
"prefill and prefix caching; disabling both.")
|
||||||
|
elif not getattr(self.model_config.hf_config, "is_causal", True):
|
||||||
|
disable_chunked_prefill_reasons.append(
|
||||||
|
"Only models using causal attention supports chunked "
|
||||||
|
"prefill and prefix caching; disabling both.")
|
||||||
|
|
||||||
if disable_chunked_prefill_reasons:
|
if disable_chunked_prefill_reasons:
|
||||||
for reason in disable_chunked_prefill_reasons:
|
for reason in disable_chunked_prefill_reasons:
|
||||||
|
|||||||
@ -1600,11 +1600,10 @@ class EngineArgs:
|
|||||||
else:
|
else:
|
||||||
|
|
||||||
pooling_type = model_config.pooler_config.pooling_type
|
pooling_type = model_config.pooler_config.pooling_type
|
||||||
|
is_causal = getattr(model_config.hf_config, "is_causal", True)
|
||||||
# TODO: when encoder models are supported we'll have to
|
incremental_prefill_supported = (pooling_type is not None
|
||||||
# check for causal attention here.
|
and pooling_type.lower() == "last"
|
||||||
incremental_prefill_supported = (pooling_type is not None and
|
and is_causal)
|
||||||
pooling_type.lower() == "last")
|
|
||||||
|
|
||||||
action = "Enabling" if \
|
action = "Enabling" if \
|
||||||
incremental_prefill_supported else "Disabling"
|
incremental_prefill_supported else "Disabling"
|
||||||
|
|||||||
@ -1100,6 +1100,10 @@ class LLM:
|
|||||||
"Try passing `--runner pooling` to use the model as a "
|
"Try passing `--runner pooling` to use the model as a "
|
||||||
"pooling model.")
|
"pooling model.")
|
||||||
|
|
||||||
|
if pooling_task not in self.supported_tasks:
|
||||||
|
raise ValueError(
|
||||||
|
f"pooling_task must be one of {self.supported_tasks}.")
|
||||||
|
|
||||||
if prompt_token_ids is not None:
|
if prompt_token_ids is not None:
|
||||||
parsed_prompts = self._convert_v1_inputs(
|
parsed_prompts = self._convert_v1_inputs(
|
||||||
prompts=cast(Optional[Union[str, list[str]]], prompts),
|
prompts=cast(Optional[Union[str, list[str]]], prompts),
|
||||||
|
|||||||
@ -44,15 +44,14 @@ class ResolvedPoolingConfig:
|
|||||||
task: PoolingTask
|
task: PoolingTask
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config_with_defaults(
|
def from_config(
|
||||||
cls,
|
cls,
|
||||||
task: PoolingTask,
|
task: PoolingTask,
|
||||||
pooler_config: PoolerConfig,
|
pooler_config: PoolerConfig,
|
||||||
pooling_type: PoolingType,
|
|
||||||
) -> "ResolvedPoolingConfig":
|
) -> "ResolvedPoolingConfig":
|
||||||
|
assert pooler_config.pooling_type is not None
|
||||||
return cls(task=task,
|
return cls(task=task,
|
||||||
pooling_type=PoolingType[pooler_config.pooling_type]
|
pooling_type=PoolingType[pooler_config.pooling_type])
|
||||||
if pooler_config.pooling_type is not None else pooling_type)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@ -68,32 +67,20 @@ class Pooler(nn.Module, ABC):
|
|||||||
"""The interface required for all poolers used in pooling models in vLLM."""
|
"""The interface required for all poolers used in pooling models in vLLM."""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def for_encode(
|
def for_encode(pooler_config: PoolerConfig):
|
||||||
pooler_config: PoolerConfig,
|
if pooler_config.pooling_type == "STEP":
|
||||||
*,
|
|
||||||
default_pooling_type: PoolingType = PoolingType.ALL,
|
|
||||||
):
|
|
||||||
resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
|
|
||||||
task="encode",
|
|
||||||
pooler_config=pooler_config,
|
|
||||||
pooling_type=default_pooling_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
if resolved_config.pooling_type == PoolingType.STEP:
|
|
||||||
return StepPooler()
|
return StepPooler()
|
||||||
|
|
||||||
|
resolved_config = ResolvedPoolingConfig(task="encode",
|
||||||
|
pooling_type=PoolingType.ALL)
|
||||||
|
|
||||||
return SimplePooler.from_config(resolved_config)
|
return SimplePooler.from_config(resolved_config)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def for_embed(
|
def for_embed(pooler_config: PoolerConfig):
|
||||||
pooler_config: PoolerConfig,
|
resolved_config = ResolvedPoolingConfig.from_config(
|
||||||
*,
|
|
||||||
default_pooling_type: PoolingType = PoolingType.LAST,
|
|
||||||
):
|
|
||||||
resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
|
|
||||||
task="embed",
|
task="embed",
|
||||||
pooler_config=pooler_config,
|
pooler_config=pooler_config,
|
||||||
pooling_type=default_pooling_type,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return SimplePooler.from_config(resolved_config)
|
return SimplePooler.from_config(resolved_config)
|
||||||
@ -102,13 +89,10 @@ class Pooler(nn.Module, ABC):
|
|||||||
def for_classify(
|
def for_classify(
|
||||||
pooler_config: PoolerConfig,
|
pooler_config: PoolerConfig,
|
||||||
classifier: Optional[ClassifierFn],
|
classifier: Optional[ClassifierFn],
|
||||||
*,
|
|
||||||
default_pooling_type: PoolingType = PoolingType.LAST,
|
|
||||||
):
|
):
|
||||||
resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
|
resolved_config = ResolvedPoolingConfig.from_config(
|
||||||
task="classify",
|
task="classify",
|
||||||
pooler_config=pooler_config,
|
pooler_config=pooler_config,
|
||||||
pooling_type=default_pooling_type,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)
|
pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)
|
||||||
|
|||||||
@ -182,8 +182,8 @@ def as_seq_cls_model(cls: _T) -> _T:
|
|||||||
assert pooler_config is not None
|
assert pooler_config is not None
|
||||||
|
|
||||||
pooling_type_str = pooler_config.pooling_type
|
pooling_type_str = pooler_config.pooling_type
|
||||||
pooling_type = (PoolingType.LAST if pooling_type_str is None else
|
assert pooling_type_str is not None
|
||||||
PoolingType[pooling_type_str])
|
pooling_type = PoolingType[pooling_type_str]
|
||||||
|
|
||||||
self.pooler = DispatchPooler({
|
self.pooler = DispatchPooler({
|
||||||
"encode":
|
"encode":
|
||||||
|
|||||||
@ -28,7 +28,8 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.tasks import PoolingTask
|
from vllm.tasks import PoolingTask
|
||||||
|
|
||||||
from .interfaces import SupportsCrossEncoding, SupportsQuant
|
from .interfaces import (SupportsCrossEncoding, SupportsQuant,
|
||||||
|
default_pooling_type)
|
||||||
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
|
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
|
||||||
|
|
||||||
|
|
||||||
@ -327,6 +328,7 @@ class BertOutput(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
@support_torch_compile
|
@support_torch_compile
|
||||||
|
@default_pooling_type("CLS")
|
||||||
class BertModel(nn.Module, SupportsQuant):
|
class BertModel(nn.Module, SupportsQuant):
|
||||||
|
|
||||||
is_pooling_model = True
|
is_pooling_model = True
|
||||||
@ -401,6 +403,7 @@ class BertModel(nn.Module, SupportsQuant):
|
|||||||
return loaded_params
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
|
@default_pooling_type("ALL")
|
||||||
class BertPoolingModel(BertModel):
|
class BertPoolingModel(BertModel):
|
||||||
|
|
||||||
is_pooling_model = True
|
is_pooling_model = True
|
||||||
@ -431,6 +434,7 @@ class BertPoolingModel(BertModel):
|
|||||||
return loaded_params
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
|
@default_pooling_type("CLS")
|
||||||
class BertEmbeddingModel(nn.Module, SupportsQuant):
|
class BertEmbeddingModel(nn.Module, SupportsQuant):
|
||||||
"""A model that uses Bert to provide embedding functionalities.
|
"""A model that uses Bert to provide embedding functionalities.
|
||||||
|
|
||||||
@ -486,13 +490,8 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
|
|||||||
|
|
||||||
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
|
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
|
||||||
return DispatchPooler({
|
return DispatchPooler({
|
||||||
"encode":
|
"encode": Pooler.for_encode(pooler_config),
|
||||||
Pooler.for_encode(pooler_config),
|
"embed": Pooler.for_embed(pooler_config),
|
||||||
"embed":
|
|
||||||
Pooler.for_embed(
|
|
||||||
pooler_config,
|
|
||||||
default_pooling_type=PoolingType.CLS,
|
|
||||||
),
|
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
@ -541,6 +540,7 @@ def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor:
|
|||||||
return token_type_ids
|
return token_type_ids
|
||||||
|
|
||||||
|
|
||||||
|
@default_pooling_type("CLS")
|
||||||
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
||||||
SupportsQuant):
|
SupportsQuant):
|
||||||
"""A model that uses Bert to provide embedding functionalities.
|
"""A model that uses Bert to provide embedding functionalities.
|
||||||
|
|||||||
@ -27,7 +27,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
|||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding)
|
VocabParallelEmbedding)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.interfaces import SupportsQuant
|
from vllm.model_executor.models.interfaces import (SupportsQuant,
|
||||||
|
default_pooling_type)
|
||||||
from vllm.model_executor.models.utils import WeightsMapper
|
from vllm.model_executor.models.utils import WeightsMapper
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@ -401,6 +402,7 @@ class BertWithRopeEncoder(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
@support_torch_compile
|
@support_torch_compile
|
||||||
|
@default_pooling_type("CLS")
|
||||||
class BertWithRope(nn.Module, SupportsQuant):
|
class BertWithRope(nn.Module, SupportsQuant):
|
||||||
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
|
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
|
||||||
|
|
||||||
|
|||||||
@ -641,6 +641,20 @@ def supports_cross_encoding(
|
|||||||
return is_pooling_model(model) and _supports_cross_encoding(model)
|
return is_pooling_model(model) and _supports_cross_encoding(model)
|
||||||
|
|
||||||
|
|
||||||
|
def default_pooling_type(pooling_type: str) -> object:
|
||||||
|
"""Set default_pooling_type decorator. """
|
||||||
|
|
||||||
|
def func(model: object):
|
||||||
|
model.default_pooling_type = pooling_type
|
||||||
|
return model
|
||||||
|
|
||||||
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_pooling_type(model: Union[type[object], object]) -> str:
|
||||||
|
return getattr(model, "default_pooling_type", "LAST")
|
||||||
|
|
||||||
|
|
||||||
class SupportsQuant:
|
class SupportsQuant:
|
||||||
"""The interface required for all models that support quantization."""
|
"""The interface required for all models that support quantization."""
|
||||||
|
|
||||||
|
|||||||
@ -31,7 +31,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import SupportsLoRA, SupportsPP
|
from .interfaces import SupportsLoRA, SupportsPP, default_pooling_type
|
||||||
from .utils import (is_pp_missing_parameter,
|
from .utils import (is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory, make_layers,
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
maybe_prefix)
|
maybe_prefix)
|
||||||
@ -401,6 +401,7 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
|
|||||||
return loaded_params
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
|
@default_pooling_type("ALL")
|
||||||
class InternLM2ForRewardModel(InternLM2ForCausalLM):
|
class InternLM2ForRewardModel(InternLM2ForCausalLM):
|
||||||
|
|
||||||
is_pooling_model = True
|
is_pooling_model = True
|
||||||
|
|||||||
@ -22,8 +22,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|||||||
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
|
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
|
||||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||||
MambaStateShapeCalculator)
|
MambaStateShapeCalculator)
|
||||||
from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler,
|
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
|
||||||
PoolingType)
|
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||||
@ -604,6 +603,5 @@ class JambaForSequenceClassification(JambaForCausalLM):
|
|||||||
Pooler.for_classify(
|
Pooler.for_classify(
|
||||||
pooler_config,
|
pooler_config,
|
||||||
classifier=self.score,
|
classifier=self.score,
|
||||||
default_pooling_type=PoolingType.LAST,
|
|
||||||
),
|
),
|
||||||
})
|
})
|
||||||
|
|||||||
@ -26,7 +26,8 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.tasks import PoolingTask
|
from vllm.tasks import PoolingTask
|
||||||
|
|
||||||
from .interfaces import SupportsCrossEncoding, SupportsV0Only
|
from .interfaces import (SupportsCrossEncoding, SupportsV0Only,
|
||||||
|
default_pooling_type)
|
||||||
from .utils import WeightsMapper, maybe_prefix
|
from .utils import WeightsMapper, maybe_prefix
|
||||||
|
|
||||||
|
|
||||||
@ -201,6 +202,7 @@ class ModernBertEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
@support_torch_compile
|
@support_torch_compile
|
||||||
|
@default_pooling_type("CLS")
|
||||||
class ModernBertModel(nn.Module):
|
class ModernBertModel(nn.Module):
|
||||||
hf_to_vllm_mapper = WeightsMapper(
|
hf_to_vllm_mapper = WeightsMapper(
|
||||||
orig_to_new_prefix={"layers.": "encoder_layer.layers."})
|
orig_to_new_prefix={"layers.": "encoder_layer.layers."})
|
||||||
@ -264,7 +266,6 @@ class ModernBertPooler(Pooler):
|
|||||||
self.pooling = PoolingMethod.from_pooling_type(pooling_type)
|
self.pooling = PoolingMethod.from_pooling_type(pooling_type)
|
||||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size,
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size,
|
||||||
config.classifier_bias)
|
config.classifier_bias)
|
||||||
self.pooling_type = config.classifier_pooling
|
|
||||||
self.act = nn.GELU()
|
self.act = nn.GELU()
|
||||||
self.norm = nn.LayerNorm(config.hidden_size,
|
self.norm = nn.LayerNorm(config.hidden_size,
|
||||||
eps=config.norm_eps,
|
eps=config.norm_eps,
|
||||||
@ -294,6 +295,7 @@ class ModernBertPooler(Pooler):
|
|||||||
return pooled_output
|
return pooled_output
|
||||||
|
|
||||||
|
|
||||||
|
@default_pooling_type("CLS")
|
||||||
class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
|
class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
|
||||||
SupportsCrossEncoding):
|
SupportsCrossEncoding):
|
||||||
|
|
||||||
|
|||||||
@ -15,11 +15,10 @@ from torch import nn
|
|||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler,
|
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
|
||||||
PoolingType)
|
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import SupportsLoRA, SupportsPP
|
from .interfaces import SupportsLoRA, SupportsPP, default_pooling_type
|
||||||
from .qwen2 import Qwen2Model
|
from .qwen2 import Qwen2Model
|
||||||
from .utils import AutoWeightsLoader, maybe_prefix
|
from .utils import AutoWeightsLoader, maybe_prefix
|
||||||
|
|
||||||
@ -90,6 +89,7 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
return loader.load_weights(weights)
|
return loader.load_weights(weights)
|
||||||
|
|
||||||
|
|
||||||
|
@default_pooling_type("ALL")
|
||||||
class Qwen2ForRewardModel(Qwen2RewardBaseModel):
|
class Qwen2ForRewardModel(Qwen2RewardBaseModel):
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
@ -103,6 +103,7 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel):
|
|||||||
{"encode": Pooler.for_encode(pooler_config)}, )
|
{"encode": Pooler.for_encode(pooler_config)}, )
|
||||||
|
|
||||||
|
|
||||||
|
@default_pooling_type("STEP")
|
||||||
class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
|
class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
@ -112,10 +113,5 @@ class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
|
|||||||
pooler_config = vllm_config.model_config.pooler_config
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
assert pooler_config is not None
|
assert pooler_config is not None
|
||||||
|
|
||||||
self.pooler = DispatchPooler({
|
self.pooler = DispatchPooler(
|
||||||
"encode":
|
{"encode": Pooler.for_encode(pooler_config)})
|
||||||
Pooler.for_encode(
|
|
||||||
pooler_config,
|
|
||||||
default_pooling_type=PoolingType.STEP,
|
|
||||||
)
|
|
||||||
})
|
|
||||||
|
|||||||
@ -25,8 +25,8 @@ from vllm.logger import init_logger
|
|||||||
from vllm.transformers_utils.dynamic_module import (
|
from vllm.transformers_utils.dynamic_module import (
|
||||||
try_get_class_from_dynamic_module)
|
try_get_class_from_dynamic_module)
|
||||||
|
|
||||||
from .interfaces import (has_inner_state, has_noops, is_attention_free,
|
from .interfaces import (get_default_pooling_type, has_inner_state, has_noops,
|
||||||
is_hybrid, supports_cross_encoding,
|
is_attention_free, is_hybrid, supports_cross_encoding,
|
||||||
supports_multimodal, supports_multimodal_raw_input,
|
supports_multimodal, supports_multimodal_raw_input,
|
||||||
supports_pp, supports_transcription, supports_v0_only)
|
supports_pp, supports_transcription, supports_v0_only)
|
||||||
from .interfaces_base import is_pooling_model, is_text_generation_model
|
from .interfaces_base import is_pooling_model, is_text_generation_model
|
||||||
@ -305,6 +305,7 @@ class _ModelInfo:
|
|||||||
architecture: str
|
architecture: str
|
||||||
is_text_generation_model: bool
|
is_text_generation_model: bool
|
||||||
is_pooling_model: bool
|
is_pooling_model: bool
|
||||||
|
default_pooling_type: str
|
||||||
supports_cross_encoding: bool
|
supports_cross_encoding: bool
|
||||||
supports_multimodal: bool
|
supports_multimodal: bool
|
||||||
supports_multimodal_raw_input: bool
|
supports_multimodal_raw_input: bool
|
||||||
@ -323,6 +324,7 @@ class _ModelInfo:
|
|||||||
architecture=model.__name__,
|
architecture=model.__name__,
|
||||||
is_text_generation_model=is_text_generation_model(model),
|
is_text_generation_model=is_text_generation_model(model),
|
||||||
is_pooling_model=is_pooling_model(model),
|
is_pooling_model=is_pooling_model(model),
|
||||||
|
default_pooling_type=get_default_pooling_type(model),
|
||||||
supports_cross_encoding=supports_cross_encoding(model),
|
supports_cross_encoding=supports_cross_encoding(model),
|
||||||
supports_multimodal=supports_multimodal(model),
|
supports_multimodal=supports_multimodal(model),
|
||||||
supports_multimodal_raw_input=supports_multimodal_raw_input(model),
|
supports_multimodal_raw_input=supports_multimodal_raw_input(model),
|
||||||
|
|||||||
@ -23,7 +23,7 @@ from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .bert_with_rope import BertWithRope, JinaRobertaModel
|
from .bert_with_rope import BertWithRope, JinaRobertaModel
|
||||||
from .interfaces import SupportsCrossEncoding
|
from .interfaces import SupportsCrossEncoding, default_pooling_type
|
||||||
|
|
||||||
|
|
||||||
class RobertaEmbedding(nn.Module):
|
class RobertaEmbedding(nn.Module):
|
||||||
@ -86,6 +86,7 @@ class RobertaClassificationHead(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@default_pooling_type("CLS")
|
||||||
class RobertaEmbeddingModel(BertEmbeddingModel):
|
class RobertaEmbeddingModel(BertEmbeddingModel):
|
||||||
"""A model that uses Roberta to provide embedding functionalities.
|
"""A model that uses Roberta to provide embedding functionalities.
|
||||||
|
|
||||||
@ -149,6 +150,7 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
|
|||||||
return loader.load_weights(weights_list, mapper=mapper)
|
return loader.load_weights(weights_list, mapper=mapper)
|
||||||
|
|
||||||
|
|
||||||
|
@default_pooling_type("CLS")
|
||||||
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||||
"""A model that uses Roberta to provide embedding functionalities.
|
"""A model that uses Roberta to provide embedding functionalities.
|
||||||
|
|
||||||
|
|||||||
@ -1272,7 +1272,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
if not is_pooling_model(model):
|
if not is_pooling_model(model):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
return list(model.pooler.get_supported_tasks())
|
supported_tasks = list(model.pooler.get_supported_tasks())
|
||||||
|
|
||||||
|
if (self.scheduler_config.chunked_prefill_enabled
|
||||||
|
and "encode" in supported_tasks):
|
||||||
|
supported_tasks.remove("encode")
|
||||||
|
|
||||||
|
logger.info_once("Chunked prefill is not supported with "
|
||||||
|
"encode task which using ALL pooling. "
|
||||||
|
"Please turn off chunked prefill by "
|
||||||
|
"`--no-enable-chunked-prefill` before using it.")
|
||||||
|
|
||||||
|
return supported_tasks
|
||||||
|
|
||||||
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||||
tasks = list[SupportedTask]()
|
tasks = list[SupportedTask]()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user