[Model][1/N] Support multiple poolers at model level (#21227)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-07-21 17:22:21 +08:00 committed by GitHub
parent 378d33c392
commit 042af0c8d3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 549 additions and 413 deletions

View File

@ -11,26 +11,51 @@ before returning them.
As shown in the [Compatibility Matrix](../features/compatibility_matrix.md), most vLLM features are not applicable to As shown in the [Compatibility Matrix](../features/compatibility_matrix.md), most vLLM features are not applicable to
pooling models as they only work on the generation or decode stage, so performance may not improve as much. pooling models as they only work on the generation or decode stage, so performance may not improve as much.
For pooling models, we support the following `--task` options. If the model doesn't implement this interface, you can set `--task` which tells vLLM
The selected option sets the default pooler used to extract the final hidden states: to convert the model into a pooling model.
| Task | Pooling Type | Normalization | Softmax | | `--task` | Model type | Supported pooling tasks |
|---------------------------------|----------------|-----------------|-----------| |------------|----------------------|-------------------------------|
| Embedding (`embed`) | `LAST` | ✅︎ | ❌ | | `embed` | Embedding model | `encode`, `embed` |
| Classification (`classify`) | `LAST` | ❌ | ✅︎ | | `classify` | Classification model | `encode`, `classify`, `score` |
| Sentence Pair Scoring (`score`) | \* | \* | \* | | `reward` | Reward model | `encode` |
\*The default pooler is always defined by the model. ## Pooling Tasks
!!! note In vLLM, we define the following pooling tasks and corresponding APIs:
If the model's implementation in vLLM defines its own pooler, the default pooler is set to that instead of the one specified in this table.
| Task | APIs |
|------------|--------------------|
| `encode` | `encode` |
| `embed` | `embed`, `score`\* |
| `classify` | `classify` |
| `score` | `score` |
\*The `score` API falls back to `embed` task if the model does not support `score` task.
Each pooling model in vLLM supports one or more of these tasks according to [Pooler.get_supported_tasks][vllm.model_executor.layers.Pooler.get_supported_tasks].
By default, the pooler assigned to each task has the following attributes:
| Task | Pooling Type | Normalization | Softmax |
|------------|----------------|---------------|---------|
| `encode` | `ALL` | ❌ | ❌ |
| `embed` | `LAST` | ✅︎ | ❌ |
| `classify` | `LAST` | ❌ | ✅︎ |
These defaults may be overridden by the model's implementation in vLLM.
When loading [Sentence Transformers](https://huggingface.co/sentence-transformers) models, When loading [Sentence Transformers](https://huggingface.co/sentence-transformers) models,
we attempt to override the default pooler based on its Sentence Transformers configuration file (`modules.json`). we attempt to override the defaults based on its Sentence Transformers configuration file (`modules.json`),
which takes priority over the model's defaults.
!!! tip You can further customize this via the `--override-pooler-config` option,
You can customize the model's pooling method via the `--override-pooler-config` option, which takes priority over both the model's and Sentence Transformers's defaults.
which takes priority over both the model's and Sentence Transformers's defaults.
!!! note
The above configuration may be disregarded if the model's implementation in vLLM defines its own pooler
that is not based on [PoolerConfig][vllm.config.PoolerConfig].
## Offline Inference ## Offline Inference

View File

@ -144,7 +144,7 @@ def test_quantization(
"model", "model",
["jason9693/Qwen2.5-1.5B-apeach"], ["jason9693/Qwen2.5-1.5B-apeach"],
) )
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["float"])
def test_classify( def test_classify(
hf_runner, hf_runner,
vllm_runner, vllm_runner,

View File

@ -8,7 +8,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
from vllm.model_executor.models.gemma2 import Gemma2Model from vllm.model_executor.models.gemma2 import Gemma2Model
from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
@ -26,12 +26,13 @@ class MyGemma2Embedding(nn.Module):
self.model = Gemma2Model(vllm_config=vllm_config, self.model = Gemma2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model")) prefix=maybe_prefix(prefix, "model"))
self.pooler = Pooler.from_config_with_defaults( pooler_config = vllm_config.model_config.pooler_config
vllm_config.model_config.pooler_config, assert pooler_config is not None
pooling_type=PoolingType.LAST,
normalize=True, self.pooler = DispatchPooler({
softmax=False, "encode": Pooler.for_encode(pooler_config),
) "embed": Pooler.for_embed(pooler_config),
})
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)

View File

@ -94,7 +94,7 @@ ConfigT = TypeVar("ConfigT", bound=ConfigType)
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify", TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
"score", "reward", "transcription", "draft"] "score", "reward", "transcription", "draft"]
_ResolvedTask = Literal["generate", "transcription", "pooling", "embed", _ResolvedTask = Literal["generate", "transcription", "encode", "embed",
"classify", "reward", "draft"] "classify", "reward", "draft"]
RunnerOption = Literal["auto", "generate", "pooling", "draft"] RunnerOption = Literal["auto", "generate", "pooling", "draft"]
@ -103,7 +103,7 @@ RunnerType = Literal["generate", "pooling", "draft"]
_RUNNER_TASKS: dict[RunnerType, list[_ResolvedTask]] = { _RUNNER_TASKS: dict[RunnerType, list[_ResolvedTask]] = {
"generate": ["generate", "transcription"], "generate": ["generate", "transcription"],
"pooling": ["pooling", "embed", "classify", "reward"], "pooling": ["encode", "embed", "classify", "reward"],
"draft": [], "draft": [],
} }
@ -579,7 +579,7 @@ class ModelConfig:
# user-selected task # user-selected task
if runner_type == "pooling" and self.task == "auto": if runner_type == "pooling" and self.task == "auto":
selected_task = all_supported_tasks[runner_type][-1] selected_task = all_supported_tasks[runner_type][-1]
assert selected_task != "pooling" assert selected_task != "encode"
self.task = selected_task self.task = selected_task
self.supported_runner_types = supported_runner_types self.supported_runner_types = supported_runner_types
self.runner_type = runner_type self.runner_type = runner_type
@ -884,7 +884,7 @@ class ModelConfig:
supported_tasks = list[_ResolvedTask]() supported_tasks = list[_ResolvedTask]()
if registry.is_pooling_model(architectures): if registry.is_pooling_model(architectures):
supported_tasks.append("pooling") supported_tasks.append("encode")
# For now, users must specify the task (other than "pooling") # For now, users must specify the task (other than "pooling")
# to use for pooling models # to use for pooling models

View File

@ -1668,7 +1668,7 @@ async def init_app_state(
request_logger=request_logger, request_logger=request_logger,
chat_template=resolved_chat_template, chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format, chat_template_content_format=args.chat_template_content_format,
) if "pooling" in model_config.supported_tasks else None ) if "encode" in model_config.supported_tasks else None
state.openai_serving_embedding = OpenAIServingEmbedding( state.openai_serving_embedding = OpenAIServingEmbedding(
engine_client, engine_client,
model_config, model_config,

View File

@ -1,15 +1,16 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Mapping, Set
from dataclasses import dataclass from dataclasses import dataclass
from enum import IntEnum from enum import IntEnum
from itertools import groupby
from typing import Callable, Optional, TypeVar, Union from typing import Callable, Optional, TypeVar, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers import PretrainedConfig from transformers import PretrainedConfig
from typing_extensions import assert_never
from vllm.config import ModelConfig, PoolerConfig from vllm.config import ModelConfig, PoolerConfig
from vllm.model_executor.pooling_metadata import ( # noqa: E501 from vllm.model_executor.pooling_metadata import ( # noqa: E501
@ -21,6 +22,10 @@ from vllm.utils import resolve_obj_by_qualname
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata
PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata] PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata]
PoolingFn = Callable[
[Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata],
Union[torch.Tensor, list[torch.Tensor]]]
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
class PoolingType(IntEnum): class PoolingType(IntEnum):
@ -79,37 +84,81 @@ class Pooler(nn.Module, ABC):
"""The interface required for all poolers used in pooling models in vLLM.""" """The interface required for all poolers used in pooling models in vLLM."""
@staticmethod @staticmethod
def from_config_with_defaults( def for_encode(
pooler_config: PoolerConfig, pooler_config: PoolerConfig,
pooling_type: PoolingType, *,
normalize: bool, default_pooling_type: PoolingType = PoolingType.ALL,
softmax: bool, default_normalize: bool = False,
step_tag_id: Optional[int] = None, default_softmax: bool = False,
returned_token_ids: Optional[list[int]] = None, default_step_tag_id: Optional[int] = None,
) -> "Pooler": default_returned_token_ids: Optional[list[int]] = None,
):
resolved_config = ResolvedPoolingConfig.from_config_with_defaults( resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
pooler_config=pooler_config, pooler_config=pooler_config,
pooling_type=pooling_type, pooling_type=default_pooling_type,
normalize=normalize, normalize=default_normalize,
softmax=softmax, softmax=default_softmax,
step_tag_id=step_tag_id, step_tag_id=default_step_tag_id,
returned_token_ids=returned_token_ids, returned_token_ids=default_returned_token_ids,
) )
if pooling_type == PoolingType.STEP: if resolved_config.pooling_type == PoolingType.STEP:
return StepPooler.from_config(resolved_config) return StepPooler.from_config(resolved_config)
return SimplePooler.from_config(resolved_config) return SimplePooler.from_config(resolved_config)
def get_pooling_updates( @staticmethod
self, def for_embed(
task: PoolingTask, pooler_config: PoolerConfig,
) -> Optional[PoolingParamsUpdate]: *,
default_pooling_type: PoolingType = PoolingType.LAST,
default_normalize: bool = True,
default_softmax: bool = False,
):
resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
pooler_config=pooler_config,
pooling_type=default_pooling_type,
normalize=default_normalize,
softmax=default_softmax,
)
return SimplePooler.from_config(resolved_config)
@staticmethod
def for_classify(
pooler_config: PoolerConfig,
classifier: Optional[ClassifierFn],
*,
default_pooling_type: PoolingType = PoolingType.LAST,
default_normalize: bool = False,
default_softmax: bool = True,
):
resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
pooler_config=pooler_config,
pooling_type=default_pooling_type,
normalize=default_normalize,
softmax=default_softmax,
)
base_pooler = SimplePooler.from_config(resolved_config)
if classifier is None:
return base_pooler
return ClassifierPooler(
pooling=base_pooler.pooling,
classifier=classifier,
act_fn=base_pooler.head.activation,
)
@abstractmethod
def get_supported_tasks(self) -> Set[PoolingTask]:
"""Determine which pooling tasks are supported."""
raise NotImplementedError
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
""" """
Construct the pooling parameters to use for a task, Construct the updated pooling parameters to use for a supported task.
or `None` if the task is not supported.
""" """
return None return PoolingParamsUpdate()
@abstractmethod @abstractmethod
def forward( def forward(
@ -127,9 +176,8 @@ def get_prompt_lens(
if isinstance(pooling_metadata, V1PoolingMetadata): if isinstance(pooling_metadata, V1PoolingMetadata):
return pooling_metadata.prompt_lens return pooling_metadata.prompt_lens
assert isinstance(hidden_states, torch.Tensor)
return PoolingTensors.from_pooling_metadata( return PoolingTensors.from_pooling_metadata(
pooling_metadata, hidden_states.device).prompt_lens pooling_metadata, hidden_states[0].device).prompt_lens
def get_prompt_token_ids( def get_prompt_token_ids(
@ -149,6 +197,21 @@ def get_prompt_token_ids(
] ]
def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]:
if isinstance(pooling_metadata, V0PoolingMetadata):
pooling_params = [p for _, p in pooling_metadata.seq_groups]
else:
pooling_params = pooling_metadata.pooling_params
tasks: list[PoolingTask] = [
task for pooling_param in pooling_params
if (task := pooling_param.task) is not None
]
assert len(pooling_params) == len(tasks)
return tasks
def get_classification_activation_function(config: PretrainedConfig): def get_classification_activation_function(config: PretrainedConfig):
return PoolerClassify() return PoolerClassify()
@ -172,7 +235,8 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
return PoolerScore() return PoolerScore()
def build_output(all_data: torch.Tensor) -> PoolerOutput: def build_output(
all_data: Union[torch.Tensor, list[torch.Tensor]], ) -> PoolerOutput:
all_outputs = [PoolingSequenceGroupOutput(data) for data in all_data] all_outputs = [PoolingSequenceGroupOutput(data) for data in all_data]
return PoolerOutput(outputs=all_outputs) return PoolerOutput(outputs=all_outputs)
@ -193,12 +257,12 @@ class PoolingMethod(nn.Module, ABC):
raise NotImplementedError(f"Unsupported method: {pooling_type}") raise NotImplementedError(f"Unsupported method: {pooling_type}")
@abstractmethod @abstractmethod
def get_pooling_updates( def get_supported_tasks(self) -> Set[PoolingTask]:
self,
task: PoolingTask,
) -> Optional[PoolingParamsUpdate]:
raise NotImplementedError raise NotImplementedError
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return PoolingParamsUpdate()
@abstractmethod @abstractmethod
def forward_one( def forward_one(
self, self,
@ -237,16 +301,8 @@ class PoolingMethod(nn.Module, ABC):
class CLSPool(PoolingMethod): class CLSPool(PoolingMethod):
def get_pooling_updates( def get_supported_tasks(self) -> Set[PoolingTask]:
self, return {"encode", "embed", "classify", "score"}
task: PoolingTask,
) -> Optional[PoolingParamsUpdate]:
# The equalities are split up to keep mypy happy
if (task == "encode" or task == "embed" or task == "classify"
or task == "score"):
return PoolingParamsUpdate()
assert_never(task)
def forward_one( def forward_one(
self, self,
@ -270,16 +326,8 @@ class CLSPool(PoolingMethod):
class LastPool(PoolingMethod): class LastPool(PoolingMethod):
def get_pooling_updates( def get_supported_tasks(self) -> Set[PoolingTask]:
self, return {"encode", "embed", "classify", "score"}
task: PoolingTask,
) -> Optional[PoolingParamsUpdate]:
# The equalities are split up to keep mypy happy
if (task == "encode" or task == "embed" or task == "classify"
or task == "score"):
return PoolingParamsUpdate()
assert_never(task)
def forward_one( def forward_one(
self, self,
@ -299,18 +347,8 @@ class LastPool(PoolingMethod):
class AllPool(PoolingMethod): class AllPool(PoolingMethod):
def get_pooling_updates( def get_supported_tasks(self) -> Set[PoolingTask]:
self, return {"encode"}
task: PoolingTask,
) -> Optional[PoolingParamsUpdate]:
if task == "encode":
return PoolingParamsUpdate()
# The equalities are split up to keep mypy happy
if task == "embed" or task == "classify" or task == "score":
return None
assert_never(task)
def forward_one( def forward_one(
self, self,
@ -327,28 +365,13 @@ class AllPool(PoolingMethod):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
prompt_lens: torch.Tensor, prompt_lens: torch.Tensor,
) -> Union[list[torch.Tensor], torch.Tensor]: ) -> Union[list[torch.Tensor], torch.Tensor]:
offset = 0 return list(hidden_states.split_with_sizes(prompt_lens.tolist()))
pooled_data = list[torch.Tensor]()
for prompt_len in prompt_lens:
pooled_data.append(hidden_states[offset:offset + prompt_len])
offset += prompt_len
return pooled_data
class MeanPool(PoolingMethod): class MeanPool(PoolingMethod):
def get_pooling_updates( def get_supported_tasks(self) -> Set[PoolingTask]:
self, return {"encode", "embed", "classify", "score"}
task: PoolingTask,
) -> Optional[PoolingParamsUpdate]:
# The equalities are split up to keep mypy happy
if (task == "encode" or task == "embed" or task == "classify"
or task == "score"):
return PoolingParamsUpdate()
assert_never(task)
def forward_one( def forward_one(
self, self,
@ -529,24 +552,6 @@ class SimplePooler(Pooler):
3. Returns structured results as `PoolerOutput`. 3. Returns structured results as `PoolerOutput`.
""" """
@classmethod
def from_config_with_defaults( # type: ignore[override]
cls,
pooler_config: PoolerConfig,
pooling_type: PoolingType,
normalize: bool,
softmax: bool,
) -> "SimplePooler":
resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
pooler_config=pooler_config,
pooling_type=pooling_type,
normalize=normalize,
softmax=softmax,
)
assert resolved_config.pooling_type != PoolingType.STEP
return cls.from_config(resolved_config)
@classmethod @classmethod
def from_config( def from_config(
cls, cls,
@ -563,10 +568,10 @@ class SimplePooler(Pooler):
self.pooling = pooling self.pooling = pooling
self.head = head self.head = head
def get_pooling_updates( def get_supported_tasks(self) -> Set[PoolingTask]:
self, return self.pooling.get_supported_tasks()
task: PoolingTask,
) -> Optional[PoolingParamsUpdate]: def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return self.pooling.get_pooling_updates(task) return self.pooling.get_pooling_updates(task)
def forward( def forward(
@ -627,18 +632,11 @@ class StepPooler(Pooler):
return pooled_data return pooled_data
def get_pooling_updates( def get_supported_tasks(self) -> Set[PoolingTask]:
self, return {"encode"}
task: PoolingTask,
) -> Optional[PoolingParamsUpdate]:
if task == "encode":
return PoolingParamsUpdate(requires_token_ids=True)
# The equalities are split up to keep mypy happy def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
if task == "embed" or task == "classify" or task == "score": return PoolingParamsUpdate(requires_token_ids=True)
return None
assert_never(task)
def forward( def forward(
self, self,
@ -650,68 +648,43 @@ class StepPooler(Pooler):
return build_output(pooled_data) return build_output(pooled_data)
PoolingFn = Callable[ class ClassifierPooler(Pooler):
[Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata],
Union[torch.Tensor, list[torch.Tensor]]]
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
class ClassifierPooler(nn.Module):
"""A pooling layer for classification tasks. """A pooling layer for classification tasks.
This layer does the following: This layer does the following:
1. Applies a classification layer to the hidden states. 1. Applies a classification layer to the hidden states.
2. Optionally applies a pooler layer. 2. Optionally applies a pooler layer.
3. Applies an activation function to the output. In the case of 3. Applies an activation function to the output.
classification models it is either sigmoid or softmax. In the
case of scoring models, the same behavior is configuration
dependent, as in the sentence-transformers library.
""" """
@staticmethod
def act_fn_for_seq_cls(config: ModelConfig):
return get_classification_activation_function(config.hf_config)
@staticmethod
def act_fn_for_cross_encoder(config: ModelConfig):
return get_cross_encoder_activation_function(config.hf_config)
def __init__( def __init__(
self, self,
config: ModelConfig,
pooling: PoolingFn, pooling: PoolingFn,
classifier: ClassifierFn, classifier: ClassifierFn,
act_fn: Optional[PoolerActivation] = None, act_fn: PoolerActivation,
) -> None: ) -> None:
super().__init__() super().__init__()
self.pooling = pooling self.pooling = pooling
self.classifier = classifier self.classifier = classifier
self.act_fn = act_fn
self.classification_act_fn = get_classification_activation_function( def get_supported_tasks(self) -> Set[PoolingTask]:
config.hf_config) if act_fn is None else act_fn return {"classify", "score"}
self.cross_encoder_act_fn = get_cross_encoder_activation_function(
config.hf_config) if act_fn is None else act_fn
def _get_act_fn(self, task: PoolingTask):
if task == "encode" or task == "classify":
return self.classification_act_fn
if task == "score":
return self.cross_encoder_act_fn
raise ValueError(f"Unsupported task: {task!r}")
def get_pooling_updates(
self,
task: PoolingTask,
) -> Optional[PoolingParamsUpdate]:
# The equalities are split up to keep mypy happy
if task == "encode" or task == "classify" or task == "score":
return PoolingParamsUpdate()
if task == "embed":
return None
assert_never(task)
def forward( def forward(
self, self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]], hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
) -> PoolerOutput: ) -> PoolerOutput:
"""Pools sentence pair scores from the hidden_states."""
pooled_data = self.pooling(hidden_states, pooling_metadata) pooled_data = self.pooling(hidden_states, pooling_metadata)
# apply classifier once on the full batch if possible # apply classifier once on the full batch if possible
@ -722,28 +695,59 @@ class ClassifierPooler(nn.Module):
else: else:
pooled_output = [self.classifier(data) for data in pooled_data] pooled_output = [self.classifier(data) for data in pooled_data]
task_list: list[PoolingTask] scores = self.act_fn(pooled_output)
if isinstance(pooling_metadata, V0PoolingMetadata):
task_list = [
task for _, pooling_param in pooling_metadata.seq_groups
if (task := pooling_param.task) is not None
]
else:
task_list = [
task for pooling_param in pooling_metadata.pooling_params
if (task := pooling_param.task) is not None
]
assert len(task_list) == len(pooled_output)
# shape of scores: (batch_size, num_labels)
if len(set(task_list)) <= 1:
act_fn = self._get_act_fn(task_list[0])
scores = act_fn(pooled_output)
else:
scores = torch.stack([
self._get_act_fn(task)(vecs)
for task, vecs in zip(task_list, pooled_output)
])
return build_output(scores) return build_output(scores)
class DispatchPooler(Pooler):
"""Dispatches calls to a sub-pooler based on the pooling task."""
def __init__(self, poolers_by_task: Mapping[PoolingTask, Pooler]) -> None:
super().__init__()
for task, pooler in poolers_by_task.items():
if task not in pooler.get_supported_tasks():
raise ValueError(
f"{pooler=} does not support {task=}. "
f"Supported tasks: {pooler.get_supported_tasks()}")
self.poolers_by_task = poolers_by_task
def get_supported_tasks(self) -> Set[PoolingTask]:
return set(self.poolers_by_task)
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return self.poolers_by_task[task].get_pooling_updates(task)
def forward(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
poolers_by_task = self.poolers_by_task
if isinstance(hidden_states, list):
hidden_states_lst = hidden_states
else:
prompt_lens = get_prompt_lens(hidden_states, pooling_metadata)
hidden_states_lst = list(hidden_states.split(prompt_lens.tolist()))
outputs = list[PoolingSequenceGroupOutput]()
offset = 0
for task, group in groupby(get_tasks(pooling_metadata)):
if not (pooler := poolers_by_task.get(task)):
raise ValueError(
f"Unsupported task: {task} "
f"Supported tasks: {self.get_supported_tasks()}")
num_items = len(list(group))
group_output: PoolerOutput = pooler(
hidden_states_lst[offset:offset + num_items],
pooling_metadata[offset:offset + num_items],
)
outputs.extend(group_output.outputs)
offset += num_items
return PoolerOutput(outputs)

View File

@ -13,7 +13,6 @@ from .interfaces_base import VllmModelForPooling, is_pooling_model
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import PoolingType
_T = TypeVar("_T", bound=type[nn.Module]) _T = TypeVar("_T", bound=type[nn.Module])
@ -34,16 +33,8 @@ def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:
return model_name + pooling_suffix return model_name + pooling_suffix
def _create_pooling_model_cls( def _create_pooling_model_cls(orig_cls: _T) -> _T:
orig_cls: _T,
*,
default_pooling_type: "PoolingType",
default_normalize: bool,
default_softmax: bool,
) -> _T:
# Lazy import # Lazy import
from vllm.model_executor.layers.pooler import Pooler
from .utils import AutoWeightsLoader, WeightsMapper from .utils import AutoWeightsLoader, WeightsMapper
class ModelForPooling(orig_cls, VllmModelForPooling): class ModelForPooling(orig_cls, VllmModelForPooling):
@ -71,15 +62,7 @@ def _create_pooling_model_cls(
self._init_pooler(vllm_config, prefix=prefix) self._init_pooler(vllm_config, prefix=prefix)
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
pooler_config = vllm_config.model_config.pooler_config raise NotImplementedError
assert pooler_config is not None
self.pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=default_pooling_type,
normalize=default_normalize,
softmax=default_softmax,
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
# TODO: Support uninitialized params tracking # TODO: Support uninitialized params tracking
@ -132,14 +115,20 @@ def as_embedding_model(cls: _T) -> _T:
return cls return cls
# Lazy import # Lazy import
from vllm.model_executor.layers.pooler import PoolingType from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
class ModelForEmbedding(_create_pooling_model_cls(cls)):
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{
"encode": Pooler.for_encode(pooler_config),
"embed": Pooler.for_embed(pooler_config),
}, )
ModelForEmbedding = _create_pooling_model_cls(
cls,
default_pooling_type=PoolingType.LAST,
default_normalize=True,
default_softmax=False,
)
ModelForEmbedding.__name__ = \ ModelForEmbedding.__name__ = \
_get_pooling_model_name(cls.__name__, "ForEmbedding") _get_pooling_model_name(cls.__name__, "ForEmbedding")
@ -165,20 +154,14 @@ def as_seq_cls_model(cls: _T) -> _T:
# Lazy import # Lazy import
from vllm.model_executor.layers.linear import RowParallelLinear from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.pooler import (ClassifierPooler, from vllm.model_executor.layers.pooler import (ClassifierPooler,
PoolingType, SimplePooler) DispatchPooler, Pooler,
PoolingMethod, PoolingType)
from vllm.model_executor.models.interfaces import SupportsCrossEncoding from vllm.model_executor.models.interfaces import SupportsCrossEncoding
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .utils import maybe_prefix from .utils import maybe_prefix
ModelForPooling = _create_pooling_model_cls( class ModelForSequenceClassification(_create_pooling_model_cls(cls),
cls,
default_pooling_type=PoolingType.LAST,
default_normalize=False,
default_softmax=True,
)
class ModelForSequenceClassification(ModelForPooling,
SupportsCrossEncoding): SupportsCrossEncoding):
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
@ -198,19 +181,28 @@ def as_seq_cls_model(cls: _T) -> _T:
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None assert pooler_config is not None
pooler = SimplePooler.from_config_with_defaults( pooling_type_str = pooler_config.pooling_type
pooler_config, pooling_type = (PoolingType.LAST if pooling_type_str is None else
pooling_type=PoolingType.LAST, PoolingType[pooling_type_str])
normalize=False,
softmax=True,
)
self.pooler = ClassifierPooler( self.pooler = DispatchPooler({
vllm_config.model_config, "encode":
pooling=pooler.pooling, Pooler.for_encode(pooler_config),
classifier=self._classifier, "classify":
act_fn=pooler.head.activation, ClassifierPooler(
) pooling=PoolingMethod.from_pooling_type(pooling_type),
classifier=self._classifier,
act_fn=ClassifierPooler.act_fn_for_seq_cls(
vllm_config.model_config),
),
"score":
ClassifierPooler(
pooling=PoolingMethod.from_pooling_type(pooling_type),
classifier=self._classifier,
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
vllm_config.model_config),
),
})
def _classifier(self, x: torch.Tensor): def _classifier(self, x: torch.Tensor):
x, _ = self.score(x.float()) x, _ = self.score(x.float())
@ -259,14 +251,16 @@ def as_reward_model(cls: _T) -> _T:
return cls return cls
# Lazy import # Lazy import
from vllm.model_executor.layers.pooler import PoolingType from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
ModelForReward = _create_pooling_model_cls( class ModelForReward(_create_pooling_model_cls(cls)):
cls,
default_pooling_type=PoolingType.ALL, def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
default_normalize=False, pooler_config = vllm_config.model_config.pooler_config
default_softmax=False, assert pooler_config is not None
)
self.pooler = DispatchPooler(
{"encode": Pooler.for_encode(pooler_config)}, )
ModelForReward.__name__ = \ ModelForReward.__name__ = \
_get_pooling_model_name(cls.__name__, "ForReward") _get_pooling_model_name(cls.__name__, "ForReward")

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable from collections.abc import Iterable, Set
from typing import Optional, Union from typing import Optional, Union
import torch import torch
@ -17,7 +17,8 @@ from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler, from vllm.model_executor.layers.pooler import (ClassifierPooler,
DispatchPooler, Pooler,
PoolingMethod, PoolingMethod,
PoolingParamsUpdate, PoolingParamsUpdate,
PoolingType) PoolingType)
@ -92,20 +93,29 @@ class BertPooler(Pooler):
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh() self.activation = nn.Tanh()
def get_pooling_updates( def get_supported_tasks(self) -> Set[PoolingTask]:
self, return self.pooling.get_supported_tasks()
task: PoolingTask,
) -> Optional[PoolingParamsUpdate]: def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return self.pooling.get_pooling_updates(task) return self.pooling.get_pooling_updates(task)
def _head(self, pooled_output: torch.Tensor):
pooled_output = self.dense(pooled_output)
pooled_output = self.activation(pooled_output)
return pooled_output
def forward( def forward(
self, self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]], hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
) -> Union[torch.Tensor, list[torch.Tensor]]: ) -> Union[torch.Tensor, list[torch.Tensor]]:
pooled_output = self.pooling(hidden_states, pooling_metadata) pooled_output = self.pooling(hidden_states, pooling_metadata)
pooled_output = self.dense(pooled_output)
pooled_output = self.activation(pooled_output) if isinstance(pooled_output, list):
pooled_output = [self._head(output) for output in pooled_output]
else:
pooled_output = self._head(pooled_output)
return pooled_output return pooled_output
@ -333,18 +343,19 @@ class BertModel(nn.Module, SupportsQuant):
packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]} packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]}
def __init__(self, def __init__(
*, self,
vllm_config: VllmConfig, *,
prefix: str = "", vllm_config: VllmConfig,
embedding_class: type = BertEmbedding, prefix: str = "",
add_pooling_layer: bool = False): embedding_class: type[nn.Module] = BertEmbedding,
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
self.embeddings = embedding_class(config) self.embeddings = embedding_class(config)
self.encoder = BertEncoder(vllm_config=vllm_config, self.encoder = BertEncoder(vllm_config=vllm_config,
prefix=f"{prefix}.encoder") prefix=f"{prefix}.encoder")
self.pooler = BertPooler(config) if add_pooling_layer else None
def forward( def forward(
self, self,
@ -366,8 +377,7 @@ class BertModel(nn.Module, SupportsQuant):
token_type_ids=token_type_ids) token_type_ids=token_type_ids)
return self.encoder(hidden_states) return self.encoder(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, def _load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "query", "q"), ("qkv_proj", "query", "q"),
@ -395,10 +405,43 @@ class BertModel(nn.Module, SupportsQuant):
if name in params_dict: if name in params_dict:
other_weights.append((name, loaded_weight)) other_weights.append((name, loaded_weight))
loader = AutoWeightsLoader( return other_weights, loaded_stacked_params
self,
skip_prefixes=(["pooler."] if self.pooler is None else []), def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
other_weights, loaded_stacked_params = self._load_weights(weights)
loader = AutoWeightsLoader(self, skip_prefixes=["pooler."])
loaded_params = loader.load_weights(other_weights)
loaded_params.update(loaded_stacked_params)
return loaded_params
class BertPoolingModel(BertModel):
is_pooling_model = True
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
embedding_class: type[nn.Module] = BertEmbedding,
) -> None:
super().__init__(
vllm_config=vllm_config,
prefix=prefix,
embedding_class=embedding_class,
) )
config = vllm_config.model_config.hf_config
self.pooler = BertPooler(config)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
other_weights, loaded_stacked_params = self._load_weights(weights)
loader = AutoWeightsLoader(self)
loaded_params = loader.load_weights(other_weights) loaded_params = loader.load_weights(other_weights)
loaded_params.update(loaded_stacked_params) loaded_params.update(loaded_stacked_params)
return loaded_params return loaded_params
@ -421,6 +464,8 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
super().__init__() super().__init__()
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.model = self._build_model(vllm_config=vllm_config, self.model = self._build_model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model")) prefix=maybe_prefix(prefix, "model"))
self.pooler = self._build_pooler(pooler_config) self.pooler = self._build_pooler(pooler_config)
@ -456,10 +501,15 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
embedding_class=BertEmbedding) embedding_class=BertEmbedding)
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
return Pooler.from_config_with_defaults(pooler_config, return DispatchPooler({
pooling_type=PoolingType.CLS, "encode":
normalize=True, Pooler.for_encode(pooler_config),
softmax=False) "embed":
Pooler.for_embed(
pooler_config,
default_pooling_type=PoolingType.CLS,
),
})
class BertForSequenceClassification(nn.Module, SupportsV0Only, class BertForSequenceClassification(nn.Module, SupportsV0Only,
@ -481,16 +531,32 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
self.num_labels = config.num_labels self.num_labels = config.num_labels
self.bert = BertModel(vllm_config=vllm_config, self.bert = BertPoolingModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "bert"), prefix=maybe_prefix(prefix, "bert"),
embedding_class=BertEmbedding, embedding_class=BertEmbedding)
add_pooling_layer=True)
self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.pooler = ClassifierPooler(
vllm_config.model_config, pooler_config = vllm_config.model_config.pooler_config
pooling=self.bert.pooler, assert pooler_config is not None
classifier=self.classifier,
) self.pooler = DispatchPooler({
"encode":
Pooler.for_encode(pooler_config),
"classify":
ClassifierPooler(
pooling=self.bert.pooler,
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_seq_cls(
vllm_config.model_config),
),
"score":
ClassifierPooler(
pooling=self.bert.pooler,
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
vllm_config.model_config),
),
})
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)

View File

@ -43,7 +43,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from ..layers.pooler import Pooler, PoolingType from ..layers.pooler import DispatchPooler, Pooler
from .interfaces import SupportsPP from .interfaces import SupportsPP
from .utils import (AutoWeightsLoader, is_pp_missing_parameter, from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
@ -339,12 +339,16 @@ class GPT2ForSequenceClassification(nn.Module):
self.transformer = GPT2Model(vllm_config=vllm_config, self.transformer = GPT2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "gpt2")) prefix=maybe_prefix(prefix, "gpt2"))
self.score = nn.Linear(config.n_embd, config.num_labels, bias=False) self.score = nn.Linear(config.n_embd, config.num_labels, bias=False)
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
self.pooler = Pooler.from_config_with_defaults( assert pooler_config is not None
pooler_config,
pooling_type=PoolingType.LAST, self.pooler = DispatchPooler({
normalize=False, "encode":
softmax=True) Pooler.for_encode(pooler_config),
"classify":
Pooler.for_classify(pooler_config, classifier=None),
})
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)

View File

@ -1,17 +1,16 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Set
from typing import Optional, Union from typing import Optional, Union
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing_extensions import assert_never
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import (Pooler, PoolerHead, from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler,
PoolerNormalize, PoolerHead, PoolerNormalize,
PoolingParamsUpdate, PoolingParamsUpdate,
build_output, get_prompt_lens, build_output, get_prompt_lens,
get_prompt_token_ids) get_prompt_token_ids)
@ -135,18 +134,11 @@ class GritLMMeanPool(nn.Module):
return instruction_len return instruction_len
def get_pooling_updates( def get_supported_tasks(self) -> Set[PoolingTask]:
self, return {"encode", "embed"}
task: PoolingTask,
) -> Optional[PoolingParamsUpdate]:
# The equalities are split up to keep mypy happy
if task == "encode" or task == "embed":
return PoolingParamsUpdate(requires_token_ids=True)
if task == "classify" or task == "score": def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return None return PoolingParamsUpdate(requires_token_ids=True)
assert_never(task)
def forward_one( def forward_one(
self, self,
@ -207,10 +199,10 @@ class GritLMPooler(Pooler):
self.pooling = GritLMMeanPool(model_config) self.pooling = GritLMMeanPool(model_config)
self.head = PoolerHead(PoolerNormalize()) self.head = PoolerHead(PoolerNormalize())
def get_pooling_updates( def get_supported_tasks(self) -> Set[PoolingTask]:
self, return self.pooling.get_supported_tasks()
task: PoolingTask,
) -> Optional[PoolingParamsUpdate]: def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return self.pooling.get_pooling_updates(task) return self.pooling.get_pooling_updates(task)
def forward( def forward(
@ -262,4 +254,11 @@ class GritLM(LlamaForCausalLM, SupportsV0Only):
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
self.pooler = GritLMPooler(vllm_config.model_config) pooler_config = vllm_config.model_config.pooler_config
if pooler_config is not None:
self.pooler = DispatchPooler({
"encode":
Pooler.for_encode(pooler_config),
"embed":
GritLMPooler(vllm_config.model_config),
})

View File

@ -22,7 +22,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -429,12 +429,10 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
) )
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
self.pooler = Pooler.from_config_with_defaults( assert pooler_config is not None
pooler_config,
pooling_type=PoolingType.ALL, self.pooler = DispatchPooler(
normalize=False, {"encode": Pooler.for_encode(pooler_config)}, )
softmax=False,
)
def forward( def forward(
self, self,

View File

@ -19,8 +19,8 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
from vllm.model_executor.layers.pooler import (ClassifierPooler, PoolingType, from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler,
SimplePooler) PoolingType)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
@ -584,16 +584,15 @@ class JambaForSequenceClassification(JambaForCausalLM):
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None assert pooler_config is not None
pooler = SimplePooler.from_config_with_defaults( self.pooler = DispatchPooler({
pooler_config, "encode":
pooling_type=PoolingType.LAST, Pooler.for_encode(pooler_config),
normalize=False, "classify":
softmax=False, Pooler.for_classify(
) pooler_config,
classifier=self.score,
self.pooler = ClassifierPooler( default_pooling_type=PoolingType.LAST,
vllm_config.model_config, default_normalize=False,
pooling=pooler.pooling, default_softmax=False,
classifier=self.score, ),
act_fn=pooler.head.activation, })
)

View File

@ -12,7 +12,7 @@ from vllm.inputs import TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
@ -96,11 +96,17 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration,
self.score = JinaVLScorer(config) self.score = JinaVLScorer(config)
self.pooler = Pooler.from_config_with_defaults( pooler_config = vllm_config.model_config.pooler_config
pooler_config, assert pooler_config is not None
pooling_type=PoolingType.LAST,
normalize=False, self.pooler = DispatchPooler({
softmax=True) "encode":
Pooler.for_encode(pooler_config),
"classify":
Pooler.for_classify(pooler_config, classifier=None),
"score":
Pooler.for_classify(pooler_config, classifier=None),
})
@classmethod @classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable from collections.abc import Iterable, Set
from typing import Optional, Union from typing import Optional, Union
import torch import torch
@ -13,7 +13,8 @@ from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import (QKVParallelLinear, from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler, from vllm.model_executor.layers.pooler import (ClassifierPooler,
DispatchPooler, Pooler,
PoolingMethod, PoolingMethod,
PoolingParamsUpdate, PoolingParamsUpdate,
PoolingType) PoolingType)
@ -271,19 +272,27 @@ class ModernBertPooler(Pooler):
eps=config.norm_eps, eps=config.norm_eps,
bias=config.norm_bias) bias=config.norm_bias)
def get_pooling_updates( def get_supported_tasks(self) -> Set[PoolingTask]:
self, return self.pooling.get_supported_tasks()
task: PoolingTask,
) -> Optional[PoolingParamsUpdate]: def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return self.pooling.get_pooling_updates(task) return self.pooling.get_pooling_updates(task)
def _head(self, pooled_output: torch.Tensor):
return self.norm(self.act(self.dense(pooled_output)))
def forward( def forward(
self, self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]], hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
) -> Union[torch.Tensor, list[torch.Tensor]]: ) -> Union[torch.Tensor, list[torch.Tensor]]:
pooled_output = self.pooling(hidden_states, pooling_metadata) pooled_output = self.pooling(hidden_states, pooling_metadata)
pooled_output = self.norm(self.act(self.dense(pooled_output)))
if isinstance(pooled_output, list):
pooled_output = [self._head(output) for output in pooled_output]
else:
pooled_output = self._head(pooled_output)
return pooled_output return pooled_output
@ -299,11 +308,28 @@ class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
self.model = ModernBertModel(vllm_config=vllm_config, self.model = ModernBertModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "modernbert")) prefix=maybe_prefix(prefix, "modernbert"))
self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.pooler = ClassifierPooler(
vllm_config.model_config, pooler_config = vllm_config.model_config.pooler_config
pooling=ModernBertPooler(config), assert pooler_config is not None
classifier=self.classifier,
) self.pooler = DispatchPooler({
"encode":
Pooler.for_encode(pooler_config),
"classify":
ClassifierPooler(
pooling=ModernBertPooler(config),
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_seq_cls(
vllm_config.model_config),
),
"score":
ClassifierPooler(
pooling=ModernBertPooler(config),
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
vllm_config.model_config),
),
})
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):

View File

@ -15,7 +15,8 @@ from torch import nn
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.pooler import Pooler, PoolingType, SimplePooler from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler,
PoolingType)
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
@ -26,7 +27,7 @@ from .utils import AutoWeightsLoader, maybe_prefix
class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP): class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
is_pooling_model = True is_pooling_model = True
pooler: SimplePooler pooler: Pooler
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
@ -94,12 +95,12 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
vllm_config.model_config.hf_config.num_labels = 1 vllm_config.model_config.hf_config.num_labels = 1
super().__init__(vllm_config=vllm_config, prefix=prefix) super().__init__(vllm_config=vllm_config, prefix=prefix)
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
self.pooler = Pooler.from_config_with_defaults( assert pooler_config is not None
pooler_config,
pooling_type=PoolingType.ALL, self.pooler = DispatchPooler(
normalize=False, {"encode": Pooler.for_encode(pooler_config)}, )
softmax=False)
class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel): class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
@ -107,11 +108,17 @@ class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
vllm_config.model_config.hf_config.num_labels = 2 vllm_config.model_config.hf_config.num_labels = 2
super().__init__(vllm_config=vllm_config, prefix=prefix) super().__init__(vllm_config=vllm_config, prefix=prefix)
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
self.pooler = Pooler.from_config_with_defaults( assert pooler_config is not None
pooler_config,
pooling_type=PoolingType.STEP, self.pooler = DispatchPooler({
normalize=False, "encode":
softmax=True, Pooler.for_encode(
step_tag_id=151651, pooler_config,
) default_pooling_type=PoolingType.STEP,
default_normalize=False,
default_softmax=True,
default_step_tag_id=151651,
)
})

View File

@ -9,7 +9,8 @@ from torch import nn
from transformers import RobertaConfig from transformers import RobertaConfig
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import ClassifierPooler, CLSPool from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool,
DispatchPooler, Pooler)
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel
@ -63,16 +64,10 @@ class RobertaEmbedding(nn.Module):
# References: # References:
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133 # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669 # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
pos_list = [] seq_lens_list = seq_lens.tolist()
token_list = []
offset = 0
for seq_len in seq_lens:
pos_list.append(position_ids[offset:offset + seq_len])
token_list.append(input_ids[offset:offset + seq_len])
offset += seq_len
new_pos_list = [] new_pos_list = []
for positions, tokens in zip(pos_list, token_list): for positions, tokens in zip(position_ids.split(seq_lens_list),
input_ids.split(seq_lens_list)):
# Verify assumption that incoming position are # Verify assumption that incoming position are
# always a sequence from 0 to N. # always a sequence from 0 to N.
expected_pos = torch.arange(positions.size()[0], expected_pos = torch.arange(positions.size()[0],
@ -184,15 +179,30 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
self.num_labels = config.num_labels self.num_labels = config.num_labels
self.roberta = BertModel(vllm_config=vllm_config, self.roberta = BertModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "bert"), prefix=maybe_prefix(prefix, "bert"),
embedding_class=RobertaEmbedding, embedding_class=RobertaEmbedding)
add_pooling_layer=False)
self.classifier = RobertaClassificationHead(config) self.classifier = RobertaClassificationHead(config)
self.pooler = ClassifierPooler( pooler_config = vllm_config.model_config.pooler_config
vllm_config.model_config, assert pooler_config is not None
pooling=CLSPool(),
classifier=self.classifier, self.pooler = DispatchPooler({
) "encode":
Pooler.for_encode(pooler_config),
"classify":
ClassifierPooler(
pooling=CLSPool(),
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_seq_cls(
vllm_config.model_config),
),
"score":
ClassifierPooler(
pooling=CLSPool(),
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
vllm_config.model_config),
),
})
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)

View File

@ -38,6 +38,13 @@ class PoolingMetadata:
f"seq_data={self.seq_data}, " f"seq_data={self.seq_data}, "
f"prompt_lens={self.prompt_lens})") f"prompt_lens={self.prompt_lens})")
def __getitem__(self, indices: slice):
return PoolingMetadata(
seq_groups=self.seq_groups[indices],
seq_data=dict(list(self.seq_data.items())[indices]),
prompt_lens=self.prompt_lens[indices],
)
@dataclass @dataclass
class PoolingTensors: class PoolingTensors:

View File

@ -15,3 +15,11 @@ class PoolingMetadata:
prompt_lens: torch.Tensor prompt_lens: torch.Tensor
prompt_token_ids: Optional[torch.Tensor] prompt_token_ids: Optional[torch.Tensor]
pooling_params: list[PoolingParams] pooling_params: list[PoolingParams]
def __getitem__(self, indices: slice):
return PoolingMetadata(
prompt_lens=self.prompt_lens[indices],
prompt_token_ids=None if self.prompt_token_ids is None else
self.prompt_token_ids[indices],
pooling_params=self.pooling_params[indices],
)

View File

@ -5,7 +5,7 @@ import copy
import gc import gc
import time import time
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Optional, Union, cast, get_args from typing import TYPE_CHECKING, Any, Optional, Union, cast
import numpy as np import numpy as np
import torch import torch
@ -415,15 +415,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
generator = None generator = None
if pooling_params: if pooling_params:
assert pooling_params.task is not None, ( assert (task := pooling_params.task) is not None, (
"You did not set `task` in the API") "You did not set `task` in the API")
model = cast(VllmModelForPooling, self.model) model = cast(VllmModelForPooling, self.model)
to_update = (model.pooler.get_pooling_updates( to_update = model.pooler.get_pooling_updates(task)
pooling_params.task))
assert to_update is not None, (
f"{pooling_params.task=} is not supported by the model")
to_update.apply(pooling_params) to_update.apply(pooling_params)
self.requests[req_id] = CachedRequestState( self.requests[req_id] = CachedRequestState(
@ -1122,10 +1118,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if not is_pooling_model(model): if not is_pooling_model(model):
return [] return []
return [ return list(model.pooler.get_supported_tasks())
task for task in get_args(PoolingTask)
if model.pooler.get_pooling_updates(task)
]
def apply_grammar_bitmask( def apply_grammar_bitmask(
self, self,
@ -2247,7 +2240,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
dummy_pooling_params = PoolingParams(task=dummy_task) dummy_pooling_params = PoolingParams(task=dummy_task)
to_update = model.pooler.get_pooling_updates(dummy_task) to_update = model.pooler.get_pooling_updates(dummy_task)
assert to_update is not None
to_update.apply(dummy_pooling_params) to_update.apply(dummy_pooling_params)
dummy_metadata = PoolingMetadata( dummy_metadata = PoolingMetadata(

View File

@ -3,7 +3,7 @@
import bisect import bisect
import gc import gc
import time import time
from typing import TYPE_CHECKING, Any, Optional, cast, get_args from typing import TYPE_CHECKING, Any, Optional, cast
from unittest.mock import patch from unittest.mock import patch
import numpy as np import numpy as np
@ -491,10 +491,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
if not is_pooling_model(model): if not is_pooling_model(model):
return [] return []
return [ return list(model.pooler.get_supported_tasks())
task for task in get_args(PoolingTask)
if model.pooler.get_pooling_updates(task)
]
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
""" """

View File

@ -4,7 +4,7 @@
import dataclasses import dataclasses
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type, from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
TypeVar, get_args) TypeVar)
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -230,10 +230,7 @@ class ModelRunnerBase(ABC, Generic[T]):
if not is_pooling_model(model): if not is_pooling_model(model):
return [] return []
return [ return list(model.pooler.get_supported_tasks())
task for task in get_args(PoolingTask)
if model.pooler.get_pooling_updates(task)
]
def execute_model( def execute_model(
self, self,

View File

@ -199,15 +199,11 @@ class PoolingModelRunner(
pooling_params = seq_group_metadata.pooling_params pooling_params = seq_group_metadata.pooling_params
assert pooling_params is not None assert pooling_params is not None
assert pooling_params.task is not None, ( assert (task := pooling_params.task) is not None, (
"You did not set `task` in the API") "You did not set `task` in the API")
to_update = (cast(VllmModelForPooling, model = cast(VllmModelForPooling, self.model)
self.model).pooler.get_pooling_updates( to_update = model.pooler.get_pooling_updates(task)
pooling_params.task))
assert to_update is not None, (
f"{pooling_params.task=} is not supported by the model")
to_update.apply(pooling_params) to_update.apply(pooling_params)
seq_groups.append((seq_ids, pooling_params)) seq_groups.append((seq_ids, pooling_params))