mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 11:06:25 +08:00
[Model] Consolidate pooler implementations (#20927)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
260127ea54
commit
1c3198b6c4
@ -1,22 +1,21 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum
|
||||
from typing import Optional, Union
|
||||
from typing import Callable, Optional, TypeVar, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing_extensions import assert_never
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import ModelConfig, PoolerConfig
|
||||
from vllm.model_executor.pooling_metadata import ( # noqa: E501
|
||||
PoolingMetadata as V0PoolingMetadata)
|
||||
from vllm.model_executor.pooling_metadata import PoolingTensors
|
||||
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
|
||||
from vllm.transformers_utils.config import (
|
||||
get_classification_activation_function,
|
||||
get_cross_encoder_activation_function)
|
||||
from vllm.utils import resolve_obj_by_qualname
|
||||
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata
|
||||
|
||||
PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata]
|
||||
@ -31,140 +30,202 @@ class PoolingType(IntEnum):
|
||||
MEAN = 4
|
||||
|
||||
|
||||
class SimplePooler(nn.Module):
|
||||
"""A layer that pools specific information from hidden states.
|
||||
@dataclass(frozen=True)
|
||||
class ResolvedPoolingConfig:
|
||||
pooling_type: PoolingType
|
||||
|
||||
This layer does the following:
|
||||
1. Extracts specific tokens or aggregates data based on pooling method.
|
||||
2. Normalizes output if specified.
|
||||
3. Returns structured results as `PoolerOutput`.
|
||||
normalize: bool
|
||||
softmax: bool
|
||||
step_tag_id: Optional[int]
|
||||
returned_token_ids: Optional[list[int]]
|
||||
|
||||
Attributes:
|
||||
pooling_type: The type of pooling to use.
|
||||
normalize: Whether to normalize the pooled data.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def from_pooling_type(
|
||||
@classmethod
|
||||
def from_config_with_defaults(
|
||||
cls,
|
||||
pooler_config: PoolerConfig,
|
||||
pooling_type: PoolingType,
|
||||
*,
|
||||
normalize: bool,
|
||||
softmax: bool,
|
||||
step_tag_id: Optional[int] = None,
|
||||
returned_token_ids: Optional[list[int]] = None,
|
||||
) -> "SimplePooler":
|
||||
if pooling_type == PoolingType.LAST:
|
||||
assert step_tag_id is None and returned_token_ids is None
|
||||
return LastPool(normalize=normalize, softmax=softmax)
|
||||
if pooling_type == PoolingType.ALL:
|
||||
assert step_tag_id is None and returned_token_ids is None
|
||||
return AllPool(normalize=normalize, softmax=softmax)
|
||||
if pooling_type == PoolingType.CLS:
|
||||
assert step_tag_id is None and returned_token_ids is None
|
||||
return CLSPool(normalize=normalize, softmax=softmax)
|
||||
if pooling_type == PoolingType.MEAN:
|
||||
assert step_tag_id is None and returned_token_ids is None
|
||||
return MeanPool(normalize=normalize, softmax=softmax)
|
||||
if pooling_type == PoolingType.STEP:
|
||||
return StepPool(normalize=normalize,
|
||||
softmax=softmax,
|
||||
step_tag_id=step_tag_id,
|
||||
returned_token_ids=returned_token_ids)
|
||||
) -> "ResolvedPoolingConfig":
|
||||
return cls(
|
||||
pooling_type=PoolingType[pooler_config.pooling_type]
|
||||
if pooler_config.pooling_type is not None else pooling_type,
|
||||
normalize=pooler_config.normalize
|
||||
if pooler_config.normalize is not None else normalize,
|
||||
softmax=pooler_config.softmax
|
||||
if pooler_config.softmax is not None else softmax,
|
||||
step_tag_id=pooler_config.step_tag_id
|
||||
if pooler_config.step_tag_id is not None else step_tag_id,
|
||||
returned_token_ids=pooler_config.returned_token_ids
|
||||
if pooler_config.returned_token_ids is not None else
|
||||
returned_token_ids,
|
||||
)
|
||||
|
||||
assert_never(pooling_type)
|
||||
|
||||
def __init__(self, *, normalize: bool, softmax: bool) -> None:
|
||||
super().__init__()
|
||||
def get_prompt_lens(
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> torch.Tensor:
|
||||
if isinstance(pooling_metadata, V1PoolingMetadata):
|
||||
return pooling_metadata.prompt_lens
|
||||
|
||||
self.head = PoolerHead(normalize=normalize, softmax=softmax)
|
||||
assert isinstance(hidden_states, torch.Tensor)
|
||||
return PoolingTensors.from_pooling_metadata(
|
||||
pooling_metadata, hidden_states.device).prompt_lens
|
||||
|
||||
def get_prompt_lens(
|
||||
self,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> torch.Tensor:
|
||||
if isinstance(pooling_metadata, V1PoolingMetadata):
|
||||
return pooling_metadata.prompt_lens
|
||||
assert isinstance(hidden_states, torch.Tensor)
|
||||
return PoolingTensors.from_pooling_metadata(
|
||||
pooling_metadata, hidden_states.device).prompt_lens
|
||||
|
||||
def extract_states(
|
||||
self,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
def get_classification_activation_function(config: PretrainedConfig):
|
||||
return PoolerClassify()
|
||||
|
||||
def build_output(self, data: torch.Tensor) -> PoolingSequenceGroupOutput:
|
||||
return PoolingSequenceGroupOutput(data)
|
||||
|
||||
def get_cross_encoder_activation_function(config: PretrainedConfig):
|
||||
function_name: Optional[str] = None
|
||||
if (hasattr(config, "sentence_transformers")
|
||||
and "activation_fn" in config.sentence_transformers):
|
||||
function_name = config.sentence_transformers["activation_fn"]
|
||||
elif (hasattr(config, "sbert_ce_default_activation_function")
|
||||
and config.sbert_ce_default_activation_function is not None):
|
||||
function_name = config.sbert_ce_default_activation_function
|
||||
|
||||
if function_name is not None:
|
||||
assert function_name.startswith("torch.nn.modules."), (
|
||||
"Loading of activation functions is restricted to "
|
||||
"torch.nn.modules for security reasons")
|
||||
fn = resolve_obj_by_qualname(function_name)()
|
||||
return PoolerActivation.wraps(fn)
|
||||
|
||||
return PoolerScore()
|
||||
|
||||
|
||||
def build_output(all_data: torch.Tensor) -> PoolerOutput:
|
||||
all_outputs = [PoolingSequenceGroupOutput(data) for data in all_data]
|
||||
return PoolerOutput(outputs=all_outputs)
|
||||
|
||||
|
||||
class BasePooler(nn.Module):
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
pooled_data = self.extract_states(hidden_states, pooling_metadata)
|
||||
pooled_data = self.head(pooled_data, pooling_metadata)
|
||||
pooled_outputs = [self.build_output(data) for data in pooled_data]
|
||||
return PoolerOutput(outputs=pooled_outputs)
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CLSPool(SimplePooler):
|
||||
class PoolingMethod(nn.Module, ABC):
|
||||
|
||||
def extract_states(
|
||||
@staticmethod
|
||||
def from_pooling_type(pooling_type: PoolingType) -> "PoolingMethod":
|
||||
if pooling_type == PoolingType.LAST:
|
||||
return LastPool()
|
||||
if pooling_type == PoolingType.ALL:
|
||||
return AllPool()
|
||||
if pooling_type == PoolingType.CLS:
|
||||
return CLSPool()
|
||||
if pooling_type == PoolingType.MEAN:
|
||||
return MeanPool()
|
||||
|
||||
raise NotImplementedError(f"Unsupported method: {pooling_type}")
|
||||
|
||||
@abstractmethod
|
||||
def forward_one(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
prompt_len: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Note:
|
||||
`prompt_len=None` means `prompt_len=len(hidden_states)`.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def forward_all(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
prompt_lens: torch.Tensor,
|
||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
|
||||
prompt_lens = get_prompt_lens(hidden_states, pooling_metadata)
|
||||
|
||||
if isinstance(hidden_states, list):
|
||||
result = []
|
||||
for req_state, prompt_len in zip(hidden_states, prompt_lens):
|
||||
assert prompt_len == req_state.shape[0], \
|
||||
"partial prefill not supported with CLS pooling"
|
||||
result.append(req_state[0])
|
||||
return result
|
||||
return [
|
||||
self.forward_one(h, prompt_len)
|
||||
for h, prompt_len in zip(hidden_states, prompt_lens)
|
||||
]
|
||||
|
||||
return self.forward_all(hidden_states, prompt_lens)
|
||||
|
||||
|
||||
class CLSPool(PoolingMethod):
|
||||
|
||||
def forward_one(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
prompt_len: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert prompt_len is None or prompt_len == hidden_states.shape[0], \
|
||||
"partial prefill not supported with CLS pooling"
|
||||
|
||||
return hidden_states[0]
|
||||
|
||||
def forward_all(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
prompt_lens: torch.Tensor,
|
||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||
first_token_flat_indices = torch.zeros_like(prompt_lens)
|
||||
first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1]
|
||||
return hidden_states[first_token_flat_indices]
|
||||
|
||||
|
||||
class LastPool(SimplePooler):
|
||||
class LastPool(PoolingMethod):
|
||||
|
||||
def extract_states(
|
||||
def forward_one(
|
||||
self,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
hidden_states: torch.Tensor,
|
||||
prompt_len: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return hidden_states[-1]
|
||||
|
||||
def forward_all(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
prompt_lens: torch.Tensor,
|
||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||
if isinstance(hidden_states, list):
|
||||
return [h[-1] for h in hidden_states]
|
||||
|
||||
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
|
||||
|
||||
last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
|
||||
return hidden_states[last_token_flat_indices]
|
||||
|
||||
|
||||
class AllPool(SimplePooler):
|
||||
class AllPool(PoolingMethod):
|
||||
|
||||
def extract_states(
|
||||
def forward_one(
|
||||
self,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
hidden_states: torch.Tensor,
|
||||
prompt_len: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert prompt_len is None or prompt_len == hidden_states.shape[0], \
|
||||
"partial prefill not supported with ALL pooling"
|
||||
|
||||
return hidden_states
|
||||
|
||||
def forward_all(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
prompt_lens: torch.Tensor,
|
||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
|
||||
|
||||
if isinstance(hidden_states, list):
|
||||
for req_state, prompt_len in zip(hidden_states, prompt_lens):
|
||||
assert prompt_len == req_state.shape[0], \
|
||||
"partial prefill not supported with ALL pooling"
|
||||
return hidden_states
|
||||
|
||||
offset = 0
|
||||
pooled_data = list[torch.Tensor]()
|
||||
|
||||
for prompt_len in prompt_lens:
|
||||
pooled_data.append(hidden_states[offset:offset + prompt_len])
|
||||
offset += prompt_len
|
||||
@ -172,24 +233,23 @@ class AllPool(SimplePooler):
|
||||
return pooled_data
|
||||
|
||||
|
||||
class MeanPool(SimplePooler):
|
||||
class MeanPool(PoolingMethod):
|
||||
|
||||
def extract_states(
|
||||
def forward_one(
|
||||
self,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
hidden_states: torch.Tensor,
|
||||
prompt_len: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert prompt_len is None or prompt_len == hidden_states.shape[0], \
|
||||
"partial prefill not supported with MEAN pooling"
|
||||
|
||||
return hidden_states.mean(dim=0, dtype=torch.float32)
|
||||
|
||||
def forward_all(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
prompt_lens: torch.Tensor,
|
||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
|
||||
|
||||
if isinstance(hidden_states, list):
|
||||
result = []
|
||||
for req_state, prompt_len in zip(hidden_states, prompt_lens):
|
||||
assert prompt_len == req_state.shape[0], \
|
||||
"partial prefill not supported with mean pooling"
|
||||
result.append(torch.mean(req_state, dim=0,
|
||||
dtype=torch.float32))
|
||||
return result
|
||||
|
||||
# Use float32 for torch.cumsum in MeanPool,
|
||||
# otherwise precision will be lost significantly.
|
||||
cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32)
|
||||
@ -203,78 +263,127 @@ class MeanPool(SimplePooler):
|
||||
hidden_states[start_indices]) / prompt_lens.unsqueeze(1)
|
||||
|
||||
|
||||
class StepPool(SimplePooler):
|
||||
_T = TypeVar("_T", torch.Tensor, list[torch.Tensor])
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
normalize: bool,
|
||||
softmax: bool,
|
||||
step_tag_id: Optional[int] = None,
|
||||
returned_token_ids: Optional[list[int]] = None,
|
||||
):
|
||||
super().__init__(normalize=normalize, softmax=softmax)
|
||||
|
||||
self.step_tag_id = step_tag_id
|
||||
self.returned_token_ids = returned_token_ids
|
||||
class BasePoolerActivation(nn.Module, ABC):
|
||||
|
||||
def get_prompt_token_ids(
|
||||
self,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> list[torch.Tensor]:
|
||||
if isinstance(pooling_metadata, V1PoolingMetadata):
|
||||
return [
|
||||
pooling_metadata.prompt_token_ids[i, :num]
|
||||
for i, num in enumerate(pooling_metadata.prompt_lens)
|
||||
]
|
||||
return [
|
||||
torch.tensor(seq_data_i.prompt_token_ids)
|
||||
for seq_data_i in pooling_metadata.seq_data.values()
|
||||
]
|
||||
@abstractmethod
|
||||
def forward(self, pooled_data: _T) -> _T:
|
||||
# shape:
|
||||
# classify (& score) -> (batch_size, num_classes)
|
||||
# embed -> (batch_size, embedding_dim) or list(embedding_dim)
|
||||
# (batch_size, dimensions) or list(dimensions) if using MRL
|
||||
raise NotImplementedError
|
||||
|
||||
def extract_states(
|
||||
self,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
|
||||
prompt_token_ids = self.get_prompt_token_ids(pooling_metadata)
|
||||
|
||||
pooled_data_lst = list[torch.Tensor]()
|
||||
if isinstance(hidden_states, list):
|
||||
for req_state, prompt_len in zip(hidden_states, prompt_lens):
|
||||
assert prompt_len == req_state.shape[0], \
|
||||
"partial prefill not supported with step pooling"
|
||||
pooled_data_lst = hidden_states
|
||||
else:
|
||||
offset = 0
|
||||
for prompt_len in prompt_lens:
|
||||
pooled_data_i = hidden_states[offset:offset + prompt_len]
|
||||
offset += prompt_len
|
||||
pooled_data_lst.append(pooled_data_i)
|
||||
class PoolerActivation(BasePoolerActivation):
|
||||
|
||||
pooled_data = list[torch.Tensor]()
|
||||
returned_token_ids = self.returned_token_ids
|
||||
step_tag_id = self.step_tag_id
|
||||
@staticmethod
|
||||
def wraps(module: nn.Module):
|
||||
if isinstance(module, nn.Identity):
|
||||
return PoolerIdentity()
|
||||
if isinstance(module, (nn.Sigmoid, nn.Softmax)):
|
||||
return PoolerClassify()
|
||||
|
||||
for data, token_id in zip(pooled_data_lst, prompt_token_ids):
|
||||
if returned_token_ids is not None and len(returned_token_ids) > 0:
|
||||
data = data[:, returned_token_ids]
|
||||
return LambdaPoolerActivation(module)
|
||||
|
||||
if step_tag_id is not None:
|
||||
data = data[token_id == step_tag_id]
|
||||
pooled_data.append(data)
|
||||
@abstractmethod
|
||||
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, pooled_data: _T) -> _T:
|
||||
if isinstance(pooled_data, list):
|
||||
return [self.forward_chunk(data) for data in pooled_data]
|
||||
|
||||
return self.forward_chunk(pooled_data)
|
||||
|
||||
|
||||
class PoolerIdentity(PoolerActivation):
|
||||
|
||||
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
|
||||
return pooled_data
|
||||
|
||||
|
||||
class PoolerNormalize(PoolerActivation):
|
||||
|
||||
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
|
||||
x = F.normalize(pooled_data.float(), p=2, dim=-1)
|
||||
return x.to(pooled_data.dtype)
|
||||
|
||||
|
||||
class PoolerClassify(PoolerActivation):
|
||||
|
||||
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
|
||||
num_labels = pooled_data.shape[-1]
|
||||
if num_labels < 2:
|
||||
return F.sigmoid(pooled_data.float()).to(pooled_data.dtype)
|
||||
|
||||
return F.softmax(pooled_data.float(), dim=-1).to(pooled_data.dtype)
|
||||
|
||||
|
||||
class PoolerScore(PoolerActivation):
|
||||
|
||||
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
|
||||
num_labels = pooled_data.shape[-1]
|
||||
if num_labels < 2:
|
||||
return F.sigmoid(pooled_data.float()).to(pooled_data.dtype)
|
||||
|
||||
return pooled_data
|
||||
|
||||
|
||||
class PoolerHead(nn.Module):
|
||||
class LambdaPoolerActivation(PoolerActivation):
|
||||
|
||||
def __init__(self, *, normalize: bool, softmax: bool) -> None:
|
||||
def __init__(self, fn: Callable[[torch.Tensor], torch.Tensor]):
|
||||
super().__init__()
|
||||
|
||||
self.normalize = normalize
|
||||
self.softmax = softmax
|
||||
self.fn = fn
|
||||
|
||||
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
|
||||
return self.fn(pooled_data)
|
||||
|
||||
|
||||
class PoolerHead(nn.Module):
|
||||
|
||||
@classmethod
|
||||
def from_config_with_defaults(
|
||||
cls,
|
||||
pooler_config: PoolerConfig,
|
||||
pooling_type: PoolingType,
|
||||
normalize: bool,
|
||||
softmax: bool,
|
||||
) -> "PoolerHead":
|
||||
resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
|
||||
pooler_config=pooler_config,
|
||||
pooling_type=pooling_type,
|
||||
normalize=normalize,
|
||||
softmax=softmax,
|
||||
step_tag_id=None,
|
||||
returned_token_ids=None,
|
||||
)
|
||||
|
||||
return cls.from_config(resolved_config)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, pooler_config: ResolvedPoolingConfig) -> "PoolerHead":
|
||||
if pooler_config.normalize and pooler_config.softmax:
|
||||
raise ValueError("`normalize=True` and `softmax=True` should not "
|
||||
"be set together")
|
||||
|
||||
activation: PoolerActivation
|
||||
if pooler_config.normalize:
|
||||
activation = PoolerNormalize()
|
||||
elif pooler_config.softmax:
|
||||
activation = PoolerClassify()
|
||||
else:
|
||||
activation = PoolerIdentity()
|
||||
|
||||
return cls(activation)
|
||||
|
||||
def __init__(self, activation: PoolerActivation) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.activation = activation
|
||||
|
||||
def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
|
||||
pooling_metadata: PoolingMetadata):
|
||||
@ -312,35 +421,21 @@ class PoolerHead(nn.Module):
|
||||
for vecs, d in zip(pooled_data, dimensions_list)
|
||||
]
|
||||
|
||||
if self.normalize:
|
||||
if isinstance(pooled_data, list):
|
||||
pooled_data = [
|
||||
F.normalize(data, p=2, dim=-1) for data in pooled_data
|
||||
]
|
||||
else:
|
||||
pooled_data = F.normalize(pooled_data, p=2, dim=-1)
|
||||
|
||||
if self.softmax:
|
||||
if isinstance(pooled_data, list):
|
||||
pooled_data = [
|
||||
F.softmax(data, dim=-1)
|
||||
if data.shape[-1] >= 2 else F.sigmoid(data)
|
||||
for data in pooled_data
|
||||
]
|
||||
else:
|
||||
if pooled_data.shape[-1] >= 2:
|
||||
pooled_data = F.softmax(pooled_data, dim=-1)
|
||||
else:
|
||||
pooled_data = F.sigmoid(pooled_data)
|
||||
|
||||
# shape:
|
||||
# classify (& score) -> (batch_size, num_classes)
|
||||
# embed -> (batch_size, embedding_dim) or list(embedding_dim)
|
||||
# (batch_size, dimensions) or list(dimensions) if using MRL
|
||||
return pooled_data
|
||||
return self.activation(pooled_data)
|
||||
|
||||
|
||||
class Pooler(nn.Module):
|
||||
class SimplePooler(BasePooler):
|
||||
"""A layer that pools specific information from hidden states.
|
||||
|
||||
This layer does the following:
|
||||
1. Extracts specific tokens or aggregates data based on pooling method.
|
||||
2. Normalizes output if specified.
|
||||
3. Returns structured results as `PoolerOutput`.
|
||||
|
||||
Attributes:
|
||||
pooling_type: The type of pooling to use.
|
||||
normalize: Whether to normalize the pooled data.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_config_with_defaults(
|
||||
@ -349,23 +444,146 @@ class Pooler(nn.Module):
|
||||
pooling_type: PoolingType,
|
||||
normalize: bool,
|
||||
softmax: bool,
|
||||
) -> "SimplePooler":
|
||||
resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
|
||||
pooler_config=pooler_config,
|
||||
pooling_type=pooling_type,
|
||||
normalize=normalize,
|
||||
softmax=softmax,
|
||||
)
|
||||
assert resolved_config.pooling_type != PoolingType.STEP
|
||||
|
||||
return cls.from_config(resolved_config)
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
pooler_config: ResolvedPoolingConfig,
|
||||
) -> "SimplePooler":
|
||||
pooling = PoolingMethod.from_pooling_type(pooler_config.pooling_type)
|
||||
head = PoolerHead.from_config(pooler_config)
|
||||
|
||||
return cls(pooling, head)
|
||||
|
||||
def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.pooling = pooling
|
||||
self.head = head
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
||||
pooled_data = self.head(pooled_data, pooling_metadata)
|
||||
return build_output(pooled_data)
|
||||
|
||||
|
||||
class StepPooler(BasePooler):
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, pooler_config: ResolvedPoolingConfig) -> "StepPooler":
|
||||
assert pooler_config.pooling_type == PoolingType.STEP
|
||||
|
||||
return cls(
|
||||
PoolerHead.from_config(pooler_config),
|
||||
step_tag_id=pooler_config.step_tag_id,
|
||||
returned_token_ids=pooler_config.returned_token_ids,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head: PoolerHead,
|
||||
*,
|
||||
step_tag_id: Optional[int] = None,
|
||||
returned_token_ids: Optional[list[int]] = None,
|
||||
) -> SimplePooler:
|
||||
return SimplePooler.from_pooling_type(
|
||||
pooling_type=PoolingType[pooler_config.pooling_type]
|
||||
if pooler_config.pooling_type is not None else pooling_type,
|
||||
normalize=pooler_config.normalize
|
||||
if pooler_config.normalize is not None else normalize,
|
||||
softmax=pooler_config.softmax
|
||||
if pooler_config.softmax is not None else softmax,
|
||||
step_tag_id=pooler_config.step_tag_id
|
||||
if pooler_config.step_tag_id is not None else step_tag_id,
|
||||
returned_token_ids=pooler_config.returned_token_ids
|
||||
if pooler_config.returned_token_ids is not None else
|
||||
returned_token_ids,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.pooling = AllPool()
|
||||
self.head = head
|
||||
self.step_tag_id = step_tag_id
|
||||
self.returned_token_ids = returned_token_ids
|
||||
|
||||
def get_prompt_token_ids(
|
||||
self,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> list[torch.Tensor]:
|
||||
if isinstance(pooling_metadata, V1PoolingMetadata):
|
||||
return [
|
||||
pooling_metadata.prompt_token_ids[i, :num]
|
||||
for i, num in enumerate(pooling_metadata.prompt_lens)
|
||||
]
|
||||
return [
|
||||
torch.tensor(seq_data_i.prompt_token_ids)
|
||||
for seq_data_i in pooling_metadata.seq_data.values()
|
||||
]
|
||||
|
||||
def extract_states(
|
||||
self,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||
pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
|
||||
prompt_token_ids = self.get_prompt_token_ids(pooling_metadata)
|
||||
|
||||
pooled_data = list[torch.Tensor]()
|
||||
returned_token_ids = self.returned_token_ids
|
||||
step_tag_id = self.step_tag_id
|
||||
|
||||
for data, token_id in zip(pooled_data_lst, prompt_token_ids):
|
||||
if returned_token_ids is not None and len(returned_token_ids) > 0:
|
||||
data = data[:, returned_token_ids]
|
||||
|
||||
if step_tag_id is not None:
|
||||
data = data[token_id == step_tag_id]
|
||||
pooled_data.append(data)
|
||||
|
||||
return pooled_data
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
pooled_data = self.extract_states(hidden_states, pooling_metadata)
|
||||
pooled_data = self.head(pooled_data, pooling_metadata)
|
||||
return build_output(pooled_data)
|
||||
|
||||
|
||||
class Pooler(nn.Module):
|
||||
|
||||
@staticmethod
|
||||
def from_config_with_defaults(
|
||||
pooler_config: PoolerConfig,
|
||||
pooling_type: PoolingType,
|
||||
normalize: bool,
|
||||
softmax: bool,
|
||||
step_tag_id: Optional[int] = None,
|
||||
returned_token_ids: Optional[list[int]] = None,
|
||||
) -> BasePooler:
|
||||
resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
|
||||
pooler_config=pooler_config,
|
||||
pooling_type=pooling_type,
|
||||
normalize=normalize,
|
||||
softmax=softmax,
|
||||
step_tag_id=step_tag_id,
|
||||
returned_token_ids=returned_token_ids,
|
||||
)
|
||||
|
||||
if pooling_type == PoolingType.STEP:
|
||||
return StepPooler.from_config(resolved_config)
|
||||
|
||||
return SimplePooler.from_config(resolved_config)
|
||||
|
||||
|
||||
PoolingFn = Callable[
|
||||
[Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata],
|
||||
Union[torch.Tensor, list[torch.Tensor]]]
|
||||
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
|
||||
|
||||
|
||||
class ClassifierPooler(nn.Module):
|
||||
"""A pooling layer for classification tasks.
|
||||
@ -382,69 +600,39 @@ class ClassifierPooler(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: ModelConfig,
|
||||
classifier: nn.Module,
|
||||
pooler: Optional[nn.Module] = None,
|
||||
):
|
||||
pooling: PoolingFn,
|
||||
classifier: ClassifierFn,
|
||||
act_fn: Optional[PoolerActivation] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.pooling = pooling
|
||||
self.classifier = classifier
|
||||
self.pooler = pooler
|
||||
|
||||
self.classification_act_fn = get_classification_activation_function(
|
||||
config.hf_config)
|
||||
config.hf_config) if act_fn is None else act_fn
|
||||
self.cross_encoder_act_fn = get_cross_encoder_activation_function(
|
||||
config.hf_config)
|
||||
config.hf_config) if act_fn is None else act_fn
|
||||
|
||||
def _get_act_fn(self, use_cross_encoder: bool):
|
||||
return (self.cross_encoder_act_fn
|
||||
if use_cross_encoder else self.classification_act_fn)
|
||||
|
||||
def get_prompt_lens(
|
||||
self,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> torch.Tensor:
|
||||
if isinstance(pooling_metadata, V1PoolingMetadata):
|
||||
return pooling_metadata.prompt_lens
|
||||
assert isinstance(hidden_states, torch.Tensor)
|
||||
return PoolingTensors.from_pooling_metadata(
|
||||
pooling_metadata, hidden_states.device).prompt_lens
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
"""Pools sentence pair scores from the hidden_states."""
|
||||
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
|
||||
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
||||
|
||||
pooled_data = list[torch.Tensor]()
|
||||
if isinstance(hidden_states, list):
|
||||
for req_state, prompt_len in zip(hidden_states, prompt_lens):
|
||||
assert prompt_len == req_state.shape[0], \
|
||||
"partial prefill not supported with classifier"
|
||||
pooled_data = hidden_states
|
||||
# apply classifier once on the full batch if possible
|
||||
if isinstance(pooled_data, torch.Tensor):
|
||||
pooled_output = self.classifier(pooled_data)
|
||||
elif len({data.shape for data in pooled_data}) <= 1:
|
||||
pooled_output = self.classifier(torch.stack(pooled_data))
|
||||
else:
|
||||
offset = 0
|
||||
for prompt_len in prompt_lens:
|
||||
pooled_data_i = hidden_states[offset:offset + prompt_len]
|
||||
offset += prompt_len
|
||||
pooled_data.append(pooled_data_i)
|
||||
|
||||
pooled_data_lst = []
|
||||
for pooled_data_i in pooled_data:
|
||||
|
||||
if self.pooler is not None:
|
||||
final_shape_tensor = self.pooler(pooled_data_i)
|
||||
else:
|
||||
final_shape_tensor = self.classifier(pooled_data_i)
|
||||
|
||||
pooled_data_lst.append(final_shape_tensor)
|
||||
|
||||
pooled_output = torch.stack(pooled_data_lst)
|
||||
|
||||
if self.pooler is not None:
|
||||
# apply classifier once on the full batch if possible
|
||||
pooled_output = self.classifier(pooled_output)
|
||||
pooled_output = [self.classifier(data) for data in pooled_data]
|
||||
|
||||
if isinstance(pooling_metadata, V0PoolingMetadata):
|
||||
use_cross_encoder_list = [
|
||||
@ -469,5 +657,4 @@ class ClassifierPooler(nn.Module):
|
||||
pooled_output)
|
||||
])
|
||||
|
||||
pooled_outputs = [PoolingSequenceGroupOutput(data) for data in scores]
|
||||
return PoolerOutput(outputs=pooled_outputs)
|
||||
return build_output(scores)
|
||||
|
||||
@ -58,22 +58,27 @@ def _create_pooling_model_cls(
|
||||
) -> None:
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
|
||||
# These are not used in pooling models
|
||||
for attr in ("lm_head", "logits_processor"):
|
||||
if hasattr(self, attr):
|
||||
delattr(self, attr)
|
||||
|
||||
# If the model already defines a pooler instance, don't overwrite it
|
||||
if not getattr(self, "_pooler", None):
|
||||
self._init_pooler(vllm_config, prefix=prefix)
|
||||
|
||||
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
# If the model already defines a pooler instance, don't overwrite it
|
||||
if not getattr(self, "_pooler", None):
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=default_pooling_type,
|
||||
normalize=default_normalize,
|
||||
softmax=default_softmax,
|
||||
)
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=default_pooling_type,
|
||||
normalize=default_normalize,
|
||||
softmax=default_softmax,
|
||||
)
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
@ -165,7 +170,9 @@ def as_seq_cls_model(cls: _T) -> _T:
|
||||
|
||||
# Lazy import
|
||||
from vllm.model_executor.layers.linear import RowParallelLinear
|
||||
from vllm.model_executor.layers.pooler import PoolerOutput, PoolingType
|
||||
from vllm.model_executor.layers.pooler import (ClassifierPooler,
|
||||
PoolerOutput, PoolingType,
|
||||
SimplePooler)
|
||||
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@ -182,30 +189,40 @@ def as_seq_cls_model(cls: _T) -> _T:
|
||||
class ModelForSequenceClassification(ModelForPooling,
|
||||
SupportsCrossEncoding):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: "VllmConfig",
|
||||
prefix: str = "",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||
|
||||
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
self.task = vllm_config.model_config.task
|
||||
self.pooling_type = (
|
||||
vllm_config.model_config.pooler_config.pooling_type)
|
||||
self.score = RowParallelLinear(
|
||||
config.hidden_size,
|
||||
config.num_labels,
|
||||
input_is_parallel=False,
|
||||
bias=False,
|
||||
params_dtype=torch.float32,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "score"),
|
||||
)
|
||||
|
||||
self.score = RowParallelLinear(config.hidden_size,
|
||||
config.num_labels,
|
||||
quant_config=quant_config,
|
||||
input_is_parallel=False,
|
||||
bias=False,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "score"))
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
pooler = SimplePooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.LAST,
|
||||
normalize=False,
|
||||
softmax=True,
|
||||
)
|
||||
|
||||
self._pooler = ClassifierPooler(
|
||||
vllm_config.model_config,
|
||||
pooling=pooler.pooling,
|
||||
classifier=self._classifier,
|
||||
act_fn=pooler.head.activation,
|
||||
)
|
||||
|
||||
def _classifier(self, x: torch.Tensor):
|
||||
x, _ = self.score(x.float())
|
||||
return x
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -222,27 +239,7 @@ def as_seq_cls_model(cls: _T) -> _T:
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
|
||||
def get_logits(hidden_states):
|
||||
if isinstance(hidden_states, list):
|
||||
logits = [self.score(state)[0] for state in hidden_states]
|
||||
else:
|
||||
logits, _ = self.score(hidden_states)
|
||||
return logits
|
||||
|
||||
if self.pooling_type == PoolingType.ALL:
|
||||
logits = get_logits(hidden_states)
|
||||
return self._pooler(logits, pooling_metadata)
|
||||
else:
|
||||
hidden_states = self._pooler.extract_states(
|
||||
hidden_states, pooling_metadata)
|
||||
logits = get_logits(hidden_states)
|
||||
pooled_data = self._pooler.head(logits, pooling_metadata)
|
||||
|
||||
pooled_outputs = [
|
||||
self._pooler.build_output(data) for data in pooled_data
|
||||
]
|
||||
return PoolerOutput(outputs=pooled_outputs)
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
tokens = getattr(self.config, "classifier_from_token", None)
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -18,7 +18,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler,
|
||||
PoolingType)
|
||||
PoolingMethod, PoolingType)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
@ -84,14 +84,18 @@ class BertPooler(nn.Module):
|
||||
|
||||
def __init__(self, config: BertConfig):
|
||||
super().__init__()
|
||||
|
||||
self.pooling = PoolingMethod.from_pooling_type(PoolingType.CLS)
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.activation = nn.Tanh()
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
# We "pool" the model by simply taking the hidden state corresponding
|
||||
# to the first token.
|
||||
first_token_tensor = hidden_states[0, :]
|
||||
pooled_output = self.dense(first_token_tensor)
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
||||
pooled_output = self.pooling(hidden_states, pooling_metadata)
|
||||
pooled_output = self.dense(pooled_output)
|
||||
pooled_output = self.activation(pooled_output)
|
||||
return pooled_output
|
||||
|
||||
@ -472,8 +476,11 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
|
||||
embedding_class=BertEmbedding,
|
||||
add_pooling_layer=True)
|
||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||
self._pooler = ClassifierPooler(vllm_config.model_config,
|
||||
self.classifier, self.bert.pooler)
|
||||
self._pooler = ClassifierPooler(
|
||||
vllm_config.model_config,
|
||||
pooling=self.bert.pooler,
|
||||
classifier=self.classifier,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
loader = AutoWeightsLoader(self)
|
||||
|
||||
@ -9,7 +9,7 @@ import torch.nn as nn
|
||||
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.pooler import PoolerHead
|
||||
from vllm.model_executor.layers.pooler import PoolerHead, PoolerNormalize
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.model_executor.pooling_metadata import (PoolingMetadata,
|
||||
PoolingTensors)
|
||||
@ -49,7 +49,7 @@ class GritLMPooler(nn.Module):
|
||||
self.embed_pattern_ids = tokens_to_ids(
|
||||
["▁<", "|", "embed", "|", ">", "<0x0A>"])
|
||||
|
||||
self.head = PoolerHead(normalize=True, softmax=False)
|
||||
self.head = PoolerHead(PoolerNormalize())
|
||||
|
||||
def _find_array(self, arr: array, target: array, start_idx: int) -> int:
|
||||
"""
|
||||
|
||||
@ -659,7 +659,7 @@ def supports_cross_encoding(
|
||||
def has_step_pooler(model: Union[type[object], object]) -> bool:
|
||||
"""Check if the model uses step pooler."""
|
||||
return is_pooling_model(model) and any(
|
||||
type(module).__name__ == "StepPool" for module in model.modules())
|
||||
type(module).__name__ == "StepPooler" for module in model.modules())
|
||||
|
||||
|
||||
class SupportsQuant:
|
||||
|
||||
@ -19,7 +19,8 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||
from vllm.model_executor.layers.pooler import (ClassifierPooler, PoolingType,
|
||||
SimplePooler)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
@ -564,29 +565,41 @@ class JambaForSequenceClassification(JambaForCausalLM):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
num_labels: int = config.num_labels
|
||||
score_bias: bool = getattr(config, 'score_bias', False)
|
||||
self.score = nn.Linear(config.hidden_size, num_labels, bias=score_bias)
|
||||
|
||||
# TODO: The original reward weights have float32 accuracy data, we
|
||||
# would like to load them in fp32 to get that extra precision.
|
||||
# Currently weight_loader passes the weight which is already in bf16
|
||||
self.score = nn.Linear(
|
||||
config.hidden_size,
|
||||
num_labels,
|
||||
bias=score_bias,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
assert pooler_config is not None
|
||||
|
||||
pooler = SimplePooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.LAST,
|
||||
normalize=False,
|
||||
softmax=False)
|
||||
softmax=False,
|
||||
)
|
||||
|
||||
self._pooler = ClassifierPooler(
|
||||
vllm_config.model_config,
|
||||
pooling=pooler.pooling,
|
||||
classifier=self.score,
|
||||
act_fn=pooler.head.activation,
|
||||
)
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
hidden_states = hidden_states.float()
|
||||
logits = self.score(hidden_states)
|
||||
return self._pooler(logits, pooling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
# TODO: The reward weights themselves have float32 accuracy data, we
|
||||
# would like to load them in fp32 to get that extra precision.
|
||||
super().load_weights(weights)
|
||||
self.score = self.score.float()
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -13,7 +13,8 @@ from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.pooler import ClassifierPooler
|
||||
from vllm.model_executor.layers.pooler import (BasePooler, ClassifierPooler,
|
||||
PoolingMethod, PoolingType)
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
@ -252,10 +253,13 @@ class ModernBertModel(nn.Module):
|
||||
return norm_outputs
|
||||
|
||||
|
||||
class ModernBertPooler(nn.Module):
|
||||
class ModernBertPooler(BasePooler):
|
||||
|
||||
def __init__(self, config: ModernBertConfig):
|
||||
super().__init__()
|
||||
|
||||
pooling_type = PoolingType[config.classifier_pooling.upper()]
|
||||
self.pooling = PoolingMethod.from_pooling_type(pooling_type)
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size,
|
||||
config.classifier_bias)
|
||||
self.pooling_type = config.classifier_pooling
|
||||
@ -264,15 +268,12 @@ class ModernBertPooler(nn.Module):
|
||||
eps=config.norm_eps,
|
||||
bias=config.norm_bias)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
pooled_output = hidden_states
|
||||
if self.pooling_type == "mean":
|
||||
pooled_output = pooled_output.mean(dim=0, keepdim=False)
|
||||
elif self.pooling_type == "cls":
|
||||
pooled_output = pooled_output[0, :]
|
||||
else:
|
||||
raise ValueError("Pooling type should be either `cls` or `mean`, "
|
||||
f"but got {self.pooling_type}")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
||||
pooled_output = self.pooling(hidden_states, pooling_metadata)
|
||||
pooled_output = self.norm(self.act(self.dense(pooled_output)))
|
||||
return pooled_output
|
||||
|
||||
@ -287,9 +288,11 @@ class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
|
||||
self.model = ModernBertModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "modernbert"))
|
||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||
self._pooler = ClassifierPooler(vllm_config.model_config,
|
||||
self.classifier,
|
||||
ModernBertPooler(config))
|
||||
self._pooler = ClassifierPooler(
|
||||
vllm_config.model_config,
|
||||
pooling=ModernBertPooler(config),
|
||||
classifier=self.classifier,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
|
||||
|
||||
@ -9,7 +9,7 @@ from torch import nn
|
||||
from transformers import RobertaConfig
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.pooler import ClassifierPooler
|
||||
from vllm.model_executor.layers.pooler import ClassifierPooler, CLSPool
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel
|
||||
@ -106,8 +106,8 @@ class RobertaClassificationHead(nn.Module):
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
def forward(self, features, **kwargs):
|
||||
x = features[0, :] # take <s> token (equiv. to [CLS])
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# CLSPool has already been applied in `pooling`
|
||||
x = self.dense(x)
|
||||
x = torch.tanh(x)
|
||||
x = self.out_proj(x)
|
||||
@ -188,8 +188,11 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
||||
add_pooling_layer=False)
|
||||
self.classifier = RobertaClassificationHead(config)
|
||||
|
||||
self._pooler = ClassifierPooler(vllm_config.model_config,
|
||||
self.classifier)
|
||||
self._pooler = ClassifierPooler(
|
||||
vllm_config.model_config,
|
||||
pooling=CLSPool(),
|
||||
classifier=self.classifier,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
loader = AutoWeightsLoader(self)
|
||||
|
||||
@ -17,7 +17,6 @@ from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError,
|
||||
HFValidationError, LocalEntryNotFoundError,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError)
|
||||
from torch import nn
|
||||
from transformers import GenerationConfig, PretrainedConfig
|
||||
from transformers.models.auto.image_processing_auto import (
|
||||
get_image_processor_config)
|
||||
@ -44,7 +43,6 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
|
||||
# yapf: enable
|
||||
from vllm.transformers_utils.configs.mistral import adapt_config_dict
|
||||
from vllm.transformers_utils.utils import check_gguf_file
|
||||
from vllm.utils import resolve_obj_by_qualname
|
||||
|
||||
if envs.VLLM_USE_MODELSCOPE:
|
||||
from modelscope import AutoConfig
|
||||
@ -775,28 +773,6 @@ def try_get_generation_config(
|
||||
return None
|
||||
|
||||
|
||||
def get_classification_activation_function(config: PretrainedConfig):
|
||||
return nn.Sigmoid() if config.num_labels == 1 else nn.Softmax()
|
||||
|
||||
|
||||
def get_cross_encoder_activation_function(config: PretrainedConfig):
|
||||
function_name: Optional[str] = None
|
||||
if (hasattr(config, "sentence_transformers")
|
||||
and "activation_fn" in config.sentence_transformers):
|
||||
function_name = config.sentence_transformers["activation_fn"]
|
||||
elif (hasattr(config, "sbert_ce_default_activation_function")
|
||||
and config.sbert_ce_default_activation_function is not None):
|
||||
function_name = config.sbert_ce_default_activation_function
|
||||
|
||||
if function_name is not None:
|
||||
assert function_name.startswith("torch.nn.modules."), (
|
||||
"Loading of activation functions is restricted to "
|
||||
"torch.nn.modules for security reasons")
|
||||
return resolve_obj_by_qualname(function_name)()
|
||||
|
||||
return nn.Sigmoid() if config.num_labels == 1 else nn.Identity()
|
||||
|
||||
|
||||
def try_get_safetensors_metadata(
|
||||
model: str,
|
||||
*,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user