Prevent the cross-encoder logic from being applied to classification tasks (#18838)

Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
Maximilien de Bayser 2025-05-28 23:16:17 -03:00 committed by GitHub
parent 269d901734
commit 515b413ebf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 40 additions and 24 deletions

View File

@ -6,10 +6,9 @@ from typing import Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers import PretrainedConfig
from typing_extensions import assert_never from typing_extensions import assert_never
from vllm.config import PoolerConfig from vllm.config import ModelConfig, PoolerConfig
from vllm.model_executor.pooling_metadata import (PoolingMetadata, from vllm.model_executor.pooling_metadata import (PoolingMetadata,
PoolingTensors) PoolingTensors)
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
@ -283,30 +282,37 @@ class Pooler(nn.Module):
) )
class CrossEncodingPooler(nn.Module): class ClassifierPooler(nn.Module):
"""A layer that pools specific information from hidden states. """A pooling layer for classification tasks.
This layer does the following: This layer does the following:
1. Extracts specific tokens or aggregates data based on pooling method. 1. Applies a classification layer to the hidden states.
2. Normalizes output if specified. 2. Optionally applies a pooler layer.
3. Returns structured results as `PoolerOutput`. 3. Applies an activation function to the output. In the case of
classification models it is either sigmoid or softmax. In the
Attributes: case of scoring models, the same behavior is configuration
pooling_type: The type of pooling to use. dependent, as in the sentence-transformers library.
normalize: Whether to normalize the pooled data.
""" """
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: ModelConfig,
classifier: nn.Module, classifier: nn.Module,
pooler: Optional[nn.Module] = None, pooler: Optional[nn.Module] = None,
): ):
super().__init__() super().__init__()
self.classifier = classifier self.classifier = classifier
self.pooler = pooler self.pooler = pooler
if config.task == "score":
self.default_activation_function = \ self.default_activation_function = \
get_cross_encoder_activation_function(config) get_cross_encoder_activation_function(config.hf_config)
elif config.task == "classify":
self.default_activation_function = nn.Sigmoid() \
if config.hf_config.num_labels == 1 else nn.Softmax()
else:
raise NotImplementedError(f"task={config.task!r} is not supported"
" with the classification pooler")
def forward( def forward(
self, self,

View File

@ -16,7 +16,7 @@ from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.pooler import (CrossEncodingPooler, Pooler, from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler,
PoolingType) PoolingType)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -470,8 +470,8 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
embedding_class=BertEmbedding, embedding_class=BertEmbedding,
add_pooling_layer=True) add_pooling_layer=True)
self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self._pooler = CrossEncodingPooler(config, self.classifier, self._pooler = ClassifierPooler(vllm_config.model_config,
self.bert.pooler) self.classifier, self.bert.pooler)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):

View File

@ -12,7 +12,7 @@ from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import (QKVParallelLinear, from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.pooler import CrossEncodingPooler from vllm.model_executor.layers.pooler import ClassifierPooler
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
@ -278,7 +278,8 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
self.model = ModernBertModel(vllm_config=vllm_config, self.model = ModernBertModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "modernbert")) prefix=maybe_prefix(prefix, "modernbert"))
self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self._pooler = CrossEncodingPooler(config, self.classifier, self._pooler = ClassifierPooler(vllm_config.model_config,
self.classifier,
ModernBertPooler(config)) ModernBertPooler(config))
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):

View File

@ -9,7 +9,7 @@ from torch import nn
from transformers import RobertaConfig from transformers import RobertaConfig
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import CrossEncodingPooler from vllm.model_executor.layers.pooler import ClassifierPooler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@ -186,7 +186,9 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
embedding_class=RobertaEmbedding, embedding_class=RobertaEmbedding,
add_pooling_layer=False) add_pooling_layer=False)
self.classifier = RobertaClassificationHead(config) self.classifier = RobertaClassificationHead(config)
self._pooler = CrossEncodingPooler(config, self.classifier)
self._pooler = ClassifierPooler(vllm_config.model_config,
self.classifier)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
bert_weights, task_weights = roberta_task_weights_filter(weights) bert_weights, task_weights = roberta_task_weights_filter(weights)

View File

@ -823,10 +823,17 @@ def try_get_generation_config(
def get_cross_encoder_activation_function(config: PretrainedConfig): def get_cross_encoder_activation_function(config: PretrainedConfig):
if (hasattr(config, "sbert_ce_default_activation_function")
and config.sbert_ce_default_activation_function is not None):
function_name: Optional[str] = None
if hasattr(config, "sentence_transformers") and "activation_fn" in \
config.sentence_transformers:
function_name = config.sentence_transformers["activation_fn"]
elif (hasattr(config, "sbert_ce_default_activation_function")
and config.sbert_ce_default_activation_function is not None):
function_name = config.sbert_ce_default_activation_function function_name = config.sbert_ce_default_activation_function
if function_name is not None:
assert function_name.startswith("torch.nn.modules."), \ assert function_name.startswith("torch.nn.modules."), \
"Loading of activation functions is restricted to " \ "Loading of activation functions is restricted to " \
"torch.nn.modules for security reasons" "torch.nn.modules for security reasons"