mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-15 04:53:33 +08:00
[Model][1/N] Support multiple poolers at model level (#21227)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
378d33c392
commit
042af0c8d3
@ -11,26 +11,51 @@ before returning them.
|
|||||||
As shown in the [Compatibility Matrix](../features/compatibility_matrix.md), most vLLM features are not applicable to
|
As shown in the [Compatibility Matrix](../features/compatibility_matrix.md), most vLLM features are not applicable to
|
||||||
pooling models as they only work on the generation or decode stage, so performance may not improve as much.
|
pooling models as they only work on the generation or decode stage, so performance may not improve as much.
|
||||||
|
|
||||||
For pooling models, we support the following `--task` options.
|
If the model doesn't implement this interface, you can set `--task` which tells vLLM
|
||||||
The selected option sets the default pooler used to extract the final hidden states:
|
to convert the model into a pooling model.
|
||||||
|
|
||||||
| Task | Pooling Type | Normalization | Softmax |
|
| `--task` | Model type | Supported pooling tasks |
|
||||||
|---------------------------------|----------------|-----------------|-----------|
|
|------------|----------------------|-------------------------------|
|
||||||
| Embedding (`embed`) | `LAST` | ✅︎ | ❌ |
|
| `embed` | Embedding model | `encode`, `embed` |
|
||||||
| Classification (`classify`) | `LAST` | ❌ | ✅︎ |
|
| `classify` | Classification model | `encode`, `classify`, `score` |
|
||||||
| Sentence Pair Scoring (`score`) | \* | \* | \* |
|
| `reward` | Reward model | `encode` |
|
||||||
|
|
||||||
\*The default pooler is always defined by the model.
|
## Pooling Tasks
|
||||||
|
|
||||||
!!! note
|
In vLLM, we define the following pooling tasks and corresponding APIs:
|
||||||
If the model's implementation in vLLM defines its own pooler, the default pooler is set to that instead of the one specified in this table.
|
|
||||||
|
| Task | APIs |
|
||||||
|
|------------|--------------------|
|
||||||
|
| `encode` | `encode` |
|
||||||
|
| `embed` | `embed`, `score`\* |
|
||||||
|
| `classify` | `classify` |
|
||||||
|
| `score` | `score` |
|
||||||
|
|
||||||
|
\*The `score` API falls back to `embed` task if the model does not support `score` task.
|
||||||
|
|
||||||
|
Each pooling model in vLLM supports one or more of these tasks according to [Pooler.get_supported_tasks][vllm.model_executor.layers.Pooler.get_supported_tasks].
|
||||||
|
|
||||||
|
By default, the pooler assigned to each task has the following attributes:
|
||||||
|
|
||||||
|
| Task | Pooling Type | Normalization | Softmax |
|
||||||
|
|------------|----------------|---------------|---------|
|
||||||
|
| `encode` | `ALL` | ❌ | ❌ |
|
||||||
|
| `embed` | `LAST` | ✅︎ | ❌ |
|
||||||
|
| `classify` | `LAST` | ❌ | ✅︎ |
|
||||||
|
|
||||||
|
These defaults may be overridden by the model's implementation in vLLM.
|
||||||
|
|
||||||
When loading [Sentence Transformers](https://huggingface.co/sentence-transformers) models,
|
When loading [Sentence Transformers](https://huggingface.co/sentence-transformers) models,
|
||||||
we attempt to override the default pooler based on its Sentence Transformers configuration file (`modules.json`).
|
we attempt to override the defaults based on its Sentence Transformers configuration file (`modules.json`),
|
||||||
|
which takes priority over the model's defaults.
|
||||||
|
|
||||||
!!! tip
|
You can further customize this via the `--override-pooler-config` option,
|
||||||
You can customize the model's pooling method via the `--override-pooler-config` option,
|
which takes priority over both the model's and Sentence Transformers's defaults.
|
||||||
which takes priority over both the model's and Sentence Transformers's defaults.
|
|
||||||
|
!!! note
|
||||||
|
|
||||||
|
The above configuration may be disregarded if the model's implementation in vLLM defines its own pooler
|
||||||
|
that is not based on [PoolerConfig][vllm.config.PoolerConfig].
|
||||||
|
|
||||||
## Offline Inference
|
## Offline Inference
|
||||||
|
|
||||||
|
|||||||
@ -144,7 +144,7 @@ def test_quantization(
|
|||||||
"model",
|
"model",
|
||||||
["jason9693/Qwen2.5-1.5B-apeach"],
|
["jason9693/Qwen2.5-1.5B-apeach"],
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize("dtype", ["half"])
|
@pytest.mark.parametrize("dtype", ["float"])
|
||||||
def test_classify(
|
def test_classify(
|
||||||
hf_runner,
|
hf_runner,
|
||||||
vllm_runner,
|
vllm_runner,
|
||||||
|
|||||||
@ -8,7 +8,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
|
||||||
from vllm.model_executor.models.gemma2 import Gemma2Model
|
from vllm.model_executor.models.gemma2 import Gemma2Model
|
||||||
from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix
|
from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
@ -26,12 +26,13 @@ class MyGemma2Embedding(nn.Module):
|
|||||||
self.model = Gemma2Model(vllm_config=vllm_config,
|
self.model = Gemma2Model(vllm_config=vllm_config,
|
||||||
prefix=maybe_prefix(prefix, "model"))
|
prefix=maybe_prefix(prefix, "model"))
|
||||||
|
|
||||||
self.pooler = Pooler.from_config_with_defaults(
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
vllm_config.model_config.pooler_config,
|
assert pooler_config is not None
|
||||||
pooling_type=PoolingType.LAST,
|
|
||||||
normalize=True,
|
self.pooler = DispatchPooler({
|
||||||
softmax=False,
|
"encode": Pooler.for_encode(pooler_config),
|
||||||
)
|
"embed": Pooler.for_embed(pooler_config),
|
||||||
|
})
|
||||||
|
|
||||||
self.make_empty_intermediate_tensors = (
|
self.make_empty_intermediate_tensors = (
|
||||||
self.model.make_empty_intermediate_tensors)
|
self.model.make_empty_intermediate_tensors)
|
||||||
|
|||||||
@ -94,7 +94,7 @@ ConfigT = TypeVar("ConfigT", bound=ConfigType)
|
|||||||
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
|
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
|
||||||
"score", "reward", "transcription", "draft"]
|
"score", "reward", "transcription", "draft"]
|
||||||
|
|
||||||
_ResolvedTask = Literal["generate", "transcription", "pooling", "embed",
|
_ResolvedTask = Literal["generate", "transcription", "encode", "embed",
|
||||||
"classify", "reward", "draft"]
|
"classify", "reward", "draft"]
|
||||||
|
|
||||||
RunnerOption = Literal["auto", "generate", "pooling", "draft"]
|
RunnerOption = Literal["auto", "generate", "pooling", "draft"]
|
||||||
@ -103,7 +103,7 @@ RunnerType = Literal["generate", "pooling", "draft"]
|
|||||||
|
|
||||||
_RUNNER_TASKS: dict[RunnerType, list[_ResolvedTask]] = {
|
_RUNNER_TASKS: dict[RunnerType, list[_ResolvedTask]] = {
|
||||||
"generate": ["generate", "transcription"],
|
"generate": ["generate", "transcription"],
|
||||||
"pooling": ["pooling", "embed", "classify", "reward"],
|
"pooling": ["encode", "embed", "classify", "reward"],
|
||||||
"draft": [],
|
"draft": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -579,7 +579,7 @@ class ModelConfig:
|
|||||||
# user-selected task
|
# user-selected task
|
||||||
if runner_type == "pooling" and self.task == "auto":
|
if runner_type == "pooling" and self.task == "auto":
|
||||||
selected_task = all_supported_tasks[runner_type][-1]
|
selected_task = all_supported_tasks[runner_type][-1]
|
||||||
assert selected_task != "pooling"
|
assert selected_task != "encode"
|
||||||
self.task = selected_task
|
self.task = selected_task
|
||||||
self.supported_runner_types = supported_runner_types
|
self.supported_runner_types = supported_runner_types
|
||||||
self.runner_type = runner_type
|
self.runner_type = runner_type
|
||||||
@ -884,7 +884,7 @@ class ModelConfig:
|
|||||||
|
|
||||||
supported_tasks = list[_ResolvedTask]()
|
supported_tasks = list[_ResolvedTask]()
|
||||||
if registry.is_pooling_model(architectures):
|
if registry.is_pooling_model(architectures):
|
||||||
supported_tasks.append("pooling")
|
supported_tasks.append("encode")
|
||||||
|
|
||||||
# For now, users must specify the task (other than "pooling")
|
# For now, users must specify the task (other than "pooling")
|
||||||
# to use for pooling models
|
# to use for pooling models
|
||||||
|
|||||||
@ -1668,7 +1668,7 @@ async def init_app_state(
|
|||||||
request_logger=request_logger,
|
request_logger=request_logger,
|
||||||
chat_template=resolved_chat_template,
|
chat_template=resolved_chat_template,
|
||||||
chat_template_content_format=args.chat_template_content_format,
|
chat_template_content_format=args.chat_template_content_format,
|
||||||
) if "pooling" in model_config.supported_tasks else None
|
) if "encode" in model_config.supported_tasks else None
|
||||||
state.openai_serving_embedding = OpenAIServingEmbedding(
|
state.openai_serving_embedding = OpenAIServingEmbedding(
|
||||||
engine_client,
|
engine_client,
|
||||||
model_config,
|
model_config,
|
||||||
|
|||||||
@ -1,15 +1,16 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import Mapping, Set
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
|
from itertools import groupby
|
||||||
from typing import Callable, Optional, TypeVar, Union
|
from typing import Callable, Optional, TypeVar, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
from typing_extensions import assert_never
|
|
||||||
|
|
||||||
from vllm.config import ModelConfig, PoolerConfig
|
from vllm.config import ModelConfig, PoolerConfig
|
||||||
from vllm.model_executor.pooling_metadata import ( # noqa: E501
|
from vllm.model_executor.pooling_metadata import ( # noqa: E501
|
||||||
@ -21,6 +22,10 @@ from vllm.utils import resolve_obj_by_qualname
|
|||||||
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata
|
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata
|
||||||
|
|
||||||
PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata]
|
PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata]
|
||||||
|
PoolingFn = Callable[
|
||||||
|
[Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata],
|
||||||
|
Union[torch.Tensor, list[torch.Tensor]]]
|
||||||
|
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
class PoolingType(IntEnum):
|
class PoolingType(IntEnum):
|
||||||
@ -79,37 +84,81 @@ class Pooler(nn.Module, ABC):
|
|||||||
"""The interface required for all poolers used in pooling models in vLLM."""
|
"""The interface required for all poolers used in pooling models in vLLM."""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_config_with_defaults(
|
def for_encode(
|
||||||
pooler_config: PoolerConfig,
|
pooler_config: PoolerConfig,
|
||||||
pooling_type: PoolingType,
|
*,
|
||||||
normalize: bool,
|
default_pooling_type: PoolingType = PoolingType.ALL,
|
||||||
softmax: bool,
|
default_normalize: bool = False,
|
||||||
step_tag_id: Optional[int] = None,
|
default_softmax: bool = False,
|
||||||
returned_token_ids: Optional[list[int]] = None,
|
default_step_tag_id: Optional[int] = None,
|
||||||
) -> "Pooler":
|
default_returned_token_ids: Optional[list[int]] = None,
|
||||||
|
):
|
||||||
resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
|
resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
|
||||||
pooler_config=pooler_config,
|
pooler_config=pooler_config,
|
||||||
pooling_type=pooling_type,
|
pooling_type=default_pooling_type,
|
||||||
normalize=normalize,
|
normalize=default_normalize,
|
||||||
softmax=softmax,
|
softmax=default_softmax,
|
||||||
step_tag_id=step_tag_id,
|
step_tag_id=default_step_tag_id,
|
||||||
returned_token_ids=returned_token_ids,
|
returned_token_ids=default_returned_token_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
if pooling_type == PoolingType.STEP:
|
if resolved_config.pooling_type == PoolingType.STEP:
|
||||||
return StepPooler.from_config(resolved_config)
|
return StepPooler.from_config(resolved_config)
|
||||||
|
|
||||||
return SimplePooler.from_config(resolved_config)
|
return SimplePooler.from_config(resolved_config)
|
||||||
|
|
||||||
def get_pooling_updates(
|
@staticmethod
|
||||||
self,
|
def for_embed(
|
||||||
task: PoolingTask,
|
pooler_config: PoolerConfig,
|
||||||
) -> Optional[PoolingParamsUpdate]:
|
*,
|
||||||
|
default_pooling_type: PoolingType = PoolingType.LAST,
|
||||||
|
default_normalize: bool = True,
|
||||||
|
default_softmax: bool = False,
|
||||||
|
):
|
||||||
|
resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
|
||||||
|
pooler_config=pooler_config,
|
||||||
|
pooling_type=default_pooling_type,
|
||||||
|
normalize=default_normalize,
|
||||||
|
softmax=default_softmax,
|
||||||
|
)
|
||||||
|
|
||||||
|
return SimplePooler.from_config(resolved_config)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def for_classify(
|
||||||
|
pooler_config: PoolerConfig,
|
||||||
|
classifier: Optional[ClassifierFn],
|
||||||
|
*,
|
||||||
|
default_pooling_type: PoolingType = PoolingType.LAST,
|
||||||
|
default_normalize: bool = False,
|
||||||
|
default_softmax: bool = True,
|
||||||
|
):
|
||||||
|
resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
|
||||||
|
pooler_config=pooler_config,
|
||||||
|
pooling_type=default_pooling_type,
|
||||||
|
normalize=default_normalize,
|
||||||
|
softmax=default_softmax,
|
||||||
|
)
|
||||||
|
base_pooler = SimplePooler.from_config(resolved_config)
|
||||||
|
if classifier is None:
|
||||||
|
return base_pooler
|
||||||
|
|
||||||
|
return ClassifierPooler(
|
||||||
|
pooling=base_pooler.pooling,
|
||||||
|
classifier=classifier,
|
||||||
|
act_fn=base_pooler.head.activation,
|
||||||
|
)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||||
|
"""Determine which pooling tasks are supported."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||||
"""
|
"""
|
||||||
Construct the pooling parameters to use for a task,
|
Construct the updated pooling parameters to use for a supported task.
|
||||||
or `None` if the task is not supported.
|
|
||||||
"""
|
"""
|
||||||
return None
|
return PoolingParamsUpdate()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def forward(
|
def forward(
|
||||||
@ -127,9 +176,8 @@ def get_prompt_lens(
|
|||||||
if isinstance(pooling_metadata, V1PoolingMetadata):
|
if isinstance(pooling_metadata, V1PoolingMetadata):
|
||||||
return pooling_metadata.prompt_lens
|
return pooling_metadata.prompt_lens
|
||||||
|
|
||||||
assert isinstance(hidden_states, torch.Tensor)
|
|
||||||
return PoolingTensors.from_pooling_metadata(
|
return PoolingTensors.from_pooling_metadata(
|
||||||
pooling_metadata, hidden_states.device).prompt_lens
|
pooling_metadata, hidden_states[0].device).prompt_lens
|
||||||
|
|
||||||
|
|
||||||
def get_prompt_token_ids(
|
def get_prompt_token_ids(
|
||||||
@ -149,6 +197,21 @@ def get_prompt_token_ids(
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]:
|
||||||
|
if isinstance(pooling_metadata, V0PoolingMetadata):
|
||||||
|
pooling_params = [p for _, p in pooling_metadata.seq_groups]
|
||||||
|
else:
|
||||||
|
pooling_params = pooling_metadata.pooling_params
|
||||||
|
|
||||||
|
tasks: list[PoolingTask] = [
|
||||||
|
task for pooling_param in pooling_params
|
||||||
|
if (task := pooling_param.task) is not None
|
||||||
|
]
|
||||||
|
assert len(pooling_params) == len(tasks)
|
||||||
|
|
||||||
|
return tasks
|
||||||
|
|
||||||
|
|
||||||
def get_classification_activation_function(config: PretrainedConfig):
|
def get_classification_activation_function(config: PretrainedConfig):
|
||||||
return PoolerClassify()
|
return PoolerClassify()
|
||||||
|
|
||||||
@ -172,7 +235,8 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
|
|||||||
return PoolerScore()
|
return PoolerScore()
|
||||||
|
|
||||||
|
|
||||||
def build_output(all_data: torch.Tensor) -> PoolerOutput:
|
def build_output(
|
||||||
|
all_data: Union[torch.Tensor, list[torch.Tensor]], ) -> PoolerOutput:
|
||||||
all_outputs = [PoolingSequenceGroupOutput(data) for data in all_data]
|
all_outputs = [PoolingSequenceGroupOutput(data) for data in all_data]
|
||||||
return PoolerOutput(outputs=all_outputs)
|
return PoolerOutput(outputs=all_outputs)
|
||||||
|
|
||||||
@ -193,12 +257,12 @@ class PoolingMethod(nn.Module, ABC):
|
|||||||
raise NotImplementedError(f"Unsupported method: {pooling_type}")
|
raise NotImplementedError(f"Unsupported method: {pooling_type}")
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_pooling_updates(
|
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||||
self,
|
|
||||||
task: PoolingTask,
|
|
||||||
) -> Optional[PoolingParamsUpdate]:
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||||
|
return PoolingParamsUpdate()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def forward_one(
|
def forward_one(
|
||||||
self,
|
self,
|
||||||
@ -237,16 +301,8 @@ class PoolingMethod(nn.Module, ABC):
|
|||||||
|
|
||||||
class CLSPool(PoolingMethod):
|
class CLSPool(PoolingMethod):
|
||||||
|
|
||||||
def get_pooling_updates(
|
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||||
self,
|
return {"encode", "embed", "classify", "score"}
|
||||||
task: PoolingTask,
|
|
||||||
) -> Optional[PoolingParamsUpdate]:
|
|
||||||
# The equalities are split up to keep mypy happy
|
|
||||||
if (task == "encode" or task == "embed" or task == "classify"
|
|
||||||
or task == "score"):
|
|
||||||
return PoolingParamsUpdate()
|
|
||||||
|
|
||||||
assert_never(task)
|
|
||||||
|
|
||||||
def forward_one(
|
def forward_one(
|
||||||
self,
|
self,
|
||||||
@ -270,16 +326,8 @@ class CLSPool(PoolingMethod):
|
|||||||
|
|
||||||
class LastPool(PoolingMethod):
|
class LastPool(PoolingMethod):
|
||||||
|
|
||||||
def get_pooling_updates(
|
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||||
self,
|
return {"encode", "embed", "classify", "score"}
|
||||||
task: PoolingTask,
|
|
||||||
) -> Optional[PoolingParamsUpdate]:
|
|
||||||
# The equalities are split up to keep mypy happy
|
|
||||||
if (task == "encode" or task == "embed" or task == "classify"
|
|
||||||
or task == "score"):
|
|
||||||
return PoolingParamsUpdate()
|
|
||||||
|
|
||||||
assert_never(task)
|
|
||||||
|
|
||||||
def forward_one(
|
def forward_one(
|
||||||
self,
|
self,
|
||||||
@ -299,18 +347,8 @@ class LastPool(PoolingMethod):
|
|||||||
|
|
||||||
class AllPool(PoolingMethod):
|
class AllPool(PoolingMethod):
|
||||||
|
|
||||||
def get_pooling_updates(
|
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||||
self,
|
return {"encode"}
|
||||||
task: PoolingTask,
|
|
||||||
) -> Optional[PoolingParamsUpdate]:
|
|
||||||
if task == "encode":
|
|
||||||
return PoolingParamsUpdate()
|
|
||||||
|
|
||||||
# The equalities are split up to keep mypy happy
|
|
||||||
if task == "embed" or task == "classify" or task == "score":
|
|
||||||
return None
|
|
||||||
|
|
||||||
assert_never(task)
|
|
||||||
|
|
||||||
def forward_one(
|
def forward_one(
|
||||||
self,
|
self,
|
||||||
@ -327,28 +365,13 @@ class AllPool(PoolingMethod):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
prompt_lens: torch.Tensor,
|
prompt_lens: torch.Tensor,
|
||||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||||
offset = 0
|
return list(hidden_states.split_with_sizes(prompt_lens.tolist()))
|
||||||
pooled_data = list[torch.Tensor]()
|
|
||||||
|
|
||||||
for prompt_len in prompt_lens:
|
|
||||||
pooled_data.append(hidden_states[offset:offset + prompt_len])
|
|
||||||
offset += prompt_len
|
|
||||||
|
|
||||||
return pooled_data
|
|
||||||
|
|
||||||
|
|
||||||
class MeanPool(PoolingMethod):
|
class MeanPool(PoolingMethod):
|
||||||
|
|
||||||
def get_pooling_updates(
|
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||||
self,
|
return {"encode", "embed", "classify", "score"}
|
||||||
task: PoolingTask,
|
|
||||||
) -> Optional[PoolingParamsUpdate]:
|
|
||||||
# The equalities are split up to keep mypy happy
|
|
||||||
if (task == "encode" or task == "embed" or task == "classify"
|
|
||||||
or task == "score"):
|
|
||||||
return PoolingParamsUpdate()
|
|
||||||
|
|
||||||
assert_never(task)
|
|
||||||
|
|
||||||
def forward_one(
|
def forward_one(
|
||||||
self,
|
self,
|
||||||
@ -529,24 +552,6 @@ class SimplePooler(Pooler):
|
|||||||
3. Returns structured results as `PoolerOutput`.
|
3. Returns structured results as `PoolerOutput`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config_with_defaults( # type: ignore[override]
|
|
||||||
cls,
|
|
||||||
pooler_config: PoolerConfig,
|
|
||||||
pooling_type: PoolingType,
|
|
||||||
normalize: bool,
|
|
||||||
softmax: bool,
|
|
||||||
) -> "SimplePooler":
|
|
||||||
resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
|
|
||||||
pooler_config=pooler_config,
|
|
||||||
pooling_type=pooling_type,
|
|
||||||
normalize=normalize,
|
|
||||||
softmax=softmax,
|
|
||||||
)
|
|
||||||
assert resolved_config.pooling_type != PoolingType.STEP
|
|
||||||
|
|
||||||
return cls.from_config(resolved_config)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(
|
def from_config(
|
||||||
cls,
|
cls,
|
||||||
@ -563,10 +568,10 @@ class SimplePooler(Pooler):
|
|||||||
self.pooling = pooling
|
self.pooling = pooling
|
||||||
self.head = head
|
self.head = head
|
||||||
|
|
||||||
def get_pooling_updates(
|
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||||
self,
|
return self.pooling.get_supported_tasks()
|
||||||
task: PoolingTask,
|
|
||||||
) -> Optional[PoolingParamsUpdate]:
|
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||||
return self.pooling.get_pooling_updates(task)
|
return self.pooling.get_pooling_updates(task)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -627,18 +632,11 @@ class StepPooler(Pooler):
|
|||||||
|
|
||||||
return pooled_data
|
return pooled_data
|
||||||
|
|
||||||
def get_pooling_updates(
|
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||||
self,
|
return {"encode"}
|
||||||
task: PoolingTask,
|
|
||||||
) -> Optional[PoolingParamsUpdate]:
|
|
||||||
if task == "encode":
|
|
||||||
return PoolingParamsUpdate(requires_token_ids=True)
|
|
||||||
|
|
||||||
# The equalities are split up to keep mypy happy
|
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||||
if task == "embed" or task == "classify" or task == "score":
|
return PoolingParamsUpdate(requires_token_ids=True)
|
||||||
return None
|
|
||||||
|
|
||||||
assert_never(task)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -650,68 +648,43 @@ class StepPooler(Pooler):
|
|||||||
return build_output(pooled_data)
|
return build_output(pooled_data)
|
||||||
|
|
||||||
|
|
||||||
PoolingFn = Callable[
|
class ClassifierPooler(Pooler):
|
||||||
[Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata],
|
|
||||||
Union[torch.Tensor, list[torch.Tensor]]]
|
|
||||||
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
|
|
||||||
|
|
||||||
|
|
||||||
class ClassifierPooler(nn.Module):
|
|
||||||
"""A pooling layer for classification tasks.
|
"""A pooling layer for classification tasks.
|
||||||
|
|
||||||
This layer does the following:
|
This layer does the following:
|
||||||
1. Applies a classification layer to the hidden states.
|
1. Applies a classification layer to the hidden states.
|
||||||
2. Optionally applies a pooler layer.
|
2. Optionally applies a pooler layer.
|
||||||
3. Applies an activation function to the output. In the case of
|
3. Applies an activation function to the output.
|
||||||
classification models it is either sigmoid or softmax. In the
|
|
||||||
case of scoring models, the same behavior is configuration
|
|
||||||
dependent, as in the sentence-transformers library.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def act_fn_for_seq_cls(config: ModelConfig):
|
||||||
|
return get_classification_activation_function(config.hf_config)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def act_fn_for_cross_encoder(config: ModelConfig):
|
||||||
|
return get_cross_encoder_activation_function(config.hf_config)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: ModelConfig,
|
|
||||||
pooling: PoolingFn,
|
pooling: PoolingFn,
|
||||||
classifier: ClassifierFn,
|
classifier: ClassifierFn,
|
||||||
act_fn: Optional[PoolerActivation] = None,
|
act_fn: PoolerActivation,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.pooling = pooling
|
self.pooling = pooling
|
||||||
self.classifier = classifier
|
self.classifier = classifier
|
||||||
|
self.act_fn = act_fn
|
||||||
|
|
||||||
self.classification_act_fn = get_classification_activation_function(
|
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||||
config.hf_config) if act_fn is None else act_fn
|
return {"classify", "score"}
|
||||||
self.cross_encoder_act_fn = get_cross_encoder_activation_function(
|
|
||||||
config.hf_config) if act_fn is None else act_fn
|
|
||||||
|
|
||||||
def _get_act_fn(self, task: PoolingTask):
|
|
||||||
if task == "encode" or task == "classify":
|
|
||||||
return self.classification_act_fn
|
|
||||||
if task == "score":
|
|
||||||
return self.cross_encoder_act_fn
|
|
||||||
|
|
||||||
raise ValueError(f"Unsupported task: {task!r}")
|
|
||||||
|
|
||||||
def get_pooling_updates(
|
|
||||||
self,
|
|
||||||
task: PoolingTask,
|
|
||||||
) -> Optional[PoolingParamsUpdate]:
|
|
||||||
# The equalities are split up to keep mypy happy
|
|
||||||
if task == "encode" or task == "classify" or task == "score":
|
|
||||||
return PoolingParamsUpdate()
|
|
||||||
|
|
||||||
if task == "embed":
|
|
||||||
return None
|
|
||||||
|
|
||||||
assert_never(task)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||||
pooling_metadata: PoolingMetadata,
|
pooling_metadata: PoolingMetadata,
|
||||||
) -> PoolerOutput:
|
) -> PoolerOutput:
|
||||||
"""Pools sentence pair scores from the hidden_states."""
|
|
||||||
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
||||||
|
|
||||||
# apply classifier once on the full batch if possible
|
# apply classifier once on the full batch if possible
|
||||||
@ -722,28 +695,59 @@ class ClassifierPooler(nn.Module):
|
|||||||
else:
|
else:
|
||||||
pooled_output = [self.classifier(data) for data in pooled_data]
|
pooled_output = [self.classifier(data) for data in pooled_data]
|
||||||
|
|
||||||
task_list: list[PoolingTask]
|
scores = self.act_fn(pooled_output)
|
||||||
if isinstance(pooling_metadata, V0PoolingMetadata):
|
|
||||||
task_list = [
|
|
||||||
task for _, pooling_param in pooling_metadata.seq_groups
|
|
||||||
if (task := pooling_param.task) is not None
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
task_list = [
|
|
||||||
task for pooling_param in pooling_metadata.pooling_params
|
|
||||||
if (task := pooling_param.task) is not None
|
|
||||||
]
|
|
||||||
|
|
||||||
assert len(task_list) == len(pooled_output)
|
|
||||||
|
|
||||||
# shape of scores: (batch_size, num_labels)
|
|
||||||
if len(set(task_list)) <= 1:
|
|
||||||
act_fn = self._get_act_fn(task_list[0])
|
|
||||||
scores = act_fn(pooled_output)
|
|
||||||
else:
|
|
||||||
scores = torch.stack([
|
|
||||||
self._get_act_fn(task)(vecs)
|
|
||||||
for task, vecs in zip(task_list, pooled_output)
|
|
||||||
])
|
|
||||||
|
|
||||||
return build_output(scores)
|
return build_output(scores)
|
||||||
|
|
||||||
|
|
||||||
|
class DispatchPooler(Pooler):
|
||||||
|
"""Dispatches calls to a sub-pooler based on the pooling task."""
|
||||||
|
|
||||||
|
def __init__(self, poolers_by_task: Mapping[PoolingTask, Pooler]) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
for task, pooler in poolers_by_task.items():
|
||||||
|
if task not in pooler.get_supported_tasks():
|
||||||
|
raise ValueError(
|
||||||
|
f"{pooler=} does not support {task=}. "
|
||||||
|
f"Supported tasks: {pooler.get_supported_tasks()}")
|
||||||
|
|
||||||
|
self.poolers_by_task = poolers_by_task
|
||||||
|
|
||||||
|
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||||
|
return set(self.poolers_by_task)
|
||||||
|
|
||||||
|
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||||
|
return self.poolers_by_task[task].get_pooling_updates(task)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||||
|
pooling_metadata: PoolingMetadata,
|
||||||
|
) -> PoolerOutput:
|
||||||
|
poolers_by_task = self.poolers_by_task
|
||||||
|
|
||||||
|
if isinstance(hidden_states, list):
|
||||||
|
hidden_states_lst = hidden_states
|
||||||
|
else:
|
||||||
|
prompt_lens = get_prompt_lens(hidden_states, pooling_metadata)
|
||||||
|
hidden_states_lst = list(hidden_states.split(prompt_lens.tolist()))
|
||||||
|
|
||||||
|
outputs = list[PoolingSequenceGroupOutput]()
|
||||||
|
offset = 0
|
||||||
|
for task, group in groupby(get_tasks(pooling_metadata)):
|
||||||
|
if not (pooler := poolers_by_task.get(task)):
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported task: {task} "
|
||||||
|
f"Supported tasks: {self.get_supported_tasks()}")
|
||||||
|
|
||||||
|
num_items = len(list(group))
|
||||||
|
group_output: PoolerOutput = pooler(
|
||||||
|
hidden_states_lst[offset:offset + num_items],
|
||||||
|
pooling_metadata[offset:offset + num_items],
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs.extend(group_output.outputs)
|
||||||
|
offset += num_items
|
||||||
|
|
||||||
|
return PoolerOutput(outputs)
|
||||||
|
|||||||
@ -13,7 +13,6 @@ from .interfaces_base import VllmModelForPooling, is_pooling_model
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.model_executor.layers.pooler import PoolingType
|
|
||||||
|
|
||||||
_T = TypeVar("_T", bound=type[nn.Module])
|
_T = TypeVar("_T", bound=type[nn.Module])
|
||||||
|
|
||||||
@ -34,16 +33,8 @@ def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:
|
|||||||
return model_name + pooling_suffix
|
return model_name + pooling_suffix
|
||||||
|
|
||||||
|
|
||||||
def _create_pooling_model_cls(
|
def _create_pooling_model_cls(orig_cls: _T) -> _T:
|
||||||
orig_cls: _T,
|
|
||||||
*,
|
|
||||||
default_pooling_type: "PoolingType",
|
|
||||||
default_normalize: bool,
|
|
||||||
default_softmax: bool,
|
|
||||||
) -> _T:
|
|
||||||
# Lazy import
|
# Lazy import
|
||||||
from vllm.model_executor.layers.pooler import Pooler
|
|
||||||
|
|
||||||
from .utils import AutoWeightsLoader, WeightsMapper
|
from .utils import AutoWeightsLoader, WeightsMapper
|
||||||
|
|
||||||
class ModelForPooling(orig_cls, VllmModelForPooling):
|
class ModelForPooling(orig_cls, VllmModelForPooling):
|
||||||
@ -71,15 +62,7 @@ def _create_pooling_model_cls(
|
|||||||
self._init_pooler(vllm_config, prefix=prefix)
|
self._init_pooler(vllm_config, prefix=prefix)
|
||||||
|
|
||||||
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
|
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
|
||||||
pooler_config = vllm_config.model_config.pooler_config
|
raise NotImplementedError
|
||||||
assert pooler_config is not None
|
|
||||||
|
|
||||||
self.pooler = Pooler.from_config_with_defaults(
|
|
||||||
pooler_config,
|
|
||||||
pooling_type=default_pooling_type,
|
|
||||||
normalize=default_normalize,
|
|
||||||
softmax=default_softmax,
|
|
||||||
)
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||||
# TODO: Support uninitialized params tracking
|
# TODO: Support uninitialized params tracking
|
||||||
@ -132,14 +115,20 @@ def as_embedding_model(cls: _T) -> _T:
|
|||||||
return cls
|
return cls
|
||||||
|
|
||||||
# Lazy import
|
# Lazy import
|
||||||
from vllm.model_executor.layers.pooler import PoolingType
|
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
|
||||||
|
|
||||||
|
class ModelForEmbedding(_create_pooling_model_cls(cls)):
|
||||||
|
|
||||||
|
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
|
||||||
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
|
assert pooler_config is not None
|
||||||
|
|
||||||
|
self.pooler = DispatchPooler(
|
||||||
|
{
|
||||||
|
"encode": Pooler.for_encode(pooler_config),
|
||||||
|
"embed": Pooler.for_embed(pooler_config),
|
||||||
|
}, )
|
||||||
|
|
||||||
ModelForEmbedding = _create_pooling_model_cls(
|
|
||||||
cls,
|
|
||||||
default_pooling_type=PoolingType.LAST,
|
|
||||||
default_normalize=True,
|
|
||||||
default_softmax=False,
|
|
||||||
)
|
|
||||||
ModelForEmbedding.__name__ = \
|
ModelForEmbedding.__name__ = \
|
||||||
_get_pooling_model_name(cls.__name__, "ForEmbedding")
|
_get_pooling_model_name(cls.__name__, "ForEmbedding")
|
||||||
|
|
||||||
@ -165,20 +154,14 @@ def as_seq_cls_model(cls: _T) -> _T:
|
|||||||
# Lazy import
|
# Lazy import
|
||||||
from vllm.model_executor.layers.linear import RowParallelLinear
|
from vllm.model_executor.layers.linear import RowParallelLinear
|
||||||
from vllm.model_executor.layers.pooler import (ClassifierPooler,
|
from vllm.model_executor.layers.pooler import (ClassifierPooler,
|
||||||
PoolingType, SimplePooler)
|
DispatchPooler, Pooler,
|
||||||
|
PoolingMethod, PoolingType)
|
||||||
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
|
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .utils import maybe_prefix
|
from .utils import maybe_prefix
|
||||||
|
|
||||||
ModelForPooling = _create_pooling_model_cls(
|
class ModelForSequenceClassification(_create_pooling_model_cls(cls),
|
||||||
cls,
|
|
||||||
default_pooling_type=PoolingType.LAST,
|
|
||||||
default_normalize=False,
|
|
||||||
default_softmax=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
class ModelForSequenceClassification(ModelForPooling,
|
|
||||||
SupportsCrossEncoding):
|
SupportsCrossEncoding):
|
||||||
|
|
||||||
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
|
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
|
||||||
@ -198,19 +181,28 @@ def as_seq_cls_model(cls: _T) -> _T:
|
|||||||
pooler_config = vllm_config.model_config.pooler_config
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
assert pooler_config is not None
|
assert pooler_config is not None
|
||||||
|
|
||||||
pooler = SimplePooler.from_config_with_defaults(
|
pooling_type_str = pooler_config.pooling_type
|
||||||
pooler_config,
|
pooling_type = (PoolingType.LAST if pooling_type_str is None else
|
||||||
pooling_type=PoolingType.LAST,
|
PoolingType[pooling_type_str])
|
||||||
normalize=False,
|
|
||||||
softmax=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.pooler = ClassifierPooler(
|
self.pooler = DispatchPooler({
|
||||||
vllm_config.model_config,
|
"encode":
|
||||||
pooling=pooler.pooling,
|
Pooler.for_encode(pooler_config),
|
||||||
classifier=self._classifier,
|
"classify":
|
||||||
act_fn=pooler.head.activation,
|
ClassifierPooler(
|
||||||
)
|
pooling=PoolingMethod.from_pooling_type(pooling_type),
|
||||||
|
classifier=self._classifier,
|
||||||
|
act_fn=ClassifierPooler.act_fn_for_seq_cls(
|
||||||
|
vllm_config.model_config),
|
||||||
|
),
|
||||||
|
"score":
|
||||||
|
ClassifierPooler(
|
||||||
|
pooling=PoolingMethod.from_pooling_type(pooling_type),
|
||||||
|
classifier=self._classifier,
|
||||||
|
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
|
||||||
|
vllm_config.model_config),
|
||||||
|
),
|
||||||
|
})
|
||||||
|
|
||||||
def _classifier(self, x: torch.Tensor):
|
def _classifier(self, x: torch.Tensor):
|
||||||
x, _ = self.score(x.float())
|
x, _ = self.score(x.float())
|
||||||
@ -259,14 +251,16 @@ def as_reward_model(cls: _T) -> _T:
|
|||||||
return cls
|
return cls
|
||||||
|
|
||||||
# Lazy import
|
# Lazy import
|
||||||
from vllm.model_executor.layers.pooler import PoolingType
|
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
|
||||||
|
|
||||||
ModelForReward = _create_pooling_model_cls(
|
class ModelForReward(_create_pooling_model_cls(cls)):
|
||||||
cls,
|
|
||||||
default_pooling_type=PoolingType.ALL,
|
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
|
||||||
default_normalize=False,
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
default_softmax=False,
|
assert pooler_config is not None
|
||||||
)
|
|
||||||
|
self.pooler = DispatchPooler(
|
||||||
|
{"encode": Pooler.for_encode(pooler_config)}, )
|
||||||
|
|
||||||
ModelForReward.__name__ = \
|
ModelForReward.__name__ = \
|
||||||
_get_pooling_model_name(cls.__name__, "ForReward")
|
_get_pooling_model_name(cls.__name__, "ForReward")
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable, Set
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -17,7 +17,8 @@ from vllm.model_executor.layers.activation import get_act_fn
|
|||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler,
|
from vllm.model_executor.layers.pooler import (ClassifierPooler,
|
||||||
|
DispatchPooler, Pooler,
|
||||||
PoolingMethod,
|
PoolingMethod,
|
||||||
PoolingParamsUpdate,
|
PoolingParamsUpdate,
|
||||||
PoolingType)
|
PoolingType)
|
||||||
@ -92,20 +93,29 @@ class BertPooler(Pooler):
|
|||||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
self.activation = nn.Tanh()
|
self.activation = nn.Tanh()
|
||||||
|
|
||||||
def get_pooling_updates(
|
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||||
self,
|
return self.pooling.get_supported_tasks()
|
||||||
task: PoolingTask,
|
|
||||||
) -> Optional[PoolingParamsUpdate]:
|
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||||
return self.pooling.get_pooling_updates(task)
|
return self.pooling.get_pooling_updates(task)
|
||||||
|
|
||||||
|
def _head(self, pooled_output: torch.Tensor):
|
||||||
|
pooled_output = self.dense(pooled_output)
|
||||||
|
pooled_output = self.activation(pooled_output)
|
||||||
|
return pooled_output
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||||
pooling_metadata: PoolingMetadata,
|
pooling_metadata: PoolingMetadata,
|
||||||
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
||||||
pooled_output = self.pooling(hidden_states, pooling_metadata)
|
pooled_output = self.pooling(hidden_states, pooling_metadata)
|
||||||
pooled_output = self.dense(pooled_output)
|
|
||||||
pooled_output = self.activation(pooled_output)
|
if isinstance(pooled_output, list):
|
||||||
|
pooled_output = [self._head(output) for output in pooled_output]
|
||||||
|
else:
|
||||||
|
pooled_output = self._head(pooled_output)
|
||||||
|
|
||||||
return pooled_output
|
return pooled_output
|
||||||
|
|
||||||
|
|
||||||
@ -333,18 +343,19 @@ class BertModel(nn.Module, SupportsQuant):
|
|||||||
|
|
||||||
packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]}
|
packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]}
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
*,
|
self,
|
||||||
vllm_config: VllmConfig,
|
*,
|
||||||
prefix: str = "",
|
vllm_config: VllmConfig,
|
||||||
embedding_class: type = BertEmbedding,
|
prefix: str = "",
|
||||||
add_pooling_layer: bool = False):
|
embedding_class: type[nn.Module] = BertEmbedding,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
self.embeddings = embedding_class(config)
|
self.embeddings = embedding_class(config)
|
||||||
self.encoder = BertEncoder(vllm_config=vllm_config,
|
self.encoder = BertEncoder(vllm_config=vllm_config,
|
||||||
prefix=f"{prefix}.encoder")
|
prefix=f"{prefix}.encoder")
|
||||||
self.pooler = BertPooler(config) if add_pooling_layer else None
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -366,8 +377,7 @@ class BertModel(nn.Module, SupportsQuant):
|
|||||||
token_type_ids=token_type_ids)
|
token_type_ids=token_type_ids)
|
||||||
return self.encoder(hidden_states)
|
return self.encoder(hidden_states)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def _load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||||
torch.Tensor]]) -> set[str]:
|
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
# (param_name, shard_name, shard_id)
|
# (param_name, shard_name, shard_id)
|
||||||
("qkv_proj", "query", "q"),
|
("qkv_proj", "query", "q"),
|
||||||
@ -395,10 +405,43 @@ class BertModel(nn.Module, SupportsQuant):
|
|||||||
if name in params_dict:
|
if name in params_dict:
|
||||||
other_weights.append((name, loaded_weight))
|
other_weights.append((name, loaded_weight))
|
||||||
|
|
||||||
loader = AutoWeightsLoader(
|
return other_weights, loaded_stacked_params
|
||||||
self,
|
|
||||||
skip_prefixes=(["pooler."] if self.pooler is None else []),
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
|
torch.Tensor]]) -> set[str]:
|
||||||
|
other_weights, loaded_stacked_params = self._load_weights(weights)
|
||||||
|
|
||||||
|
loader = AutoWeightsLoader(self, skip_prefixes=["pooler."])
|
||||||
|
loaded_params = loader.load_weights(other_weights)
|
||||||
|
loaded_params.update(loaded_stacked_params)
|
||||||
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
|
class BertPoolingModel(BertModel):
|
||||||
|
|
||||||
|
is_pooling_model = True
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
prefix: str = "",
|
||||||
|
embedding_class: type[nn.Module] = BertEmbedding,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
prefix=prefix,
|
||||||
|
embedding_class=embedding_class,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
self.pooler = BertPooler(config)
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
|
torch.Tensor]]) -> set[str]:
|
||||||
|
other_weights, loaded_stacked_params = self._load_weights(weights)
|
||||||
|
|
||||||
|
loader = AutoWeightsLoader(self)
|
||||||
loaded_params = loader.load_weights(other_weights)
|
loaded_params = loader.load_weights(other_weights)
|
||||||
loaded_params.update(loaded_stacked_params)
|
loaded_params.update(loaded_stacked_params)
|
||||||
return loaded_params
|
return loaded_params
|
||||||
@ -421,6 +464,8 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
pooler_config = vllm_config.model_config.pooler_config
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
|
assert pooler_config is not None
|
||||||
|
|
||||||
self.model = self._build_model(vllm_config=vllm_config,
|
self.model = self._build_model(vllm_config=vllm_config,
|
||||||
prefix=maybe_prefix(prefix, "model"))
|
prefix=maybe_prefix(prefix, "model"))
|
||||||
self.pooler = self._build_pooler(pooler_config)
|
self.pooler = self._build_pooler(pooler_config)
|
||||||
@ -456,10 +501,15 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
|
|||||||
embedding_class=BertEmbedding)
|
embedding_class=BertEmbedding)
|
||||||
|
|
||||||
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
|
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
|
||||||
return Pooler.from_config_with_defaults(pooler_config,
|
return DispatchPooler({
|
||||||
pooling_type=PoolingType.CLS,
|
"encode":
|
||||||
normalize=True,
|
Pooler.for_encode(pooler_config),
|
||||||
softmax=False)
|
"embed":
|
||||||
|
Pooler.for_embed(
|
||||||
|
pooler_config,
|
||||||
|
default_pooling_type=PoolingType.CLS,
|
||||||
|
),
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
class BertForSequenceClassification(nn.Module, SupportsV0Only,
|
class BertForSequenceClassification(nn.Module, SupportsV0Only,
|
||||||
@ -481,16 +531,32 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
|
|||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
|
|
||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
self.bert = BertModel(vllm_config=vllm_config,
|
self.bert = BertPoolingModel(vllm_config=vllm_config,
|
||||||
prefix=maybe_prefix(prefix, "bert"),
|
prefix=maybe_prefix(prefix, "bert"),
|
||||||
embedding_class=BertEmbedding,
|
embedding_class=BertEmbedding)
|
||||||
add_pooling_layer=True)
|
|
||||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||||
self.pooler = ClassifierPooler(
|
|
||||||
vllm_config.model_config,
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
pooling=self.bert.pooler,
|
assert pooler_config is not None
|
||||||
classifier=self.classifier,
|
|
||||||
)
|
self.pooler = DispatchPooler({
|
||||||
|
"encode":
|
||||||
|
Pooler.for_encode(pooler_config),
|
||||||
|
"classify":
|
||||||
|
ClassifierPooler(
|
||||||
|
pooling=self.bert.pooler,
|
||||||
|
classifier=self.classifier,
|
||||||
|
act_fn=ClassifierPooler.act_fn_for_seq_cls(
|
||||||
|
vllm_config.model_config),
|
||||||
|
),
|
||||||
|
"score":
|
||||||
|
ClassifierPooler(
|
||||||
|
pooling=self.bert.pooler,
|
||||||
|
classifier=self.classifier,
|
||||||
|
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
|
||||||
|
vllm_config.model_config),
|
||||||
|
),
|
||||||
|
})
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||||
loader = AutoWeightsLoader(self)
|
loader = AutoWeightsLoader(self)
|
||||||
|
|||||||
@ -43,7 +43,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from ..layers.pooler import Pooler, PoolingType
|
from ..layers.pooler import DispatchPooler, Pooler
|
||||||
from .interfaces import SupportsPP
|
from .interfaces import SupportsPP
|
||||||
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
|
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory, make_layers,
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
@ -339,12 +339,16 @@ class GPT2ForSequenceClassification(nn.Module):
|
|||||||
self.transformer = GPT2Model(vllm_config=vllm_config,
|
self.transformer = GPT2Model(vllm_config=vllm_config,
|
||||||
prefix=maybe_prefix(prefix, "gpt2"))
|
prefix=maybe_prefix(prefix, "gpt2"))
|
||||||
self.score = nn.Linear(config.n_embd, config.num_labels, bias=False)
|
self.score = nn.Linear(config.n_embd, config.num_labels, bias=False)
|
||||||
|
|
||||||
pooler_config = vllm_config.model_config.pooler_config
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
self.pooler = Pooler.from_config_with_defaults(
|
assert pooler_config is not None
|
||||||
pooler_config,
|
|
||||||
pooling_type=PoolingType.LAST,
|
self.pooler = DispatchPooler({
|
||||||
normalize=False,
|
"encode":
|
||||||
softmax=True)
|
Pooler.for_encode(pooler_config),
|
||||||
|
"classify":
|
||||||
|
Pooler.for_classify(pooler_config, classifier=None),
|
||||||
|
})
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||||
loader = AutoWeightsLoader(self)
|
loader = AutoWeightsLoader(self)
|
||||||
|
|||||||
@ -1,17 +1,16 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from collections.abc import Set
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from typing_extensions import assert_never
|
|
||||||
|
|
||||||
from vllm.config import ModelConfig, VllmConfig
|
from vllm.config import ModelConfig, VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.pooler import (Pooler, PoolerHead,
|
from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler,
|
||||||
PoolerNormalize,
|
PoolerHead, PoolerNormalize,
|
||||||
PoolingParamsUpdate,
|
PoolingParamsUpdate,
|
||||||
build_output, get_prompt_lens,
|
build_output, get_prompt_lens,
|
||||||
get_prompt_token_ids)
|
get_prompt_token_ids)
|
||||||
@ -135,18 +134,11 @@ class GritLMMeanPool(nn.Module):
|
|||||||
|
|
||||||
return instruction_len
|
return instruction_len
|
||||||
|
|
||||||
def get_pooling_updates(
|
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||||
self,
|
return {"encode", "embed"}
|
||||||
task: PoolingTask,
|
|
||||||
) -> Optional[PoolingParamsUpdate]:
|
|
||||||
# The equalities are split up to keep mypy happy
|
|
||||||
if task == "encode" or task == "embed":
|
|
||||||
return PoolingParamsUpdate(requires_token_ids=True)
|
|
||||||
|
|
||||||
if task == "classify" or task == "score":
|
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||||
return None
|
return PoolingParamsUpdate(requires_token_ids=True)
|
||||||
|
|
||||||
assert_never(task)
|
|
||||||
|
|
||||||
def forward_one(
|
def forward_one(
|
||||||
self,
|
self,
|
||||||
@ -207,10 +199,10 @@ class GritLMPooler(Pooler):
|
|||||||
self.pooling = GritLMMeanPool(model_config)
|
self.pooling = GritLMMeanPool(model_config)
|
||||||
self.head = PoolerHead(PoolerNormalize())
|
self.head = PoolerHead(PoolerNormalize())
|
||||||
|
|
||||||
def get_pooling_updates(
|
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||||
self,
|
return self.pooling.get_supported_tasks()
|
||||||
task: PoolingTask,
|
|
||||||
) -> Optional[PoolingParamsUpdate]:
|
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||||
return self.pooling.get_pooling_updates(task)
|
return self.pooling.get_pooling_updates(task)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -262,4 +254,11 @@ class GritLM(LlamaForCausalLM, SupportsV0Only):
|
|||||||
|
|
||||||
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||||
|
|
||||||
self.pooler = GritLMPooler(vllm_config.model_config)
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
|
if pooler_config is not None:
|
||||||
|
self.pooler = DispatchPooler({
|
||||||
|
"encode":
|
||||||
|
Pooler.for_encode(pooler_config),
|
||||||
|
"embed":
|
||||||
|
GritLMPooler(vllm_config.model_config),
|
||||||
|
})
|
||||||
|
|||||||
@ -22,7 +22,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
|||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
@ -429,12 +429,10 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
|
|||||||
)
|
)
|
||||||
|
|
||||||
pooler_config = vllm_config.model_config.pooler_config
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
self.pooler = Pooler.from_config_with_defaults(
|
assert pooler_config is not None
|
||||||
pooler_config,
|
|
||||||
pooling_type=PoolingType.ALL,
|
self.pooler = DispatchPooler(
|
||||||
normalize=False,
|
{"encode": Pooler.for_encode(pooler_config)}, )
|
||||||
softmax=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -19,8 +19,8 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
|||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
|
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
|
||||||
from vllm.model_executor.layers.pooler import (ClassifierPooler, PoolingType,
|
from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler,
|
||||||
SimplePooler)
|
PoolingType)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||||
@ -584,16 +584,15 @@ class JambaForSequenceClassification(JambaForCausalLM):
|
|||||||
pooler_config = vllm_config.model_config.pooler_config
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
assert pooler_config is not None
|
assert pooler_config is not None
|
||||||
|
|
||||||
pooler = SimplePooler.from_config_with_defaults(
|
self.pooler = DispatchPooler({
|
||||||
pooler_config,
|
"encode":
|
||||||
pooling_type=PoolingType.LAST,
|
Pooler.for_encode(pooler_config),
|
||||||
normalize=False,
|
"classify":
|
||||||
softmax=False,
|
Pooler.for_classify(
|
||||||
)
|
pooler_config,
|
||||||
|
classifier=self.score,
|
||||||
self.pooler = ClassifierPooler(
|
default_pooling_type=PoolingType.LAST,
|
||||||
vllm_config.model_config,
|
default_normalize=False,
|
||||||
pooling=pooler.pooling,
|
default_softmax=False,
|
||||||
classifier=self.score,
|
),
|
||||||
act_fn=pooler.head.activation,
|
})
|
||||||
)
|
|
||||||
|
|||||||
@ -12,7 +12,7 @@ from vllm.inputs import TokensPrompt
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
@ -96,11 +96,17 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration,
|
|||||||
|
|
||||||
self.score = JinaVLScorer(config)
|
self.score = JinaVLScorer(config)
|
||||||
|
|
||||||
self.pooler = Pooler.from_config_with_defaults(
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
pooler_config,
|
assert pooler_config is not None
|
||||||
pooling_type=PoolingType.LAST,
|
|
||||||
normalize=False,
|
self.pooler = DispatchPooler({
|
||||||
softmax=True)
|
"encode":
|
||||||
|
Pooler.for_encode(pooler_config),
|
||||||
|
"classify":
|
||||||
|
Pooler.for_classify(pooler_config, classifier=None),
|
||||||
|
"score":
|
||||||
|
Pooler.for_classify(pooler_config, classifier=None),
|
||||||
|
})
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable, Set
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -13,7 +13,8 @@ from vllm.config import VllmConfig
|
|||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler,
|
from vllm.model_executor.layers.pooler import (ClassifierPooler,
|
||||||
|
DispatchPooler, Pooler,
|
||||||
PoolingMethod,
|
PoolingMethod,
|
||||||
PoolingParamsUpdate,
|
PoolingParamsUpdate,
|
||||||
PoolingType)
|
PoolingType)
|
||||||
@ -271,19 +272,27 @@ class ModernBertPooler(Pooler):
|
|||||||
eps=config.norm_eps,
|
eps=config.norm_eps,
|
||||||
bias=config.norm_bias)
|
bias=config.norm_bias)
|
||||||
|
|
||||||
def get_pooling_updates(
|
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||||
self,
|
return self.pooling.get_supported_tasks()
|
||||||
task: PoolingTask,
|
|
||||||
) -> Optional[PoolingParamsUpdate]:
|
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||||
return self.pooling.get_pooling_updates(task)
|
return self.pooling.get_pooling_updates(task)
|
||||||
|
|
||||||
|
def _head(self, pooled_output: torch.Tensor):
|
||||||
|
return self.norm(self.act(self.dense(pooled_output)))
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||||
pooling_metadata: PoolingMetadata,
|
pooling_metadata: PoolingMetadata,
|
||||||
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
||||||
pooled_output = self.pooling(hidden_states, pooling_metadata)
|
pooled_output = self.pooling(hidden_states, pooling_metadata)
|
||||||
pooled_output = self.norm(self.act(self.dense(pooled_output)))
|
|
||||||
|
if isinstance(pooled_output, list):
|
||||||
|
pooled_output = [self._head(output) for output in pooled_output]
|
||||||
|
else:
|
||||||
|
pooled_output = self._head(pooled_output)
|
||||||
|
|
||||||
return pooled_output
|
return pooled_output
|
||||||
|
|
||||||
|
|
||||||
@ -299,11 +308,28 @@ class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
|
|||||||
self.model = ModernBertModel(vllm_config=vllm_config,
|
self.model = ModernBertModel(vllm_config=vllm_config,
|
||||||
prefix=maybe_prefix(prefix, "modernbert"))
|
prefix=maybe_prefix(prefix, "modernbert"))
|
||||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||||
self.pooler = ClassifierPooler(
|
|
||||||
vllm_config.model_config,
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
pooling=ModernBertPooler(config),
|
assert pooler_config is not None
|
||||||
classifier=self.classifier,
|
|
||||||
)
|
self.pooler = DispatchPooler({
|
||||||
|
"encode":
|
||||||
|
Pooler.for_encode(pooler_config),
|
||||||
|
"classify":
|
||||||
|
ClassifierPooler(
|
||||||
|
pooling=ModernBertPooler(config),
|
||||||
|
classifier=self.classifier,
|
||||||
|
act_fn=ClassifierPooler.act_fn_for_seq_cls(
|
||||||
|
vllm_config.model_config),
|
||||||
|
),
|
||||||
|
"score":
|
||||||
|
ClassifierPooler(
|
||||||
|
pooling=ModernBertPooler(config),
|
||||||
|
classifier=self.classifier,
|
||||||
|
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
|
||||||
|
vllm_config.model_config),
|
||||||
|
),
|
||||||
|
})
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||||
|
|
||||||
|
|||||||
@ -15,7 +15,8 @@ from torch import nn
|
|||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType, SimplePooler
|
from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler,
|
||||||
|
PoolingType)
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import SupportsLoRA, SupportsPP
|
from .interfaces import SupportsLoRA, SupportsPP
|
||||||
@ -26,7 +27,7 @@ from .utils import AutoWeightsLoader, maybe_prefix
|
|||||||
class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
|
class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||||
|
|
||||||
is_pooling_model = True
|
is_pooling_model = True
|
||||||
pooler: SimplePooler
|
pooler: Pooler
|
||||||
|
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": [
|
"qkv_proj": [
|
||||||
@ -94,12 +95,12 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel):
|
|||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
vllm_config.model_config.hf_config.num_labels = 1
|
vllm_config.model_config.hf_config.num_labels = 1
|
||||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||||
|
|
||||||
pooler_config = vllm_config.model_config.pooler_config
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
self.pooler = Pooler.from_config_with_defaults(
|
assert pooler_config is not None
|
||||||
pooler_config,
|
|
||||||
pooling_type=PoolingType.ALL,
|
self.pooler = DispatchPooler(
|
||||||
normalize=False,
|
{"encode": Pooler.for_encode(pooler_config)}, )
|
||||||
softmax=False)
|
|
||||||
|
|
||||||
|
|
||||||
class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
|
class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
|
||||||
@ -107,11 +108,17 @@ class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
|
|||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
vllm_config.model_config.hf_config.num_labels = 2
|
vllm_config.model_config.hf_config.num_labels = 2
|
||||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||||
|
|
||||||
pooler_config = vllm_config.model_config.pooler_config
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
self.pooler = Pooler.from_config_with_defaults(
|
assert pooler_config is not None
|
||||||
pooler_config,
|
|
||||||
pooling_type=PoolingType.STEP,
|
self.pooler = DispatchPooler({
|
||||||
normalize=False,
|
"encode":
|
||||||
softmax=True,
|
Pooler.for_encode(
|
||||||
step_tag_id=151651,
|
pooler_config,
|
||||||
)
|
default_pooling_type=PoolingType.STEP,
|
||||||
|
default_normalize=False,
|
||||||
|
default_softmax=True,
|
||||||
|
default_step_tag_id=151651,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|||||||
@ -9,7 +9,8 @@ from torch import nn
|
|||||||
from transformers import RobertaConfig
|
from transformers import RobertaConfig
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.model_executor.layers.pooler import ClassifierPooler, CLSPool
|
from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool,
|
||||||
|
DispatchPooler, Pooler)
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding)
|
VocabParallelEmbedding)
|
||||||
from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel
|
from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel
|
||||||
@ -63,16 +64,10 @@ class RobertaEmbedding(nn.Module):
|
|||||||
# References:
|
# References:
|
||||||
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
|
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
|
||||||
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
|
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
|
||||||
pos_list = []
|
seq_lens_list = seq_lens.tolist()
|
||||||
token_list = []
|
|
||||||
offset = 0
|
|
||||||
for seq_len in seq_lens:
|
|
||||||
pos_list.append(position_ids[offset:offset + seq_len])
|
|
||||||
token_list.append(input_ids[offset:offset + seq_len])
|
|
||||||
offset += seq_len
|
|
||||||
|
|
||||||
new_pos_list = []
|
new_pos_list = []
|
||||||
for positions, tokens in zip(pos_list, token_list):
|
for positions, tokens in zip(position_ids.split(seq_lens_list),
|
||||||
|
input_ids.split(seq_lens_list)):
|
||||||
# Verify assumption that incoming position are
|
# Verify assumption that incoming position are
|
||||||
# always a sequence from 0 to N.
|
# always a sequence from 0 to N.
|
||||||
expected_pos = torch.arange(positions.size()[0],
|
expected_pos = torch.arange(positions.size()[0],
|
||||||
@ -184,15 +179,30 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
|||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
self.roberta = BertModel(vllm_config=vllm_config,
|
self.roberta = BertModel(vllm_config=vllm_config,
|
||||||
prefix=maybe_prefix(prefix, "bert"),
|
prefix=maybe_prefix(prefix, "bert"),
|
||||||
embedding_class=RobertaEmbedding,
|
embedding_class=RobertaEmbedding)
|
||||||
add_pooling_layer=False)
|
|
||||||
self.classifier = RobertaClassificationHead(config)
|
self.classifier = RobertaClassificationHead(config)
|
||||||
|
|
||||||
self.pooler = ClassifierPooler(
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
vllm_config.model_config,
|
assert pooler_config is not None
|
||||||
pooling=CLSPool(),
|
|
||||||
classifier=self.classifier,
|
self.pooler = DispatchPooler({
|
||||||
)
|
"encode":
|
||||||
|
Pooler.for_encode(pooler_config),
|
||||||
|
"classify":
|
||||||
|
ClassifierPooler(
|
||||||
|
pooling=CLSPool(),
|
||||||
|
classifier=self.classifier,
|
||||||
|
act_fn=ClassifierPooler.act_fn_for_seq_cls(
|
||||||
|
vllm_config.model_config),
|
||||||
|
),
|
||||||
|
"score":
|
||||||
|
ClassifierPooler(
|
||||||
|
pooling=CLSPool(),
|
||||||
|
classifier=self.classifier,
|
||||||
|
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
|
||||||
|
vllm_config.model_config),
|
||||||
|
),
|
||||||
|
})
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||||
loader = AutoWeightsLoader(self)
|
loader = AutoWeightsLoader(self)
|
||||||
|
|||||||
@ -38,6 +38,13 @@ class PoolingMetadata:
|
|||||||
f"seq_data={self.seq_data}, "
|
f"seq_data={self.seq_data}, "
|
||||||
f"prompt_lens={self.prompt_lens})")
|
f"prompt_lens={self.prompt_lens})")
|
||||||
|
|
||||||
|
def __getitem__(self, indices: slice):
|
||||||
|
return PoolingMetadata(
|
||||||
|
seq_groups=self.seq_groups[indices],
|
||||||
|
seq_data=dict(list(self.seq_data.items())[indices]),
|
||||||
|
prompt_lens=self.prompt_lens[indices],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PoolingTensors:
|
class PoolingTensors:
|
||||||
|
|||||||
@ -15,3 +15,11 @@ class PoolingMetadata:
|
|||||||
prompt_lens: torch.Tensor
|
prompt_lens: torch.Tensor
|
||||||
prompt_token_ids: Optional[torch.Tensor]
|
prompt_token_ids: Optional[torch.Tensor]
|
||||||
pooling_params: list[PoolingParams]
|
pooling_params: list[PoolingParams]
|
||||||
|
|
||||||
|
def __getitem__(self, indices: slice):
|
||||||
|
return PoolingMetadata(
|
||||||
|
prompt_lens=self.prompt_lens[indices],
|
||||||
|
prompt_token_ids=None if self.prompt_token_ids is None else
|
||||||
|
self.prompt_token_ids[indices],
|
||||||
|
pooling_params=self.pooling_params[indices],
|
||||||
|
)
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import copy
|
|||||||
import gc
|
import gc
|
||||||
import time
|
import time
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast, get_args
|
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -415,15 +415,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
generator = None
|
generator = None
|
||||||
|
|
||||||
if pooling_params:
|
if pooling_params:
|
||||||
assert pooling_params.task is not None, (
|
assert (task := pooling_params.task) is not None, (
|
||||||
"You did not set `task` in the API")
|
"You did not set `task` in the API")
|
||||||
|
|
||||||
model = cast(VllmModelForPooling, self.model)
|
model = cast(VllmModelForPooling, self.model)
|
||||||
to_update = (model.pooler.get_pooling_updates(
|
to_update = model.pooler.get_pooling_updates(task)
|
||||||
pooling_params.task))
|
|
||||||
assert to_update is not None, (
|
|
||||||
f"{pooling_params.task=} is not supported by the model")
|
|
||||||
|
|
||||||
to_update.apply(pooling_params)
|
to_update.apply(pooling_params)
|
||||||
|
|
||||||
self.requests[req_id] = CachedRequestState(
|
self.requests[req_id] = CachedRequestState(
|
||||||
@ -1122,10 +1118,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
if not is_pooling_model(model):
|
if not is_pooling_model(model):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
return [
|
return list(model.pooler.get_supported_tasks())
|
||||||
task for task in get_args(PoolingTask)
|
|
||||||
if model.pooler.get_pooling_updates(task)
|
|
||||||
]
|
|
||||||
|
|
||||||
def apply_grammar_bitmask(
|
def apply_grammar_bitmask(
|
||||||
self,
|
self,
|
||||||
@ -2247,7 +2240,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
dummy_pooling_params = PoolingParams(task=dummy_task)
|
dummy_pooling_params = PoolingParams(task=dummy_task)
|
||||||
|
|
||||||
to_update = model.pooler.get_pooling_updates(dummy_task)
|
to_update = model.pooler.get_pooling_updates(dummy_task)
|
||||||
assert to_update is not None
|
|
||||||
to_update.apply(dummy_pooling_params)
|
to_update.apply(dummy_pooling_params)
|
||||||
|
|
||||||
dummy_metadata = PoolingMetadata(
|
dummy_metadata = PoolingMetadata(
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
import bisect
|
import bisect
|
||||||
import gc
|
import gc
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING, Any, Optional, cast, get_args
|
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -491,10 +491,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
if not is_pooling_model(model):
|
if not is_pooling_model(model):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
return [
|
return list(model.pooler.get_supported_tasks())
|
||||||
task for task in get_args(PoolingTask)
|
|
||||||
if model.pooler.get_pooling_updates(task)
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -4,7 +4,7 @@
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
|
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
|
||||||
TypeVar, get_args)
|
TypeVar)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -230,10 +230,7 @@ class ModelRunnerBase(ABC, Generic[T]):
|
|||||||
if not is_pooling_model(model):
|
if not is_pooling_model(model):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
return [
|
return list(model.pooler.get_supported_tasks())
|
||||||
task for task in get_args(PoolingTask)
|
|
||||||
if model.pooler.get_pooling_updates(task)
|
|
||||||
]
|
|
||||||
|
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -199,15 +199,11 @@ class PoolingModelRunner(
|
|||||||
|
|
||||||
pooling_params = seq_group_metadata.pooling_params
|
pooling_params = seq_group_metadata.pooling_params
|
||||||
assert pooling_params is not None
|
assert pooling_params is not None
|
||||||
assert pooling_params.task is not None, (
|
assert (task := pooling_params.task) is not None, (
|
||||||
"You did not set `task` in the API")
|
"You did not set `task` in the API")
|
||||||
|
|
||||||
to_update = (cast(VllmModelForPooling,
|
model = cast(VllmModelForPooling, self.model)
|
||||||
self.model).pooler.get_pooling_updates(
|
to_update = model.pooler.get_pooling_updates(task)
|
||||||
pooling_params.task))
|
|
||||||
assert to_update is not None, (
|
|
||||||
f"{pooling_params.task=} is not supported by the model")
|
|
||||||
|
|
||||||
to_update.apply(pooling_params)
|
to_update.apply(pooling_params)
|
||||||
|
|
||||||
seq_groups.append((seq_ids, pooling_params))
|
seq_groups.append((seq_ids, pooling_params))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user