mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 05:35:01 +08:00
Add explicit pooling classes for the Transformers backend (#25322)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
8328d39d40
commit
b3e1846da6
@ -657,7 +657,8 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
_TRANSFORMERS_BACKEND_MODELS = {
|
_TRANSFORMERS_BACKEND_MODELS = {
|
||||||
"TransformersModel": _HfExamplesInfo("Qwen/Qwen3-Embedding-0.6B"),
|
"TransformersEmbeddingModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5", min_transformers_version="4.57.0.dev0"), # noqa: E501
|
||||||
|
"TransformersForSequenceClassification": _HfExamplesInfo("papluca/xlm-roberta-base-language-detection", min_transformers_version="4.57.0.dev0"), # noqa: E501
|
||||||
"TransformersForCausalLM": _HfExamplesInfo("hmellor/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501
|
"TransformersForCausalLM": _HfExamplesInfo("hmellor/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501
|
||||||
"TransformersForMultimodalLM": _HfExamplesInfo("BAAI/Emu3-Chat-hf"),
|
"TransformersForMultimodalLM": _HfExamplesInfo("BAAI/Emu3-Chat-hf"),
|
||||||
}
|
}
|
||||||
|
|||||||
@ -9,9 +9,16 @@ from vllm.platforms import current_platform
|
|||||||
|
|
||||||
from ..conftest import HfRunner, VllmRunner
|
from ..conftest import HfRunner, VllmRunner
|
||||||
from ..utils import multi_gpu_test, prep_prompts
|
from ..utils import multi_gpu_test, prep_prompts
|
||||||
|
from .registry import HF_EXAMPLE_MODELS
|
||||||
from .utils import check_embeddings_close, check_logprobs_close
|
from .utils import check_embeddings_close, check_logprobs_close
|
||||||
|
|
||||||
|
|
||||||
|
def get_model(arch: str) -> str:
|
||||||
|
model_info = HF_EXAMPLE_MODELS.get_hf_info(arch)
|
||||||
|
model_info.check_transformers_version(on_fail="skip")
|
||||||
|
return model_info.default
|
||||||
|
|
||||||
|
|
||||||
def check_implementation(
|
def check_implementation(
|
||||||
runner_ref: type[Union[HfRunner, VllmRunner]],
|
runner_ref: type[Union[HfRunner, VllmRunner]],
|
||||||
runner_test: type[VllmRunner],
|
runner_test: type[VllmRunner],
|
||||||
@ -170,71 +177,47 @@ def test_embed_loading(vllm_runner, model):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model",
|
"arch",
|
||||||
[
|
["TransformersEmbeddingModel", "TransformersForSequenceClassification"])
|
||||||
# Encoder model
|
def test_pooling(hf_runner, vllm_runner, example_prompts, arch):
|
||||||
"BAAI/bge-base-en-v1.5",
|
model = get_model(arch)
|
||||||
])
|
|
||||||
def test_embed_correctness(hf_runner, vllm_runner, example_prompts, model):
|
|
||||||
import transformers
|
|
||||||
from packaging.version import Version
|
|
||||||
installed = Version(transformers.__version__)
|
|
||||||
required = Version("4.57.0.dev0")
|
|
||||||
if installed < required:
|
|
||||||
pytest.skip("Encoder models with the Transformers backend require "
|
|
||||||
f"transformers>={required}, but got {installed}")
|
|
||||||
|
|
||||||
with vllm_runner(model, max_model_len=512,
|
vllm_kwargs = dict(
|
||||||
model_impl="transformers") as vllm_model:
|
max_model_len=None,
|
||||||
|
model_impl="transformers",
|
||||||
|
compilation_config=dict(cudagraph_capture_sizes=[8]),
|
||||||
|
)
|
||||||
|
|
||||||
|
hf_kwargs = dict()
|
||||||
|
if arch == "TransformersEmbeddingModel":
|
||||||
|
hf_kwargs["is_sentence_transformer"] = True
|
||||||
|
elif arch == "TransformersForSequenceClassification":
|
||||||
|
from transformers import AutoModelForSequenceClassification
|
||||||
|
hf_kwargs["auto_cls"] = AutoModelForSequenceClassification
|
||||||
|
|
||||||
|
# The example_prompts has ending "\n", for example:
|
||||||
|
# "Write a short story about a robot that dreams for the first time.\n"
|
||||||
|
# sentence_transformers will strip the input texts, see:
|
||||||
|
# https://github.com/UKPLab/sentence-transformers/blob/v3.1.1/sentence_transformers/models/Transformer.py#L159
|
||||||
|
# This makes the input_ids different between hf_model and vllm_model.
|
||||||
|
# So we need to strip the input texts to avoid test failing.
|
||||||
|
example_prompts = [str(s).strip() for s in example_prompts]
|
||||||
|
|
||||||
|
with (vllm_runner(model, **vllm_kwargs) as
|
||||||
|
vllm_model, hf_runner(model, **hf_kwargs) as hf_model):
|
||||||
model_config = vllm_model.llm.llm_engine.model_config
|
model_config = vllm_model.llm.llm_engine.model_config
|
||||||
assert model_config.using_transformers_backend()
|
assert model_config.using_transformers_backend()
|
||||||
|
|
||||||
|
if arch == "TransformersEmbeddingModel":
|
||||||
vllm_outputs = vllm_model.embed(example_prompts)
|
vllm_outputs = vllm_model.embed(example_prompts)
|
||||||
|
|
||||||
with hf_runner(model, is_sentence_transformer=True) as hf_model:
|
|
||||||
hf_outputs = hf_model.encode(example_prompts)
|
hf_outputs = hf_model.encode(example_prompts)
|
||||||
|
elif arch == "TransformersForSequenceClassification":
|
||||||
|
vllm_outputs = vllm_model.classify(example_prompts)
|
||||||
|
hf_outputs = hf_model.classify(example_prompts)
|
||||||
|
|
||||||
check_embeddings_close(
|
check_embeddings_close(
|
||||||
embeddings_0_lst=hf_outputs,
|
embeddings_0_lst=hf_outputs,
|
||||||
embeddings_1_lst=vllm_outputs,
|
embeddings_1_lst=vllm_outputs,
|
||||||
name_0="hf",
|
name_0="hf",
|
||||||
name_1="vllm",
|
name_1="vllm",
|
||||||
tol=1e-2,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"model",
|
|
||||||
["jason9693/Qwen2.5-1.5B-apeach"],
|
|
||||||
)
|
|
||||||
@pytest.mark.parametrize("dtype", ["float"])
|
|
||||||
def test_classify(
|
|
||||||
hf_runner,
|
|
||||||
vllm_runner,
|
|
||||||
example_prompts,
|
|
||||||
model: str,
|
|
||||||
dtype: str,
|
|
||||||
) -> None:
|
|
||||||
import torch
|
|
||||||
from transformers import AutoModelForSequenceClassification
|
|
||||||
|
|
||||||
with vllm_runner(model,
|
|
||||||
max_model_len=512,
|
|
||||||
dtype=dtype,
|
|
||||||
model_impl="transformers") as vllm_model:
|
|
||||||
model_config = vllm_model.llm.llm_engine.model_config
|
|
||||||
assert model_config.using_transformers_backend()
|
|
||||||
|
|
||||||
vllm_outputs = vllm_model.classify(example_prompts)
|
|
||||||
|
|
||||||
with hf_runner(model,
|
|
||||||
dtype=dtype,
|
|
||||||
auto_cls=AutoModelForSequenceClassification) as hf_model:
|
|
||||||
hf_outputs = hf_model.classify(example_prompts)
|
|
||||||
|
|
||||||
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
|
|
||||||
hf_output = torch.tensor(hf_output)
|
|
||||||
vllm_output = torch.tensor(vllm_output)
|
|
||||||
|
|
||||||
assert torch.allclose(hf_output, vllm_output,
|
|
||||||
1e-3 if dtype == "float" else 1e-2)
|
|
||||||
|
|||||||
@ -19,6 +19,7 @@ import vllm.envs as envs
|
|||||||
from vllm.config.multimodal import (MMCacheType, MMEncoderTPMode,
|
from vllm.config.multimodal import (MMCacheType, MMEncoderTPMode,
|
||||||
MultiModalConfig)
|
MultiModalConfig)
|
||||||
from vllm.config.pooler import PoolerConfig
|
from vllm.config.pooler import PoolerConfig
|
||||||
|
from vllm.config.scheduler import RunnerType
|
||||||
from vllm.config.utils import assert_hashable, config
|
from vllm.config.utils import assert_hashable, config
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@ -40,7 +41,6 @@ if TYPE_CHECKING:
|
|||||||
import vllm.model_executor.models as me_models
|
import vllm.model_executor.models as me_models
|
||||||
from vllm.config.load import LoadConfig
|
from vllm.config.load import LoadConfig
|
||||||
from vllm.config.parallel import ParallelConfig
|
from vllm.config.parallel import ParallelConfig
|
||||||
from vllm.config.scheduler import RunnerType
|
|
||||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
from vllm.v1.sample.logits_processor import LogitsProcessor
|
from vllm.v1.sample.logits_processor import LogitsProcessor
|
||||||
else:
|
else:
|
||||||
@ -52,13 +52,12 @@ else:
|
|||||||
"vllm.model_executor.models")
|
"vllm.model_executor.models")
|
||||||
LoadConfig = Any
|
LoadConfig = Any
|
||||||
ParallelConfig = Any
|
ParallelConfig = Any
|
||||||
RunnerType = Any
|
|
||||||
QuantizationMethods = Any
|
QuantizationMethods = Any
|
||||||
LogitsProcessor = Any
|
LogitsProcessor = Any
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
RunnerOption = Literal["auto", "generate", "pooling", "draft"]
|
RunnerOption = Literal["auto", RunnerType]
|
||||||
ConvertType = Literal["none", "embed", "classify", "reward"]
|
ConvertType = Literal["none", "embed", "classify", "reward"]
|
||||||
ConvertOption = Literal["auto", ConvertType]
|
ConvertOption = Literal["auto", ConvertType]
|
||||||
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
|
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
|
||||||
@ -668,8 +667,28 @@ class ModelConfig:
|
|||||||
def _get_transformers_backend_cls(self) -> str:
|
def _get_transformers_backend_cls(self) -> str:
|
||||||
"""Determine which Transformers backend class will be used if
|
"""Determine which Transformers backend class will be used if
|
||||||
`model_impl` is set to `transformers` or `auto`."""
|
`model_impl` is set to `transformers` or `auto`."""
|
||||||
if getattr(self, "runner_type", self.runner) == "pooling":
|
# Check if the architecture we're wrapping has defaults
|
||||||
return "TransformersModel"
|
runner = None
|
||||||
|
convert = None
|
||||||
|
if defaults := try_match_architecture_defaults(self.architectures[0]):
|
||||||
|
_, (runner, convert) = defaults
|
||||||
|
# Overwrite with user-specified values
|
||||||
|
if self.runner != "auto":
|
||||||
|
runner = self.runner
|
||||||
|
if self.convert not in {"auto", "none"}:
|
||||||
|
convert = self.convert
|
||||||
|
# Fall back to default values if still not set
|
||||||
|
if runner is None:
|
||||||
|
runner = "generate"
|
||||||
|
if convert in {None, "none"}:
|
||||||
|
convert = "embed"
|
||||||
|
# Resolve Transformers backend pooling classes
|
||||||
|
if runner == "pooling":
|
||||||
|
if convert == "embed":
|
||||||
|
return "TransformersEmbeddingModel"
|
||||||
|
if convert == "classify":
|
||||||
|
return "TransformersForSequenceClassification"
|
||||||
|
# Resolve Transformers backend generate classes
|
||||||
if self.hf_config != self.hf_text_config:
|
if self.hf_config != self.hf_text_config:
|
||||||
# If 'hf_text_config' is the same as 'hf_config'. If not, it is
|
# If 'hf_text_config' is the same as 'hf_config'. If not, it is
|
||||||
# probably a composite config, i.e. multimodal
|
# probably a composite config, i.e. multimodal
|
||||||
@ -678,7 +697,9 @@ class ModelConfig:
|
|||||||
|
|
||||||
def using_transformers_backend(self) -> bool:
|
def using_transformers_backend(self) -> bool:
|
||||||
"""Check if the model is using the Transformers backend class."""
|
"""Check if the model is using the Transformers backend class."""
|
||||||
return self.architecture == self._get_transformers_backend_cls()
|
used_cls = self._model_info.architecture
|
||||||
|
transformers_backend_cls = self._get_transformers_backend_cls()
|
||||||
|
return used_cls == transformers_backend_cls
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def registry(self):
|
def registry(self):
|
||||||
|
|||||||
@ -4,6 +4,7 @@
|
|||||||
import ast
|
import ast
|
||||||
import inspect
|
import inspect
|
||||||
import textwrap
|
import textwrap
|
||||||
|
from collections.abc import Iterable
|
||||||
from dataclasses import MISSING, Field, field, fields, is_dataclass, replace
|
from dataclasses import MISSING, Field, field, fields, is_dataclass, replace
|
||||||
from typing import TYPE_CHECKING, Any, Protocol, TypeVar
|
from typing import TYPE_CHECKING, Any, Protocol, TypeVar
|
||||||
|
|
||||||
@ -52,6 +53,18 @@ def get_field(cls: ConfigType, name: str) -> Field:
|
|||||||
f"{cls.__name__}.{name} must have a default value or default factory.")
|
f"{cls.__name__}.{name} must have a default value or default factory.")
|
||||||
|
|
||||||
|
|
||||||
|
def getattr_iter(object: object, names: Iterable[str], default: Any) -> Any:
|
||||||
|
"""
|
||||||
|
A helper function that retrieves an attribute from an object which may
|
||||||
|
have multiple possible names. This is useful when fetching attributes from
|
||||||
|
arbitrary `transformers.PretrainedConfig` instances.
|
||||||
|
"""
|
||||||
|
for name in names:
|
||||||
|
if hasattr(object, name):
|
||||||
|
return getattr(object, name)
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
def contains_object_print(text: str) -> bool:
|
def contains_object_print(text: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if the text looks like a printed Python object, e.g.
|
Check if the text looks like a printed Python object, e.g.
|
||||||
|
|||||||
@ -307,7 +307,8 @@ _TRANSFORMERS_SUPPORTED_MODELS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
_TRANSFORMERS_BACKEND_MODELS = {
|
_TRANSFORMERS_BACKEND_MODELS = {
|
||||||
"TransformersModel": ("transformers", "TransformersModel"),
|
"TransformersEmbeddingModel": ("transformers_pooling", "TransformersEmbeddingModel"), # noqa: E501
|
||||||
|
"TransformersForSequenceClassification": ("transformers_pooling", "TransformersForSequenceClassification"), # noqa: E501
|
||||||
"TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
|
"TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
|
||||||
"TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
|
"TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
|
||||||
}
|
}
|
||||||
|
|||||||
@ -31,6 +31,7 @@ from vllm.attention import Attention, AttentionType
|
|||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
||||||
ParallelConfig, VllmConfig)
|
ParallelConfig, VllmConfig)
|
||||||
|
from vllm.config.utils import getattr_iter
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
from vllm.distributed.utils import get_pp_indices
|
from vllm.distributed.utils import get_pp_indices
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -486,10 +487,13 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
|||||||
|
|
||||||
# Input embeddings
|
# Input embeddings
|
||||||
if not isinstance(self.model.get_input_embeddings(), PPMissingLayer):
|
if not isinstance(self.model.get_input_embeddings(), PPMissingLayer):
|
||||||
|
names = ("embedding_size", "hidden_size")
|
||||||
|
embedding_dim = getattr_iter(self.text_config, names, None)
|
||||||
|
assert embedding_dim is not None
|
||||||
self.model.set_input_embeddings(
|
self.model.set_input_embeddings(
|
||||||
VocabParallelEmbedding(
|
VocabParallelEmbedding(
|
||||||
self.text_config.vocab_size,
|
self.text_config.vocab_size,
|
||||||
self.text_config.hidden_size,
|
embedding_dim=embedding_dim,
|
||||||
org_num_embeddings=self.text_config.vocab_size,
|
org_num_embeddings=self.text_config.vocab_size,
|
||||||
quant_config=self.quant_config,
|
quant_config=self.quant_config,
|
||||||
))
|
))
|
||||||
@ -645,7 +649,9 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
|||||||
attn_type=attn_type)
|
attn_type=attn_type)
|
||||||
return attention_instances
|
return attention_instances
|
||||||
|
|
||||||
def init_parameters(self, module: nn.Module):
|
def init_parameters(self,
|
||||||
|
module: nn.Module,
|
||||||
|
dtype: Optional[torch.dtype] = None):
|
||||||
"""
|
"""
|
||||||
If a `parameter` is on the `meta` device, then its parent
|
If a `parameter` is on the `meta` device, then its parent
|
||||||
`module` is the original module created by:
|
`module` is the original module created by:
|
||||||
@ -659,11 +665,11 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
|||||||
if param.device == torch.device("meta"):
|
if param.device == torch.device("meta"):
|
||||||
new_param = nn.Parameter(
|
new_param = nn.Parameter(
|
||||||
torch.empty_like(param.data,
|
torch.empty_like(param.data,
|
||||||
dtype=self.model_config.dtype,
|
dtype=dtype or self.model_config.dtype,
|
||||||
device=self.device_config.device))
|
device=self.device_config.device))
|
||||||
setattr(module, name, new_param)
|
setattr(module, name, new_param)
|
||||||
for child in module.children():
|
for child in module.children():
|
||||||
self.init_parameters(child)
|
self.init_parameters(child, dtype)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -712,73 +718,6 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
|||||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||||
|
|
||||||
|
|
||||||
@support_torch_compile(enable_if=can_enable_torch_compile)
|
|
||||||
class TransformersModel(TransformersBase):
|
|
||||||
hf_to_vllm_mapper = WeightsMapper(
|
|
||||||
orig_to_new_prefix={
|
|
||||||
# Handle BERT-like models
|
|
||||||
"bert": "model",
|
|
||||||
# Add `model.` prefix for base model checkpoints
|
|
||||||
"": "model.",
|
|
||||||
# Remove `model.` prefix if it was already there
|
|
||||||
"model.model.": "model.",
|
|
||||||
# Pooling adapters will be adjacent to `model`
|
|
||||||
"model.pooler": "pooler",
|
|
||||||
"model.score": "score",
|
|
||||||
# Classifier adapter's classifier layer is renamed to score
|
|
||||||
"model.classifier": "score",
|
|
||||||
},
|
|
||||||
orig_to_new_suffix={
|
|
||||||
# Replace legacy suffixes used for norms
|
|
||||||
".gamma": ".weight",
|
|
||||||
".beta": ".bias",
|
|
||||||
})
|
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
||||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
|
||||||
|
|
||||||
# After creating a pooling model, `pooler` will be duplicated.
|
|
||||||
# The one inside `model` comes from the Transformers modelling code.
|
|
||||||
# The one after `model` is an adapter from vLLM.
|
|
||||||
# We want to use the adapter so we nullify the original pooler.
|
|
||||||
if getattr(self.model, "pooler", None) is not None:
|
|
||||||
self.skip_prefixes.append("pooler.")
|
|
||||||
self.model.pooler = torch.nn.Identity()
|
|
||||||
|
|
||||||
# Some encoder models have the position_ids buffer in the checkpoint.
|
|
||||||
# vLLM will always pass position_ids as an argument, so we skip loading
|
|
||||||
# the buffer if it exists
|
|
||||||
self.skip_substrs.append("position_ids")
|
|
||||||
|
|
||||||
# Some encoder models have the bias of the final classifier layer
|
|
||||||
# in the checkpoint. vLLM does not use this bias, so we skip loading
|
|
||||||
# it if it exists
|
|
||||||
self.skip_substrs.append("score.bias")
|
|
||||||
|
|
||||||
def create_attention_instances(
|
|
||||||
self, attn_type: AttentionType = AttentionType.DECODER):
|
|
||||||
# TODO(hmellor): Better way to detect encoder models
|
|
||||||
# In encoder models, the attention layers will have `is_causal=False`
|
|
||||||
is_encoder = lambda m: not getattr(m, "is_causal", True)
|
|
||||||
# vLLM does not support encoder-decoder models, so if any encoder layer
|
|
||||||
# is found, we assume the whole model is an encoder model
|
|
||||||
if any(is_encoder(m) for m in self.model.modules()):
|
|
||||||
attn_type = AttentionType.ENCODER_ONLY
|
|
||||||
|
|
||||||
# Check minimum transformers version for encoder models support
|
|
||||||
if attn_type == AttentionType.ENCODER_ONLY:
|
|
||||||
import transformers
|
|
||||||
from packaging.version import Version
|
|
||||||
installed = Version(transformers.__version__)
|
|
||||||
required = Version("4.57.0.dev0")
|
|
||||||
if installed < required:
|
|
||||||
raise ValueError(
|
|
||||||
"Encoder models with the Transformers backend require "
|
|
||||||
f"transformers>={required}, but got {installed}")
|
|
||||||
|
|
||||||
return super().create_attention_instances(attn_type)
|
|
||||||
|
|
||||||
|
|
||||||
@support_torch_compile(enable_if=can_enable_torch_compile)
|
@support_torch_compile(enable_if=can_enable_torch_compile)
|
||||||
class TransformersForCausalLM(TransformersBase):
|
class TransformersForCausalLM(TransformersBase):
|
||||||
|
|
||||||
|
|||||||
200
vllm/model_executor/models/transformers_pooling.py
Normal file
200
vllm/model_executor/models/transformers_pooling.py
Normal file
@ -0,0 +1,200 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
# Copyright 2024 The vLLM team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Wrapper around `transformers` models for pooling tasks."""
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForSequenceClassification
|
||||||
|
|
||||||
|
from vllm.attention import AttentionType
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool,
|
||||||
|
DispatchPooler, Pooler)
|
||||||
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
|
from .interfaces_base import VllmModelForPooling
|
||||||
|
from .transformers import TransformersBase, can_enable_torch_compile
|
||||||
|
from .utils import WeightsMapper
|
||||||
|
|
||||||
|
|
||||||
|
class TransformersPoolingBase(TransformersBase, VllmModelForPooling):
|
||||||
|
hf_to_vllm_mapper = WeightsMapper(
|
||||||
|
# These are applied in order, so the order matters!
|
||||||
|
orig_to_new_prefix={
|
||||||
|
# Handle BERT-like models
|
||||||
|
"roberta": "model",
|
||||||
|
"bert": "model",
|
||||||
|
# Add `model.` prefix for base model checkpoints
|
||||||
|
"": "model.",
|
||||||
|
# Remove `model.` prefix if it was already there
|
||||||
|
"model.model.": "model.",
|
||||||
|
# Classifier/scoring heads will be adjacent to `model`
|
||||||
|
"model.score": "classifier",
|
||||||
|
"model.classifier": "classifier",
|
||||||
|
},
|
||||||
|
orig_to_new_suffix={
|
||||||
|
# Replace legacy suffixes used for norms
|
||||||
|
".gamma": ".weight",
|
||||||
|
".beta": ".bias",
|
||||||
|
})
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||||
|
|
||||||
|
# Skip unsupported/unwanted output embeddings layers
|
||||||
|
self.skip_prefixes.extend([
|
||||||
|
"model.lm_head.", "model.predictions.", "model.qa_outputs.",
|
||||||
|
"model.embeddings_project.", "model.discriminator_predictions."
|
||||||
|
])
|
||||||
|
|
||||||
|
# Some encoder models have the position_ids buffer in the checkpoint.
|
||||||
|
# vLLM will always pass position_ids as an argument, so we skip loading
|
||||||
|
# the buffer if it exists
|
||||||
|
self.skip_substrs.append("position_ids")
|
||||||
|
|
||||||
|
# Some encoder models have the bias of the final classifier layer
|
||||||
|
# in the checkpoint. vLLM does not use this bias, so we skip loading
|
||||||
|
# it if it exists
|
||||||
|
self.skip_substrs.append("score.bias")
|
||||||
|
|
||||||
|
# roberta-like models an extra padding in positions.
|
||||||
|
# FIXME(Isotr0py): This is quite hacky for roberta edge case,
|
||||||
|
# we should find a better way to handle this.
|
||||||
|
self.is_roberta = "roberta" in self.text_config.model_type
|
||||||
|
self.padding_idx = self.text_config.pad_token_id
|
||||||
|
|
||||||
|
def create_attention_instances(
|
||||||
|
self, attn_type: AttentionType = AttentionType.DECODER):
|
||||||
|
# TODO(hmellor): Better way to detect encoder models
|
||||||
|
# In encoder models, the attention layers will have `is_causal=False`
|
||||||
|
is_encoder = lambda m: not getattr(m, "is_causal", True)
|
||||||
|
# vLLM does not support encoder-decoder models, so if any encoder layer
|
||||||
|
# is found, we assume the whole model is an encoder model
|
||||||
|
if any(is_encoder(m) for m in self.model.modules()):
|
||||||
|
attn_type = AttentionType.ENCODER_ONLY
|
||||||
|
|
||||||
|
# Check minimum transformers version for encoder models support
|
||||||
|
if attn_type == AttentionType.ENCODER_ONLY:
|
||||||
|
import transformers
|
||||||
|
from packaging.version import Version
|
||||||
|
installed = Version(transformers.__version__)
|
||||||
|
required = Version("4.57.0.dev0")
|
||||||
|
if installed < required:
|
||||||
|
raise ValueError(
|
||||||
|
"Encoder models with the Transformers backend require "
|
||||||
|
f"transformers>={required}, but got {installed}")
|
||||||
|
|
||||||
|
return super().create_attention_instances(attn_type)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.Tensor],
|
||||||
|
positions: torch.Tensor,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
|
if self.is_roberta:
|
||||||
|
# RoBERTa-specific positions padding
|
||||||
|
positions += self.padding_idx + 1
|
||||||
|
return super().forward(input_ids=input_ids,
|
||||||
|
positions=positions,
|
||||||
|
intermediate_tensors=intermediate_tensors,
|
||||||
|
inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile(enable_if=can_enable_torch_compile)
|
||||||
|
class TransformersEmbeddingModel(TransformersPoolingBase):
|
||||||
|
default_pooling_type = "CLS"
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||||
|
|
||||||
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
|
assert pooler_config is not None
|
||||||
|
|
||||||
|
self.pooler = DispatchPooler({
|
||||||
|
"encode": Pooler.for_encode(pooler_config),
|
||||||
|
"embed": Pooler.for_embed(pooler_config),
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile(enable_if=can_enable_torch_compile)
|
||||||
|
class TransformersForSequenceClassification(TransformersPoolingBase):
|
||||||
|
default_pooling_type = "CLS"
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||||
|
|
||||||
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
|
assert pooler_config is not None
|
||||||
|
|
||||||
|
# Certain information about the the model and classifier can only be
|
||||||
|
# inferred from the `ForSequenceClassification` class. Therefore, we
|
||||||
|
# instantiate it on the "meta" device to avoid allocating GPU memory.
|
||||||
|
with torch.device("meta"):
|
||||||
|
seq_cls_model = AutoModelForSequenceClassification.from_config(
|
||||||
|
self.config,
|
||||||
|
torch_dtype=self.model_config.dtype,
|
||||||
|
trust_remote_code=self.model_config.trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
# When used for sequence classification, some models have their
|
||||||
|
# pooling layers removed. Make sure this is reflected in vLLM.
|
||||||
|
for module in seq_cls_model.modules():
|
||||||
|
if hasattr(module, "pooler") and module.pooler is None:
|
||||||
|
self.model.pooler = None
|
||||||
|
break
|
||||||
|
if self.model.pooler is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Sequence classification models with pooling layers are not "
|
||||||
|
"supported yet in the Transformers backend.")
|
||||||
|
|
||||||
|
# Unlike `lm_head`, `classifier` is not always `nn.Linear`.
|
||||||
|
self.classifier = seq_cls_model.classifier
|
||||||
|
self.init_parameters(self.classifier,
|
||||||
|
dtype=self.model_config.head_dtype)
|
||||||
|
|
||||||
|
class ClassifierWithReshape(self.classifier.__class__):
|
||||||
|
"""CLSPool has already been applied in `pooling`.
|
||||||
|
Add dim to match expected input shape of `classifier.forward`."""
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
if len(args) > 0:
|
||||||
|
args = (args[0].unsqueeze(1), *args[1:])
|
||||||
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
|
self.classifier.__class__ = ClassifierWithReshape
|
||||||
|
|
||||||
|
self.pooler = DispatchPooler({
|
||||||
|
"encode":
|
||||||
|
Pooler.for_encode(pooler_config),
|
||||||
|
"classify":
|
||||||
|
ClassifierPooler(
|
||||||
|
pooling=CLSPool(),
|
||||||
|
classifier=self.classifier,
|
||||||
|
act_fn=ClassifierPooler.act_fn_for_seq_cls(
|
||||||
|
vllm_config.model_config),
|
||||||
|
),
|
||||||
|
"score":
|
||||||
|
ClassifierPooler(
|
||||||
|
pooling=CLSPool(),
|
||||||
|
classifier=self.classifier,
|
||||||
|
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
|
||||||
|
vllm_config.model_config),
|
||||||
|
),
|
||||||
|
})
|
||||||
Loading…
x
Reference in New Issue
Block a user