[Model] Consolidate pooler implementations (#20927)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-07-16 21:39:13 +08:00 committed by GitHub
parent 260127ea54
commit 1c3198b6c4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 558 additions and 372 deletions

View File

@ -1,22 +1,21 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import IntEnum
from typing import Optional, Union
from typing import Callable, Optional, TypeVar, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing_extensions import assert_never
from transformers import PretrainedConfig
from vllm.config import ModelConfig, PoolerConfig
from vllm.model_executor.pooling_metadata import ( # noqa: E501
PoolingMetadata as V0PoolingMetadata)
from vllm.model_executor.pooling_metadata import PoolingTensors
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
from vllm.transformers_utils.config import (
get_classification_activation_function,
get_cross_encoder_activation_function)
from vllm.utils import resolve_obj_by_qualname
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata
PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata]
@ -31,140 +30,202 @@ class PoolingType(IntEnum):
MEAN = 4
class SimplePooler(nn.Module):
"""A layer that pools specific information from hidden states.
@dataclass(frozen=True)
class ResolvedPoolingConfig:
pooling_type: PoolingType
This layer does the following:
1. Extracts specific tokens or aggregates data based on pooling method.
2. Normalizes output if specified.
3. Returns structured results as `PoolerOutput`.
normalize: bool
softmax: bool
step_tag_id: Optional[int]
returned_token_ids: Optional[list[int]]
Attributes:
pooling_type: The type of pooling to use.
normalize: Whether to normalize the pooled data.
"""
@staticmethod
def from_pooling_type(
@classmethod
def from_config_with_defaults(
cls,
pooler_config: PoolerConfig,
pooling_type: PoolingType,
*,
normalize: bool,
softmax: bool,
step_tag_id: Optional[int] = None,
returned_token_ids: Optional[list[int]] = None,
) -> "SimplePooler":
if pooling_type == PoolingType.LAST:
assert step_tag_id is None and returned_token_ids is None
return LastPool(normalize=normalize, softmax=softmax)
if pooling_type == PoolingType.ALL:
assert step_tag_id is None and returned_token_ids is None
return AllPool(normalize=normalize, softmax=softmax)
if pooling_type == PoolingType.CLS:
assert step_tag_id is None and returned_token_ids is None
return CLSPool(normalize=normalize, softmax=softmax)
if pooling_type == PoolingType.MEAN:
assert step_tag_id is None and returned_token_ids is None
return MeanPool(normalize=normalize, softmax=softmax)
if pooling_type == PoolingType.STEP:
return StepPool(normalize=normalize,
softmax=softmax,
step_tag_id=step_tag_id,
returned_token_ids=returned_token_ids)
) -> "ResolvedPoolingConfig":
return cls(
pooling_type=PoolingType[pooler_config.pooling_type]
if pooler_config.pooling_type is not None else pooling_type,
normalize=pooler_config.normalize
if pooler_config.normalize is not None else normalize,
softmax=pooler_config.softmax
if pooler_config.softmax is not None else softmax,
step_tag_id=pooler_config.step_tag_id
if pooler_config.step_tag_id is not None else step_tag_id,
returned_token_ids=pooler_config.returned_token_ids
if pooler_config.returned_token_ids is not None else
returned_token_ids,
)
assert_never(pooling_type)
def __init__(self, *, normalize: bool, softmax: bool) -> None:
super().__init__()
def get_prompt_lens(
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> torch.Tensor:
if isinstance(pooling_metadata, V1PoolingMetadata):
return pooling_metadata.prompt_lens
self.head = PoolerHead(normalize=normalize, softmax=softmax)
assert isinstance(hidden_states, torch.Tensor)
return PoolingTensors.from_pooling_metadata(
pooling_metadata, hidden_states.device).prompt_lens
def get_prompt_lens(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> torch.Tensor:
if isinstance(pooling_metadata, V1PoolingMetadata):
return pooling_metadata.prompt_lens
assert isinstance(hidden_states, torch.Tensor)
return PoolingTensors.from_pooling_metadata(
pooling_metadata, hidden_states.device).prompt_lens
def extract_states(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> Union[list[torch.Tensor], torch.Tensor]:
raise NotImplementedError
def get_classification_activation_function(config: PretrainedConfig):
return PoolerClassify()
def build_output(self, data: torch.Tensor) -> PoolingSequenceGroupOutput:
return PoolingSequenceGroupOutput(data)
def get_cross_encoder_activation_function(config: PretrainedConfig):
function_name: Optional[str] = None
if (hasattr(config, "sentence_transformers")
and "activation_fn" in config.sentence_transformers):
function_name = config.sentence_transformers["activation_fn"]
elif (hasattr(config, "sbert_ce_default_activation_function")
and config.sbert_ce_default_activation_function is not None):
function_name = config.sbert_ce_default_activation_function
if function_name is not None:
assert function_name.startswith("torch.nn.modules."), (
"Loading of activation functions is restricted to "
"torch.nn.modules for security reasons")
fn = resolve_obj_by_qualname(function_name)()
return PoolerActivation.wraps(fn)
return PoolerScore()
def build_output(all_data: torch.Tensor) -> PoolerOutput:
all_outputs = [PoolingSequenceGroupOutput(data) for data in all_data]
return PoolerOutput(outputs=all_outputs)
class BasePooler(nn.Module):
@abstractmethod
def forward(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
pooled_data = self.extract_states(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data, pooling_metadata)
pooled_outputs = [self.build_output(data) for data in pooled_data]
return PoolerOutput(outputs=pooled_outputs)
raise NotImplementedError
class CLSPool(SimplePooler):
class PoolingMethod(nn.Module, ABC):
def extract_states(
@staticmethod
def from_pooling_type(pooling_type: PoolingType) -> "PoolingMethod":
if pooling_type == PoolingType.LAST:
return LastPool()
if pooling_type == PoolingType.ALL:
return AllPool()
if pooling_type == PoolingType.CLS:
return CLSPool()
if pooling_type == PoolingType.MEAN:
return MeanPool()
raise NotImplementedError(f"Unsupported method: {pooling_type}")
@abstractmethod
def forward_one(
self,
hidden_states: torch.Tensor,
prompt_len: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Note:
`prompt_len=None` means `prompt_len=len(hidden_states)`.
"""
raise NotImplementedError
@abstractmethod
def forward_all(
self,
hidden_states: torch.Tensor,
prompt_lens: torch.Tensor,
) -> Union[list[torch.Tensor], torch.Tensor]:
raise NotImplementedError
def forward(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> Union[list[torch.Tensor], torch.Tensor]:
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
prompt_lens = get_prompt_lens(hidden_states, pooling_metadata)
if isinstance(hidden_states, list):
result = []
for req_state, prompt_len in zip(hidden_states, prompt_lens):
assert prompt_len == req_state.shape[0], \
"partial prefill not supported with CLS pooling"
result.append(req_state[0])
return result
return [
self.forward_one(h, prompt_len)
for h, prompt_len in zip(hidden_states, prompt_lens)
]
return self.forward_all(hidden_states, prompt_lens)
class CLSPool(PoolingMethod):
def forward_one(
self,
hidden_states: torch.Tensor,
prompt_len: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert prompt_len is None or prompt_len == hidden_states.shape[0], \
"partial prefill not supported with CLS pooling"
return hidden_states[0]
def forward_all(
self,
hidden_states: torch.Tensor,
prompt_lens: torch.Tensor,
) -> Union[list[torch.Tensor], torch.Tensor]:
first_token_flat_indices = torch.zeros_like(prompt_lens)
first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1]
return hidden_states[first_token_flat_indices]
class LastPool(SimplePooler):
class LastPool(PoolingMethod):
def extract_states(
def forward_one(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
hidden_states: torch.Tensor,
prompt_len: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return hidden_states[-1]
def forward_all(
self,
hidden_states: torch.Tensor,
prompt_lens: torch.Tensor,
) -> Union[list[torch.Tensor], torch.Tensor]:
if isinstance(hidden_states, list):
return [h[-1] for h in hidden_states]
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
return hidden_states[last_token_flat_indices]
class AllPool(SimplePooler):
class AllPool(PoolingMethod):
def extract_states(
def forward_one(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
hidden_states: torch.Tensor,
prompt_len: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert prompt_len is None or prompt_len == hidden_states.shape[0], \
"partial prefill not supported with ALL pooling"
return hidden_states
def forward_all(
self,
hidden_states: torch.Tensor,
prompt_lens: torch.Tensor,
) -> Union[list[torch.Tensor], torch.Tensor]:
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
if isinstance(hidden_states, list):
for req_state, prompt_len in zip(hidden_states, prompt_lens):
assert prompt_len == req_state.shape[0], \
"partial prefill not supported with ALL pooling"
return hidden_states
offset = 0
pooled_data = list[torch.Tensor]()
for prompt_len in prompt_lens:
pooled_data.append(hidden_states[offset:offset + prompt_len])
offset += prompt_len
@ -172,24 +233,23 @@ class AllPool(SimplePooler):
return pooled_data
class MeanPool(SimplePooler):
class MeanPool(PoolingMethod):
def extract_states(
def forward_one(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
hidden_states: torch.Tensor,
prompt_len: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert prompt_len is None or prompt_len == hidden_states.shape[0], \
"partial prefill not supported with MEAN pooling"
return hidden_states.mean(dim=0, dtype=torch.float32)
def forward_all(
self,
hidden_states: torch.Tensor,
prompt_lens: torch.Tensor,
) -> Union[list[torch.Tensor], torch.Tensor]:
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
if isinstance(hidden_states, list):
result = []
for req_state, prompt_len in zip(hidden_states, prompt_lens):
assert prompt_len == req_state.shape[0], \
"partial prefill not supported with mean pooling"
result.append(torch.mean(req_state, dim=0,
dtype=torch.float32))
return result
# Use float32 for torch.cumsum in MeanPool,
# otherwise precision will be lost significantly.
cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32)
@ -203,78 +263,127 @@ class MeanPool(SimplePooler):
hidden_states[start_indices]) / prompt_lens.unsqueeze(1)
class StepPool(SimplePooler):
_T = TypeVar("_T", torch.Tensor, list[torch.Tensor])
def __init__(
self,
*,
normalize: bool,
softmax: bool,
step_tag_id: Optional[int] = None,
returned_token_ids: Optional[list[int]] = None,
):
super().__init__(normalize=normalize, softmax=softmax)
self.step_tag_id = step_tag_id
self.returned_token_ids = returned_token_ids
class BasePoolerActivation(nn.Module, ABC):
def get_prompt_token_ids(
self,
pooling_metadata: PoolingMetadata,
) -> list[torch.Tensor]:
if isinstance(pooling_metadata, V1PoolingMetadata):
return [
pooling_metadata.prompt_token_ids[i, :num]
for i, num in enumerate(pooling_metadata.prompt_lens)
]
return [
torch.tensor(seq_data_i.prompt_token_ids)
for seq_data_i in pooling_metadata.seq_data.values()
]
@abstractmethod
def forward(self, pooled_data: _T) -> _T:
# shape:
# classify (& score) -> (batch_size, num_classes)
# embed -> (batch_size, embedding_dim) or list(embedding_dim)
# (batch_size, dimensions) or list(dimensions) if using MRL
raise NotImplementedError
def extract_states(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> Union[list[torch.Tensor], torch.Tensor]:
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
prompt_token_ids = self.get_prompt_token_ids(pooling_metadata)
pooled_data_lst = list[torch.Tensor]()
if isinstance(hidden_states, list):
for req_state, prompt_len in zip(hidden_states, prompt_lens):
assert prompt_len == req_state.shape[0], \
"partial prefill not supported with step pooling"
pooled_data_lst = hidden_states
else:
offset = 0
for prompt_len in prompt_lens:
pooled_data_i = hidden_states[offset:offset + prompt_len]
offset += prompt_len
pooled_data_lst.append(pooled_data_i)
class PoolerActivation(BasePoolerActivation):
pooled_data = list[torch.Tensor]()
returned_token_ids = self.returned_token_ids
step_tag_id = self.step_tag_id
@staticmethod
def wraps(module: nn.Module):
if isinstance(module, nn.Identity):
return PoolerIdentity()
if isinstance(module, (nn.Sigmoid, nn.Softmax)):
return PoolerClassify()
for data, token_id in zip(pooled_data_lst, prompt_token_ids):
if returned_token_ids is not None and len(returned_token_ids) > 0:
data = data[:, returned_token_ids]
return LambdaPoolerActivation(module)
if step_tag_id is not None:
data = data[token_id == step_tag_id]
pooled_data.append(data)
@abstractmethod
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
def forward(self, pooled_data: _T) -> _T:
if isinstance(pooled_data, list):
return [self.forward_chunk(data) for data in pooled_data]
return self.forward_chunk(pooled_data)
class PoolerIdentity(PoolerActivation):
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
return pooled_data
class PoolerNormalize(PoolerActivation):
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
x = F.normalize(pooled_data.float(), p=2, dim=-1)
return x.to(pooled_data.dtype)
class PoolerClassify(PoolerActivation):
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
num_labels = pooled_data.shape[-1]
if num_labels < 2:
return F.sigmoid(pooled_data.float()).to(pooled_data.dtype)
return F.softmax(pooled_data.float(), dim=-1).to(pooled_data.dtype)
class PoolerScore(PoolerActivation):
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
num_labels = pooled_data.shape[-1]
if num_labels < 2:
return F.sigmoid(pooled_data.float()).to(pooled_data.dtype)
return pooled_data
class PoolerHead(nn.Module):
class LambdaPoolerActivation(PoolerActivation):
def __init__(self, *, normalize: bool, softmax: bool) -> None:
def __init__(self, fn: Callable[[torch.Tensor], torch.Tensor]):
super().__init__()
self.normalize = normalize
self.softmax = softmax
self.fn = fn
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
return self.fn(pooled_data)
class PoolerHead(nn.Module):
@classmethod
def from_config_with_defaults(
cls,
pooler_config: PoolerConfig,
pooling_type: PoolingType,
normalize: bool,
softmax: bool,
) -> "PoolerHead":
resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
pooler_config=pooler_config,
pooling_type=pooling_type,
normalize=normalize,
softmax=softmax,
step_tag_id=None,
returned_token_ids=None,
)
return cls.from_config(resolved_config)
@classmethod
def from_config(cls, pooler_config: ResolvedPoolingConfig) -> "PoolerHead":
if pooler_config.normalize and pooler_config.softmax:
raise ValueError("`normalize=True` and `softmax=True` should not "
"be set together")
activation: PoolerActivation
if pooler_config.normalize:
activation = PoolerNormalize()
elif pooler_config.softmax:
activation = PoolerClassify()
else:
activation = PoolerIdentity()
return cls(activation)
def __init__(self, activation: PoolerActivation) -> None:
super().__init__()
self.activation = activation
def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
pooling_metadata: PoolingMetadata):
@ -312,35 +421,21 @@ class PoolerHead(nn.Module):
for vecs, d in zip(pooled_data, dimensions_list)
]
if self.normalize:
if isinstance(pooled_data, list):
pooled_data = [
F.normalize(data, p=2, dim=-1) for data in pooled_data
]
else:
pooled_data = F.normalize(pooled_data, p=2, dim=-1)
if self.softmax:
if isinstance(pooled_data, list):
pooled_data = [
F.softmax(data, dim=-1)
if data.shape[-1] >= 2 else F.sigmoid(data)
for data in pooled_data
]
else:
if pooled_data.shape[-1] >= 2:
pooled_data = F.softmax(pooled_data, dim=-1)
else:
pooled_data = F.sigmoid(pooled_data)
# shape:
# classify (& score) -> (batch_size, num_classes)
# embed -> (batch_size, embedding_dim) or list(embedding_dim)
# (batch_size, dimensions) or list(dimensions) if using MRL
return pooled_data
return self.activation(pooled_data)
class Pooler(nn.Module):
class SimplePooler(BasePooler):
"""A layer that pools specific information from hidden states.
This layer does the following:
1. Extracts specific tokens or aggregates data based on pooling method.
2. Normalizes output if specified.
3. Returns structured results as `PoolerOutput`.
Attributes:
pooling_type: The type of pooling to use.
normalize: Whether to normalize the pooled data.
"""
@classmethod
def from_config_with_defaults(
@ -349,23 +444,146 @@ class Pooler(nn.Module):
pooling_type: PoolingType,
normalize: bool,
softmax: bool,
) -> "SimplePooler":
resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
pooler_config=pooler_config,
pooling_type=pooling_type,
normalize=normalize,
softmax=softmax,
)
assert resolved_config.pooling_type != PoolingType.STEP
return cls.from_config(resolved_config)
@classmethod
def from_config(
cls,
pooler_config: ResolvedPoolingConfig,
) -> "SimplePooler":
pooling = PoolingMethod.from_pooling_type(pooler_config.pooling_type)
head = PoolerHead.from_config(pooler_config)
return cls(pooling, head)
def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None:
super().__init__()
self.pooling = pooling
self.head = head
def forward(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data, pooling_metadata)
return build_output(pooled_data)
class StepPooler(BasePooler):
@classmethod
def from_config(cls, pooler_config: ResolvedPoolingConfig) -> "StepPooler":
assert pooler_config.pooling_type == PoolingType.STEP
return cls(
PoolerHead.from_config(pooler_config),
step_tag_id=pooler_config.step_tag_id,
returned_token_ids=pooler_config.returned_token_ids,
)
def __init__(
self,
head: PoolerHead,
*,
step_tag_id: Optional[int] = None,
returned_token_ids: Optional[list[int]] = None,
) -> SimplePooler:
return SimplePooler.from_pooling_type(
pooling_type=PoolingType[pooler_config.pooling_type]
if pooler_config.pooling_type is not None else pooling_type,
normalize=pooler_config.normalize
if pooler_config.normalize is not None else normalize,
softmax=pooler_config.softmax
if pooler_config.softmax is not None else softmax,
step_tag_id=pooler_config.step_tag_id
if pooler_config.step_tag_id is not None else step_tag_id,
returned_token_ids=pooler_config.returned_token_ids
if pooler_config.returned_token_ids is not None else
returned_token_ids,
) -> None:
super().__init__()
self.pooling = AllPool()
self.head = head
self.step_tag_id = step_tag_id
self.returned_token_ids = returned_token_ids
def get_prompt_token_ids(
self,
pooling_metadata: PoolingMetadata,
) -> list[torch.Tensor]:
if isinstance(pooling_metadata, V1PoolingMetadata):
return [
pooling_metadata.prompt_token_ids[i, :num]
for i, num in enumerate(pooling_metadata.prompt_lens)
]
return [
torch.tensor(seq_data_i.prompt_token_ids)
for seq_data_i in pooling_metadata.seq_data.values()
]
def extract_states(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> Union[list[torch.Tensor], torch.Tensor]:
pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
prompt_token_ids = self.get_prompt_token_ids(pooling_metadata)
pooled_data = list[torch.Tensor]()
returned_token_ids = self.returned_token_ids
step_tag_id = self.step_tag_id
for data, token_id in zip(pooled_data_lst, prompt_token_ids):
if returned_token_ids is not None and len(returned_token_ids) > 0:
data = data[:, returned_token_ids]
if step_tag_id is not None:
data = data[token_id == step_tag_id]
pooled_data.append(data)
return pooled_data
def forward(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
pooled_data = self.extract_states(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data, pooling_metadata)
return build_output(pooled_data)
class Pooler(nn.Module):
@staticmethod
def from_config_with_defaults(
pooler_config: PoolerConfig,
pooling_type: PoolingType,
normalize: bool,
softmax: bool,
step_tag_id: Optional[int] = None,
returned_token_ids: Optional[list[int]] = None,
) -> BasePooler:
resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
pooler_config=pooler_config,
pooling_type=pooling_type,
normalize=normalize,
softmax=softmax,
step_tag_id=step_tag_id,
returned_token_ids=returned_token_ids,
)
if pooling_type == PoolingType.STEP:
return StepPooler.from_config(resolved_config)
return SimplePooler.from_config(resolved_config)
PoolingFn = Callable[
[Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata],
Union[torch.Tensor, list[torch.Tensor]]]
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
class ClassifierPooler(nn.Module):
"""A pooling layer for classification tasks.
@ -382,69 +600,39 @@ class ClassifierPooler(nn.Module):
def __init__(
self,
config: ModelConfig,
classifier: nn.Module,
pooler: Optional[nn.Module] = None,
):
pooling: PoolingFn,
classifier: ClassifierFn,
act_fn: Optional[PoolerActivation] = None,
) -> None:
super().__init__()
self.pooling = pooling
self.classifier = classifier
self.pooler = pooler
self.classification_act_fn = get_classification_activation_function(
config.hf_config)
config.hf_config) if act_fn is None else act_fn
self.cross_encoder_act_fn = get_cross_encoder_activation_function(
config.hf_config)
config.hf_config) if act_fn is None else act_fn
def _get_act_fn(self, use_cross_encoder: bool):
return (self.cross_encoder_act_fn
if use_cross_encoder else self.classification_act_fn)
def get_prompt_lens(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> torch.Tensor:
if isinstance(pooling_metadata, V1PoolingMetadata):
return pooling_metadata.prompt_lens
assert isinstance(hidden_states, torch.Tensor)
return PoolingTensors.from_pooling_metadata(
pooling_metadata, hidden_states.device).prompt_lens
def forward(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
"""Pools sentence pair scores from the hidden_states."""
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
pooled_data = self.pooling(hidden_states, pooling_metadata)
pooled_data = list[torch.Tensor]()
if isinstance(hidden_states, list):
for req_state, prompt_len in zip(hidden_states, prompt_lens):
assert prompt_len == req_state.shape[0], \
"partial prefill not supported with classifier"
pooled_data = hidden_states
# apply classifier once on the full batch if possible
if isinstance(pooled_data, torch.Tensor):
pooled_output = self.classifier(pooled_data)
elif len({data.shape for data in pooled_data}) <= 1:
pooled_output = self.classifier(torch.stack(pooled_data))
else:
offset = 0
for prompt_len in prompt_lens:
pooled_data_i = hidden_states[offset:offset + prompt_len]
offset += prompt_len
pooled_data.append(pooled_data_i)
pooled_data_lst = []
for pooled_data_i in pooled_data:
if self.pooler is not None:
final_shape_tensor = self.pooler(pooled_data_i)
else:
final_shape_tensor = self.classifier(pooled_data_i)
pooled_data_lst.append(final_shape_tensor)
pooled_output = torch.stack(pooled_data_lst)
if self.pooler is not None:
# apply classifier once on the full batch if possible
pooled_output = self.classifier(pooled_output)
pooled_output = [self.classifier(data) for data in pooled_data]
if isinstance(pooling_metadata, V0PoolingMetadata):
use_cross_encoder_list = [
@ -469,5 +657,4 @@ class ClassifierPooler(nn.Module):
pooled_output)
])
pooled_outputs = [PoolingSequenceGroupOutput(data) for data in scores]
return PoolerOutput(outputs=pooled_outputs)
return build_output(scores)

View File

@ -58,22 +58,27 @@ def _create_pooling_model_cls(
) -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
self.vllm_config = vllm_config
# These are not used in pooling models
for attr in ("lm_head", "logits_processor"):
if hasattr(self, attr):
delattr(self, attr)
# If the model already defines a pooler instance, don't overwrite it
if not getattr(self, "_pooler", None):
self._init_pooler(vllm_config, prefix=prefix)
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
# If the model already defines a pooler instance, don't overwrite it
if not getattr(self, "_pooler", None):
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=default_pooling_type,
normalize=default_normalize,
softmax=default_softmax,
)
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=default_pooling_type,
normalize=default_normalize,
softmax=default_softmax,
)
def pooler(
self,
@ -165,7 +170,9 @@ def as_seq_cls_model(cls: _T) -> _T:
# Lazy import
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.pooler import PoolerOutput, PoolingType
from vllm.model_executor.layers.pooler import (ClassifierPooler,
PoolerOutput, PoolingType,
SimplePooler)
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors
@ -182,30 +189,40 @@ def as_seq_cls_model(cls: _T) -> _T:
class ModelForSequenceClassification(ModelForPooling,
SupportsCrossEncoding):
def __init__(
self,
*,
vllm_config: "VllmConfig",
prefix: str = "",
**kwargs: Any,
) -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.vllm_config = vllm_config
self.task = vllm_config.model_config.task
self.pooling_type = (
vllm_config.model_config.pooler_config.pooling_type)
self.score = RowParallelLinear(
config.hidden_size,
config.num_labels,
input_is_parallel=False,
bias=False,
params_dtype=torch.float32,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "score"),
)
self.score = RowParallelLinear(config.hidden_size,
config.num_labels,
quant_config=quant_config,
input_is_parallel=False,
bias=False,
prefix=maybe_prefix(
prefix, "score"))
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
pooler = SimplePooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=False,
softmax=True,
)
self._pooler = ClassifierPooler(
vllm_config.model_config,
pooling=pooler.pooling,
classifier=self._classifier,
act_fn=pooler.head.activation,
)
def _classifier(self, x: torch.Tensor):
x, _ = self.score(x.float())
return x
def forward(
self,
@ -222,27 +239,7 @@ def as_seq_cls_model(cls: _T) -> _T:
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
def get_logits(hidden_states):
if isinstance(hidden_states, list):
logits = [self.score(state)[0] for state in hidden_states]
else:
logits, _ = self.score(hidden_states)
return logits
if self.pooling_type == PoolingType.ALL:
logits = get_logits(hidden_states)
return self._pooler(logits, pooling_metadata)
else:
hidden_states = self._pooler.extract_states(
hidden_states, pooling_metadata)
logits = get_logits(hidden_states)
pooled_data = self._pooler.head(logits, pooling_metadata)
pooled_outputs = [
self._pooler.build_output(data) for data in pooled_data
]
return PoolerOutput(outputs=pooled_outputs)
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
tokens = getattr(self.config, "classifier_from_token", None)

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from typing import Optional
from typing import Optional, Union
import torch
from torch import nn
@ -18,7 +18,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler,
PoolingType)
PoolingMethod, PoolingType)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
@ -84,14 +84,18 @@ class BertPooler(nn.Module):
def __init__(self, config: BertConfig):
super().__init__()
self.pooling = PoolingMethod.from_pooling_type(PoolingType.CLS)
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[0, :]
pooled_output = self.dense(first_token_tensor)
def forward(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> Union[torch.Tensor, list[torch.Tensor]]:
pooled_output = self.pooling(hidden_states, pooling_metadata)
pooled_output = self.dense(pooled_output)
pooled_output = self.activation(pooled_output)
return pooled_output
@ -472,8 +476,11 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
embedding_class=BertEmbedding,
add_pooling_layer=True)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self._pooler = ClassifierPooler(vllm_config.model_config,
self.classifier, self.bert.pooler)
self._pooler = ClassifierPooler(
vllm_config.model_config,
pooling=self.bert.pooler,
classifier=self.classifier,
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)

View File

@ -9,7 +9,7 @@ import torch.nn as nn
from vllm.config import ModelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import PoolerHead
from vllm.model_executor.layers.pooler import PoolerHead, PoolerNormalize
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.pooling_metadata import (PoolingMetadata,
PoolingTensors)
@ -49,7 +49,7 @@ class GritLMPooler(nn.Module):
self.embed_pattern_ids = tokens_to_ids(
["▁<", "|", "embed", "|", ">", "<0x0A>"])
self.head = PoolerHead(normalize=True, softmax=False)
self.head = PoolerHead(PoolerNormalize())
def _find_array(self, arr: array, target: array, start_idx: int) -> int:
"""

View File

@ -659,7 +659,7 @@ def supports_cross_encoding(
def has_step_pooler(model: Union[type[object], object]) -> bool:
"""Check if the model uses step pooler."""
return is_pooling_model(model) and any(
type(module).__name__ == "StepPool" for module in model.modules())
type(module).__name__ == "StepPooler" for module in model.modules())
class SupportsQuant:

View File

@ -19,7 +19,8 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.pooler import (ClassifierPooler, PoolingType,
SimplePooler)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
@ -564,29 +565,41 @@ class JambaForSequenceClassification(JambaForCausalLM):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
config = vllm_config.model_config.hf_config
num_labels: int = config.num_labels
score_bias: bool = getattr(config, 'score_bias', False)
self.score = nn.Linear(config.hidden_size, num_labels, bias=score_bias)
# TODO: The original reward weights have float32 accuracy data, we
# would like to load them in fp32 to get that extra precision.
# Currently weight_loader passes the weight which is already in bf16
self.score = nn.Linear(
config.hidden_size,
num_labels,
bias=score_bias,
dtype=torch.float32,
)
pooler_config = vllm_config.model_config.pooler_config
self._pooler = Pooler.from_config_with_defaults(
assert pooler_config is not None
pooler = SimplePooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=False,
softmax=False)
softmax=False,
)
self._pooler = ClassifierPooler(
vllm_config.model_config,
pooling=pooler.pooling,
classifier=self.score,
act_fn=pooler.head.activation,
)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
hidden_states = hidden_states.float()
logits = self.score(hidden_states)
return self._pooler(logits, pooling_metadata)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
# TODO: The reward weights themselves have float32 accuracy data, we
# would like to load them in fp32 to get that extra precision.
super().load_weights(weights)
self.score = self.score.float()
return self._pooler(hidden_states, pooling_metadata)

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from typing import Optional
from typing import Optional, Union
import torch
from torch import nn
@ -13,7 +13,8 @@ from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.pooler import ClassifierPooler
from vllm.model_executor.layers.pooler import (BasePooler, ClassifierPooler,
PoolingMethod, PoolingType)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
@ -252,10 +253,13 @@ class ModernBertModel(nn.Module):
return norm_outputs
class ModernBertPooler(nn.Module):
class ModernBertPooler(BasePooler):
def __init__(self, config: ModernBertConfig):
super().__init__()
pooling_type = PoolingType[config.classifier_pooling.upper()]
self.pooling = PoolingMethod.from_pooling_type(pooling_type)
self.dense = nn.Linear(config.hidden_size, config.hidden_size,
config.classifier_bias)
self.pooling_type = config.classifier_pooling
@ -264,15 +268,12 @@ class ModernBertPooler(nn.Module):
eps=config.norm_eps,
bias=config.norm_bias)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
pooled_output = hidden_states
if self.pooling_type == "mean":
pooled_output = pooled_output.mean(dim=0, keepdim=False)
elif self.pooling_type == "cls":
pooled_output = pooled_output[0, :]
else:
raise ValueError("Pooling type should be either `cls` or `mean`, "
f"but got {self.pooling_type}")
def forward(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> Union[torch.Tensor, list[torch.Tensor]]:
pooled_output = self.pooling(hidden_states, pooling_metadata)
pooled_output = self.norm(self.act(self.dense(pooled_output)))
return pooled_output
@ -287,9 +288,11 @@ class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
self.model = ModernBertModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "modernbert"))
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self._pooler = ClassifierPooler(vllm_config.model_config,
self.classifier,
ModernBertPooler(config))
self._pooler = ClassifierPooler(
vllm_config.model_config,
pooling=ModernBertPooler(config),
classifier=self.classifier,
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):

View File

@ -9,7 +9,7 @@ from torch import nn
from transformers import RobertaConfig
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import ClassifierPooler
from vllm.model_executor.layers.pooler import ClassifierPooler, CLSPool
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel
@ -106,8 +106,8 @@ class RobertaClassificationHead(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, features, **kwargs):
x = features[0, :] # take <s> token (equiv. to [CLS])
def forward(self, x: torch.Tensor) -> torch.Tensor:
# CLSPool has already been applied in `pooling`
x = self.dense(x)
x = torch.tanh(x)
x = self.out_proj(x)
@ -188,8 +188,11 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
add_pooling_layer=False)
self.classifier = RobertaClassificationHead(config)
self._pooler = ClassifierPooler(vllm_config.model_config,
self.classifier)
self._pooler = ClassifierPooler(
vllm_config.model_config,
pooling=CLSPool(),
classifier=self.classifier,
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)

View File

@ -17,7 +17,6 @@ from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError,
HFValidationError, LocalEntryNotFoundError,
RepositoryNotFoundError,
RevisionNotFoundError)
from torch import nn
from transformers import GenerationConfig, PretrainedConfig
from transformers.models.auto.image_processing_auto import (
get_image_processor_config)
@ -44,7 +43,6 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
# yapf: enable
from vllm.transformers_utils.configs.mistral import adapt_config_dict
from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import resolve_obj_by_qualname
if envs.VLLM_USE_MODELSCOPE:
from modelscope import AutoConfig
@ -775,28 +773,6 @@ def try_get_generation_config(
return None
def get_classification_activation_function(config: PretrainedConfig):
return nn.Sigmoid() if config.num_labels == 1 else nn.Softmax()
def get_cross_encoder_activation_function(config: PretrainedConfig):
function_name: Optional[str] = None
if (hasattr(config, "sentence_transformers")
and "activation_fn" in config.sentence_transformers):
function_name = config.sentence_transformers["activation_fn"]
elif (hasattr(config, "sbert_ce_default_activation_function")
and config.sbert_ce_default_activation_function is not None):
function_name = config.sbert_ce_default_activation_function
if function_name is not None:
assert function_name.startswith("torch.nn.modules."), (
"Loading of activation functions is restricted to "
"torch.nn.modules for security reasons")
return resolve_obj_by_qualname(function_name)()
return nn.Sigmoid() if config.num_labels == 1 else nn.Identity()
def try_get_safetensors_metadata(
model: str,
*,