[Bugfix] Fix Dense module loading for sentence-transformers embedding models (simplified V2) (#23408)

Signed-off-by: FFFfff1FFFfff <yifanli0919@gmail.com>
This commit is contained in:
LIYIFAN_liyifan 2025-08-24 22:39:24 -07:00 committed by GitHub
parent 787cdb3829
commit c9abb10489
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 175 additions and 2 deletions

View File

@ -0,0 +1,22 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from ...utils import CLSPoolingEmbedModelInfo, EmbedModelInfo
from .mteb_utils import mteb_test_embed_models
# ST models with projector (Dense) layers
ST_PROJECTOR_MODELS = [
CLSPoolingEmbedModelInfo(
"TencentBAC/Conan-embedding-v1",
architecture="BertModel",
enable_test=True,
),
]
@pytest.mark.parametrize("model_info", ST_PROJECTOR_MODELS)
def test_embed_models_mteb(hf_runner, vllm_runner,
model_info: EmbedModelInfo) -> None:
mteb_test_embed_models(hf_runner, vllm_runner, model_info)

View File

@ -422,12 +422,23 @@ _ACTIVATION_REGISTRY = LazyDict({
lambda: nn.SiLU(),
"quick_gelu":
lambda: QuickGELU(),
"tanh":
lambda: nn.Tanh(),
"sigmoid":
lambda: nn.Sigmoid(),
})
def get_act_fn(act_fn_name: str) -> nn.Module:
"""Get an activation function by name."""
act_fn_name = act_fn_name.lower()
if act_fn_name.startswith("torch.nn.modules."):
activation_name = act_fn_name.split(".")[-1]
if activation_name == "identity":
return nn.Identity()
act_fn_name = activation_name
if act_fn_name not in _ACTIVATION_REGISTRY:
raise ValueError(
f"Activation function {act_fn_name!r} is not supported.")

View File

@ -5,7 +5,7 @@ from collections.abc import Mapping, Set
from dataclasses import dataclass
from enum import IntEnum
from itertools import groupby
from typing import Callable, Optional, TypeVar, Union
from typing import Callable, Optional, TypeVar, Union, cast
import torch
import torch.nn as nn
@ -435,9 +435,31 @@ class EmbeddingPoolerHead(PoolerHead):
def __init__(self) -> None:
super().__init__(activation=PoolerNormalize())
# Load ST projector if available
from vllm.config import get_current_vllm_config
from vllm.model_executor.models.adapters import _load_st_projector
vllm_config = get_current_vllm_config()
self.projector = _load_st_projector(
vllm_config.model_config) if vllm_config else None
def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
pooling_metadata: PoolingMetadata):
# Apply ST projector
if self.projector is not None:
projector = cast(nn.Module, self.projector)
def _proj(x: torch.Tensor) -> torch.Tensor:
orig_dtype = x.dtype
y = projector(x.to(torch.float32))
return y.to(orig_dtype)
if isinstance(pooled_data, torch.Tensor):
pooled_data = _proj(pooled_data)
else:
pooled_data = [_proj(t) for t in pooled_data]
pooling_params = get_pooling_params(pooling_metadata)
if isinstance(pooled_data, list):

View File

@ -7,15 +7,21 @@ from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
import torch
import torch.nn as nn
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.models.config import VerifyAndUpdateConfig
from vllm.transformers_utils.config import (get_hf_file_bytes,
get_hf_file_to_dict)
from .interfaces_base import VllmModelForPooling, is_pooling_model
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.config import ModelConfig, VllmConfig
_T = TypeVar("_T", bound=type[nn.Module])
logger = init_logger(__name__)
_GENERATE_SUFFIXES = [
"ForCausalLM",
"ForConditionalGeneration",
@ -24,6 +30,96 @@ _GENERATE_SUFFIXES = [
]
def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]:
"""Load Sentence-Transformers Dense projection layers."""
try:
modules = get_hf_file_to_dict("modules.json", model_config.model,
model_config.revision)
if not modules:
return None
if isinstance(modules, dict):
modules = modules.get("modules", [])
dense_modules = [
m for m in modules
if m.get("type") == "sentence_transformers.models.Dense"
]
if not dense_modules:
return None
module = dense_modules[0]
folder = module.get("path", "")
config_path = f"{folder}/config.json" if folder else "config.json"
layer_config = get_hf_file_to_dict(config_path, model_config.model,
model_config.revision)
if not layer_config:
return None
linear = nn.Linear(layer_config.get("in_features", 768),
layer_config.get("out_features", 768),
bias=layer_config.get("bias", True),
dtype=torch.float32)
if _load_dense_weights(linear, folder, model_config):
layers = [linear]
if act_name := layer_config.get("activation_function"):
layers.append(get_act_fn(act_name))
return nn.Sequential(*layers).to(dtype=torch.float32)
except Exception:
logger.exception("ST projector loading failed")
return None
def _load_dense_weights(linear: nn.Linear, folder: str,
model_config: "ModelConfig") -> bool:
"""Load weights using vLLM's weight_loader pattern."""
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader)
for filename in ["model.safetensors", "pytorch_model.bin"]:
file_path = f"{folder}/{filename}" if folder else filename
try:
file_bytes = get_hf_file_bytes(file_path, model_config.model,
model_config.revision)
if not file_bytes:
continue
if filename.endswith(".safetensors"):
from safetensors.torch import load as load_safetensors
state_dict = load_safetensors(file_bytes)
else:
import io
state_dict = torch.load(io.BytesIO(file_bytes),
map_location="cpu",
weights_only=True)
for weight_key in ["weight", "linear.weight", "dense.weight"]:
if weight_key in state_dict:
weight_loader = getattr(linear.weight, "weight_loader",
default_weight_loader)
weight_loader(linear.weight,
state_dict[weight_key].to(torch.float32))
bias_key = weight_key.replace("weight", "bias")
if linear.bias is not None and bias_key in state_dict:
bias_loader = getattr(linear.bias, "weight_loader",
default_weight_loader)
bias_loader(linear.bias,
state_dict[bias_key].to(torch.float32))
return True
except Exception:
logger.exception("Failed to load %s", filename)
continue
return False
def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:
model_name = orig_model_name

View File

@ -927,3 +927,25 @@ def get_model_path(model: Union[str, Path], revision: Optional[str] = None):
from huggingface_hub import snapshot_download
return snapshot_download(repo_id=model, **common_kwargs)
def get_hf_file_bytes(file_name: str,
model: Union[str, Path],
revision: Optional[str] = 'main') -> Optional[bytes]:
"""Get file contents from HuggingFace repository as bytes."""
file_path = try_get_local_file(model=model,
file_name=file_name,
revision=revision)
if file_path is None:
hf_hub_file = hf_hub_download(model,
file_name,
revision=revision,
token=_get_hf_token())
file_path = Path(hf_hub_file)
if file_path is not None and file_path.is_file():
with open(file_path, 'rb') as file:
return file.read()
return None