Move PoolerConfig from config/__init__.py to config/pooler.py (#25181)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-09-19 12:02:55 +01:00 committed by GitHub
parent 1dfea5f4a9
commit 058525b997
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 193 additions and 179 deletions

View File

@ -59,7 +59,7 @@ enabling the corresponding APIs:
#### Predefined models
If the [Pooler][vllm.model_executor.layers.pooler.Pooler] defined by the model accepts `pooler_config`,
you can override some of its attributes via the `--override-pooler-config` option.
you can override some of its attributes via the `--pooler-config` option.
#### Converted models
@ -75,7 +75,7 @@ the pooler assigned to each task has the following attributes by default:
When loading [Sentence Transformers](https://huggingface.co/sentence-transformers) models,
its Sentence Transformers configuration file (`modules.json`) takes priority over the model's defaults.
You can further customize this via the `--override-pooler-config` option,
You can further customize this via the `--pooler-config` option,
which takes priority over both the model's and Sentence Transformers's defaults.
## Offline Inference

View File

@ -457,7 +457,7 @@ These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) A
!!! note
`ssmits/Qwen2-7B-Instruct-embed-base` has an improperly defined Sentence Transformers config.
You need to manually set mean pooling by passing `--override-pooler-config '{"pooling_type": "MEAN"}'`.
You need to manually set mean pooling by passing `--pooler-config '{"pooling_type": "MEAN"}'`.
!!! note
For `Alibaba-NLP/gte-Qwen2-*`, you need to enable `--trust-remote-code` for the correct tokenizer to be loaded.
@ -552,7 +552,7 @@ If your model is not in the above list, we will try to automatically convert the
!!! important
For process-supervised reward models such as `peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly,
e.g.: `--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`.
e.g.: `--pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`.
#### Token Classification

View File

@ -42,7 +42,7 @@ python client.py
### Server Configuration
The key parameters for chunked processing are in the `--override-pooler-config`:
The key parameters for chunked processing are in the `--pooler-config`:
```json
{

View File

@ -13,7 +13,7 @@ Prerequisites:
# MEAN pooling (processes all chunks, recommended for complete coverage)
vllm serve intfloat/multilingual-e5-large \
--override-pooler-config \
--pooler-config \
'{"pooling_type": "MEAN", "normalize": true, ' \
'"enable_chunked_processing": true, "max_embed_len": 3072000}' \
--served-model-name multilingual-e5-large \
@ -23,7 +23,7 @@ Prerequisites:
# OR CLS pooling (native CLS within chunks, MEAN aggregation across chunks)
vllm serve BAAI/bge-large-en-v1.5 \
--override-pooler-config \
--pooler-config \
'{"pooling_type": "CLS", "normalize": true, ' \
'"enable_chunked_processing": true, "max_embed_len": 1048576}' \
--served-model-name bge-large-en-v1.5 \

View File

@ -103,7 +103,7 @@ POOLER_CONFIG="{\"pooling_type\": \"$POOLING_TYPE\", \"normalize\": true, \"enab
vllm serve "$MODEL_NAME" \
--tensor-parallel-size "$GPU_COUNT" \
--enforce-eager \
--override-pooler-config "$POOLER_CONFIG" \
--pooler-config "$POOLER_CONFIG" \
--served-model-name ${MODEL_CODE} \
--api-key "$API_KEY" \
--trust-remote-code \

View File

@ -216,7 +216,7 @@ def server_with_chunked_processing():
"--enforce-eager",
"--max-model-len",
"512", # Set smaller max_model_len to trigger chunking mechanism
'--override-pooler-config',
'--pooler-config',
('{"pooling_type": "MEAN", "normalize": true, '
'"enable_chunked_processing": true, "max_embed_len": 10000}'),
"--gpu-memory-utilization",

View File

@ -58,7 +58,7 @@ def test_models(
vllm_extra_kwargs = {}
if model == "ssmits/Qwen2-7B-Instruct-embed-base":
vllm_extra_kwargs["override_pooler_config"] = \
vllm_extra_kwargs["pooler_config"] = \
PoolerConfig(pooling_type="MEAN", normalize=False)
max_model_len: Optional[int] = 512

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.config.pooler import PoolerConfig
from vllm.platforms import current_platform
@ -99,7 +100,7 @@ def test_gemma_multimodal(
convert="classify",
load_format="auto",
hf_overrides=update_config,
override_pooler_config={"pooling_type": "LAST"},
pooler_config=PoolerConfig(pooling_type="LAST"),
max_model_len=512,
enforce_eager=True,
tensor_parallel_size=1,

View File

@ -24,18 +24,18 @@ def test_classify_models_using_activation(
dtype: str,
) -> None:
with vllm_runner(model,
max_model_len=512,
dtype=dtype,
override_pooler_config=PoolerConfig(
activation=False)) as vllm_model:
with vllm_runner(
model,
max_model_len=512,
dtype=dtype,
pooler_config=PoolerConfig(activation=False)) as vllm_model:
wo_activation_out = vllm_model.classify(example_prompts)
with vllm_runner(model,
max_model_len=512,
dtype=dtype,
override_pooler_config=PoolerConfig(
activation=True)) as vllm_model:
with vllm_runner(
model,
max_model_len=512,
dtype=dtype,
pooler_config=PoolerConfig(activation=True)) as vllm_model:
w_activation_out = vllm_model.classify(example_prompts)
for wo_activation, w_activation in zip(wo_activation_out,
@ -43,9 +43,8 @@ def test_classify_models_using_activation(
wo_activation = torch.tensor(wo_activation)
w_activation = torch.tensor(w_activation)
assert not torch.allclose(
wo_activation, w_activation,
atol=1e-2), "override_pooler_config is not working"
assert not torch.allclose(wo_activation, w_activation,
atol=1e-2), "pooler_config is not working"
assert torch.allclose(softmax(wo_activation), w_activation,
1e-3 if dtype == "float" else 1e-2)
@ -65,23 +64,22 @@ def test_embed_models_using_normalize(
dtype: str,
) -> None:
with vllm_runner(model,
max_model_len=512,
dtype=dtype,
override_pooler_config=PoolerConfig(
normalize=False)) as vllm_model:
wo_normalize = torch.tensor(vllm_model.embed(example_prompts))
with vllm_runner(
model,
max_model_len=512,
dtype=dtype,
override_pooler_config=PoolerConfig(normalize=True)) as vllm_model:
pooler_config=PoolerConfig(normalize=False)) as vllm_model:
wo_normalize = torch.tensor(vllm_model.embed(example_prompts))
with vllm_runner(model,
max_model_len=512,
dtype=dtype,
pooler_config=PoolerConfig(normalize=True)) as vllm_model:
w_normalize = torch.tensor(vllm_model.embed(example_prompts))
assert not torch.allclose(
wo_normalize, w_normalize,
atol=1e-2), "override_pooler_config normalize is not working"
atol=1e-2), "pooler_config normalize is not working"
assert torch.allclose(
F.normalize(wo_normalize, p=2, dim=-1), w_normalize,
atol=1e-2), "w_normal should be close to normal(wo_normal)."
@ -102,18 +100,16 @@ def test_reward_models_using_softmax(
dtype: str,
) -> None:
with vllm_runner(
model,
max_model_len=1024,
dtype=dtype,
override_pooler_config=PoolerConfig(softmax=False)) as vllm_model:
with vllm_runner(model,
max_model_len=1024,
dtype=dtype,
pooler_config=PoolerConfig(softmax=False)) as vllm_model:
wo_softmax = vllm_model.encode(example_prompts)
with vllm_runner(
model,
max_model_len=1024,
dtype=dtype,
override_pooler_config=PoolerConfig(softmax=True)) as vllm_model:
with vllm_runner(model,
max_model_len=1024,
dtype=dtype,
pooler_config=PoolerConfig(softmax=True)) as vllm_model:
w_softmax = vllm_model.encode(example_prompts)
for wo, w in zip(wo_softmax, w_softmax):
@ -121,7 +117,7 @@ def test_reward_models_using_softmax(
w = torch.tensor(w)
assert not torch.allclose(
wo, w, atol=1e-2), "override_pooler_config softmax is not working"
wo, w, atol=1e-2), "pooler_config softmax is not working"
assert torch.allclose(
softmax(wo), w,
atol=1e-2), "w_softmax should be close to softmax(wo_softmax)."

View File

@ -207,25 +207,19 @@ def test_get_pooling_config():
model_id = "sentence-transformers/all-MiniLM-L12-v2"
model_config = ModelConfig(model_id)
pooling_config = model_config._init_pooler_config()
assert pooling_config is not None
assert pooling_config.normalize
assert pooling_config.pooling_type == PoolingType.MEAN.name
assert model_config.pooler_config is not None
assert model_config.pooler_config.normalize
assert model_config.pooler_config.pooling_type == PoolingType.MEAN.name
@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
def test_get_pooling_config_from_args():
model_id = "sentence-transformers/all-MiniLM-L12-v2"
model_config = ModelConfig(model_id)
pooler_config = PoolerConfig(pooling_type="CLS", normalize=True)
model_config = ModelConfig(model_id, pooler_config=pooler_config)
override_pooler_config = PoolerConfig(pooling_type='CLS', normalize=True)
model_config.override_pooler_config = override_pooler_config
pooling_config = model_config._init_pooler_config()
assert pooling_config is not None
assert asdict(pooling_config) == asdict(override_pooler_config)
assert asdict(model_config.pooler_config) == asdict(pooler_config)
@pytest.mark.parametrize(

View File

@ -40,6 +40,7 @@ from vllm.config.multimodal import (MMCacheType, MMEncoderTPMode,
MultiModalConfig)
from vllm.config.parallel import (DistributedExecutorBackend, EPLBConfig,
ParallelConfig)
from vllm.config.pooler import PoolerConfig
from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy
from vllm.config.speculative import SpeculativeConfig
from vllm.config.structured_outputs import StructuredOutputsConfig
@ -406,13 +407,6 @@ class ModelConfig:
hf_overrides: HfOverrides = field(default_factory=dict)
"""If a dictionary, contains arguments to be forwarded to the Hugging Face
config. If a callable, it is called to update the HuggingFace config."""
pooler_config: Optional["PoolerConfig"] = field(init=False)
"""Pooler config which controls the behaviour of output pooling in pooling
models."""
override_pooler_config: Optional[Union[dict, "PoolerConfig"]] = None
"""Initialize non-default pooling config or override default pooling config
for the pooling model. e.g. `{"pooling_type": "mean", "normalize": false}`.
"""
logits_processor_pattern: Optional[str] = None
"""Optional regex pattern specifying valid logits processor qualified names
that can be passed with the `logits_processors` extra completion argument.
@ -448,6 +442,14 @@ class ModelConfig:
io_processor_plugin: Optional[str] = None
"""IOProcessor plugin name to load at model startup"""
# Pooler config
pooler_config: Optional[PoolerConfig] = None
"""Pooler config which controls the behaviour of output pooling in pooling
models."""
override_pooler_config: Optional[Union[dict, PoolerConfig]] = None
"""[DEPRECATED] Use `pooler_config` instead. This field will be removed in
v0.12.0 or v1.0.0, whichever is sooner."""
# Multimodal config and init vars
multimodal_config: Optional[MultiModalConfig] = None
"""Configuration for multimodal model. If `None`, this will be inferred
@ -709,7 +711,33 @@ class ModelConfig:
self._architecture = arch
logger.info("Resolved architecture: %s", arch)
self.pooler_config = self._init_pooler_config()
# Init pooler config if needed
if self.runner_type == "pooling":
if self.override_pooler_config is not None:
logger.warning_once(
"`override_pooler_config` is deprecated and will be "
"removed in v0.12.0 or v1.0.0, whichever is sooner. "
"Please use `pooler_config` instead.")
if isinstance(self.override_pooler_config, dict):
self.pooler_config = PoolerConfig(
**self.override_pooler_config)
else:
self.pooler_config = self.override_pooler_config
if self.pooler_config is None:
self.pooler_config = PoolerConfig()
base_config = get_pooling_config(self.model, self.revision)
if base_config is not None:
# Only set values that are not overridden by the user
for k, v in base_config.items():
if getattr(self.pooler_config, k) is None:
setattr(self.pooler_config, k, v)
default_pooling_type = self._model_info.default_pooling_type
if self.pooler_config.pooling_type is None:
self.pooler_config.pooling_type = default_pooling_type
self.dtype: torch.dtype = _get_and_verify_dtype(
self.model,
@ -869,29 +897,6 @@ class ModelConfig:
return get_sentence_transformer_tokenizer_config(
self.model, self.revision)
def _init_pooler_config(self) -> Optional["PoolerConfig"]:
if self.runner_type == "pooling":
if isinstance(self.override_pooler_config, dict):
self.override_pooler_config = PoolerConfig(
**self.override_pooler_config)
pooler_config = self.override_pooler_config or PoolerConfig()
base_config = get_pooling_config(self.model, self.revision)
if base_config is not None:
# Only set values that are not overridden by the user
for k, v in base_config.items():
if getattr(pooler_config, k) is None:
setattr(pooler_config, k, v)
default_pooling_type = self._model_info.default_pooling_type
if pooler_config.pooling_type is None:
pooler_config.pooling_type = default_pooling_type
return pooler_config
return None
def _verify_tokenizer_mode(self) -> None:
tokenizer_mode = cast(TokenizerMode, self.tokenizer_mode.lower())
if tokenizer_mode not in get_args(TokenizerMode):
@ -1833,94 +1838,6 @@ class DeviceConfig:
self.device = torch.device(self.device_type)
@config
@dataclass
class PoolerConfig:
"""Controls the behavior of output pooling in pooling models."""
pooling_type: Optional[str] = None
"""
The pooling method of the pooling model. This should be a key in
[`vllm.model_executor.layers.pooler.PoolingType`][].
"""
## for embeddings models
normalize: Optional[bool] = None
"""
Whether to normalize the embeddings outputs. Defaults to True.
"""
dimensions: Optional[int] = None
"""
Reduce the dimensions of embeddings if model
support matryoshka representation. Defaults to None.
"""
enable_chunked_processing: Optional[bool] = None
"""
Whether to enable chunked processing for long inputs that exceed the model's
maximum position embeddings. When enabled, long inputs will be split into
chunks, processed separately, and then aggregated using weighted averaging.
This allows embedding models to handle arbitrarily long text without CUDA
errors. Defaults to False.
"""
max_embed_len: Optional[int] = None
"""
Maximum input length allowed for embedding generation. When set, allows
inputs longer than max_embed_len to be accepted for embedding models.
When an input exceeds max_embed_len, it will be handled according to
the original max_model_len validation logic.
Defaults to None (i.e. set to max_model_len).
"""
## for classification models
activation: Optional[bool] = None
"""
Whether to apply activation function to the classification outputs.
Defaults to True.
"""
logit_bias: Optional[float] = None
"""
If provided, apply classification logit biases. Defaults to None.
"""
## for reward models
softmax: Optional[bool] = None
"""
Whether to apply softmax to the reward outputs.
Defaults to True.
"""
step_tag_id: Optional[int] = None
"""
If set, only the score corresponding to the ``step_tag_id`` in the
generated sentence should be returned. Otherwise, the scores for all tokens
are returned.
"""
returned_token_ids: Optional[list[int]] = None
"""
A list of indices for the vocabulary dimensions to be extracted,
such as the token IDs of ``good_token`` and ``bad_token`` in the
``math-shepherd-mistral-7b-prm`` model.
"""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest()
return hash_str
_STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.float16,
"float16": torch.float16,

97
vllm/config/pooler.py Normal file
View File

@ -0,0 +1,97 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
from typing import Any, Optional
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
@config
@dataclass
class PoolerConfig:
"""Controls the behavior of output pooling in pooling models."""
pooling_type: Optional[str] = None
"""
The pooling method of the pooling model. This should be a key in
[`vllm.model_executor.layers.pooler.PoolingType`][].
"""
## for embeddings models
normalize: Optional[bool] = None
"""
Whether to normalize the embeddings outputs. Defaults to True.
"""
dimensions: Optional[int] = None
"""
Reduce the dimensions of embeddings if model
support matryoshka representation. Defaults to None.
"""
enable_chunked_processing: Optional[bool] = None
"""
Whether to enable chunked processing for long inputs that exceed the model's
maximum position embeddings. When enabled, long inputs will be split into
chunks, processed separately, and then aggregated using weighted averaging.
This allows embedding models to handle arbitrarily long text without CUDA
errors. Defaults to False.
"""
max_embed_len: Optional[int] = None
"""
Maximum input length allowed for embedding generation. When set, allows
inputs longer than max_embed_len to be accepted for embedding models.
When an input exceeds max_embed_len, it will be handled according to
the original max_model_len validation logic.
Defaults to None (i.e. set to max_model_len).
"""
## for classification models
activation: Optional[bool] = None
"""
Whether to apply activation function to the classification outputs.
Defaults to True.
"""
logit_bias: Optional[float] = None
"""
If provided, apply classification logit biases. Defaults to None.
"""
## for reward models
softmax: Optional[bool] = None
"""
Whether to apply softmax to the reward outputs.
Defaults to True.
"""
step_tag_id: Optional[int] = None
"""
If set, only the score corresponding to the ``step_tag_id`` in the
generated sentence should be returned. Otherwise, the scores for all tokens
are returned.
"""
returned_token_ids: Optional[list[int]] = None
"""
A list of indices for the vocabulary dimensions to be extracted,
such as the token IDs of ``good_token`` and ``bad_token`` in the
``math-shepherd-mistral-7b-prm`` model.
"""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest()
return hash_str

View File

@ -441,6 +441,7 @@ class EngineArgs:
scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls
pooler_config: Optional[PoolerConfig] = ModelConfig.pooler_config
override_pooler_config: Optional[Union[dict, PoolerConfig]] = \
ModelConfig.override_pooler_config
compilation_config: CompilationConfig = \
@ -579,8 +580,11 @@ class EngineArgs:
help=model_kwargs["hf_token"]["help"])
model_group.add_argument("--hf-overrides",
**model_kwargs["hf_overrides"])
model_group.add_argument("--pooler-config",
**model_kwargs["pooler_config"])
model_group.add_argument("--override-pooler-config",
**model_kwargs["override_pooler_config"])
**model_kwargs["override_pooler_config"],
deprecated=True)
model_group.add_argument("--logits-processor-pattern",
**model_kwargs["logits_processor_pattern"])
model_group.add_argument("--generation-config",
@ -1031,6 +1035,7 @@ class EngineArgs:
mm_shm_cache_max_object_size_mb=self.
mm_shm_cache_max_object_size_mb,
mm_encoder_tp_mode=self.mm_encoder_tp_mode,
pooler_config=self.pooler_config,
override_pooler_config=self.override_pooler_config,
logits_processor_pattern=self.logits_processor_pattern,
generation_config=self.generation_config,

View File

@ -151,9 +151,11 @@ class LLM:
multi-modal processor obtained from `AutoProcessor.from_pretrained`.
The available overrides depend on the model that is being run.
For example, for Phi-3-Vision: `{"num_crops": 4}`.
override_pooler_config: Initialize non-default pooling config or
override default pooling config for the pooling model.
e.g. `PoolerConfig(pooling_type="mean", normalize=False)`.
pooler_config: Initialize non-default pooling config for the pooling
model. e.g. `PoolerConfig(pooling_type="mean", normalize=False)`.
override_pooler_config: [DEPRECATED] Use `pooler_config` instead. This
argument is deprecated and will be removed in v0.12.0 or v1.0.0,
whichever is sooner.
compilation_config: Either an integer or a dictionary. If it is an
integer, it is used as the level of compilation optimization. If it
is a dictionary, it can specify the full compilation configuration.
@ -191,6 +193,7 @@ class LLM:
hf_token: Optional[Union[bool, str]] = None,
hf_overrides: Optional[HfOverrides] = None,
mm_processor_kwargs: Optional[dict[str, Any]] = None,
pooler_config: Optional[PoolerConfig] = None,
override_pooler_config: Optional[PoolerConfig] = None,
structured_outputs_config: Optional[Union[dict[
str, Any], StructuredOutputsConfig]] = None,
@ -288,6 +291,7 @@ class LLM:
hf_token=hf_token,
hf_overrides=hf_overrides,
mm_processor_kwargs=mm_processor_kwargs,
pooler_config=pooler_config,
override_pooler_config=override_pooler_config,
structured_outputs_config=structured_outputs_instance,
compilation_config=compilation_config_instance,