Prevent the cross-encoder logic from being applied to classification tasks (#18838)

Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
Maximilien de Bayser 2025-05-28 23:16:17 -03:00 committed by GitHub
parent 269d901734
commit 515b413ebf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 40 additions and 24 deletions

View File

@ -6,10 +6,9 @@ from typing import Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from typing_extensions import assert_never
from vllm.config import PoolerConfig
from vllm.config import ModelConfig, PoolerConfig
from vllm.model_executor.pooling_metadata import (PoolingMetadata,
PoolingTensors)
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
@ -283,30 +282,37 @@ class Pooler(nn.Module):
)
class CrossEncodingPooler(nn.Module):
"""A layer that pools specific information from hidden states.
class ClassifierPooler(nn.Module):
"""A pooling layer for classification tasks.
This layer does the following:
1. Extracts specific tokens or aggregates data based on pooling method.
2. Normalizes output if specified.
3. Returns structured results as `PoolerOutput`.
Attributes:
pooling_type: The type of pooling to use.
normalize: Whether to normalize the pooled data.
1. Applies a classification layer to the hidden states.
2. Optionally applies a pooler layer.
3. Applies an activation function to the output. In the case of
classification models it is either sigmoid or softmax. In the
case of scoring models, the same behavior is configuration
dependent, as in the sentence-transformers library.
"""
def __init__(
self,
config: PretrainedConfig,
config: ModelConfig,
classifier: nn.Module,
pooler: Optional[nn.Module] = None,
):
super().__init__()
self.classifier = classifier
self.pooler = pooler
self.default_activation_function = \
get_cross_encoder_activation_function(config)
if config.task == "score":
self.default_activation_function = \
get_cross_encoder_activation_function(config.hf_config)
elif config.task == "classify":
self.default_activation_function = nn.Sigmoid() \
if config.hf_config.num_labels == 1 else nn.Softmax()
else:
raise NotImplementedError(f"task={config.task!r} is not supported"
" with the classification pooler")
def forward(
self,

View File

@ -16,7 +16,7 @@ from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.pooler import (CrossEncodingPooler, Pooler,
from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler,
PoolingType)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -470,8 +470,8 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
embedding_class=BertEmbedding,
add_pooling_layer=True)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self._pooler = CrossEncodingPooler(config, self.classifier,
self.bert.pooler)
self._pooler = ClassifierPooler(vllm_config.model_config,
self.classifier, self.bert.pooler)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):

View File

@ -12,7 +12,7 @@ from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.pooler import CrossEncodingPooler
from vllm.model_executor.layers.pooler import ClassifierPooler
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
@ -278,8 +278,9 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
self.model = ModernBertModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "modernbert"))
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self._pooler = CrossEncodingPooler(config, self.classifier,
ModernBertPooler(config))
self._pooler = ClassifierPooler(vllm_config.model_config,
self.classifier,
ModernBertPooler(config))
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):

View File

@ -9,7 +9,7 @@ from torch import nn
from transformers import RobertaConfig
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import CrossEncodingPooler
from vllm.model_executor.layers.pooler import ClassifierPooler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@ -186,7 +186,9 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
embedding_class=RobertaEmbedding,
add_pooling_layer=False)
self.classifier = RobertaClassificationHead(config)
self._pooler = CrossEncodingPooler(config, self.classifier)
self._pooler = ClassifierPooler(vllm_config.model_config,
self.classifier)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
bert_weights, task_weights = roberta_task_weights_filter(weights)

View File

@ -823,10 +823,17 @@ def try_get_generation_config(
def get_cross_encoder_activation_function(config: PretrainedConfig):
if (hasattr(config, "sbert_ce_default_activation_function")
and config.sbert_ce_default_activation_function is not None):
function_name: Optional[str] = None
if hasattr(config, "sentence_transformers") and "activation_fn" in \
config.sentence_transformers:
function_name = config.sentence_transformers["activation_fn"]
elif (hasattr(config, "sbert_ce_default_activation_function")
and config.sbert_ce_default_activation_function is not None):
function_name = config.sbert_ce_default_activation_function
if function_name is not None:
assert function_name.startswith("torch.nn.modules."), \
"Loading of activation functions is restricted to " \
"torch.nn.modules for security reasons"