[Model][0/N] Improve all pooling task | clean up (#25817)

Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
wang.yuqi 2025-10-13 16:44:50 +08:00 committed by GitHub
parent 4f207c7174
commit 767c3ab869
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 198 additions and 189 deletions

View File

@ -581,7 +581,7 @@ These models primarily support the [`LLM.encode`](./pooling_models.md#llmencode)
| `ModernBertForTokenClassification` | ModernBERT-based | `disham993/electrical-ner-ModernBERT-base` | | | ✅︎ | | `ModernBertForTokenClassification` | ModernBERT-based | `disham993/electrical-ner-ModernBERT-base` | | | ✅︎ |
!!! note !!! note
Named Entity Recognition (NER) usage, please refer to <gh-file:examples/offline_inference/pooling/ner.py>, <gh-file:examples/online_serving/pooling/ner.py>. Named Entity Recognition (NER) usage, please refer to <gh-file:examples/offline_inference/pooling/ner.py>, <gh-file:examples/online_serving/pooling/ner_client.py>.
[](){ #supported-mm-models } [](){ #supported-mm-models }

View File

@ -15,7 +15,7 @@ python examples/online_serving/pooling/jinaai_rerank_client.py
## Named Entity Recognition (NER) usage ## Named Entity Recognition (NER) usage
```bash ```bash
python examples/online_serving/pooling/ner.py python examples/online_serving/pooling/ner_client.py
``` ```
## Openai chat embedding for multimodal usage ## Openai chat embedding for multimodal usage

View File

@ -8,6 +8,8 @@ import os
from collections.abc import Callable from collections.abc import Callable
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from vllm.envs import maybe_convert_bool
if TYPE_CHECKING: if TYPE_CHECKING:
VLLM_CI_NO_SKIP: bool = False VLLM_CI_NO_SKIP: bool = False
VLLM_CI_DTYPE: str | None = None VLLM_CI_DTYPE: str | None = None
@ -25,6 +27,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_CI_HEAD_DTYPE": lambda: os.getenv("VLLM_CI_HEAD_DTYPE", None), "VLLM_CI_HEAD_DTYPE": lambda: os.getenv("VLLM_CI_HEAD_DTYPE", None),
# Allow changing the head dtype used by transformers in tests # Allow changing the head dtype used by transformers in tests
"VLLM_CI_HF_DTYPE": lambda: os.getenv("VLLM_CI_HF_DTYPE", None), "VLLM_CI_HF_DTYPE": lambda: os.getenv("VLLM_CI_HF_DTYPE", None),
# Allow control over whether tests use enforce_eager
"VLLM_CI_ENFORCE_EAGER": lambda: maybe_convert_bool(
os.getenv("VLLM_CI_ENFORCE_EAGER", None)
),
} }

View File

@ -58,7 +58,9 @@ def test_pooling_params(llm: LLM):
) )
@pytest.mark.skip_global_cleanup
def test_encode_api(llm: LLM): def test_encode_api(llm: LLM):
# chunked prefill does not support all pooling
err_msg = "pooling_task must be one of.+" err_msg = "pooling_task must be one of.+"
with pytest.raises(ValueError, match=err_msg): with pytest.raises(ValueError, match=err_msg):
llm.encode(prompts, use_tqdm=False) llm.encode(prompts, use_tqdm=False)

View File

@ -35,7 +35,6 @@ def llm():
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
@pytest.mark.skip_global_cleanup
def test_pooling_params(llm: LLM): def test_pooling_params(llm: LLM):
def get_outputs(normalize): def get_outputs(normalize):
outputs = llm.embed( outputs = llm.embed(

View File

@ -74,7 +74,6 @@ def test_multiple_pooling_params(llm: LLM):
assert len(PROMPTS) == len(outputs) assert len(PROMPTS) == len(outputs)
@pytest.mark.skip_global_cleanup
def test_right_side_truncation(llm: LLM): def test_right_side_truncation(llm: LLM):
# Embeddings models should truncate the end of the prompt # Embeddings models should truncate the end of the prompt
tokenizer = llm.get_tokenizer() tokenizer = llm.get_tokenizer()

View File

@ -33,7 +33,6 @@ def llm():
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
@pytest.mark.skip_global_cleanup
def test_pooling_params(llm: LLM): def test_pooling_params(llm: LLM):
def get_outputs(activation): def get_outputs(activation):
text_1 = "What is the capital of France?" text_1 = "What is the capital of France?"

View File

@ -3,12 +3,15 @@
# Adapted from https://huggingface.co/docs/transformers/perplexity # Adapted from https://huggingface.co/docs/transformers/perplexity
from typing import cast from typing import cast
import pytest
import torch import torch
from datasets import load_dataset from datasets import load_dataset
import tests.ci_envs as ci_envs import tests.ci_envs as ci_envs
from tests.models.utils import GenerateModelInfo, TokensTextLogprobsPromptLogprobs from tests.models.utils import (
GenerateModelInfo,
TokensTextLogprobsPromptLogprobs,
get_vllm_extra_kwargs,
)
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
# See #24485 # See #24485
@ -25,27 +28,10 @@ def wikitext_ppl_test(
vllm_extra_kwargs=None, vllm_extra_kwargs=None,
atol=PPL_TOL, atol=PPL_TOL,
): ):
# A model family has many models with the same architecture, vllm_extra_kwargs = get_vllm_extra_kwargs(model_info, vllm_extra_kwargs)
# and we don't need to test each one.
if not ci_envs.VLLM_CI_NO_SKIP and not model_info.enable_test:
pytest.skip("Skipping test.")
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
# Allow vllm to test using the given dtype, such as float32
vllm_extra_kwargs = vllm_extra_kwargs or {}
vllm_extra_kwargs["dtype"] = ci_envs.VLLM_CI_DTYPE or model_info.dtype
# Allow vllm to test using hf_overrides
if model_info.hf_overrides is not None:
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
# Allow changing the head dtype used by vllm in tests
if ci_envs.VLLM_CI_HEAD_DTYPE is not None:
if "hf_overrides" not in vllm_extra_kwargs:
vllm_extra_kwargs["hf_overrides"] = {}
vllm_extra_kwargs["hf_overrides"]["head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE
with vllm_runner( with vllm_runner(
model_info.name, model_info.name,
gpu_memory_utilization=0.7, gpu_memory_utilization=0.7,

View File

@ -0,0 +1,47 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from transformers import AutoModelForSequenceClassification
@pytest.mark.parametrize(
"model",
["nie3e/sentiment-polish-gpt2-small"],
)
@pytest.mark.parametrize("dtype", ["half"])
def test_classify_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
) -> None:
with hf_runner(
model, dtype=dtype, auto_cls=AutoModelForSequenceClassification
) as hf_model:
hf_outputs = hf_model.classify(example_prompts)
for head_dtype_str in ["float32", "model"]:
with vllm_runner(
model,
max_model_len=512,
dtype=dtype,
hf_overrides={"head_dtype": head_dtype_str},
) as vllm_model:
model_config = vllm_model.llm.llm_engine.model_config
model_dtype = model_config.dtype
head_dtype = model_config.head_dtype
if head_dtype_str == "float32":
assert head_dtype == torch.float32
elif head_dtype_str == "model":
assert head_dtype == model_dtype
vllm_outputs = vllm_model.classify(example_prompts)
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
hf_output = torch.tensor(hf_output).float()
vllm_output = torch.tensor(vllm_output).float()
assert torch.allclose(hf_output, vllm_output, atol=1e-2)

View File

@ -3,7 +3,6 @@
import types import types
import numpy as np
import pytest import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -14,11 +13,12 @@ from vllm.model_executor.models.bert import (
) )
# --------------------------------------------------------------------- # ---------------------------------------------------------------------
# 1) Functional test: SPLADE formula correctness (no HF download needed) # Functional test: SPLADE formula correctness (no HF download needed)
# --------------------------------------------------------------------- # ---------------------------------------------------------------------
@pytest.mark.parametrize("B,T,H,V", [(2, 3, 5, 7)]) @pytest.mark.parametrize("B,T,H,V", [(2, 3, 5, 7)])
@torch.inference_mode
def test_splade_pooler_matches_reference_formula(B, T, H, V): def test_splade_pooler_matches_reference_formula(B, T, H, V):
"""Ensure SPLADESparsePooler forward() matches the mathematical formula: """Ensure SPLADESparsePooler forward() matches the mathematical formula:
log1p(relu(logits)) -> max over sequence length (after masking).""" log1p(relu(logits)) -> max over sequence length (after masking)."""
@ -26,9 +26,11 @@ def test_splade_pooler_matches_reference_formula(B, T, H, V):
# Prepare [B] sequences of shape [T, H] # Prepare [B] sequences of shape [T, H]
hs_list = [torch.randn(T, H) for _ in range(B)] hs_list = [torch.randn(T, H) for _ in range(B)]
hs_tenser = torch.cat(hs_list)
# Simulate PoolingMetadata (only required fields) # Simulate PoolingMetadata (only required fields)
prompt_lens = [T, T - 1] prompt_lens = [T, T - 1]
prompt_lens_tenser = torch.tensor(prompt_lens, dtype=torch.int32)
token_ids = torch.tensor( token_ids = torch.tensor(
[ [
[101, 5, 102], # Batch 0: [CLS], token, [SEP] [101, 5, 102], # Batch 0: [CLS], token, [SEP]
@ -36,7 +38,9 @@ def test_splade_pooler_matches_reference_formula(B, T, H, V):
], ],
dtype=torch.long, dtype=torch.long,
) )
meta = types.SimpleNamespace(prompt_lens=prompt_lens, prompt_token_ids=token_ids) meta = types.SimpleNamespace(
prompt_lens=prompt_lens_tenser, prompt_token_ids=token_ids
)
# MLM head (prefer BertMLMHead, fallback to Linear if unavailable) # MLM head (prefer BertMLMHead, fallback to Linear if unavailable)
try: try:
@ -46,10 +50,10 @@ def test_splade_pooler_matches_reference_formula(B, T, H, V):
# Forward pass through SPLADE pooler # Forward pass through SPLADE pooler
pooler = SPLADESparsePooler(mlm_head=mlm_head, pooling="max", remove_cls_sep=True) pooler = SPLADESparsePooler(mlm_head=mlm_head, pooling="max", remove_cls_sep=True)
pooled = pooler(hidden_states=hs_list, pooling_metadata=meta) # list of [V] pooled = pooler(hidden_states=hs_tenser, pooling_metadata=meta) # list of [V]
# Basic output checks # Basic output checks
assert isinstance(pooled, list) and len(pooled) == B assert isinstance(pooled, torch.Tensor) and len(pooled) == B
for vec in pooled: for vec in pooled:
assert vec.shape == (V,) assert vec.shape == (V,)
assert torch.isfinite(vec).all() assert torch.isfinite(vec).all()
@ -83,40 +87,3 @@ def test_splade_pooler_matches_reference_formula(B, T, H, V):
rtol=1e-4, rtol=1e-4,
atol=1e-4, atol=1e-4,
) )
# ---------------------------------------------------------------------
# 2) Integration smoke test: end-to-end embedding path wiring
# ---------------------------------------------------------------------
@pytest.mark.cpu_model
def test_bert_splade_sparse_embed_smoke(vllm_runner, monkeypatch):
"""Ensure BertSpladeSparseEmbeddingModel loads and produces sparse embeddings."""
from transformers import AutoTokenizer
MODEL_ID = "hf-internal-testing/tiny-random-bert"
hf_overrides = {"architectures": ["BertSpladeSparseEmbeddingModel"]}
# Enforce CPU-only execution (optional)
monkeypatch.setenv("CUDA_VISIBLE_DEVICES", "")
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
tok = AutoTokenizer.from_pretrained(MODEL_ID)
vocab_size = tok.vocab_size
# The embed path should route through SPLADESparsePooler
with vllm_runner(
MODEL_ID,
runner="pooling",
max_model_len=64,
hf_overrides=hf_overrides,
) as vm:
outs = vm.embed(["hello world", "splade sparse test"])
# Basic sanity checks
assert len(outs) == 2
assert outs[0].shape[0] == vocab_size
assert outs[1].shape[0] == vocab_size
assert np.isfinite(outs[0]).all() and (outs[0] >= 0).all()
assert np.isfinite(outs[1]).all() and (outs[1] >= 0).all()

View File

@ -6,12 +6,16 @@ from collections.abc import Sequence
import mteb import mteb
import numpy as np import numpy as np
import pytest
import requests import requests
import torch import torch
import tests.ci_envs as ci_envs import tests.ci_envs as ci_envs
from tests.models.utils import EmbedModelInfo, RerankModelInfo, check_embeddings_close from tests.models.utils import (
EmbedModelInfo,
RerankModelInfo,
check_embeddings_close,
get_vllm_extra_kwargs,
)
# Most embedding models on the STS12 task (See #17175): # Most embedding models on the STS12 task (See #17175):
# - Model implementation and minor changes in tensor dtype # - Model implementation and minor changes in tensor dtype
@ -165,28 +169,11 @@ def mteb_test_embed_models(
hf_model_callback=None, hf_model_callback=None,
atol=MTEB_EMBED_TOL, atol=MTEB_EMBED_TOL,
): ):
# A model family has many models with the same architecture, vllm_extra_kwargs = get_vllm_extra_kwargs(model_info, vllm_extra_kwargs)
# and we don't need to test each one.
if not ci_envs.VLLM_CI_NO_SKIP and not model_info.enable_test:
pytest.skip("Skipping test.")
# Test embed_dims, isnan and whether to use normalize # Test embed_dims, isnan and whether to use normalize
example_prompts = ["The chef prepared a delicious meal." * 1000] example_prompts = ["The chef prepared a delicious meal." * 1000]
# Allow vllm to test using the given dtype, such as float32
vllm_extra_kwargs = vllm_extra_kwargs or {}
vllm_extra_kwargs["dtype"] = ci_envs.VLLM_CI_DTYPE or model_info.dtype
# Allow vllm to test using hf_overrides
if model_info.hf_overrides is not None:
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
# Allow changing the head dtype used by vllm in tests
if ci_envs.VLLM_CI_HEAD_DTYPE is not None:
if "hf_overrides" not in vllm_extra_kwargs:
vllm_extra_kwargs["hf_overrides"] = {}
vllm_extra_kwargs["hf_overrides"]["head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE
with vllm_runner( with vllm_runner(
model_info.name, model_info.name,
runner="pooling", runner="pooling",
@ -212,9 +199,12 @@ def mteb_test_embed_models(
vllm_dtype = vllm_model.llm.llm_engine.model_config.dtype vllm_dtype = vllm_model.llm.llm_engine.model_config.dtype
head_dtype = model_config.head_dtype head_dtype = model_config.head_dtype
# Test embed_dims, isnan and whether to use normalize # Test embedding_size, isnan and whether to use normalize
vllm_outputs = vllm_model.embed(example_prompts, truncate_prompt_tokens=-1) vllm_outputs = vllm_model.embed(example_prompts, truncate_prompt_tokens=-1)
assert not torch.any(torch.isnan(torch.tensor(vllm_outputs))) outputs_tensor = torch.tensor(vllm_outputs)
assert not torch.any(torch.isnan(outputs_tensor))
embedding_size = model_config.embedding_size
assert torch.tensor(vllm_outputs).shape[-1] == embedding_size
# Accelerate mteb test by setting # Accelerate mteb test by setting
# SentenceTransformers mteb score to a constant # SentenceTransformers mteb score to a constant
@ -231,7 +221,7 @@ def mteb_test_embed_models(
st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS) st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS)
st_dtype = next(hf_model.model.parameters()).dtype st_dtype = next(hf_model.model.parameters()).dtype
# Test embed_dims and whether to use normalize # Check embeddings close to hf outputs
hf_outputs = hf_model.encode(example_prompts) hf_outputs = hf_model.encode(example_prompts)
check_embeddings_close( check_embeddings_close(
embeddings_0_lst=hf_outputs, embeddings_0_lst=hf_outputs,
@ -323,24 +313,7 @@ def mteb_test_rerank_models(
vllm_mteb_encoder=VllmMtebEncoder, vllm_mteb_encoder=VllmMtebEncoder,
atol=MTEB_RERANK_TOL, atol=MTEB_RERANK_TOL,
): ):
# A model family has many models with the same architecture, vllm_extra_kwargs = get_vllm_extra_kwargs(model_info, vllm_extra_kwargs)
# and we don't need to test each one.
if not ci_envs.VLLM_CI_NO_SKIP and not model_info.enable_test:
pytest.skip("Skipping test.")
# Allow vllm to test using the given dtype, such as float32
vllm_extra_kwargs = vllm_extra_kwargs or {}
vllm_extra_kwargs["dtype"] = ci_envs.VLLM_CI_DTYPE or model_info.dtype
# Allow vllm to test using hf_overrides
if model_info.hf_overrides is not None:
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
# Allow changing the head dtype used by vllm in tests
if ci_envs.VLLM_CI_HEAD_DTYPE is not None:
if "hf_overrides" not in vllm_extra_kwargs:
vllm_extra_kwargs["hf_overrides"] = {}
vllm_extra_kwargs["hf_overrides"]["head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE
with vllm_runner( with vllm_runner(
model_info.name, model_info.name,

View File

@ -15,6 +15,7 @@ from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs
from vllm.multimodal.processing import InputProcessingContext from vllm.multimodal.processing import InputProcessingContext
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from .. import ci_envs
from .registry import HF_EXAMPLE_MODELS from .registry import HF_EXAMPLE_MODELS
TokensText = tuple[list[int], str] TokensText = tuple[list[int], str]
@ -414,6 +415,35 @@ class GenerateModelInfo(ModelInfo):
hf_ppl: float | None = None hf_ppl: float | None = None
def get_vllm_extra_kwargs(model_info: ModelInfo, vllm_extra_kwargs):
# A model family has many models with the same architecture,
# and we don't need to test each one.
if not ci_envs.VLLM_CI_NO_SKIP and not model_info.enable_test:
import pytest
pytest.skip("Skipping test.")
# Allow vllm to test using the given dtype, such as float32
vllm_extra_kwargs = vllm_extra_kwargs or {}
vllm_extra_kwargs["dtype"] = ci_envs.VLLM_CI_DTYPE or model_info.dtype
# Allow vllm to test using hf_overrides
if model_info.hf_overrides is not None:
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
# Allow changing the head dtype used by vllm in tests
if ci_envs.VLLM_CI_HEAD_DTYPE is not None:
if "hf_overrides" not in vllm_extra_kwargs:
vllm_extra_kwargs["hf_overrides"] = {}
vllm_extra_kwargs["hf_overrides"]["head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE
# Allow control over whether tests use enforce_eager
if ci_envs.VLLM_CI_ENFORCE_EAGER is not None:
vllm_extra_kwargs["enforce_eager"] = ci_envs.VLLM_CI_ENFORCE_EAGER
return vllm_extra_kwargs
def dummy_hf_overrides( def dummy_hf_overrides(
hf_config: PretrainedConfig, hf_config: PretrainedConfig,
*, *,

View File

@ -30,6 +30,7 @@ from vllm.transformers_utils.config import (
get_sentence_transformer_tokenizer_config, get_sentence_transformer_tokenizer_config,
is_encoder_decoder, is_encoder_decoder,
is_interleaved, is_interleaved,
try_get_dense_modules,
try_get_generation_config, try_get_generation_config,
try_get_safetensors_metadata, try_get_safetensors_metadata,
try_get_tokenizer_config, try_get_tokenizer_config,
@ -1681,6 +1682,20 @@ class ModelConfig:
logger.debug_once("head dtype: %s", head_dtype) logger.debug_once("head dtype: %s", head_dtype)
return head_dtype return head_dtype
@property
def hidden_size(self):
if hasattr(self.hf_config, "hidden_size"):
return self.hf_config.hidden_size
text_config = self.hf_config.get_text_config()
return text_config.hidden_size
@property
def embedding_size(self):
dense_modules = try_get_dense_modules(self.model, revision=self.revision)
if dense_modules is not None:
return dense_modules[-1]["out_features"]
return self.hidden_size
def get_and_verify_max_len(self, max_model_len: int): def get_and_verify_max_len(self, max_model_len: int):
# Consider max_model_len in tokenizer_config only when # Consider max_model_len in tokenizer_config only when
# pooling models use absolute position_embedding. # pooling models use absolute position_embedding.

View File

@ -13,7 +13,10 @@ from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.models.config import VerifyAndUpdateConfig from vllm.model_executor.models.config import VerifyAndUpdateConfig
from vllm.transformers_utils.config import get_hf_file_bytes, get_hf_file_to_dict from vllm.transformers_utils.config import (
get_hf_file_bytes,
try_get_dense_modules,
)
from .interfaces_base import VllmModelForPooling, is_pooling_model from .interfaces_base import VllmModelForPooling, is_pooling_model
@ -35,43 +38,25 @@ _GENERATE_SUFFIXES = [
def _load_st_projector(model_config: "ModelConfig") -> nn.Module | None: def _load_st_projector(model_config: "ModelConfig") -> nn.Module | None:
"""Load Sentence-Transformers Dense projection layers.""" """Load Sentence-Transformers Dense projection layers."""
dense_modules = try_get_dense_modules(
model_config.model, revision=model_config.revision
)
if dense_modules is None:
return
try: try:
modules = get_hf_file_to_dict(
"modules.json", model_config.model, model_config.revision
)
if not modules:
return None
if isinstance(modules, dict):
modules = modules.get("modules", [])
dense_modules = [
m for m in modules if m.get("type") == "sentence_transformers.models.Dense"
]
if not dense_modules:
return None
layers = [] layers = []
for module in dense_modules: for layer_config in dense_modules:
folder = module.get("path", "") folder = layer_config["folder"]
config_path = f"{folder}/config.json" if folder else "config.json"
layer_config = get_hf_file_to_dict(
config_path, model_config.model, model_config.revision
)
if not layer_config:
continue
linear = nn.Linear( linear = nn.Linear(
layer_config.get("in_features", 768), layer_config["in_features"],
layer_config.get("out_features", 768), layer_config["out_features"],
bias=layer_config.get("bias", True), bias=layer_config.get("bias", True),
dtype=model_config.head_dtype, dtype=model_config.head_dtype,
) )
if not _load_dense_weights(linear, folder, model_config): if not _load_dense_weights(linear, folder, model_config):
continue continue
layers.append(linear) layers.append(linear)
if act_name := layer_config.get("activation_function"): if act_name := layer_config.get("activation_function"):
layers.append(get_act_fn(act_name)) layers.append(get_act_fn(act_name))
@ -303,18 +288,18 @@ def as_seq_cls_model(cls: _T) -> _T:
from vllm.model_executor.models.interfaces import SupportsCrossEncoding from vllm.model_executor.models.interfaces import SupportsCrossEncoding
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .utils import get_model_hidden_size, maybe_prefix from .utils import maybe_prefix
class ModelForSequenceClassification( class ModelForSequenceClassification(
_create_pooling_model_cls(cls), SupportsCrossEncoding _create_pooling_model_cls(cls), SupportsCrossEncoding
): ):
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
model_config = vllm_config.model_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
hidden_size = get_model_hidden_size(config)
self.score = ReplicatedLinear( self.score = ReplicatedLinear(
hidden_size, model_config.hidden_size,
config.num_labels, config.num_labels,
bias=False, bias=False,
params_dtype=torch.float32, params_dtype=torch.float32,

View File

@ -50,7 +50,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from ..layers.pooler import DispatchPooler, Pooler from ..layers.pooler import DispatchPooler, Pooler
from .interfaces import SupportsPP from .interfaces import SupportsCrossEncoding, SupportsPP
from .utils import ( from .utils import (
AutoWeightsLoader, AutoWeightsLoader,
is_pp_missing_parameter, is_pp_missing_parameter,
@ -321,7 +321,7 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
return loader.load_weights(weights) return loader.load_weights(weights)
class GPT2ForSequenceClassification(nn.Module): class GPT2ForSequenceClassification(nn.Module, SupportsCrossEncoding):
"""GPT2 Model for sequence classification. """GPT2 Model for sequence classification.
This class expands GPT2Model with pooling and score functions - last token This class expands GPT2Model with pooling and score functions - last token
@ -358,6 +358,9 @@ class GPT2ForSequenceClassification(nn.Module):
} }
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.get_input_embeddings(input_ids)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights)

View File

@ -148,37 +148,6 @@ class GritLMMeanPool(nn.Module):
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return PoolingParamsUpdate(requires_token_ids=True) return PoolingParamsUpdate(requires_token_ids=True)
def forward_one(
self,
hidden_states: torch.Tensor,
prompt_len: torch.Tensor | None = None,
instr_len: torch.Tensor | None = None,
) -> torch.Tensor:
assert prompt_len is None or prompt_len == hidden_states.shape[0], (
"partial prefill not supported with MEAN pooling"
)
return hidden_states[instr_len:].mean(dim=0, dtype=torch.float32)
def forward_all(
self,
hidden_states: torch.Tensor,
prompt_lens: torch.Tensor,
instr_lens: torch.Tensor,
) -> list[torch.Tensor] | torch.Tensor:
offset = 0
pooled_data = list[torch.Tensor]()
for prompt_len, instr_len in zip(prompt_lens, instr_lens):
pooled_data.append(
hidden_states[offset + instr_len : offset + prompt_len].mean(
dim=0, dtype=torch.float32
)
)
offset += prompt_len
return pooled_data
def forward( def forward(
self, self,
hidden_states: torch.Tensor | list[torch.Tensor], hidden_states: torch.Tensor | list[torch.Tensor],
@ -190,18 +159,20 @@ class GritLMMeanPool(nn.Module):
self._get_instruction_len(token_ids.cpu().numpy()) self._get_instruction_len(token_ids.cpu().numpy())
for token_ids in get_prompt_token_ids(pooling_metadata) for token_ids in get_prompt_token_ids(pooling_metadata)
], ],
device=prompt_lens.device, device="cpu",
) )
if isinstance(hidden_states, list): offset = 0
return [ pooled_data = list[torch.Tensor]()
self.forward_one(h, prompt_len, instr_len) for prompt_len, instr_len in zip(prompt_lens, instr_lens):
for h, prompt_len, instr_len in zip( pooled_data.append(
hidden_states, prompt_lens, instr_lens hidden_states[offset + instr_len : offset + prompt_len].mean(
dim=0, dtype=torch.float32
) )
] )
offset += prompt_len
return self.forward_all(hidden_states, prompt_lens, instr_lens) return pooled_data
class GritLMPooler(Pooler): class GritLMPooler(Pooler):

View File

@ -777,13 +777,6 @@ def fast_topk(
return torch.topk(values, topk, dim=dim) return torch.topk(values, topk, dim=dim)
def get_model_hidden_size(hf_config: PretrainedConfig) -> int:
if hasattr(hf_config, "hidden_size"):
return hf_config.hidden_size
text_config = hf_config.get_text_config()
return text_config.hidden_size
# Chunk x along the num_tokens axis for sequence parallelism # Chunk x along the num_tokens axis for sequence parallelism
# NOTE: This is wrapped in a torch custom op to work around the following issue: # NOTE: This is wrapped in a torch custom op to work around the following issue:
# The output tensor can have a sequence length 0 at small input sequence lengths # The output tensor can have a sequence length 0 at small input sequence lengths

View File

@ -1049,6 +1049,40 @@ def try_get_tokenizer_config(
return None return None
@cache
def try_get_dense_modules(
model: str | Path,
revision: str | None = None,
) -> list[dict[str, Any]] | None:
try:
modules = get_hf_file_to_dict("modules.json", model, revision)
if not modules:
return None
if isinstance(modules, dict):
modules = modules.get("modules", [])
dense_modules = [
m for m in modules if m.get("type") == "sentence_transformers.models.Dense"
]
if not dense_modules:
return None
layer_configs = []
for module in dense_modules:
folder = module.get("path", "")
config_path = f"{folder}/config.json" if folder else "config.json"
layer_config = get_hf_file_to_dict(config_path, model, revision)
if not layer_config:
continue
layer_config["folder"] = folder
layer_configs.append(layer_config)
return layer_configs
except Exception:
return None
def get_safetensors_params_metadata( def get_safetensors_params_metadata(
model: str, model: str,
*, *,