mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-28 22:14:56 +08:00
[Bugfix] Fix Dense module loading for sentence-transformers embedding models (simplified V2) (#23408)
Signed-off-by: FFFfff1FFFfff <yifanli0919@gmail.com>
This commit is contained in:
parent
787cdb3829
commit
c9abb10489
22
tests/models/language/pooling/test_st_projector.py
Normal file
22
tests/models/language/pooling/test_st_projector.py
Normal file
@ -0,0 +1,22 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
|
||||
from ...utils import CLSPoolingEmbedModelInfo, EmbedModelInfo
|
||||
from .mteb_utils import mteb_test_embed_models
|
||||
|
||||
# ST models with projector (Dense) layers
|
||||
ST_PROJECTOR_MODELS = [
|
||||
CLSPoolingEmbedModelInfo(
|
||||
"TencentBAC/Conan-embedding-v1",
|
||||
architecture="BertModel",
|
||||
enable_test=True,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_info", ST_PROJECTOR_MODELS)
|
||||
def test_embed_models_mteb(hf_runner, vllm_runner,
|
||||
model_info: EmbedModelInfo) -> None:
|
||||
|
||||
mteb_test_embed_models(hf_runner, vllm_runner, model_info)
|
||||
@ -422,12 +422,23 @@ _ACTIVATION_REGISTRY = LazyDict({
|
||||
lambda: nn.SiLU(),
|
||||
"quick_gelu":
|
||||
lambda: QuickGELU(),
|
||||
"tanh":
|
||||
lambda: nn.Tanh(),
|
||||
"sigmoid":
|
||||
lambda: nn.Sigmoid(),
|
||||
})
|
||||
|
||||
|
||||
def get_act_fn(act_fn_name: str) -> nn.Module:
|
||||
"""Get an activation function by name."""
|
||||
act_fn_name = act_fn_name.lower()
|
||||
|
||||
if act_fn_name.startswith("torch.nn.modules."):
|
||||
activation_name = act_fn_name.split(".")[-1]
|
||||
if activation_name == "identity":
|
||||
return nn.Identity()
|
||||
act_fn_name = activation_name
|
||||
|
||||
if act_fn_name not in _ACTIVATION_REGISTRY:
|
||||
raise ValueError(
|
||||
f"Activation function {act_fn_name!r} is not supported.")
|
||||
|
||||
@ -5,7 +5,7 @@ from collections.abc import Mapping, Set
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum
|
||||
from itertools import groupby
|
||||
from typing import Callable, Optional, TypeVar, Union
|
||||
from typing import Callable, Optional, TypeVar, Union, cast
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -435,9 +435,31 @@ class EmbeddingPoolerHead(PoolerHead):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(activation=PoolerNormalize())
|
||||
|
||||
# Load ST projector if available
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.model_executor.models.adapters import _load_st_projector
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.projector = _load_st_projector(
|
||||
vllm_config.model_config) if vllm_config else None
|
||||
|
||||
def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
|
||||
pooling_metadata: PoolingMetadata):
|
||||
|
||||
# Apply ST projector
|
||||
if self.projector is not None:
|
||||
projector = cast(nn.Module, self.projector)
|
||||
|
||||
def _proj(x: torch.Tensor) -> torch.Tensor:
|
||||
orig_dtype = x.dtype
|
||||
y = projector(x.to(torch.float32))
|
||||
return y.to(orig_dtype)
|
||||
|
||||
if isinstance(pooled_data, torch.Tensor):
|
||||
pooled_data = _proj(pooled_data)
|
||||
else:
|
||||
pooled_data = [_proj(t) for t in pooled_data]
|
||||
|
||||
pooling_params = get_pooling_params(pooling_metadata)
|
||||
|
||||
if isinstance(pooled_data, list):
|
||||
|
||||
@ -7,15 +7,21 @@ from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.models.config import VerifyAndUpdateConfig
|
||||
from vllm.transformers_utils.config import (get_hf_file_bytes,
|
||||
get_hf_file_to_dict)
|
||||
|
||||
from .interfaces_base import VllmModelForPooling, is_pooling_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
|
||||
_T = TypeVar("_T", bound=type[nn.Module])
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_GENERATE_SUFFIXES = [
|
||||
"ForCausalLM",
|
||||
"ForConditionalGeneration",
|
||||
@ -24,6 +30,96 @@ _GENERATE_SUFFIXES = [
|
||||
]
|
||||
|
||||
|
||||
def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]:
|
||||
"""Load Sentence-Transformers Dense projection layers."""
|
||||
|
||||
try:
|
||||
modules = get_hf_file_to_dict("modules.json", model_config.model,
|
||||
model_config.revision)
|
||||
if not modules:
|
||||
return None
|
||||
|
||||
if isinstance(modules, dict):
|
||||
modules = modules.get("modules", [])
|
||||
|
||||
dense_modules = [
|
||||
m for m in modules
|
||||
if m.get("type") == "sentence_transformers.models.Dense"
|
||||
]
|
||||
if not dense_modules:
|
||||
return None
|
||||
|
||||
module = dense_modules[0]
|
||||
folder = module.get("path", "")
|
||||
|
||||
config_path = f"{folder}/config.json" if folder else "config.json"
|
||||
layer_config = get_hf_file_to_dict(config_path, model_config.model,
|
||||
model_config.revision)
|
||||
if not layer_config:
|
||||
return None
|
||||
|
||||
linear = nn.Linear(layer_config.get("in_features", 768),
|
||||
layer_config.get("out_features", 768),
|
||||
bias=layer_config.get("bias", True),
|
||||
dtype=torch.float32)
|
||||
|
||||
if _load_dense_weights(linear, folder, model_config):
|
||||
layers = [linear]
|
||||
if act_name := layer_config.get("activation_function"):
|
||||
layers.append(get_act_fn(act_name))
|
||||
return nn.Sequential(*layers).to(dtype=torch.float32)
|
||||
|
||||
except Exception:
|
||||
logger.exception("ST projector loading failed")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _load_dense_weights(linear: nn.Linear, folder: str,
|
||||
model_config: "ModelConfig") -> bool:
|
||||
"""Load weights using vLLM's weight_loader pattern."""
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader)
|
||||
|
||||
for filename in ["model.safetensors", "pytorch_model.bin"]:
|
||||
file_path = f"{folder}/{filename}" if folder else filename
|
||||
|
||||
try:
|
||||
file_bytes = get_hf_file_bytes(file_path, model_config.model,
|
||||
model_config.revision)
|
||||
if not file_bytes:
|
||||
continue
|
||||
|
||||
if filename.endswith(".safetensors"):
|
||||
from safetensors.torch import load as load_safetensors
|
||||
state_dict = load_safetensors(file_bytes)
|
||||
else:
|
||||
import io
|
||||
state_dict = torch.load(io.BytesIO(file_bytes),
|
||||
map_location="cpu",
|
||||
weights_only=True)
|
||||
|
||||
for weight_key in ["weight", "linear.weight", "dense.weight"]:
|
||||
if weight_key in state_dict:
|
||||
weight_loader = getattr(linear.weight, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(linear.weight,
|
||||
state_dict[weight_key].to(torch.float32))
|
||||
|
||||
bias_key = weight_key.replace("weight", "bias")
|
||||
if linear.bias is not None and bias_key in state_dict:
|
||||
bias_loader = getattr(linear.bias, "weight_loader",
|
||||
default_weight_loader)
|
||||
bias_loader(linear.bias,
|
||||
state_dict[bias_key].to(torch.float32))
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("Failed to load %s", filename)
|
||||
continue
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:
|
||||
model_name = orig_model_name
|
||||
|
||||
|
||||
@ -927,3 +927,25 @@ def get_model_path(model: Union[str, Path], revision: Optional[str] = None):
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
return snapshot_download(repo_id=model, **common_kwargs)
|
||||
|
||||
|
||||
def get_hf_file_bytes(file_name: str,
|
||||
model: Union[str, Path],
|
||||
revision: Optional[str] = 'main') -> Optional[bytes]:
|
||||
"""Get file contents from HuggingFace repository as bytes."""
|
||||
file_path = try_get_local_file(model=model,
|
||||
file_name=file_name,
|
||||
revision=revision)
|
||||
|
||||
if file_path is None:
|
||||
hf_hub_file = hf_hub_download(model,
|
||||
file_name,
|
||||
revision=revision,
|
||||
token=_get_hf_token())
|
||||
file_path = Path(hf_hub_file)
|
||||
|
||||
if file_path is not None and file_path.is_file():
|
||||
with open(file_path, 'rb') as file:
|
||||
return file.read()
|
||||
|
||||
return None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user