mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-12 19:15:35 +08:00
Prevent the cross-encoder logic from being applied to classification tasks (#18838)
Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
parent
269d901734
commit
515b413ebf
@ -6,10 +6,9 @@ from typing import Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from transformers import PretrainedConfig
|
|
||||||
from typing_extensions import assert_never
|
from typing_extensions import assert_never
|
||||||
|
|
||||||
from vllm.config import PoolerConfig
|
from vllm.config import ModelConfig, PoolerConfig
|
||||||
from vllm.model_executor.pooling_metadata import (PoolingMetadata,
|
from vllm.model_executor.pooling_metadata import (PoolingMetadata,
|
||||||
PoolingTensors)
|
PoolingTensors)
|
||||||
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
|
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
|
||||||
@ -283,30 +282,37 @@ class Pooler(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class CrossEncodingPooler(nn.Module):
|
class ClassifierPooler(nn.Module):
|
||||||
"""A layer that pools specific information from hidden states.
|
"""A pooling layer for classification tasks.
|
||||||
|
|
||||||
This layer does the following:
|
This layer does the following:
|
||||||
1. Extracts specific tokens or aggregates data based on pooling method.
|
1. Applies a classification layer to the hidden states.
|
||||||
2. Normalizes output if specified.
|
2. Optionally applies a pooler layer.
|
||||||
3. Returns structured results as `PoolerOutput`.
|
3. Applies an activation function to the output. In the case of
|
||||||
|
classification models it is either sigmoid or softmax. In the
|
||||||
Attributes:
|
case of scoring models, the same behavior is configuration
|
||||||
pooling_type: The type of pooling to use.
|
dependent, as in the sentence-transformers library.
|
||||||
normalize: Whether to normalize the pooled data.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: ModelConfig,
|
||||||
classifier: nn.Module,
|
classifier: nn.Module,
|
||||||
pooler: Optional[nn.Module] = None,
|
pooler: Optional[nn.Module] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.classifier = classifier
|
self.classifier = classifier
|
||||||
self.pooler = pooler
|
self.pooler = pooler
|
||||||
|
|
||||||
|
if config.task == "score":
|
||||||
self.default_activation_function = \
|
self.default_activation_function = \
|
||||||
get_cross_encoder_activation_function(config)
|
get_cross_encoder_activation_function(config.hf_config)
|
||||||
|
elif config.task == "classify":
|
||||||
|
self.default_activation_function = nn.Sigmoid() \
|
||||||
|
if config.hf_config.num_labels == 1 else nn.Softmax()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"task={config.task!r} is not supported"
|
||||||
|
" with the classification pooler")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -16,7 +16,7 @@ from vllm.model_executor.layers.activation import get_act_fn
|
|||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.pooler import (CrossEncodingPooler, Pooler,
|
from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler,
|
||||||
PoolingType)
|
PoolingType)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
@ -470,8 +470,8 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
|||||||
embedding_class=BertEmbedding,
|
embedding_class=BertEmbedding,
|
||||||
add_pooling_layer=True)
|
add_pooling_layer=True)
|
||||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||||
self._pooler = CrossEncodingPooler(config, self.classifier,
|
self._pooler = ClassifierPooler(vllm_config.model_config,
|
||||||
self.bert.pooler)
|
self.classifier, self.bert.pooler)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||||
|
|
||||||
|
|||||||
@ -12,7 +12,7 @@ from vllm.config import VllmConfig
|
|||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.pooler import CrossEncodingPooler
|
from vllm.model_executor.layers.pooler import ClassifierPooler
|
||||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding)
|
VocabParallelEmbedding)
|
||||||
@ -278,7 +278,8 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
|||||||
self.model = ModernBertModel(vllm_config=vllm_config,
|
self.model = ModernBertModel(vllm_config=vllm_config,
|
||||||
prefix=maybe_prefix(prefix, "modernbert"))
|
prefix=maybe_prefix(prefix, "modernbert"))
|
||||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||||
self._pooler = CrossEncodingPooler(config, self.classifier,
|
self._pooler = ClassifierPooler(vllm_config.model_config,
|
||||||
|
self.classifier,
|
||||||
ModernBertPooler(config))
|
ModernBertPooler(config))
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from torch import nn
|
|||||||
from transformers import RobertaConfig
|
from transformers import RobertaConfig
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.model_executor.layers.pooler import CrossEncodingPooler
|
from vllm.model_executor.layers.pooler import ClassifierPooler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding)
|
VocabParallelEmbedding)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
@ -186,7 +186,9 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
|||||||
embedding_class=RobertaEmbedding,
|
embedding_class=RobertaEmbedding,
|
||||||
add_pooling_layer=False)
|
add_pooling_layer=False)
|
||||||
self.classifier = RobertaClassificationHead(config)
|
self.classifier = RobertaClassificationHead(config)
|
||||||
self._pooler = CrossEncodingPooler(config, self.classifier)
|
|
||||||
|
self._pooler = ClassifierPooler(vllm_config.model_config,
|
||||||
|
self.classifier)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||||
bert_weights, task_weights = roberta_task_weights_filter(weights)
|
bert_weights, task_weights = roberta_task_weights_filter(weights)
|
||||||
|
|||||||
@ -823,10 +823,17 @@ def try_get_generation_config(
|
|||||||
|
|
||||||
|
|
||||||
def get_cross_encoder_activation_function(config: PretrainedConfig):
|
def get_cross_encoder_activation_function(config: PretrainedConfig):
|
||||||
if (hasattr(config, "sbert_ce_default_activation_function")
|
|
||||||
and config.sbert_ce_default_activation_function is not None):
|
|
||||||
|
|
||||||
|
function_name: Optional[str] = None
|
||||||
|
if hasattr(config, "sentence_transformers") and "activation_fn" in \
|
||||||
|
config.sentence_transformers:
|
||||||
|
function_name = config.sentence_transformers["activation_fn"]
|
||||||
|
|
||||||
|
elif (hasattr(config, "sbert_ce_default_activation_function")
|
||||||
|
and config.sbert_ce_default_activation_function is not None):
|
||||||
function_name = config.sbert_ce_default_activation_function
|
function_name = config.sbert_ce_default_activation_function
|
||||||
|
|
||||||
|
if function_name is not None:
|
||||||
assert function_name.startswith("torch.nn.modules."), \
|
assert function_name.startswith("torch.nn.modules."), \
|
||||||
"Loading of activation functions is restricted to " \
|
"Loading of activation functions is restricted to " \
|
||||||
"torch.nn.modules for security reasons"
|
"torch.nn.modules for security reasons"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user