diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index d864a915a0732..b378a3db03225 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -1,22 +1,21 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - +from abc import ABC, abstractmethod +from dataclasses import dataclass from enum import IntEnum -from typing import Optional, Union +from typing import Callable, Optional, TypeVar, Union import torch import torch.nn as nn import torch.nn.functional as F -from typing_extensions import assert_never +from transformers import PretrainedConfig from vllm.config import ModelConfig, PoolerConfig from vllm.model_executor.pooling_metadata import ( # noqa: E501 PoolingMetadata as V0PoolingMetadata) from vllm.model_executor.pooling_metadata import PoolingTensors from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput -from vllm.transformers_utils.config import ( - get_classification_activation_function, - get_cross_encoder_activation_function) +from vllm.utils import resolve_obj_by_qualname from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata] @@ -31,140 +30,202 @@ class PoolingType(IntEnum): MEAN = 4 -class SimplePooler(nn.Module): - """A layer that pools specific information from hidden states. +@dataclass(frozen=True) +class ResolvedPoolingConfig: + pooling_type: PoolingType - This layer does the following: - 1. Extracts specific tokens or aggregates data based on pooling method. - 2. Normalizes output if specified. - 3. Returns structured results as `PoolerOutput`. + normalize: bool + softmax: bool + step_tag_id: Optional[int] + returned_token_ids: Optional[list[int]] - Attributes: - pooling_type: The type of pooling to use. - normalize: Whether to normalize the pooled data. - """ - - @staticmethod - def from_pooling_type( + @classmethod + def from_config_with_defaults( + cls, + pooler_config: PoolerConfig, pooling_type: PoolingType, - *, normalize: bool, softmax: bool, step_tag_id: Optional[int] = None, returned_token_ids: Optional[list[int]] = None, - ) -> "SimplePooler": - if pooling_type == PoolingType.LAST: - assert step_tag_id is None and returned_token_ids is None - return LastPool(normalize=normalize, softmax=softmax) - if pooling_type == PoolingType.ALL: - assert step_tag_id is None and returned_token_ids is None - return AllPool(normalize=normalize, softmax=softmax) - if pooling_type == PoolingType.CLS: - assert step_tag_id is None and returned_token_ids is None - return CLSPool(normalize=normalize, softmax=softmax) - if pooling_type == PoolingType.MEAN: - assert step_tag_id is None and returned_token_ids is None - return MeanPool(normalize=normalize, softmax=softmax) - if pooling_type == PoolingType.STEP: - return StepPool(normalize=normalize, - softmax=softmax, - step_tag_id=step_tag_id, - returned_token_ids=returned_token_ids) + ) -> "ResolvedPoolingConfig": + return cls( + pooling_type=PoolingType[pooler_config.pooling_type] + if pooler_config.pooling_type is not None else pooling_type, + normalize=pooler_config.normalize + if pooler_config.normalize is not None else normalize, + softmax=pooler_config.softmax + if pooler_config.softmax is not None else softmax, + step_tag_id=pooler_config.step_tag_id + if pooler_config.step_tag_id is not None else step_tag_id, + returned_token_ids=pooler_config.returned_token_ids + if pooler_config.returned_token_ids is not None else + returned_token_ids, + ) - assert_never(pooling_type) - def __init__(self, *, normalize: bool, softmax: bool) -> None: - super().__init__() +def get_prompt_lens( + hidden_states: Union[torch.Tensor, list[torch.Tensor]], + pooling_metadata: PoolingMetadata, +) -> torch.Tensor: + if isinstance(pooling_metadata, V1PoolingMetadata): + return pooling_metadata.prompt_lens - self.head = PoolerHead(normalize=normalize, softmax=softmax) + assert isinstance(hidden_states, torch.Tensor) + return PoolingTensors.from_pooling_metadata( + pooling_metadata, hidden_states.device).prompt_lens - def get_prompt_lens( - self, - hidden_states: Union[torch.Tensor, list[torch.Tensor]], - pooling_metadata: PoolingMetadata, - ) -> torch.Tensor: - if isinstance(pooling_metadata, V1PoolingMetadata): - return pooling_metadata.prompt_lens - assert isinstance(hidden_states, torch.Tensor) - return PoolingTensors.from_pooling_metadata( - pooling_metadata, hidden_states.device).prompt_lens - def extract_states( - self, - hidden_states: Union[torch.Tensor, list[torch.Tensor]], - pooling_metadata: PoolingMetadata, - ) -> Union[list[torch.Tensor], torch.Tensor]: - raise NotImplementedError +def get_classification_activation_function(config: PretrainedConfig): + return PoolerClassify() - def build_output(self, data: torch.Tensor) -> PoolingSequenceGroupOutput: - return PoolingSequenceGroupOutput(data) +def get_cross_encoder_activation_function(config: PretrainedConfig): + function_name: Optional[str] = None + if (hasattr(config, "sentence_transformers") + and "activation_fn" in config.sentence_transformers): + function_name = config.sentence_transformers["activation_fn"] + elif (hasattr(config, "sbert_ce_default_activation_function") + and config.sbert_ce_default_activation_function is not None): + function_name = config.sbert_ce_default_activation_function + + if function_name is not None: + assert function_name.startswith("torch.nn.modules."), ( + "Loading of activation functions is restricted to " + "torch.nn.modules for security reasons") + fn = resolve_obj_by_qualname(function_name)() + return PoolerActivation.wraps(fn) + + return PoolerScore() + + +def build_output(all_data: torch.Tensor) -> PoolerOutput: + all_outputs = [PoolingSequenceGroupOutput(data) for data in all_data] + return PoolerOutput(outputs=all_outputs) + + +class BasePooler(nn.Module): + + @abstractmethod def forward( self, hidden_states: Union[torch.Tensor, list[torch.Tensor]], pooling_metadata: PoolingMetadata, ) -> PoolerOutput: - pooled_data = self.extract_states(hidden_states, pooling_metadata) - pooled_data = self.head(pooled_data, pooling_metadata) - pooled_outputs = [self.build_output(data) for data in pooled_data] - return PoolerOutput(outputs=pooled_outputs) + raise NotImplementedError -class CLSPool(SimplePooler): +class PoolingMethod(nn.Module, ABC): - def extract_states( + @staticmethod + def from_pooling_type(pooling_type: PoolingType) -> "PoolingMethod": + if pooling_type == PoolingType.LAST: + return LastPool() + if pooling_type == PoolingType.ALL: + return AllPool() + if pooling_type == PoolingType.CLS: + return CLSPool() + if pooling_type == PoolingType.MEAN: + return MeanPool() + + raise NotImplementedError(f"Unsupported method: {pooling_type}") + + @abstractmethod + def forward_one( + self, + hidden_states: torch.Tensor, + prompt_len: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Note: + `prompt_len=None` means `prompt_len=len(hidden_states)`. + """ + raise NotImplementedError + + @abstractmethod + def forward_all( + self, + hidden_states: torch.Tensor, + prompt_lens: torch.Tensor, + ) -> Union[list[torch.Tensor], torch.Tensor]: + raise NotImplementedError + + def forward( self, hidden_states: Union[torch.Tensor, list[torch.Tensor]], pooling_metadata: PoolingMetadata, ) -> Union[list[torch.Tensor], torch.Tensor]: - prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata) + prompt_lens = get_prompt_lens(hidden_states, pooling_metadata) if isinstance(hidden_states, list): - result = [] - for req_state, prompt_len in zip(hidden_states, prompt_lens): - assert prompt_len == req_state.shape[0], \ - "partial prefill not supported with CLS pooling" - result.append(req_state[0]) - return result + return [ + self.forward_one(h, prompt_len) + for h, prompt_len in zip(hidden_states, prompt_lens) + ] + return self.forward_all(hidden_states, prompt_lens) + + +class CLSPool(PoolingMethod): + + def forward_one( + self, + hidden_states: torch.Tensor, + prompt_len: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + assert prompt_len is None or prompt_len == hidden_states.shape[0], \ + "partial prefill not supported with CLS pooling" + + return hidden_states[0] + + def forward_all( + self, + hidden_states: torch.Tensor, + prompt_lens: torch.Tensor, + ) -> Union[list[torch.Tensor], torch.Tensor]: first_token_flat_indices = torch.zeros_like(prompt_lens) first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1] return hidden_states[first_token_flat_indices] -class LastPool(SimplePooler): +class LastPool(PoolingMethod): - def extract_states( + def forward_one( self, - hidden_states: Union[torch.Tensor, list[torch.Tensor]], - pooling_metadata: PoolingMetadata, + hidden_states: torch.Tensor, + prompt_len: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return hidden_states[-1] + + def forward_all( + self, + hidden_states: torch.Tensor, + prompt_lens: torch.Tensor, ) -> Union[list[torch.Tensor], torch.Tensor]: - if isinstance(hidden_states, list): - return [h[-1] for h in hidden_states] - - prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata) - last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1 return hidden_states[last_token_flat_indices] -class AllPool(SimplePooler): +class AllPool(PoolingMethod): - def extract_states( + def forward_one( self, - hidden_states: Union[torch.Tensor, list[torch.Tensor]], - pooling_metadata: PoolingMetadata, + hidden_states: torch.Tensor, + prompt_len: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + assert prompt_len is None or prompt_len == hidden_states.shape[0], \ + "partial prefill not supported with ALL pooling" + + return hidden_states + + def forward_all( + self, + hidden_states: torch.Tensor, + prompt_lens: torch.Tensor, ) -> Union[list[torch.Tensor], torch.Tensor]: - prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata) - - if isinstance(hidden_states, list): - for req_state, prompt_len in zip(hidden_states, prompt_lens): - assert prompt_len == req_state.shape[0], \ - "partial prefill not supported with ALL pooling" - return hidden_states - offset = 0 pooled_data = list[torch.Tensor]() + for prompt_len in prompt_lens: pooled_data.append(hidden_states[offset:offset + prompt_len]) offset += prompt_len @@ -172,24 +233,23 @@ class AllPool(SimplePooler): return pooled_data -class MeanPool(SimplePooler): +class MeanPool(PoolingMethod): - def extract_states( + def forward_one( self, - hidden_states: Union[torch.Tensor, list[torch.Tensor]], - pooling_metadata: PoolingMetadata, + hidden_states: torch.Tensor, + prompt_len: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + assert prompt_len is None or prompt_len == hidden_states.shape[0], \ + "partial prefill not supported with MEAN pooling" + + return hidden_states.mean(dim=0, dtype=torch.float32) + + def forward_all( + self, + hidden_states: torch.Tensor, + prompt_lens: torch.Tensor, ) -> Union[list[torch.Tensor], torch.Tensor]: - prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata) - - if isinstance(hidden_states, list): - result = [] - for req_state, prompt_len in zip(hidden_states, prompt_lens): - assert prompt_len == req_state.shape[0], \ - "partial prefill not supported with mean pooling" - result.append(torch.mean(req_state, dim=0, - dtype=torch.float32)) - return result - # Use float32 for torch.cumsum in MeanPool, # otherwise precision will be lost significantly. cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32) @@ -203,78 +263,127 @@ class MeanPool(SimplePooler): hidden_states[start_indices]) / prompt_lens.unsqueeze(1) -class StepPool(SimplePooler): +_T = TypeVar("_T", torch.Tensor, list[torch.Tensor]) - def __init__( - self, - *, - normalize: bool, - softmax: bool, - step_tag_id: Optional[int] = None, - returned_token_ids: Optional[list[int]] = None, - ): - super().__init__(normalize=normalize, softmax=softmax) - self.step_tag_id = step_tag_id - self.returned_token_ids = returned_token_ids +class BasePoolerActivation(nn.Module, ABC): - def get_prompt_token_ids( - self, - pooling_metadata: PoolingMetadata, - ) -> list[torch.Tensor]: - if isinstance(pooling_metadata, V1PoolingMetadata): - return [ - pooling_metadata.prompt_token_ids[i, :num] - for i, num in enumerate(pooling_metadata.prompt_lens) - ] - return [ - torch.tensor(seq_data_i.prompt_token_ids) - for seq_data_i in pooling_metadata.seq_data.values() - ] + @abstractmethod + def forward(self, pooled_data: _T) -> _T: + # shape: + # classify (& score) -> (batch_size, num_classes) + # embed -> (batch_size, embedding_dim) or list(embedding_dim) + # (batch_size, dimensions) or list(dimensions) if using MRL + raise NotImplementedError - def extract_states( - self, - hidden_states: Union[torch.Tensor, list[torch.Tensor]], - pooling_metadata: PoolingMetadata, - ) -> Union[list[torch.Tensor], torch.Tensor]: - prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata) - prompt_token_ids = self.get_prompt_token_ids(pooling_metadata) - pooled_data_lst = list[torch.Tensor]() - if isinstance(hidden_states, list): - for req_state, prompt_len in zip(hidden_states, prompt_lens): - assert prompt_len == req_state.shape[0], \ - "partial prefill not supported with step pooling" - pooled_data_lst = hidden_states - else: - offset = 0 - for prompt_len in prompt_lens: - pooled_data_i = hidden_states[offset:offset + prompt_len] - offset += prompt_len - pooled_data_lst.append(pooled_data_i) +class PoolerActivation(BasePoolerActivation): - pooled_data = list[torch.Tensor]() - returned_token_ids = self.returned_token_ids - step_tag_id = self.step_tag_id + @staticmethod + def wraps(module: nn.Module): + if isinstance(module, nn.Identity): + return PoolerIdentity() + if isinstance(module, (nn.Sigmoid, nn.Softmax)): + return PoolerClassify() - for data, token_id in zip(pooled_data_lst, prompt_token_ids): - if returned_token_ids is not None and len(returned_token_ids) > 0: - data = data[:, returned_token_ids] + return LambdaPoolerActivation(module) - if step_tag_id is not None: - data = data[token_id == step_tag_id] - pooled_data.append(data) + @abstractmethod + def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + def forward(self, pooled_data: _T) -> _T: + if isinstance(pooled_data, list): + return [self.forward_chunk(data) for data in pooled_data] + + return self.forward_chunk(pooled_data) + + +class PoolerIdentity(PoolerActivation): + + def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: + return pooled_data + + +class PoolerNormalize(PoolerActivation): + + def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: + x = F.normalize(pooled_data.float(), p=2, dim=-1) + return x.to(pooled_data.dtype) + + +class PoolerClassify(PoolerActivation): + + def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: + num_labels = pooled_data.shape[-1] + if num_labels < 2: + return F.sigmoid(pooled_data.float()).to(pooled_data.dtype) + + return F.softmax(pooled_data.float(), dim=-1).to(pooled_data.dtype) + + +class PoolerScore(PoolerActivation): + + def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: + num_labels = pooled_data.shape[-1] + if num_labels < 2: + return F.sigmoid(pooled_data.float()).to(pooled_data.dtype) return pooled_data -class PoolerHead(nn.Module): +class LambdaPoolerActivation(PoolerActivation): - def __init__(self, *, normalize: bool, softmax: bool) -> None: + def __init__(self, fn: Callable[[torch.Tensor], torch.Tensor]): super().__init__() - self.normalize = normalize - self.softmax = softmax + self.fn = fn + + def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: + return self.fn(pooled_data) + + +class PoolerHead(nn.Module): + + @classmethod + def from_config_with_defaults( + cls, + pooler_config: PoolerConfig, + pooling_type: PoolingType, + normalize: bool, + softmax: bool, + ) -> "PoolerHead": + resolved_config = ResolvedPoolingConfig.from_config_with_defaults( + pooler_config=pooler_config, + pooling_type=pooling_type, + normalize=normalize, + softmax=softmax, + step_tag_id=None, + returned_token_ids=None, + ) + + return cls.from_config(resolved_config) + + @classmethod + def from_config(cls, pooler_config: ResolvedPoolingConfig) -> "PoolerHead": + if pooler_config.normalize and pooler_config.softmax: + raise ValueError("`normalize=True` and `softmax=True` should not " + "be set together") + + activation: PoolerActivation + if pooler_config.normalize: + activation = PoolerNormalize() + elif pooler_config.softmax: + activation = PoolerClassify() + else: + activation = PoolerIdentity() + + return cls(activation) + + def __init__(self, activation: PoolerActivation) -> None: + super().__init__() + + self.activation = activation def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], pooling_metadata: PoolingMetadata): @@ -312,35 +421,21 @@ class PoolerHead(nn.Module): for vecs, d in zip(pooled_data, dimensions_list) ] - if self.normalize: - if isinstance(pooled_data, list): - pooled_data = [ - F.normalize(data, p=2, dim=-1) for data in pooled_data - ] - else: - pooled_data = F.normalize(pooled_data, p=2, dim=-1) - - if self.softmax: - if isinstance(pooled_data, list): - pooled_data = [ - F.softmax(data, dim=-1) - if data.shape[-1] >= 2 else F.sigmoid(data) - for data in pooled_data - ] - else: - if pooled_data.shape[-1] >= 2: - pooled_data = F.softmax(pooled_data, dim=-1) - else: - pooled_data = F.sigmoid(pooled_data) - - # shape: - # classify (& score) -> (batch_size, num_classes) - # embed -> (batch_size, embedding_dim) or list(embedding_dim) - # (batch_size, dimensions) or list(dimensions) if using MRL - return pooled_data + return self.activation(pooled_data) -class Pooler(nn.Module): +class SimplePooler(BasePooler): + """A layer that pools specific information from hidden states. + + This layer does the following: + 1. Extracts specific tokens or aggregates data based on pooling method. + 2. Normalizes output if specified. + 3. Returns structured results as `PoolerOutput`. + + Attributes: + pooling_type: The type of pooling to use. + normalize: Whether to normalize the pooled data. + """ @classmethod def from_config_with_defaults( @@ -349,23 +444,146 @@ class Pooler(nn.Module): pooling_type: PoolingType, normalize: bool, softmax: bool, + ) -> "SimplePooler": + resolved_config = ResolvedPoolingConfig.from_config_with_defaults( + pooler_config=pooler_config, + pooling_type=pooling_type, + normalize=normalize, + softmax=softmax, + ) + assert resolved_config.pooling_type != PoolingType.STEP + + return cls.from_config(resolved_config) + + @classmethod + def from_config( + cls, + pooler_config: ResolvedPoolingConfig, + ) -> "SimplePooler": + pooling = PoolingMethod.from_pooling_type(pooler_config.pooling_type) + head = PoolerHead.from_config(pooler_config) + + return cls(pooling, head) + + def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None: + super().__init__() + + self.pooling = pooling + self.head = head + + def forward( + self, + hidden_states: Union[torch.Tensor, list[torch.Tensor]], + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + pooled_data = self.pooling(hidden_states, pooling_metadata) + pooled_data = self.head(pooled_data, pooling_metadata) + return build_output(pooled_data) + + +class StepPooler(BasePooler): + + @classmethod + def from_config(cls, pooler_config: ResolvedPoolingConfig) -> "StepPooler": + assert pooler_config.pooling_type == PoolingType.STEP + + return cls( + PoolerHead.from_config(pooler_config), + step_tag_id=pooler_config.step_tag_id, + returned_token_ids=pooler_config.returned_token_ids, + ) + + def __init__( + self, + head: PoolerHead, + *, step_tag_id: Optional[int] = None, returned_token_ids: Optional[list[int]] = None, - ) -> SimplePooler: - return SimplePooler.from_pooling_type( - pooling_type=PoolingType[pooler_config.pooling_type] - if pooler_config.pooling_type is not None else pooling_type, - normalize=pooler_config.normalize - if pooler_config.normalize is not None else normalize, - softmax=pooler_config.softmax - if pooler_config.softmax is not None else softmax, - step_tag_id=pooler_config.step_tag_id - if pooler_config.step_tag_id is not None else step_tag_id, - returned_token_ids=pooler_config.returned_token_ids - if pooler_config.returned_token_ids is not None else - returned_token_ids, + ) -> None: + super().__init__() + + self.pooling = AllPool() + self.head = head + self.step_tag_id = step_tag_id + self.returned_token_ids = returned_token_ids + + def get_prompt_token_ids( + self, + pooling_metadata: PoolingMetadata, + ) -> list[torch.Tensor]: + if isinstance(pooling_metadata, V1PoolingMetadata): + return [ + pooling_metadata.prompt_token_ids[i, :num] + for i, num in enumerate(pooling_metadata.prompt_lens) + ] + return [ + torch.tensor(seq_data_i.prompt_token_ids) + for seq_data_i in pooling_metadata.seq_data.values() + ] + + def extract_states( + self, + hidden_states: Union[torch.Tensor, list[torch.Tensor]], + pooling_metadata: PoolingMetadata, + ) -> Union[list[torch.Tensor], torch.Tensor]: + pooled_data_lst = self.pooling(hidden_states, pooling_metadata) + prompt_token_ids = self.get_prompt_token_ids(pooling_metadata) + + pooled_data = list[torch.Tensor]() + returned_token_ids = self.returned_token_ids + step_tag_id = self.step_tag_id + + for data, token_id in zip(pooled_data_lst, prompt_token_ids): + if returned_token_ids is not None and len(returned_token_ids) > 0: + data = data[:, returned_token_ids] + + if step_tag_id is not None: + data = data[token_id == step_tag_id] + pooled_data.append(data) + + return pooled_data + + def forward( + self, + hidden_states: Union[torch.Tensor, list[torch.Tensor]], + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + pooled_data = self.extract_states(hidden_states, pooling_metadata) + pooled_data = self.head(pooled_data, pooling_metadata) + return build_output(pooled_data) + + +class Pooler(nn.Module): + + @staticmethod + def from_config_with_defaults( + pooler_config: PoolerConfig, + pooling_type: PoolingType, + normalize: bool, + softmax: bool, + step_tag_id: Optional[int] = None, + returned_token_ids: Optional[list[int]] = None, + ) -> BasePooler: + resolved_config = ResolvedPoolingConfig.from_config_with_defaults( + pooler_config=pooler_config, + pooling_type=pooling_type, + normalize=normalize, + softmax=softmax, + step_tag_id=step_tag_id, + returned_token_ids=returned_token_ids, ) + if pooling_type == PoolingType.STEP: + return StepPooler.from_config(resolved_config) + + return SimplePooler.from_config(resolved_config) + + +PoolingFn = Callable[ + [Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata], + Union[torch.Tensor, list[torch.Tensor]]] +ClassifierFn = Callable[[torch.Tensor], torch.Tensor] + class ClassifierPooler(nn.Module): """A pooling layer for classification tasks. @@ -382,69 +600,39 @@ class ClassifierPooler(nn.Module): def __init__( self, config: ModelConfig, - classifier: nn.Module, - pooler: Optional[nn.Module] = None, - ): + pooling: PoolingFn, + classifier: ClassifierFn, + act_fn: Optional[PoolerActivation] = None, + ) -> None: super().__init__() + + self.pooling = pooling self.classifier = classifier - self.pooler = pooler self.classification_act_fn = get_classification_activation_function( - config.hf_config) + config.hf_config) if act_fn is None else act_fn self.cross_encoder_act_fn = get_cross_encoder_activation_function( - config.hf_config) + config.hf_config) if act_fn is None else act_fn def _get_act_fn(self, use_cross_encoder: bool): return (self.cross_encoder_act_fn if use_cross_encoder else self.classification_act_fn) - def get_prompt_lens( - self, - hidden_states: Union[torch.Tensor, list[torch.Tensor]], - pooling_metadata: PoolingMetadata, - ) -> torch.Tensor: - if isinstance(pooling_metadata, V1PoolingMetadata): - return pooling_metadata.prompt_lens - assert isinstance(hidden_states, torch.Tensor) - return PoolingTensors.from_pooling_metadata( - pooling_metadata, hidden_states.device).prompt_lens - def forward( self, hidden_states: Union[torch.Tensor, list[torch.Tensor]], pooling_metadata: PoolingMetadata, ) -> PoolerOutput: """Pools sentence pair scores from the hidden_states.""" - prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata) + pooled_data = self.pooling(hidden_states, pooling_metadata) - pooled_data = list[torch.Tensor]() - if isinstance(hidden_states, list): - for req_state, prompt_len in zip(hidden_states, prompt_lens): - assert prompt_len == req_state.shape[0], \ - "partial prefill not supported with classifier" - pooled_data = hidden_states + # apply classifier once on the full batch if possible + if isinstance(pooled_data, torch.Tensor): + pooled_output = self.classifier(pooled_data) + elif len({data.shape for data in pooled_data}) <= 1: + pooled_output = self.classifier(torch.stack(pooled_data)) else: - offset = 0 - for prompt_len in prompt_lens: - pooled_data_i = hidden_states[offset:offset + prompt_len] - offset += prompt_len - pooled_data.append(pooled_data_i) - - pooled_data_lst = [] - for pooled_data_i in pooled_data: - - if self.pooler is not None: - final_shape_tensor = self.pooler(pooled_data_i) - else: - final_shape_tensor = self.classifier(pooled_data_i) - - pooled_data_lst.append(final_shape_tensor) - - pooled_output = torch.stack(pooled_data_lst) - - if self.pooler is not None: - # apply classifier once on the full batch if possible - pooled_output = self.classifier(pooled_output) + pooled_output = [self.classifier(data) for data in pooled_data] if isinstance(pooling_metadata, V0PoolingMetadata): use_cross_encoder_list = [ @@ -469,5 +657,4 @@ class ClassifierPooler(nn.Module): pooled_output) ]) - pooled_outputs = [PoolingSequenceGroupOutput(data) for data in scores] - return PoolerOutput(outputs=pooled_outputs) + return build_output(scores) diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index dcdf69f773ad4..5c09ac306052b 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -58,22 +58,27 @@ def _create_pooling_model_cls( ) -> None: super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) + self.vllm_config = vllm_config + # These are not used in pooling models for attr in ("lm_head", "logits_processor"): if hasattr(self, attr): delattr(self, attr) + # If the model already defines a pooler instance, don't overwrite it + if not getattr(self, "_pooler", None): + self._init_pooler(vllm_config, prefix=prefix) + + def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - # If the model already defines a pooler instance, don't overwrite it - if not getattr(self, "_pooler", None): - self._pooler = Pooler.from_config_with_defaults( - pooler_config, - pooling_type=default_pooling_type, - normalize=default_normalize, - softmax=default_softmax, - ) + self._pooler = Pooler.from_config_with_defaults( + pooler_config, + pooling_type=default_pooling_type, + normalize=default_normalize, + softmax=default_softmax, + ) def pooler( self, @@ -165,7 +170,9 @@ def as_seq_cls_model(cls: _T) -> _T: # Lazy import from vllm.model_executor.layers.linear import RowParallelLinear - from vllm.model_executor.layers.pooler import PoolerOutput, PoolingType + from vllm.model_executor.layers.pooler import (ClassifierPooler, + PoolerOutput, PoolingType, + SimplePooler) from vllm.model_executor.models.interfaces import SupportsCrossEncoding from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors @@ -182,30 +189,40 @@ def as_seq_cls_model(cls: _T) -> _T: class ModelForSequenceClassification(ModelForPooling, SupportsCrossEncoding): - def __init__( - self, - *, - vllm_config: "VllmConfig", - prefix: str = "", - **kwargs: Any, - ) -> None: - super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) - + def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - self.vllm_config = vllm_config - self.task = vllm_config.model_config.task - self.pooling_type = ( - vllm_config.model_config.pooler_config.pooling_type) + self.score = RowParallelLinear( + config.hidden_size, + config.num_labels, + input_is_parallel=False, + bias=False, + params_dtype=torch.float32, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "score"), + ) - self.score = RowParallelLinear(config.hidden_size, - config.num_labels, - quant_config=quant_config, - input_is_parallel=False, - bias=False, - prefix=maybe_prefix( - prefix, "score")) + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + pooler = SimplePooler.from_config_with_defaults( + pooler_config, + pooling_type=PoolingType.LAST, + normalize=False, + softmax=True, + ) + + self._pooler = ClassifierPooler( + vllm_config.model_config, + pooling=pooler.pooling, + classifier=self._classifier, + act_fn=pooler.head.activation, + ) + + def _classifier(self, x: torch.Tensor): + x, _ = self.score(x.float()) + return x def forward( self, @@ -222,27 +239,7 @@ def as_seq_cls_model(cls: _T) -> _T: hidden_states: Union[torch.Tensor, list[torch.Tensor]], pooling_metadata: PoolingMetadata, ) -> PoolerOutput: - - def get_logits(hidden_states): - if isinstance(hidden_states, list): - logits = [self.score(state)[0] for state in hidden_states] - else: - logits, _ = self.score(hidden_states) - return logits - - if self.pooling_type == PoolingType.ALL: - logits = get_logits(hidden_states) - return self._pooler(logits, pooling_metadata) - else: - hidden_states = self._pooler.extract_states( - hidden_states, pooling_metadata) - logits = get_logits(hidden_states) - pooled_data = self._pooler.head(logits, pooling_metadata) - - pooled_outputs = [ - self._pooler.build_output(data) for data in pooled_data - ] - return PoolerOutput(outputs=pooled_outputs) + return self._pooler(hidden_states, pooling_metadata) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): tokens = getattr(self.config, "classifier_from_token", None) diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index a43803ed43333..65e6428f4912c 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable -from typing import Optional +from typing import Optional, Union import torch from torch import nn @@ -18,7 +18,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler, - PoolingType) + PoolingMethod, PoolingType) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -84,14 +84,18 @@ class BertPooler(nn.Module): def __init__(self, config: BertConfig): super().__init__() + + self.pooling = PoolingMethod.from_pooling_type(PoolingType.CLS) self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.activation = nn.Tanh() - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. - first_token_tensor = hidden_states[0, :] - pooled_output = self.dense(first_token_tensor) + def forward( + self, + hidden_states: Union[torch.Tensor, list[torch.Tensor]], + pooling_metadata: PoolingMetadata, + ) -> Union[torch.Tensor, list[torch.Tensor]]: + pooled_output = self.pooling(hidden_states, pooling_metadata) + pooled_output = self.dense(pooled_output) pooled_output = self.activation(pooled_output) return pooled_output @@ -472,8 +476,11 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only, embedding_class=BertEmbedding, add_pooling_layer=True) self.classifier = nn.Linear(config.hidden_size, config.num_labels) - self._pooler = ClassifierPooler(vllm_config.model_config, - self.classifier, self.bert.pooler) + self._pooler = ClassifierPooler( + vllm_config.model_config, + pooling=self.bert.pooler, + classifier=self.classifier, + ) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index 4273afbf46998..dfec8a51c4c20 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -9,7 +9,7 @@ import torch.nn as nn from vllm.config import ModelConfig, VllmConfig from vllm.logger import init_logger -from vllm.model_executor.layers.pooler import PoolerHead +from vllm.model_executor.layers.pooler import PoolerHead, PoolerNormalize from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.pooling_metadata import (PoolingMetadata, PoolingTensors) @@ -49,7 +49,7 @@ class GritLMPooler(nn.Module): self.embed_pattern_ids = tokens_to_ids( ["▁<", "|", "embed", "|", ">", "<0x0A>"]) - self.head = PoolerHead(normalize=True, softmax=False) + self.head = PoolerHead(PoolerNormalize()) def _find_array(self, arr: array, target: array, start_idx: int) -> int: """ diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 92ecb8972d55b..9655bdf6f3e3a 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -659,7 +659,7 @@ def supports_cross_encoding( def has_step_pooler(model: Union[type[object], object]) -> bool: """Check if the model uses step pooler.""" return is_pooling_model(model) and any( - type(module).__name__ == "StepPool" for module in model.modules()) + type(module).__name__ == "StepPooler" for module in model.modules()) class SupportsQuant: diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 8294f846bbd10..233c222963be2 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -19,7 +19,8 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer -from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.layers.pooler import (ClassifierPooler, PoolingType, + SimplePooler) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) @@ -564,29 +565,41 @@ class JambaForSequenceClassification(JambaForCausalLM): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) + config = vllm_config.model_config.hf_config num_labels: int = config.num_labels score_bias: bool = getattr(config, 'score_bias', False) - self.score = nn.Linear(config.hidden_size, num_labels, bias=score_bias) + + # TODO: The original reward weights have float32 accuracy data, we + # would like to load them in fp32 to get that extra precision. + # Currently weight_loader passes the weight which is already in bf16 + self.score = nn.Linear( + config.hidden_size, + num_labels, + bias=score_bias, + dtype=torch.float32, + ) pooler_config = vllm_config.model_config.pooler_config - self._pooler = Pooler.from_config_with_defaults( + assert pooler_config is not None + + pooler = SimplePooler.from_config_with_defaults( pooler_config, pooling_type=PoolingType.LAST, normalize=False, - softmax=False) + softmax=False, + ) + + self._pooler = ClassifierPooler( + vllm_config.model_config, + pooling=pooler.pooling, + classifier=self.score, + act_fn=pooler.head.activation, + ) def pooler( self, hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata, ) -> Optional[PoolerOutput]: - hidden_states = hidden_states.float() - logits = self.score(hidden_states) - return self._pooler(logits, pooling_metadata) - - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - # TODO: The reward weights themselves have float32 accuracy data, we - # would like to load them in fp32 to get that extra precision. - super().load_weights(weights) - self.score = self.score.float() + return self._pooler(hidden_states, pooling_metadata) diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index 9d619b38d38d6..e094ff1635720 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable -from typing import Optional +from typing import Optional, Union import torch from torch import nn @@ -13,7 +13,8 @@ from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.linear import (QKVParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.pooler import ClassifierPooler +from vllm.model_executor.layers.pooler import (BasePooler, ClassifierPooler, + PoolingMethod, PoolingType) from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -252,10 +253,13 @@ class ModernBertModel(nn.Module): return norm_outputs -class ModernBertPooler(nn.Module): +class ModernBertPooler(BasePooler): def __init__(self, config: ModernBertConfig): super().__init__() + + pooling_type = PoolingType[config.classifier_pooling.upper()] + self.pooling = PoolingMethod.from_pooling_type(pooling_type) self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias) self.pooling_type = config.classifier_pooling @@ -264,15 +268,12 @@ class ModernBertPooler(nn.Module): eps=config.norm_eps, bias=config.norm_bias) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - pooled_output = hidden_states - if self.pooling_type == "mean": - pooled_output = pooled_output.mean(dim=0, keepdim=False) - elif self.pooling_type == "cls": - pooled_output = pooled_output[0, :] - else: - raise ValueError("Pooling type should be either `cls` or `mean`, " - f"but got {self.pooling_type}") + def forward( + self, + hidden_states: Union[torch.Tensor, list[torch.Tensor]], + pooling_metadata: PoolingMetadata, + ) -> Union[torch.Tensor, list[torch.Tensor]]: + pooled_output = self.pooling(hidden_states, pooling_metadata) pooled_output = self.norm(self.act(self.dense(pooled_output))) return pooled_output @@ -287,9 +288,11 @@ class ModernBertForSequenceClassification(nn.Module, SupportsV0Only, self.model = ModernBertModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "modernbert")) self.classifier = nn.Linear(config.hidden_size, config.num_labels) - self._pooler = ClassifierPooler(vllm_config.model_config, - self.classifier, - ModernBertPooler(config)) + self._pooler = ClassifierPooler( + vllm_config.model_config, + pooling=ModernBertPooler(config), + classifier=self.classifier, + ) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 1d3a23a5e5445..55ebb6e9e2a44 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -9,7 +9,7 @@ from torch import nn from transformers import RobertaConfig from vllm.config import VllmConfig -from vllm.model_executor.layers.pooler import ClassifierPooler +from vllm.model_executor.layers.pooler import ClassifierPooler, CLSPool from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel @@ -106,8 +106,8 @@ class RobertaClassificationHead(nn.Module): self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.out_proj = nn.Linear(config.hidden_size, config.num_labels) - def forward(self, features, **kwargs): - x = features[0, :] # take token (equiv. to [CLS]) + def forward(self, x: torch.Tensor) -> torch.Tensor: + # CLSPool has already been applied in `pooling` x = self.dense(x) x = torch.tanh(x) x = self.out_proj(x) @@ -188,8 +188,11 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding, add_pooling_layer=False) self.classifier = RobertaClassificationHead(config) - self._pooler = ClassifierPooler(vllm_config.model_config, - self.classifier) + self._pooler = ClassifierPooler( + vllm_config.model_config, + pooling=CLSPool(), + classifier=self.classifier, + ) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index cf3f519b027ca..db8f675bcc5ee 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -17,7 +17,6 @@ from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError, HFValidationError, LocalEntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError) -from torch import nn from transformers import GenerationConfig, PretrainedConfig from transformers.models.auto.image_processing_auto import ( get_image_processor_config) @@ -44,7 +43,6 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config, # yapf: enable from vllm.transformers_utils.configs.mistral import adapt_config_dict from vllm.transformers_utils.utils import check_gguf_file -from vllm.utils import resolve_obj_by_qualname if envs.VLLM_USE_MODELSCOPE: from modelscope import AutoConfig @@ -775,28 +773,6 @@ def try_get_generation_config( return None -def get_classification_activation_function(config: PretrainedConfig): - return nn.Sigmoid() if config.num_labels == 1 else nn.Softmax() - - -def get_cross_encoder_activation_function(config: PretrainedConfig): - function_name: Optional[str] = None - if (hasattr(config, "sentence_transformers") - and "activation_fn" in config.sentence_transformers): - function_name = config.sentence_transformers["activation_fn"] - elif (hasattr(config, "sbert_ce_default_activation_function") - and config.sbert_ce_default_activation_function is not None): - function_name = config.sbert_ce_default_activation_function - - if function_name is not None: - assert function_name.startswith("torch.nn.modules."), ( - "Loading of activation functions is restricted to " - "torch.nn.modules for security reasons") - return resolve_obj_by_qualname(function_name)() - - return nn.Sigmoid() if config.num_labels == 1 else nn.Identity() - - def try_get_safetensors_metadata( model: str, *,