mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:15:01 +08:00
Prevent the cross-encoder logic from being applied to classification tasks (#18838)
Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
parent
269d901734
commit
515b413ebf
@ -6,10 +6,9 @@ from typing import Optional, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import PretrainedConfig
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.config import PoolerConfig
|
||||
from vllm.config import ModelConfig, PoolerConfig
|
||||
from vllm.model_executor.pooling_metadata import (PoolingMetadata,
|
||||
PoolingTensors)
|
||||
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
|
||||
@ -283,30 +282,37 @@ class Pooler(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
class CrossEncodingPooler(nn.Module):
|
||||
"""A layer that pools specific information from hidden states.
|
||||
class ClassifierPooler(nn.Module):
|
||||
"""A pooling layer for classification tasks.
|
||||
|
||||
This layer does the following:
|
||||
1. Extracts specific tokens or aggregates data based on pooling method.
|
||||
2. Normalizes output if specified.
|
||||
3. Returns structured results as `PoolerOutput`.
|
||||
|
||||
Attributes:
|
||||
pooling_type: The type of pooling to use.
|
||||
normalize: Whether to normalize the pooled data.
|
||||
1. Applies a classification layer to the hidden states.
|
||||
2. Optionally applies a pooler layer.
|
||||
3. Applies an activation function to the output. In the case of
|
||||
classification models it is either sigmoid or softmax. In the
|
||||
case of scoring models, the same behavior is configuration
|
||||
dependent, as in the sentence-transformers library.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
config: ModelConfig,
|
||||
classifier: nn.Module,
|
||||
pooler: Optional[nn.Module] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.classifier = classifier
|
||||
self.pooler = pooler
|
||||
self.default_activation_function = \
|
||||
get_cross_encoder_activation_function(config)
|
||||
|
||||
if config.task == "score":
|
||||
self.default_activation_function = \
|
||||
get_cross_encoder_activation_function(config.hf_config)
|
||||
elif config.task == "classify":
|
||||
self.default_activation_function = nn.Sigmoid() \
|
||||
if config.hf_config.num_labels == 1 else nn.Softmax()
|
||||
else:
|
||||
raise NotImplementedError(f"task={config.task!r} is not supported"
|
||||
" with the classification pooler")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -16,7 +16,7 @@ from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.pooler import (CrossEncodingPooler, Pooler,
|
||||
from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler,
|
||||
PoolingType)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -470,8 +470,8 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
||||
embedding_class=BertEmbedding,
|
||||
add_pooling_layer=True)
|
||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||
self._pooler = CrossEncodingPooler(config, self.classifier,
|
||||
self.bert.pooler)
|
||||
self._pooler = ClassifierPooler(vllm_config.model_config,
|
||||
self.classifier, self.bert.pooler)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
|
||||
|
||||
@ -12,7 +12,7 @@ from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.pooler import CrossEncodingPooler
|
||||
from vllm.model_executor.layers.pooler import ClassifierPooler
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
@ -278,8 +278,9 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
self.model = ModernBertModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "modernbert"))
|
||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||
self._pooler = CrossEncodingPooler(config, self.classifier,
|
||||
ModernBertPooler(config))
|
||||
self._pooler = ClassifierPooler(vllm_config.model_config,
|
||||
self.classifier,
|
||||
ModernBertPooler(config))
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
|
||||
|
||||
@ -9,7 +9,7 @@ from torch import nn
|
||||
from transformers import RobertaConfig
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.pooler import CrossEncodingPooler
|
||||
from vllm.model_executor.layers.pooler import ClassifierPooler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
@ -186,7 +186,9 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
||||
embedding_class=RobertaEmbedding,
|
||||
add_pooling_layer=False)
|
||||
self.classifier = RobertaClassificationHead(config)
|
||||
self._pooler = CrossEncodingPooler(config, self.classifier)
|
||||
|
||||
self._pooler = ClassifierPooler(vllm_config.model_config,
|
||||
self.classifier)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
bert_weights, task_weights = roberta_task_weights_filter(weights)
|
||||
|
||||
@ -823,10 +823,17 @@ def try_get_generation_config(
|
||||
|
||||
|
||||
def get_cross_encoder_activation_function(config: PretrainedConfig):
|
||||
if (hasattr(config, "sbert_ce_default_activation_function")
|
||||
and config.sbert_ce_default_activation_function is not None):
|
||||
|
||||
function_name: Optional[str] = None
|
||||
if hasattr(config, "sentence_transformers") and "activation_fn" in \
|
||||
config.sentence_transformers:
|
||||
function_name = config.sentence_transformers["activation_fn"]
|
||||
|
||||
elif (hasattr(config, "sbert_ce_default_activation_function")
|
||||
and config.sbert_ce_default_activation_function is not None):
|
||||
function_name = config.sbert_ce_default_activation_function
|
||||
|
||||
if function_name is not None:
|
||||
assert function_name.startswith("torch.nn.modules."), \
|
||||
"Loading of activation functions is restricted to " \
|
||||
"torch.nn.modules for security reasons"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user