diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 6abbc90819a8..d2c42191bb3f 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -6,10 +6,9 @@ from typing import Optional, Union import torch import torch.nn as nn import torch.nn.functional as F -from transformers import PretrainedConfig from typing_extensions import assert_never -from vllm.config import PoolerConfig +from vllm.config import ModelConfig, PoolerConfig from vllm.model_executor.pooling_metadata import (PoolingMetadata, PoolingTensors) from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput @@ -283,30 +282,37 @@ class Pooler(nn.Module): ) -class CrossEncodingPooler(nn.Module): - """A layer that pools specific information from hidden states. +class ClassifierPooler(nn.Module): + """A pooling layer for classification tasks. This layer does the following: - 1. Extracts specific tokens or aggregates data based on pooling method. - 2. Normalizes output if specified. - 3. Returns structured results as `PoolerOutput`. - - Attributes: - pooling_type: The type of pooling to use. - normalize: Whether to normalize the pooled data. + 1. Applies a classification layer to the hidden states. + 2. Optionally applies a pooler layer. + 3. Applies an activation function to the output. In the case of + classification models it is either sigmoid or softmax. In the + case of scoring models, the same behavior is configuration + dependent, as in the sentence-transformers library. """ def __init__( self, - config: PretrainedConfig, + config: ModelConfig, classifier: nn.Module, pooler: Optional[nn.Module] = None, ): super().__init__() self.classifier = classifier self.pooler = pooler - self.default_activation_function = \ - get_cross_encoder_activation_function(config) + + if config.task == "score": + self.default_activation_function = \ + get_cross_encoder_activation_function(config.hf_config) + elif config.task == "classify": + self.default_activation_function = nn.Sigmoid() \ + if config.hf_config.num_labels == 1 else nn.Softmax() + else: + raise NotImplementedError(f"task={config.task!r} is not supported" + " with the classification pooler") def forward( self, diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 0c6593bbe3a1..0b1d0f103408 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -16,7 +16,7 @@ from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.pooler import (CrossEncodingPooler, Pooler, +from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler, PoolingType) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -470,8 +470,8 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, embedding_class=BertEmbedding, add_pooling_layer=True) self.classifier = nn.Linear(config.hidden_size, config.num_labels) - self._pooler = CrossEncodingPooler(config, self.classifier, - self.bert.pooler) + self._pooler = ClassifierPooler(vllm_config.model_config, + self.classifier, self.bert.pooler) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index 86552aa05bf9..18eab6051736 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -12,7 +12,7 @@ from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.linear import (QKVParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.pooler import CrossEncodingPooler +from vllm.model_executor.layers.pooler import ClassifierPooler from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -278,8 +278,9 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding): self.model = ModernBertModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "modernbert")) self.classifier = nn.Linear(config.hidden_size, config.num_labels) - self._pooler = CrossEncodingPooler(config, self.classifier, - ModernBertPooler(config)) + self._pooler = ClassifierPooler(vllm_config.model_config, + self.classifier, + ModernBertPooler(config)) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 9a4d0ab2dd4d..76008b72941d 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -9,7 +9,7 @@ from torch import nn from transformers import RobertaConfig from vllm.config import VllmConfig -from vllm.model_executor.layers.pooler import CrossEncodingPooler +from vllm.model_executor.layers.pooler import ClassifierPooler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -186,7 +186,9 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding, embedding_class=RobertaEmbedding, add_pooling_layer=False) self.classifier = RobertaClassificationHead(config) - self._pooler = CrossEncodingPooler(config, self.classifier) + + self._pooler = ClassifierPooler(vllm_config.model_config, + self.classifier) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): bert_weights, task_weights = roberta_task_weights_filter(weights) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 69e7207cc350..2ed71a4d334b 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -823,10 +823,17 @@ def try_get_generation_config( def get_cross_encoder_activation_function(config: PretrainedConfig): - if (hasattr(config, "sbert_ce_default_activation_function") - and config.sbert_ce_default_activation_function is not None): + function_name: Optional[str] = None + if hasattr(config, "sentence_transformers") and "activation_fn" in \ + config.sentence_transformers: + function_name = config.sentence_transformers["activation_fn"] + + elif (hasattr(config, "sbert_ce_default_activation_function") + and config.sbert_ce_default_activation_function is not None): function_name = config.sbert_ce_default_activation_function + + if function_name is not None: assert function_name.startswith("torch.nn.modules."), \ "Loading of activation functions is restricted to " \ "torch.nn.modules for security reasons"