diff --git a/docs/models/pooling_models.md b/docs/models/pooling_models.md index f0de84a66f8b0..eef8f20e4e5c6 100644 --- a/docs/models/pooling_models.md +++ b/docs/models/pooling_models.md @@ -11,26 +11,51 @@ before returning them. As shown in the [Compatibility Matrix](../features/compatibility_matrix.md), most vLLM features are not applicable to pooling models as they only work on the generation or decode stage, so performance may not improve as much. -For pooling models, we support the following `--task` options. -The selected option sets the default pooler used to extract the final hidden states: +If the model doesn't implement this interface, you can set `--task` which tells vLLM +to convert the model into a pooling model. -| Task | Pooling Type | Normalization | Softmax | -|---------------------------------|----------------|-----------------|-----------| -| Embedding (`embed`) | `LAST` | ✅︎ | ❌ | -| Classification (`classify`) | `LAST` | ❌ | ✅︎ | -| Sentence Pair Scoring (`score`) | \* | \* | \* | +| `--task` | Model type | Supported pooling tasks | +|------------|----------------------|-------------------------------| +| `embed` | Embedding model | `encode`, `embed` | +| `classify` | Classification model | `encode`, `classify`, `score` | +| `reward` | Reward model | `encode` | -\*The default pooler is always defined by the model. +## Pooling Tasks -!!! note - If the model's implementation in vLLM defines its own pooler, the default pooler is set to that instead of the one specified in this table. +In vLLM, we define the following pooling tasks and corresponding APIs: + +| Task | APIs | +|------------|--------------------| +| `encode` | `encode` | +| `embed` | `embed`, `score`\* | +| `classify` | `classify` | +| `score` | `score` | + +\*The `score` API falls back to `embed` task if the model does not support `score` task. + +Each pooling model in vLLM supports one or more of these tasks according to [Pooler.get_supported_tasks][vllm.model_executor.layers.Pooler.get_supported_tasks]. + +By default, the pooler assigned to each task has the following attributes: + +| Task | Pooling Type | Normalization | Softmax | +|------------|----------------|---------------|---------| +| `encode` | `ALL` | ❌ | ❌ | +| `embed` | `LAST` | ✅︎ | ❌ | +| `classify` | `LAST` | ❌ | ✅︎ | + +These defaults may be overridden by the model's implementation in vLLM. When loading [Sentence Transformers](https://huggingface.co/sentence-transformers) models, -we attempt to override the default pooler based on its Sentence Transformers configuration file (`modules.json`). +we attempt to override the defaults based on its Sentence Transformers configuration file (`modules.json`), +which takes priority over the model's defaults. -!!! tip - You can customize the model's pooling method via the `--override-pooler-config` option, - which takes priority over both the model's and Sentence Transformers's defaults. +You can further customize this via the `--override-pooler-config` option, +which takes priority over both the model's and Sentence Transformers's defaults. + +!!! note + + The above configuration may be disregarded if the model's implementation in vLLM defines its own pooler + that is not based on [PoolerConfig][vllm.config.PoolerConfig]. ## Offline Inference diff --git a/tests/models/test_transformers.py b/tests/models/test_transformers.py index b87290e96a27e..16b9bcffd2650 100644 --- a/tests/models/test_transformers.py +++ b/tests/models/test_transformers.py @@ -144,7 +144,7 @@ def test_quantization( "model", ["jason9693/Qwen2.5-1.5B-apeach"], ) -@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("dtype", ["float"]) def test_classify( hf_runner, vllm_runner, diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py index 797353e4f7a8b..fc654f20fff22 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py @@ -8,7 +8,7 @@ import torch import torch.nn as nn from vllm.config import VllmConfig -from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.layers.pooler import DispatchPooler, Pooler from vllm.model_executor.models.gemma2 import Gemma2Model from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix from vllm.sequence import IntermediateTensors @@ -26,12 +26,13 @@ class MyGemma2Embedding(nn.Module): self.model = Gemma2Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) - self.pooler = Pooler.from_config_with_defaults( - vllm_config.model_config.pooler_config, - pooling_type=PoolingType.LAST, - normalize=True, - softmax=False, - ) + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + self.pooler = DispatchPooler({ + "encode": Pooler.for_encode(pooler_config), + "embed": Pooler.for_embed(pooler_config), + }) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/config.py b/vllm/config.py index 44106dd279b6b..4cafbc9260525 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -94,7 +94,7 @@ ConfigT = TypeVar("ConfigT", bound=ConfigType) TaskOption = Literal["auto", "generate", "embedding", "embed", "classify", "score", "reward", "transcription", "draft"] -_ResolvedTask = Literal["generate", "transcription", "pooling", "embed", +_ResolvedTask = Literal["generate", "transcription", "encode", "embed", "classify", "reward", "draft"] RunnerOption = Literal["auto", "generate", "pooling", "draft"] @@ -103,7 +103,7 @@ RunnerType = Literal["generate", "pooling", "draft"] _RUNNER_TASKS: dict[RunnerType, list[_ResolvedTask]] = { "generate": ["generate", "transcription"], - "pooling": ["pooling", "embed", "classify", "reward"], + "pooling": ["encode", "embed", "classify", "reward"], "draft": [], } @@ -579,7 +579,7 @@ class ModelConfig: # user-selected task if runner_type == "pooling" and self.task == "auto": selected_task = all_supported_tasks[runner_type][-1] - assert selected_task != "pooling" + assert selected_task != "encode" self.task = selected_task self.supported_runner_types = supported_runner_types self.runner_type = runner_type @@ -884,7 +884,7 @@ class ModelConfig: supported_tasks = list[_ResolvedTask]() if registry.is_pooling_model(architectures): - supported_tasks.append("pooling") + supported_tasks.append("encode") # For now, users must specify the task (other than "pooling") # to use for pooling models diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 3f0c1c85dee61..57240bb4f3330 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1668,7 +1668,7 @@ async def init_app_state( request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, - ) if "pooling" in model_config.supported_tasks else None + ) if "encode" in model_config.supported_tasks else None state.openai_serving_embedding = OpenAIServingEmbedding( engine_client, model_config, diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 6a474b8e73a35..c06cca080227e 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -1,15 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod +from collections.abc import Mapping, Set from dataclasses import dataclass from enum import IntEnum +from itertools import groupby from typing import Callable, Optional, TypeVar, Union import torch import torch.nn as nn import torch.nn.functional as F from transformers import PretrainedConfig -from typing_extensions import assert_never from vllm.config import ModelConfig, PoolerConfig from vllm.model_executor.pooling_metadata import ( # noqa: E501 @@ -21,6 +22,10 @@ from vllm.utils import resolve_obj_by_qualname from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata] +PoolingFn = Callable[ + [Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata], + Union[torch.Tensor, list[torch.Tensor]]] +ClassifierFn = Callable[[torch.Tensor], torch.Tensor] class PoolingType(IntEnum): @@ -79,37 +84,81 @@ class Pooler(nn.Module, ABC): """The interface required for all poolers used in pooling models in vLLM.""" @staticmethod - def from_config_with_defaults( + def for_encode( pooler_config: PoolerConfig, - pooling_type: PoolingType, - normalize: bool, - softmax: bool, - step_tag_id: Optional[int] = None, - returned_token_ids: Optional[list[int]] = None, - ) -> "Pooler": + *, + default_pooling_type: PoolingType = PoolingType.ALL, + default_normalize: bool = False, + default_softmax: bool = False, + default_step_tag_id: Optional[int] = None, + default_returned_token_ids: Optional[list[int]] = None, + ): resolved_config = ResolvedPoolingConfig.from_config_with_defaults( pooler_config=pooler_config, - pooling_type=pooling_type, - normalize=normalize, - softmax=softmax, - step_tag_id=step_tag_id, - returned_token_ids=returned_token_ids, + pooling_type=default_pooling_type, + normalize=default_normalize, + softmax=default_softmax, + step_tag_id=default_step_tag_id, + returned_token_ids=default_returned_token_ids, ) - if pooling_type == PoolingType.STEP: + if resolved_config.pooling_type == PoolingType.STEP: return StepPooler.from_config(resolved_config) return SimplePooler.from_config(resolved_config) - def get_pooling_updates( - self, - task: PoolingTask, - ) -> Optional[PoolingParamsUpdate]: + @staticmethod + def for_embed( + pooler_config: PoolerConfig, + *, + default_pooling_type: PoolingType = PoolingType.LAST, + default_normalize: bool = True, + default_softmax: bool = False, + ): + resolved_config = ResolvedPoolingConfig.from_config_with_defaults( + pooler_config=pooler_config, + pooling_type=default_pooling_type, + normalize=default_normalize, + softmax=default_softmax, + ) + + return SimplePooler.from_config(resolved_config) + + @staticmethod + def for_classify( + pooler_config: PoolerConfig, + classifier: Optional[ClassifierFn], + *, + default_pooling_type: PoolingType = PoolingType.LAST, + default_normalize: bool = False, + default_softmax: bool = True, + ): + resolved_config = ResolvedPoolingConfig.from_config_with_defaults( + pooler_config=pooler_config, + pooling_type=default_pooling_type, + normalize=default_normalize, + softmax=default_softmax, + ) + base_pooler = SimplePooler.from_config(resolved_config) + if classifier is None: + return base_pooler + + return ClassifierPooler( + pooling=base_pooler.pooling, + classifier=classifier, + act_fn=base_pooler.head.activation, + ) + + @abstractmethod + def get_supported_tasks(self) -> Set[PoolingTask]: + """Determine which pooling tasks are supported.""" + raise NotImplementedError + + def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: """ - Construct the pooling parameters to use for a task, - or `None` if the task is not supported. + Construct the updated pooling parameters to use for a supported task. """ - return None + return PoolingParamsUpdate() @abstractmethod def forward( @@ -127,9 +176,8 @@ def get_prompt_lens( if isinstance(pooling_metadata, V1PoolingMetadata): return pooling_metadata.prompt_lens - assert isinstance(hidden_states, torch.Tensor) return PoolingTensors.from_pooling_metadata( - pooling_metadata, hidden_states.device).prompt_lens + pooling_metadata, hidden_states[0].device).prompt_lens def get_prompt_token_ids( @@ -149,6 +197,21 @@ def get_prompt_token_ids( ] +def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]: + if isinstance(pooling_metadata, V0PoolingMetadata): + pooling_params = [p for _, p in pooling_metadata.seq_groups] + else: + pooling_params = pooling_metadata.pooling_params + + tasks: list[PoolingTask] = [ + task for pooling_param in pooling_params + if (task := pooling_param.task) is not None + ] + assert len(pooling_params) == len(tasks) + + return tasks + + def get_classification_activation_function(config: PretrainedConfig): return PoolerClassify() @@ -172,7 +235,8 @@ def get_cross_encoder_activation_function(config: PretrainedConfig): return PoolerScore() -def build_output(all_data: torch.Tensor) -> PoolerOutput: +def build_output( + all_data: Union[torch.Tensor, list[torch.Tensor]], ) -> PoolerOutput: all_outputs = [PoolingSequenceGroupOutput(data) for data in all_data] return PoolerOutput(outputs=all_outputs) @@ -193,12 +257,12 @@ class PoolingMethod(nn.Module, ABC): raise NotImplementedError(f"Unsupported method: {pooling_type}") @abstractmethod - def get_pooling_updates( - self, - task: PoolingTask, - ) -> Optional[PoolingParamsUpdate]: + def get_supported_tasks(self) -> Set[PoolingTask]: raise NotImplementedError + def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: + return PoolingParamsUpdate() + @abstractmethod def forward_one( self, @@ -237,16 +301,8 @@ class PoolingMethod(nn.Module, ABC): class CLSPool(PoolingMethod): - def get_pooling_updates( - self, - task: PoolingTask, - ) -> Optional[PoolingParamsUpdate]: - # The equalities are split up to keep mypy happy - if (task == "encode" or task == "embed" or task == "classify" - or task == "score"): - return PoolingParamsUpdate() - - assert_never(task) + def get_supported_tasks(self) -> Set[PoolingTask]: + return {"encode", "embed", "classify", "score"} def forward_one( self, @@ -270,16 +326,8 @@ class CLSPool(PoolingMethod): class LastPool(PoolingMethod): - def get_pooling_updates( - self, - task: PoolingTask, - ) -> Optional[PoolingParamsUpdate]: - # The equalities are split up to keep mypy happy - if (task == "encode" or task == "embed" or task == "classify" - or task == "score"): - return PoolingParamsUpdate() - - assert_never(task) + def get_supported_tasks(self) -> Set[PoolingTask]: + return {"encode", "embed", "classify", "score"} def forward_one( self, @@ -299,18 +347,8 @@ class LastPool(PoolingMethod): class AllPool(PoolingMethod): - def get_pooling_updates( - self, - task: PoolingTask, - ) -> Optional[PoolingParamsUpdate]: - if task == "encode": - return PoolingParamsUpdate() - - # The equalities are split up to keep mypy happy - if task == "embed" or task == "classify" or task == "score": - return None - - assert_never(task) + def get_supported_tasks(self) -> Set[PoolingTask]: + return {"encode"} def forward_one( self, @@ -327,28 +365,13 @@ class AllPool(PoolingMethod): hidden_states: torch.Tensor, prompt_lens: torch.Tensor, ) -> Union[list[torch.Tensor], torch.Tensor]: - offset = 0 - pooled_data = list[torch.Tensor]() - - for prompt_len in prompt_lens: - pooled_data.append(hidden_states[offset:offset + prompt_len]) - offset += prompt_len - - return pooled_data + return list(hidden_states.split_with_sizes(prompt_lens.tolist())) class MeanPool(PoolingMethod): - def get_pooling_updates( - self, - task: PoolingTask, - ) -> Optional[PoolingParamsUpdate]: - # The equalities are split up to keep mypy happy - if (task == "encode" or task == "embed" or task == "classify" - or task == "score"): - return PoolingParamsUpdate() - - assert_never(task) + def get_supported_tasks(self) -> Set[PoolingTask]: + return {"encode", "embed", "classify", "score"} def forward_one( self, @@ -529,24 +552,6 @@ class SimplePooler(Pooler): 3. Returns structured results as `PoolerOutput`. """ - @classmethod - def from_config_with_defaults( # type: ignore[override] - cls, - pooler_config: PoolerConfig, - pooling_type: PoolingType, - normalize: bool, - softmax: bool, - ) -> "SimplePooler": - resolved_config = ResolvedPoolingConfig.from_config_with_defaults( - pooler_config=pooler_config, - pooling_type=pooling_type, - normalize=normalize, - softmax=softmax, - ) - assert resolved_config.pooling_type != PoolingType.STEP - - return cls.from_config(resolved_config) - @classmethod def from_config( cls, @@ -563,10 +568,10 @@ class SimplePooler(Pooler): self.pooling = pooling self.head = head - def get_pooling_updates( - self, - task: PoolingTask, - ) -> Optional[PoolingParamsUpdate]: + def get_supported_tasks(self) -> Set[PoolingTask]: + return self.pooling.get_supported_tasks() + + def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: return self.pooling.get_pooling_updates(task) def forward( @@ -627,18 +632,11 @@ class StepPooler(Pooler): return pooled_data - def get_pooling_updates( - self, - task: PoolingTask, - ) -> Optional[PoolingParamsUpdate]: - if task == "encode": - return PoolingParamsUpdate(requires_token_ids=True) + def get_supported_tasks(self) -> Set[PoolingTask]: + return {"encode"} - # The equalities are split up to keep mypy happy - if task == "embed" or task == "classify" or task == "score": - return None - - assert_never(task) + def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: + return PoolingParamsUpdate(requires_token_ids=True) def forward( self, @@ -650,68 +648,43 @@ class StepPooler(Pooler): return build_output(pooled_data) -PoolingFn = Callable[ - [Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata], - Union[torch.Tensor, list[torch.Tensor]]] -ClassifierFn = Callable[[torch.Tensor], torch.Tensor] - - -class ClassifierPooler(nn.Module): +class ClassifierPooler(Pooler): """A pooling layer for classification tasks. This layer does the following: 1. Applies a classification layer to the hidden states. 2. Optionally applies a pooler layer. - 3. Applies an activation function to the output. In the case of - classification models it is either sigmoid or softmax. In the - case of scoring models, the same behavior is configuration - dependent, as in the sentence-transformers library. + 3. Applies an activation function to the output. """ + @staticmethod + def act_fn_for_seq_cls(config: ModelConfig): + return get_classification_activation_function(config.hf_config) + + @staticmethod + def act_fn_for_cross_encoder(config: ModelConfig): + return get_cross_encoder_activation_function(config.hf_config) + def __init__( self, - config: ModelConfig, pooling: PoolingFn, classifier: ClassifierFn, - act_fn: Optional[PoolerActivation] = None, + act_fn: PoolerActivation, ) -> None: super().__init__() self.pooling = pooling self.classifier = classifier + self.act_fn = act_fn - self.classification_act_fn = get_classification_activation_function( - config.hf_config) if act_fn is None else act_fn - self.cross_encoder_act_fn = get_cross_encoder_activation_function( - config.hf_config) if act_fn is None else act_fn - - def _get_act_fn(self, task: PoolingTask): - if task == "encode" or task == "classify": - return self.classification_act_fn - if task == "score": - return self.cross_encoder_act_fn - - raise ValueError(f"Unsupported task: {task!r}") - - def get_pooling_updates( - self, - task: PoolingTask, - ) -> Optional[PoolingParamsUpdate]: - # The equalities are split up to keep mypy happy - if task == "encode" or task == "classify" or task == "score": - return PoolingParamsUpdate() - - if task == "embed": - return None - - assert_never(task) + def get_supported_tasks(self) -> Set[PoolingTask]: + return {"classify", "score"} def forward( self, hidden_states: Union[torch.Tensor, list[torch.Tensor]], pooling_metadata: PoolingMetadata, ) -> PoolerOutput: - """Pools sentence pair scores from the hidden_states.""" pooled_data = self.pooling(hidden_states, pooling_metadata) # apply classifier once on the full batch if possible @@ -722,28 +695,59 @@ class ClassifierPooler(nn.Module): else: pooled_output = [self.classifier(data) for data in pooled_data] - task_list: list[PoolingTask] - if isinstance(pooling_metadata, V0PoolingMetadata): - task_list = [ - task for _, pooling_param in pooling_metadata.seq_groups - if (task := pooling_param.task) is not None - ] - else: - task_list = [ - task for pooling_param in pooling_metadata.pooling_params - if (task := pooling_param.task) is not None - ] - - assert len(task_list) == len(pooled_output) - - # shape of scores: (batch_size, num_labels) - if len(set(task_list)) <= 1: - act_fn = self._get_act_fn(task_list[0]) - scores = act_fn(pooled_output) - else: - scores = torch.stack([ - self._get_act_fn(task)(vecs) - for task, vecs in zip(task_list, pooled_output) - ]) + scores = self.act_fn(pooled_output) return build_output(scores) + + +class DispatchPooler(Pooler): + """Dispatches calls to a sub-pooler based on the pooling task.""" + + def __init__(self, poolers_by_task: Mapping[PoolingTask, Pooler]) -> None: + super().__init__() + + for task, pooler in poolers_by_task.items(): + if task not in pooler.get_supported_tasks(): + raise ValueError( + f"{pooler=} does not support {task=}. " + f"Supported tasks: {pooler.get_supported_tasks()}") + + self.poolers_by_task = poolers_by_task + + def get_supported_tasks(self) -> Set[PoolingTask]: + return set(self.poolers_by_task) + + def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: + return self.poolers_by_task[task].get_pooling_updates(task) + + def forward( + self, + hidden_states: Union[torch.Tensor, list[torch.Tensor]], + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + poolers_by_task = self.poolers_by_task + + if isinstance(hidden_states, list): + hidden_states_lst = hidden_states + else: + prompt_lens = get_prompt_lens(hidden_states, pooling_metadata) + hidden_states_lst = list(hidden_states.split(prompt_lens.tolist())) + + outputs = list[PoolingSequenceGroupOutput]() + offset = 0 + for task, group in groupby(get_tasks(pooling_metadata)): + if not (pooler := poolers_by_task.get(task)): + raise ValueError( + f"Unsupported task: {task} " + f"Supported tasks: {self.get_supported_tasks()}") + + num_items = len(list(group)) + group_output: PoolerOutput = pooler( + hidden_states_lst[offset:offset + num_items], + pooling_metadata[offset:offset + num_items], + ) + + outputs.extend(group_output.outputs) + offset += num_items + + return PoolerOutput(outputs) diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index 31b1d9a8b3c0d..867de2c68b4c5 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -13,7 +13,6 @@ from .interfaces_base import VllmModelForPooling, is_pooling_model if TYPE_CHECKING: from vllm.config import VllmConfig - from vllm.model_executor.layers.pooler import PoolingType _T = TypeVar("_T", bound=type[nn.Module]) @@ -34,16 +33,8 @@ def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str: return model_name + pooling_suffix -def _create_pooling_model_cls( - orig_cls: _T, - *, - default_pooling_type: "PoolingType", - default_normalize: bool, - default_softmax: bool, -) -> _T: +def _create_pooling_model_cls(orig_cls: _T) -> _T: # Lazy import - from vllm.model_executor.layers.pooler import Pooler - from .utils import AutoWeightsLoader, WeightsMapper class ModelForPooling(orig_cls, VllmModelForPooling): @@ -71,15 +62,7 @@ def _create_pooling_model_cls( self._init_pooler(vllm_config, prefix=prefix) def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): - pooler_config = vllm_config.model_config.pooler_config - assert pooler_config is not None - - self.pooler = Pooler.from_config_with_defaults( - pooler_config, - pooling_type=default_pooling_type, - normalize=default_normalize, - softmax=default_softmax, - ) + raise NotImplementedError def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): # TODO: Support uninitialized params tracking @@ -132,14 +115,20 @@ def as_embedding_model(cls: _T) -> _T: return cls # Lazy import - from vllm.model_executor.layers.pooler import PoolingType + from vllm.model_executor.layers.pooler import DispatchPooler, Pooler + + class ModelForEmbedding(_create_pooling_model_cls(cls)): + + def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + self.pooler = DispatchPooler( + { + "encode": Pooler.for_encode(pooler_config), + "embed": Pooler.for_embed(pooler_config), + }, ) - ModelForEmbedding = _create_pooling_model_cls( - cls, - default_pooling_type=PoolingType.LAST, - default_normalize=True, - default_softmax=False, - ) ModelForEmbedding.__name__ = \ _get_pooling_model_name(cls.__name__, "ForEmbedding") @@ -165,20 +154,14 @@ def as_seq_cls_model(cls: _T) -> _T: # Lazy import from vllm.model_executor.layers.linear import RowParallelLinear from vllm.model_executor.layers.pooler import (ClassifierPooler, - PoolingType, SimplePooler) + DispatchPooler, Pooler, + PoolingMethod, PoolingType) from vllm.model_executor.models.interfaces import SupportsCrossEncoding from vllm.sequence import IntermediateTensors from .utils import maybe_prefix - ModelForPooling = _create_pooling_model_cls( - cls, - default_pooling_type=PoolingType.LAST, - default_normalize=False, - default_softmax=True, - ) - - class ModelForSequenceClassification(ModelForPooling, + class ModelForSequenceClassification(_create_pooling_model_cls(cls), SupportsCrossEncoding): def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): @@ -198,19 +181,28 @@ def as_seq_cls_model(cls: _T) -> _T: pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - pooler = SimplePooler.from_config_with_defaults( - pooler_config, - pooling_type=PoolingType.LAST, - normalize=False, - softmax=True, - ) + pooling_type_str = pooler_config.pooling_type + pooling_type = (PoolingType.LAST if pooling_type_str is None else + PoolingType[pooling_type_str]) - self.pooler = ClassifierPooler( - vllm_config.model_config, - pooling=pooler.pooling, - classifier=self._classifier, - act_fn=pooler.head.activation, - ) + self.pooler = DispatchPooler({ + "encode": + Pooler.for_encode(pooler_config), + "classify": + ClassifierPooler( + pooling=PoolingMethod.from_pooling_type(pooling_type), + classifier=self._classifier, + act_fn=ClassifierPooler.act_fn_for_seq_cls( + vllm_config.model_config), + ), + "score": + ClassifierPooler( + pooling=PoolingMethod.from_pooling_type(pooling_type), + classifier=self._classifier, + act_fn=ClassifierPooler.act_fn_for_cross_encoder( + vllm_config.model_config), + ), + }) def _classifier(self, x: torch.Tensor): x, _ = self.score(x.float()) @@ -259,14 +251,16 @@ def as_reward_model(cls: _T) -> _T: return cls # Lazy import - from vllm.model_executor.layers.pooler import PoolingType + from vllm.model_executor.layers.pooler import DispatchPooler, Pooler - ModelForReward = _create_pooling_model_cls( - cls, - default_pooling_type=PoolingType.ALL, - default_normalize=False, - default_softmax=False, - ) + class ModelForReward(_create_pooling_model_cls(cls)): + + def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + self.pooler = DispatchPooler( + {"encode": Pooler.for_encode(pooler_config)}, ) ModelForReward.__name__ = \ _get_pooling_model_name(cls.__name__, "ForReward") diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 006f547bb4617..9dc6115f850ec 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Iterable +from collections.abc import Iterable, Set from typing import Optional, Union import torch @@ -17,7 +17,8 @@ from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler, +from vllm.model_executor.layers.pooler import (ClassifierPooler, + DispatchPooler, Pooler, PoolingMethod, PoolingParamsUpdate, PoolingType) @@ -92,20 +93,29 @@ class BertPooler(Pooler): self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.activation = nn.Tanh() - def get_pooling_updates( - self, - task: PoolingTask, - ) -> Optional[PoolingParamsUpdate]: + def get_supported_tasks(self) -> Set[PoolingTask]: + return self.pooling.get_supported_tasks() + + def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: return self.pooling.get_pooling_updates(task) + def _head(self, pooled_output: torch.Tensor): + pooled_output = self.dense(pooled_output) + pooled_output = self.activation(pooled_output) + return pooled_output + def forward( self, hidden_states: Union[torch.Tensor, list[torch.Tensor]], pooling_metadata: PoolingMetadata, ) -> Union[torch.Tensor, list[torch.Tensor]]: pooled_output = self.pooling(hidden_states, pooling_metadata) - pooled_output = self.dense(pooled_output) - pooled_output = self.activation(pooled_output) + + if isinstance(pooled_output, list): + pooled_output = [self._head(output) for output in pooled_output] + else: + pooled_output = self._head(pooled_output) + return pooled_output @@ -333,18 +343,19 @@ class BertModel(nn.Module, SupportsQuant): packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]} - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = "", - embedding_class: type = BertEmbedding, - add_pooling_layer: bool = False): + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + embedding_class: type[nn.Module] = BertEmbedding, + ) -> None: super().__init__() + config = vllm_config.model_config.hf_config self.embeddings = embedding_class(config) self.encoder = BertEncoder(vllm_config=vllm_config, prefix=f"{prefix}.encoder") - self.pooler = BertPooler(config) if add_pooling_layer else None def forward( self, @@ -366,8 +377,7 @@ class BertModel(nn.Module, SupportsQuant): token_type_ids=token_type_ids) return self.encoder(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def _load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "query", "q"), @@ -395,10 +405,43 @@ class BertModel(nn.Module, SupportsQuant): if name in params_dict: other_weights.append((name, loaded_weight)) - loader = AutoWeightsLoader( - self, - skip_prefixes=(["pooler."] if self.pooler is None else []), + return other_weights, loaded_stacked_params + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + other_weights, loaded_stacked_params = self._load_weights(weights) + + loader = AutoWeightsLoader(self, skip_prefixes=["pooler."]) + loaded_params = loader.load_weights(other_weights) + loaded_params.update(loaded_stacked_params) + return loaded_params + + +class BertPoolingModel(BertModel): + + is_pooling_model = True + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + embedding_class: type[nn.Module] = BertEmbedding, + ) -> None: + super().__init__( + vllm_config=vllm_config, + prefix=prefix, + embedding_class=embedding_class, ) + + config = vllm_config.model_config.hf_config + self.pooler = BertPooler(config) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + other_weights, loaded_stacked_params = self._load_weights(weights) + + loader = AutoWeightsLoader(self) loaded_params = loader.load_weights(other_weights) loaded_params.update(loaded_stacked_params) return loaded_params @@ -421,6 +464,8 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant): super().__init__() pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + self.model = self._build_model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) self.pooler = self._build_pooler(pooler_config) @@ -456,10 +501,15 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant): embedding_class=BertEmbedding) def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: - return Pooler.from_config_with_defaults(pooler_config, - pooling_type=PoolingType.CLS, - normalize=True, - softmax=False) + return DispatchPooler({ + "encode": + Pooler.for_encode(pooler_config), + "embed": + Pooler.for_embed( + pooler_config, + default_pooling_type=PoolingType.CLS, + ), + }) class BertForSequenceClassification(nn.Module, SupportsV0Only, @@ -481,16 +531,32 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only, config = vllm_config.model_config.hf_config self.num_labels = config.num_labels - self.bert = BertModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "bert"), - embedding_class=BertEmbedding, - add_pooling_layer=True) + self.bert = BertPoolingModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "bert"), + embedding_class=BertEmbedding) self.classifier = nn.Linear(config.hidden_size, config.num_labels) - self.pooler = ClassifierPooler( - vllm_config.model_config, - pooling=self.bert.pooler, - classifier=self.classifier, - ) + + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + self.pooler = DispatchPooler({ + "encode": + Pooler.for_encode(pooler_config), + "classify": + ClassifierPooler( + pooling=self.bert.pooler, + classifier=self.classifier, + act_fn=ClassifierPooler.act_fn_for_seq_cls( + vllm_config.model_config), + ), + "score": + ClassifierPooler( + pooling=self.bert.pooler, + classifier=self.classifier, + act_fn=ClassifierPooler.act_fn_for_cross_encoder( + vllm_config.model_config), + ), + }) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 82883bfa890de..98d76337395b9 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -43,7 +43,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from ..layers.pooler import Pooler, PoolingType +from ..layers.pooler import DispatchPooler, Pooler from .interfaces import SupportsPP from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, @@ -339,12 +339,16 @@ class GPT2ForSequenceClassification(nn.Module): self.transformer = GPT2Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "gpt2")) self.score = nn.Linear(config.n_embd, config.num_labels, bias=False) + pooler_config = vllm_config.model_config.pooler_config - self.pooler = Pooler.from_config_with_defaults( - pooler_config, - pooling_type=PoolingType.LAST, - normalize=False, - softmax=True) + assert pooler_config is not None + + self.pooler = DispatchPooler({ + "encode": + Pooler.for_encode(pooler_config), + "classify": + Pooler.for_classify(pooler_config, classifier=None), + }) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index 8443482119b0c..8a3fbc6a49f04 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -1,17 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - +from collections.abc import Set from typing import Optional, Union import numpy as np import torch import torch.nn as nn -from typing_extensions import assert_never from vllm.config import ModelConfig, VllmConfig from vllm.logger import init_logger -from vllm.model_executor.layers.pooler import (Pooler, PoolerHead, - PoolerNormalize, +from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler, + PoolerHead, PoolerNormalize, PoolingParamsUpdate, build_output, get_prompt_lens, get_prompt_token_ids) @@ -135,18 +134,11 @@ class GritLMMeanPool(nn.Module): return instruction_len - def get_pooling_updates( - self, - task: PoolingTask, - ) -> Optional[PoolingParamsUpdate]: - # The equalities are split up to keep mypy happy - if task == "encode" or task == "embed": - return PoolingParamsUpdate(requires_token_ids=True) + def get_supported_tasks(self) -> Set[PoolingTask]: + return {"encode", "embed"} - if task == "classify" or task == "score": - return None - - assert_never(task) + def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: + return PoolingParamsUpdate(requires_token_ids=True) def forward_one( self, @@ -207,10 +199,10 @@ class GritLMPooler(Pooler): self.pooling = GritLMMeanPool(model_config) self.head = PoolerHead(PoolerNormalize()) - def get_pooling_updates( - self, - task: PoolingTask, - ) -> Optional[PoolingParamsUpdate]: + def get_supported_tasks(self) -> Set[PoolingTask]: + return self.pooling.get_supported_tasks() + + def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: return self.pooling.get_pooling_updates(task) def forward( @@ -262,4 +254,11 @@ class GritLM(LlamaForCausalLM, SupportsV0Only): super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) - self.pooler = GritLMPooler(vllm_config.model_config) + pooler_config = vllm_config.model_config.pooler_config + if pooler_config is not None: + self.pooler = DispatchPooler({ + "encode": + Pooler.for_encode(pooler_config), + "embed": + GritLMPooler(vllm_config.model_config), + }) diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index d9bbee0a2463c..d29779a35e5c9 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -22,7 +22,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.layers.pooler import DispatchPooler, Pooler from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -429,12 +429,10 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM): ) pooler_config = vllm_config.model_config.pooler_config - self.pooler = Pooler.from_config_with_defaults( - pooler_config, - pooling_type=PoolingType.ALL, - normalize=False, - softmax=False, - ) + assert pooler_config is not None + + self.pooler = DispatchPooler( + {"encode": Pooler.for_encode(pooler_config)}, ) def forward( self, diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index e95f3491c6b6e..34281b2e99ee8 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -19,8 +19,8 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer -from vllm.model_executor.layers.pooler import (ClassifierPooler, PoolingType, - SimplePooler) +from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler, + PoolingType) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) @@ -584,16 +584,15 @@ class JambaForSequenceClassification(JambaForCausalLM): pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - pooler = SimplePooler.from_config_with_defaults( - pooler_config, - pooling_type=PoolingType.LAST, - normalize=False, - softmax=False, - ) - - self.pooler = ClassifierPooler( - vllm_config.model_config, - pooling=pooler.pooling, - classifier=self.score, - act_fn=pooler.head.activation, - ) + self.pooler = DispatchPooler({ + "encode": + Pooler.for_encode(pooler_config), + "classify": + Pooler.for_classify( + pooler_config, + classifier=self.score, + default_pooling_type=PoolingType.LAST, + default_normalize=False, + default_softmax=False, + ), + }) diff --git a/vllm/model_executor/models/jina_vl.py b/vllm/model_executor/models/jina_vl.py index 6b191b09b4bfd..0c4284f7daaac 100644 --- a/vllm/model_executor/models/jina_vl.py +++ b/vllm/model_executor/models/jina_vl.py @@ -12,7 +12,7 @@ from vllm.inputs import TokensPrompt from vllm.logger import init_logger from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.layers.pooler import DispatchPooler, Pooler from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import IntermediateTensors @@ -96,11 +96,17 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration, self.score = JinaVLScorer(config) - self.pooler = Pooler.from_config_with_defaults( - pooler_config, - pooling_type=PoolingType.LAST, - normalize=False, - softmax=True) + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + self.pooler = DispatchPooler({ + "encode": + Pooler.for_encode(pooler_config), + "classify": + Pooler.for_classify(pooler_config, classifier=None), + "score": + Pooler.for_classify(pooler_config, classifier=None), + }) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index 74986f9f57340..be1c3438d9db1 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Iterable +from collections.abc import Iterable, Set from typing import Optional, Union import torch @@ -13,7 +13,8 @@ from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.linear import (QKVParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler, +from vllm.model_executor.layers.pooler import (ClassifierPooler, + DispatchPooler, Pooler, PoolingMethod, PoolingParamsUpdate, PoolingType) @@ -271,19 +272,27 @@ class ModernBertPooler(Pooler): eps=config.norm_eps, bias=config.norm_bias) - def get_pooling_updates( - self, - task: PoolingTask, - ) -> Optional[PoolingParamsUpdate]: + def get_supported_tasks(self) -> Set[PoolingTask]: + return self.pooling.get_supported_tasks() + + def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: return self.pooling.get_pooling_updates(task) + def _head(self, pooled_output: torch.Tensor): + return self.norm(self.act(self.dense(pooled_output))) + def forward( self, hidden_states: Union[torch.Tensor, list[torch.Tensor]], pooling_metadata: PoolingMetadata, ) -> Union[torch.Tensor, list[torch.Tensor]]: pooled_output = self.pooling(hidden_states, pooling_metadata) - pooled_output = self.norm(self.act(self.dense(pooled_output))) + + if isinstance(pooled_output, list): + pooled_output = [self._head(output) for output in pooled_output] + else: + pooled_output = self._head(pooled_output) + return pooled_output @@ -299,11 +308,28 @@ class ModernBertForSequenceClassification(nn.Module, SupportsV0Only, self.model = ModernBertModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "modernbert")) self.classifier = nn.Linear(config.hidden_size, config.num_labels) - self.pooler = ClassifierPooler( - vllm_config.model_config, - pooling=ModernBertPooler(config), - classifier=self.classifier, - ) + + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + self.pooler = DispatchPooler({ + "encode": + Pooler.for_encode(pooler_config), + "classify": + ClassifierPooler( + pooling=ModernBertPooler(config), + classifier=self.classifier, + act_fn=ClassifierPooler.act_fn_for_seq_cls( + vllm_config.model_config), + ), + "score": + ClassifierPooler( + pooling=ModernBertPooler(config), + classifier=self.classifier, + act_fn=ClassifierPooler.act_fn_for_cross_encoder( + vllm_config.model_config), + ), + }) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index 58f95d6eebfb4..f12e9a041a944 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -15,7 +15,8 @@ from torch import nn from vllm.config import VllmConfig from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.pooler import Pooler, PoolingType, SimplePooler +from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler, + PoolingType) from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -26,7 +27,7 @@ from .utils import AutoWeightsLoader, maybe_prefix class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP): is_pooling_model = True - pooler: SimplePooler + pooler: Pooler packed_modules_mapping = { "qkv_proj": [ @@ -94,12 +95,12 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): vllm_config.model_config.hf_config.num_labels = 1 super().__init__(vllm_config=vllm_config, prefix=prefix) + pooler_config = vllm_config.model_config.pooler_config - self.pooler = Pooler.from_config_with_defaults( - pooler_config, - pooling_type=PoolingType.ALL, - normalize=False, - softmax=False) + assert pooler_config is not None + + self.pooler = DispatchPooler( + {"encode": Pooler.for_encode(pooler_config)}, ) class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel): @@ -107,11 +108,17 @@ class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): vllm_config.model_config.hf_config.num_labels = 2 super().__init__(vllm_config=vllm_config, prefix=prefix) + pooler_config = vllm_config.model_config.pooler_config - self.pooler = Pooler.from_config_with_defaults( - pooler_config, - pooling_type=PoolingType.STEP, - normalize=False, - softmax=True, - step_tag_id=151651, - ) + assert pooler_config is not None + + self.pooler = DispatchPooler({ + "encode": + Pooler.for_encode( + pooler_config, + default_pooling_type=PoolingType.STEP, + default_normalize=False, + default_softmax=True, + default_step_tag_id=151651, + ) + }) diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 7d3b56ced5c40..c6b4116440346 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -9,7 +9,8 @@ from torch import nn from transformers import RobertaConfig from vllm.config import VllmConfig -from vllm.model_executor.layers.pooler import ClassifierPooler, CLSPool +from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool, + DispatchPooler, Pooler) from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel @@ -63,16 +64,10 @@ class RobertaEmbedding(nn.Module): # References: # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133 # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669 - pos_list = [] - token_list = [] - offset = 0 - for seq_len in seq_lens: - pos_list.append(position_ids[offset:offset + seq_len]) - token_list.append(input_ids[offset:offset + seq_len]) - offset += seq_len - + seq_lens_list = seq_lens.tolist() new_pos_list = [] - for positions, tokens in zip(pos_list, token_list): + for positions, tokens in zip(position_ids.split(seq_lens_list), + input_ids.split(seq_lens_list)): # Verify assumption that incoming position are # always a sequence from 0 to N. expected_pos = torch.arange(positions.size()[0], @@ -184,15 +179,30 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding, self.num_labels = config.num_labels self.roberta = BertModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "bert"), - embedding_class=RobertaEmbedding, - add_pooling_layer=False) + embedding_class=RobertaEmbedding) self.classifier = RobertaClassificationHead(config) - self.pooler = ClassifierPooler( - vllm_config.model_config, - pooling=CLSPool(), - classifier=self.classifier, - ) + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + self.pooler = DispatchPooler({ + "encode": + Pooler.for_encode(pooler_config), + "classify": + ClassifierPooler( + pooling=CLSPool(), + classifier=self.classifier, + act_fn=ClassifierPooler.act_fn_for_seq_cls( + vllm_config.model_config), + ), + "score": + ClassifierPooler( + pooling=CLSPool(), + classifier=self.classifier, + act_fn=ClassifierPooler.act_fn_for_cross_encoder( + vllm_config.model_config), + ), + }) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/pooling_metadata.py b/vllm/model_executor/pooling_metadata.py index 4dd443bc26ea0..e6f1ca61dd291 100644 --- a/vllm/model_executor/pooling_metadata.py +++ b/vllm/model_executor/pooling_metadata.py @@ -38,6 +38,13 @@ class PoolingMetadata: f"seq_data={self.seq_data}, " f"prompt_lens={self.prompt_lens})") + def __getitem__(self, indices: slice): + return PoolingMetadata( + seq_groups=self.seq_groups[indices], + seq_data=dict(list(self.seq_data.items())[indices]), + prompt_lens=self.prompt_lens[indices], + ) + @dataclass class PoolingTensors: diff --git a/vllm/v1/pool/metadata.py b/vllm/v1/pool/metadata.py index 5f321cd87c524..28af720d05fd1 100644 --- a/vllm/v1/pool/metadata.py +++ b/vllm/v1/pool/metadata.py @@ -15,3 +15,11 @@ class PoolingMetadata: prompt_lens: torch.Tensor prompt_token_ids: Optional[torch.Tensor] pooling_params: list[PoolingParams] + + def __getitem__(self, indices: slice): + return PoolingMetadata( + prompt_lens=self.prompt_lens[indices], + prompt_token_ids=None if self.prompt_token_ids is None else + self.prompt_token_ids[indices], + pooling_params=self.pooling_params[indices], + ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 670e653929cea..cd66d8bcd6342 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -5,7 +5,7 @@ import copy import gc import time from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Optional, Union, cast, get_args +from typing import TYPE_CHECKING, Any, Optional, Union, cast import numpy as np import torch @@ -415,15 +415,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): generator = None if pooling_params: - assert pooling_params.task is not None, ( + assert (task := pooling_params.task) is not None, ( "You did not set `task` in the API") model = cast(VllmModelForPooling, self.model) - to_update = (model.pooler.get_pooling_updates( - pooling_params.task)) - assert to_update is not None, ( - f"{pooling_params.task=} is not supported by the model") - + to_update = model.pooler.get_pooling_updates(task) to_update.apply(pooling_params) self.requests[req_id] = CachedRequestState( @@ -1122,10 +1118,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): if not is_pooling_model(model): return [] - return [ - task for task in get_args(PoolingTask) - if model.pooler.get_pooling_updates(task) - ] + return list(model.pooler.get_supported_tasks()) def apply_grammar_bitmask( self, @@ -2247,7 +2240,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): dummy_pooling_params = PoolingParams(task=dummy_task) to_update = model.pooler.get_pooling_updates(dummy_task) - assert to_update is not None to_update.apply(dummy_pooling_params) dummy_metadata = PoolingMetadata( diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 7ed1cf41011ba..aad45b6abd128 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -3,7 +3,7 @@ import bisect import gc import time -from typing import TYPE_CHECKING, Any, Optional, cast, get_args +from typing import TYPE_CHECKING, Any, Optional, cast from unittest.mock import patch import numpy as np @@ -491,10 +491,7 @@ class TPUModelRunner(LoRAModelRunnerMixin): if not is_pooling_model(model): return [] - return [ - task for task in get_args(PoolingTask) - if model.pooler.get_pooling_updates(task) - ] + return list(model.pooler.get_supported_tasks()) def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: """ diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index b0737dfe31978..62f26ac57a98b 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -4,7 +4,7 @@ import dataclasses from abc import ABC, abstractmethod from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type, - TypeVar, get_args) + TypeVar) import torch import torch.nn as nn @@ -230,10 +230,7 @@ class ModelRunnerBase(ABC, Generic[T]): if not is_pooling_model(model): return [] - return [ - task for task in get_args(PoolingTask) - if model.pooler.get_pooling_updates(task) - ] + return list(model.pooler.get_supported_tasks()) def execute_model( self, diff --git a/vllm/worker/pooling_model_runner.py b/vllm/worker/pooling_model_runner.py index 2c3f4eb3ad4d4..d91b16be83d70 100644 --- a/vllm/worker/pooling_model_runner.py +++ b/vllm/worker/pooling_model_runner.py @@ -199,15 +199,11 @@ class PoolingModelRunner( pooling_params = seq_group_metadata.pooling_params assert pooling_params is not None - assert pooling_params.task is not None, ( + assert (task := pooling_params.task) is not None, ( "You did not set `task` in the API") - to_update = (cast(VllmModelForPooling, - self.model).pooler.get_pooling_updates( - pooling_params.task)) - assert to_update is not None, ( - f"{pooling_params.task=} is not supported by the model") - + model = cast(VllmModelForPooling, self.model) + to_update = model.pooler.get_pooling_updates(task) to_update.apply(pooling_params) seq_groups.append((seq_ids, pooling_params))