diff --git a/tests/models/registry.py b/tests/models/registry.py index b7a2514d8bc0c..1068f97cb5a8c 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -657,7 +657,8 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { } _TRANSFORMERS_BACKEND_MODELS = { - "TransformersModel": _HfExamplesInfo("Qwen/Qwen3-Embedding-0.6B"), + "TransformersEmbeddingModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5", min_transformers_version="4.57.0.dev0"), # noqa: E501 + "TransformersForSequenceClassification": _HfExamplesInfo("papluca/xlm-roberta-base-language-detection", min_transformers_version="4.57.0.dev0"), # noqa: E501 "TransformersForCausalLM": _HfExamplesInfo("hmellor/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501 "TransformersForMultimodalLM": _HfExamplesInfo("BAAI/Emu3-Chat-hf"), } diff --git a/tests/models/test_transformers.py b/tests/models/test_transformers.py index e4b5e7c244539..733ac8de67a30 100644 --- a/tests/models/test_transformers.py +++ b/tests/models/test_transformers.py @@ -9,9 +9,16 @@ from vllm.platforms import current_platform from ..conftest import HfRunner, VllmRunner from ..utils import multi_gpu_test, prep_prompts +from .registry import HF_EXAMPLE_MODELS from .utils import check_embeddings_close, check_logprobs_close +def get_model(arch: str) -> str: + model_info = HF_EXAMPLE_MODELS.get_hf_info(arch) + model_info.check_transformers_version(on_fail="skip") + return model_info.default + + def check_implementation( runner_ref: type[Union[HfRunner, VllmRunner]], runner_test: type[VllmRunner], @@ -170,71 +177,47 @@ def test_embed_loading(vllm_runner, model): @pytest.mark.parametrize( - "model", - [ - # Encoder model - "BAAI/bge-base-en-v1.5", - ]) -def test_embed_correctness(hf_runner, vllm_runner, example_prompts, model): - import transformers - from packaging.version import Version - installed = Version(transformers.__version__) - required = Version("4.57.0.dev0") - if installed < required: - pytest.skip("Encoder models with the Transformers backend require " - f"transformers>={required}, but got {installed}") + "arch", + ["TransformersEmbeddingModel", "TransformersForSequenceClassification"]) +def test_pooling(hf_runner, vllm_runner, example_prompts, arch): + model = get_model(arch) - with vllm_runner(model, max_model_len=512, - model_impl="transformers") as vllm_model: + vllm_kwargs = dict( + max_model_len=None, + model_impl="transformers", + compilation_config=dict(cudagraph_capture_sizes=[8]), + ) + + hf_kwargs = dict() + if arch == "TransformersEmbeddingModel": + hf_kwargs["is_sentence_transformer"] = True + elif arch == "TransformersForSequenceClassification": + from transformers import AutoModelForSequenceClassification + hf_kwargs["auto_cls"] = AutoModelForSequenceClassification + + # The example_prompts has ending "\n", for example: + # "Write a short story about a robot that dreams for the first time.\n" + # sentence_transformers will strip the input texts, see: + # https://github.com/UKPLab/sentence-transformers/blob/v3.1.1/sentence_transformers/models/Transformer.py#L159 + # This makes the input_ids different between hf_model and vllm_model. + # So we need to strip the input texts to avoid test failing. + example_prompts = [str(s).strip() for s in example_prompts] + + with (vllm_runner(model, **vllm_kwargs) as + vllm_model, hf_runner(model, **hf_kwargs) as hf_model): model_config = vllm_model.llm.llm_engine.model_config assert model_config.using_transformers_backend() - vllm_outputs = vllm_model.embed(example_prompts) - - with hf_runner(model, is_sentence_transformer=True) as hf_model: - hf_outputs = hf_model.encode(example_prompts) + if arch == "TransformersEmbeddingModel": + vllm_outputs = vllm_model.embed(example_prompts) + hf_outputs = hf_model.encode(example_prompts) + elif arch == "TransformersForSequenceClassification": + vllm_outputs = vllm_model.classify(example_prompts) + hf_outputs = hf_model.classify(example_prompts) check_embeddings_close( embeddings_0_lst=hf_outputs, embeddings_1_lst=vllm_outputs, name_0="hf", name_1="vllm", - tol=1e-2, ) - - -@pytest.mark.parametrize( - "model", - ["jason9693/Qwen2.5-1.5B-apeach"], -) -@pytest.mark.parametrize("dtype", ["float"]) -def test_classify( - hf_runner, - vllm_runner, - example_prompts, - model: str, - dtype: str, -) -> None: - import torch - from transformers import AutoModelForSequenceClassification - - with vllm_runner(model, - max_model_len=512, - dtype=dtype, - model_impl="transformers") as vllm_model: - model_config = vllm_model.llm.llm_engine.model_config - assert model_config.using_transformers_backend() - - vllm_outputs = vllm_model.classify(example_prompts) - - with hf_runner(model, - dtype=dtype, - auto_cls=AutoModelForSequenceClassification) as hf_model: - hf_outputs = hf_model.classify(example_prompts) - - for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): - hf_output = torch.tensor(hf_output) - vllm_output = torch.tensor(vllm_output) - - assert torch.allclose(hf_output, vllm_output, - 1e-3 if dtype == "float" else 1e-2) diff --git a/vllm/config/model.py b/vllm/config/model.py index c1392318dd8ec..e9d5b58ff2c2b 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -19,6 +19,7 @@ import vllm.envs as envs from vllm.config.multimodal import (MMCacheType, MMEncoderTPMode, MultiModalConfig) from vllm.config.pooler import PoolerConfig +from vllm.config.scheduler import RunnerType from vllm.config.utils import assert_hashable, config from vllm.logger import init_logger from vllm.platforms import current_platform @@ -40,7 +41,6 @@ if TYPE_CHECKING: import vllm.model_executor.models as me_models from vllm.config.load import LoadConfig from vllm.config.parallel import ParallelConfig - from vllm.config.scheduler import RunnerType from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.v1.sample.logits_processor import LogitsProcessor else: @@ -52,13 +52,12 @@ else: "vllm.model_executor.models") LoadConfig = Any ParallelConfig = Any - RunnerType = Any QuantizationMethods = Any LogitsProcessor = Any logger = init_logger(__name__) -RunnerOption = Literal["auto", "generate", "pooling", "draft"] +RunnerOption = Literal["auto", RunnerType] ConvertType = Literal["none", "embed", "classify", "reward"] ConvertOption = Literal["auto", ConvertType] TaskOption = Literal["auto", "generate", "embedding", "embed", "classify", @@ -668,8 +667,28 @@ class ModelConfig: def _get_transformers_backend_cls(self) -> str: """Determine which Transformers backend class will be used if `model_impl` is set to `transformers` or `auto`.""" - if getattr(self, "runner_type", self.runner) == "pooling": - return "TransformersModel" + # Check if the architecture we're wrapping has defaults + runner = None + convert = None + if defaults := try_match_architecture_defaults(self.architectures[0]): + _, (runner, convert) = defaults + # Overwrite with user-specified values + if self.runner != "auto": + runner = self.runner + if self.convert not in {"auto", "none"}: + convert = self.convert + # Fall back to default values if still not set + if runner is None: + runner = "generate" + if convert in {None, "none"}: + convert = "embed" + # Resolve Transformers backend pooling classes + if runner == "pooling": + if convert == "embed": + return "TransformersEmbeddingModel" + if convert == "classify": + return "TransformersForSequenceClassification" + # Resolve Transformers backend generate classes if self.hf_config != self.hf_text_config: # If 'hf_text_config' is the same as 'hf_config'. If not, it is # probably a composite config, i.e. multimodal @@ -678,7 +697,9 @@ class ModelConfig: def using_transformers_backend(self) -> bool: """Check if the model is using the Transformers backend class.""" - return self.architecture == self._get_transformers_backend_cls() + used_cls = self._model_info.architecture + transformers_backend_cls = self._get_transformers_backend_cls() + return used_cls == transformers_backend_cls @property def registry(self): diff --git a/vllm/config/utils.py b/vllm/config/utils.py index 2da30cbf149c2..d355ff3a9023d 100644 --- a/vllm/config/utils.py +++ b/vllm/config/utils.py @@ -4,6 +4,7 @@ import ast import inspect import textwrap +from collections.abc import Iterable from dataclasses import MISSING, Field, field, fields, is_dataclass, replace from typing import TYPE_CHECKING, Any, Protocol, TypeVar @@ -52,6 +53,18 @@ def get_field(cls: ConfigType, name: str) -> Field: f"{cls.__name__}.{name} must have a default value or default factory.") +def getattr_iter(object: object, names: Iterable[str], default: Any) -> Any: + """ + A helper function that retrieves an attribute from an object which may + have multiple possible names. This is useful when fetching attributes from + arbitrary `transformers.PretrainedConfig` instances. + """ + for name in names: + if hasattr(object, name): + return getattr(object, name) + return default + + def contains_object_print(text: str) -> bool: """ Check if the text looks like a printed Python object, e.g. diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 732181265a97c..eb572dc30810a 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -307,9 +307,10 @@ _TRANSFORMERS_SUPPORTED_MODELS = { } _TRANSFORMERS_BACKEND_MODELS = { - "TransformersModel": ("transformers", "TransformersModel"), + "TransformersEmbeddingModel": ("transformers_pooling", "TransformersEmbeddingModel"), # noqa: E501 + "TransformersForSequenceClassification": ("transformers_pooling", "TransformersForSequenceClassification"), # noqa: E501 "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"), - "TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501 + "TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501 } # yapf: enable diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 00d87f560e70a..d168398aa1826 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -31,6 +31,7 @@ from vllm.attention import Attention, AttentionType from vllm.compilation.decorators import support_torch_compile from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, ParallelConfig, VllmConfig) +from vllm.config.utils import getattr_iter from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed.utils import get_pp_indices from vllm.logger import init_logger @@ -486,10 +487,13 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): # Input embeddings if not isinstance(self.model.get_input_embeddings(), PPMissingLayer): + names = ("embedding_size", "hidden_size") + embedding_dim = getattr_iter(self.text_config, names, None) + assert embedding_dim is not None self.model.set_input_embeddings( VocabParallelEmbedding( self.text_config.vocab_size, - self.text_config.hidden_size, + embedding_dim=embedding_dim, org_num_embeddings=self.text_config.vocab_size, quant_config=self.quant_config, )) @@ -645,7 +649,9 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): attn_type=attn_type) return attention_instances - def init_parameters(self, module: nn.Module): + def init_parameters(self, + module: nn.Module, + dtype: Optional[torch.dtype] = None): """ If a `parameter` is on the `meta` device, then its parent `module` is the original module created by: @@ -659,11 +665,11 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): if param.device == torch.device("meta"): new_param = nn.Parameter( torch.empty_like(param.data, - dtype=self.model_config.dtype, + dtype=dtype or self.model_config.dtype, device=self.device_config.device)) setattr(module, name, new_param) for child in module.children(): - self.init_parameters(child) + self.init_parameters(child, dtype) def forward( self, @@ -712,73 +718,6 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) -@support_torch_compile(enable_if=can_enable_torch_compile) -class TransformersModel(TransformersBase): - hf_to_vllm_mapper = WeightsMapper( - orig_to_new_prefix={ - # Handle BERT-like models - "bert": "model", - # Add `model.` prefix for base model checkpoints - "": "model.", - # Remove `model.` prefix if it was already there - "model.model.": "model.", - # Pooling adapters will be adjacent to `model` - "model.pooler": "pooler", - "model.score": "score", - # Classifier adapter's classifier layer is renamed to score - "model.classifier": "score", - }, - orig_to_new_suffix={ - # Replace legacy suffixes used for norms - ".gamma": ".weight", - ".beta": ".bias", - }) - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, prefix=prefix) - - # After creating a pooling model, `pooler` will be duplicated. - # The one inside `model` comes from the Transformers modelling code. - # The one after `model` is an adapter from vLLM. - # We want to use the adapter so we nullify the original pooler. - if getattr(self.model, "pooler", None) is not None: - self.skip_prefixes.append("pooler.") - self.model.pooler = torch.nn.Identity() - - # Some encoder models have the position_ids buffer in the checkpoint. - # vLLM will always pass position_ids as an argument, so we skip loading - # the buffer if it exists - self.skip_substrs.append("position_ids") - - # Some encoder models have the bias of the final classifier layer - # in the checkpoint. vLLM does not use this bias, so we skip loading - # it if it exists - self.skip_substrs.append("score.bias") - - def create_attention_instances( - self, attn_type: AttentionType = AttentionType.DECODER): - # TODO(hmellor): Better way to detect encoder models - # In encoder models, the attention layers will have `is_causal=False` - is_encoder = lambda m: not getattr(m, "is_causal", True) - # vLLM does not support encoder-decoder models, so if any encoder layer - # is found, we assume the whole model is an encoder model - if any(is_encoder(m) for m in self.model.modules()): - attn_type = AttentionType.ENCODER_ONLY - - # Check minimum transformers version for encoder models support - if attn_type == AttentionType.ENCODER_ONLY: - import transformers - from packaging.version import Version - installed = Version(transformers.__version__) - required = Version("4.57.0.dev0") - if installed < required: - raise ValueError( - "Encoder models with the Transformers backend require " - f"transformers>={required}, but got {installed}") - - return super().create_attention_instances(attn_type) - - @support_torch_compile(enable_if=can_enable_torch_compile) class TransformersForCausalLM(TransformersBase): diff --git a/vllm/model_executor/models/transformers_pooling.py b/vllm/model_executor/models/transformers_pooling.py new file mode 100644 index 0000000000000..7e262ade156a3 --- /dev/null +++ b/vllm/model_executor/models/transformers_pooling.py @@ -0,0 +1,200 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Wrapper around `transformers` models for pooling tasks.""" +from typing import Optional, Union + +import torch +from transformers import AutoModelForSequenceClassification + +from vllm.attention import AttentionType +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool, + DispatchPooler, Pooler) +from vllm.sequence import IntermediateTensors + +from .interfaces_base import VllmModelForPooling +from .transformers import TransformersBase, can_enable_torch_compile +from .utils import WeightsMapper + + +class TransformersPoolingBase(TransformersBase, VllmModelForPooling): + hf_to_vllm_mapper = WeightsMapper( + # These are applied in order, so the order matters! + orig_to_new_prefix={ + # Handle BERT-like models + "roberta": "model", + "bert": "model", + # Add `model.` prefix for base model checkpoints + "": "model.", + # Remove `model.` prefix if it was already there + "model.model.": "model.", + # Classifier/scoring heads will be adjacent to `model` + "model.score": "classifier", + "model.classifier": "classifier", + }, + orig_to_new_suffix={ + # Replace legacy suffixes used for norms + ".gamma": ".weight", + ".beta": ".bias", + }) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + + # Skip unsupported/unwanted output embeddings layers + self.skip_prefixes.extend([ + "model.lm_head.", "model.predictions.", "model.qa_outputs.", + "model.embeddings_project.", "model.discriminator_predictions." + ]) + + # Some encoder models have the position_ids buffer in the checkpoint. + # vLLM will always pass position_ids as an argument, so we skip loading + # the buffer if it exists + self.skip_substrs.append("position_ids") + + # Some encoder models have the bias of the final classifier layer + # in the checkpoint. vLLM does not use this bias, so we skip loading + # it if it exists + self.skip_substrs.append("score.bias") + + # roberta-like models an extra padding in positions. + # FIXME(Isotr0py): This is quite hacky for roberta edge case, + # we should find a better way to handle this. + self.is_roberta = "roberta" in self.text_config.model_type + self.padding_idx = self.text_config.pad_token_id + + def create_attention_instances( + self, attn_type: AttentionType = AttentionType.DECODER): + # TODO(hmellor): Better way to detect encoder models + # In encoder models, the attention layers will have `is_causal=False` + is_encoder = lambda m: not getattr(m, "is_causal", True) + # vLLM does not support encoder-decoder models, so if any encoder layer + # is found, we assume the whole model is an encoder model + if any(is_encoder(m) for m in self.model.modules()): + attn_type = AttentionType.ENCODER_ONLY + + # Check minimum transformers version for encoder models support + if attn_type == AttentionType.ENCODER_ONLY: + import transformers + from packaging.version import Version + installed = Version(transformers.__version__) + required = Version("4.57.0.dev0") + if installed < required: + raise ValueError( + "Encoder models with the Transformers backend require " + f"transformers>={required}, but got {installed}") + + return super().create_attention_instances(attn_type) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if self.is_roberta: + # RoBERTa-specific positions padding + positions += self.padding_idx + 1 + return super().forward(input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds) + + +@support_torch_compile(enable_if=can_enable_torch_compile) +class TransformersEmbeddingModel(TransformersPoolingBase): + default_pooling_type = "CLS" + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + self.pooler = DispatchPooler({ + "encode": Pooler.for_encode(pooler_config), + "embed": Pooler.for_embed(pooler_config), + }) + + +@support_torch_compile(enable_if=can_enable_torch_compile) +class TransformersForSequenceClassification(TransformersPoolingBase): + default_pooling_type = "CLS" + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + # Certain information about the the model and classifier can only be + # inferred from the `ForSequenceClassification` class. Therefore, we + # instantiate it on the "meta" device to avoid allocating GPU memory. + with torch.device("meta"): + seq_cls_model = AutoModelForSequenceClassification.from_config( + self.config, + torch_dtype=self.model_config.dtype, + trust_remote_code=self.model_config.trust_remote_code, + ) + + # When used for sequence classification, some models have their + # pooling layers removed. Make sure this is reflected in vLLM. + for module in seq_cls_model.modules(): + if hasattr(module, "pooler") and module.pooler is None: + self.model.pooler = None + break + if self.model.pooler is not None: + raise ValueError( + "Sequence classification models with pooling layers are not " + "supported yet in the Transformers backend.") + + # Unlike `lm_head`, `classifier` is not always `nn.Linear`. + self.classifier = seq_cls_model.classifier + self.init_parameters(self.classifier, + dtype=self.model_config.head_dtype) + + class ClassifierWithReshape(self.classifier.__class__): + """CLSPool has already been applied in `pooling`. + Add dim to match expected input shape of `classifier.forward`.""" + + def forward(self, *args, **kwargs): + if len(args) > 0: + args = (args[0].unsqueeze(1), *args[1:]) + return super().forward(*args, **kwargs) + + self.classifier.__class__ = ClassifierWithReshape + + self.pooler = DispatchPooler({ + "encode": + Pooler.for_encode(pooler_config), + "classify": + ClassifierPooler( + pooling=CLSPool(), + classifier=self.classifier, + act_fn=ClassifierPooler.act_fn_for_seq_cls( + vllm_config.model_config), + ), + "score": + ClassifierPooler( + pooling=CLSPool(), + classifier=self.classifier, + act_fn=ClassifierPooler.act_fn_for_cross_encoder( + vllm_config.model_config), + ), + })