mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 20:34:58 +08:00
Move PoolerConfig from config/__init__.py to config/pooler.py (#25181)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
1dfea5f4a9
commit
058525b997
@ -59,7 +59,7 @@ enabling the corresponding APIs:
|
|||||||
#### Predefined models
|
#### Predefined models
|
||||||
|
|
||||||
If the [Pooler][vllm.model_executor.layers.pooler.Pooler] defined by the model accepts `pooler_config`,
|
If the [Pooler][vllm.model_executor.layers.pooler.Pooler] defined by the model accepts `pooler_config`,
|
||||||
you can override some of its attributes via the `--override-pooler-config` option.
|
you can override some of its attributes via the `--pooler-config` option.
|
||||||
|
|
||||||
#### Converted models
|
#### Converted models
|
||||||
|
|
||||||
@ -75,7 +75,7 @@ the pooler assigned to each task has the following attributes by default:
|
|||||||
When loading [Sentence Transformers](https://huggingface.co/sentence-transformers) models,
|
When loading [Sentence Transformers](https://huggingface.co/sentence-transformers) models,
|
||||||
its Sentence Transformers configuration file (`modules.json`) takes priority over the model's defaults.
|
its Sentence Transformers configuration file (`modules.json`) takes priority over the model's defaults.
|
||||||
|
|
||||||
You can further customize this via the `--override-pooler-config` option,
|
You can further customize this via the `--pooler-config` option,
|
||||||
which takes priority over both the model's and Sentence Transformers's defaults.
|
which takes priority over both the model's and Sentence Transformers's defaults.
|
||||||
|
|
||||||
## Offline Inference
|
## Offline Inference
|
||||||
|
|||||||
@ -457,7 +457,7 @@ These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) A
|
|||||||
|
|
||||||
!!! note
|
!!! note
|
||||||
`ssmits/Qwen2-7B-Instruct-embed-base` has an improperly defined Sentence Transformers config.
|
`ssmits/Qwen2-7B-Instruct-embed-base` has an improperly defined Sentence Transformers config.
|
||||||
You need to manually set mean pooling by passing `--override-pooler-config '{"pooling_type": "MEAN"}'`.
|
You need to manually set mean pooling by passing `--pooler-config '{"pooling_type": "MEAN"}'`.
|
||||||
|
|
||||||
!!! note
|
!!! note
|
||||||
For `Alibaba-NLP/gte-Qwen2-*`, you need to enable `--trust-remote-code` for the correct tokenizer to be loaded.
|
For `Alibaba-NLP/gte-Qwen2-*`, you need to enable `--trust-remote-code` for the correct tokenizer to be loaded.
|
||||||
@ -552,7 +552,7 @@ If your model is not in the above list, we will try to automatically convert the
|
|||||||
|
|
||||||
!!! important
|
!!! important
|
||||||
For process-supervised reward models such as `peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly,
|
For process-supervised reward models such as `peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly,
|
||||||
e.g.: `--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`.
|
e.g.: `--pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`.
|
||||||
|
|
||||||
#### Token Classification
|
#### Token Classification
|
||||||
|
|
||||||
|
|||||||
@ -42,7 +42,7 @@ python client.py
|
|||||||
|
|
||||||
### Server Configuration
|
### Server Configuration
|
||||||
|
|
||||||
The key parameters for chunked processing are in the `--override-pooler-config`:
|
The key parameters for chunked processing are in the `--pooler-config`:
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
|
|||||||
@ -13,7 +13,7 @@ Prerequisites:
|
|||||||
|
|
||||||
# MEAN pooling (processes all chunks, recommended for complete coverage)
|
# MEAN pooling (processes all chunks, recommended for complete coverage)
|
||||||
vllm serve intfloat/multilingual-e5-large \
|
vllm serve intfloat/multilingual-e5-large \
|
||||||
--override-pooler-config \
|
--pooler-config \
|
||||||
'{"pooling_type": "MEAN", "normalize": true, ' \
|
'{"pooling_type": "MEAN", "normalize": true, ' \
|
||||||
'"enable_chunked_processing": true, "max_embed_len": 3072000}' \
|
'"enable_chunked_processing": true, "max_embed_len": 3072000}' \
|
||||||
--served-model-name multilingual-e5-large \
|
--served-model-name multilingual-e5-large \
|
||||||
@ -23,7 +23,7 @@ Prerequisites:
|
|||||||
|
|
||||||
# OR CLS pooling (native CLS within chunks, MEAN aggregation across chunks)
|
# OR CLS pooling (native CLS within chunks, MEAN aggregation across chunks)
|
||||||
vllm serve BAAI/bge-large-en-v1.5 \
|
vllm serve BAAI/bge-large-en-v1.5 \
|
||||||
--override-pooler-config \
|
--pooler-config \
|
||||||
'{"pooling_type": "CLS", "normalize": true, ' \
|
'{"pooling_type": "CLS", "normalize": true, ' \
|
||||||
'"enable_chunked_processing": true, "max_embed_len": 1048576}' \
|
'"enable_chunked_processing": true, "max_embed_len": 1048576}' \
|
||||||
--served-model-name bge-large-en-v1.5 \
|
--served-model-name bge-large-en-v1.5 \
|
||||||
|
|||||||
@ -103,7 +103,7 @@ POOLER_CONFIG="{\"pooling_type\": \"$POOLING_TYPE\", \"normalize\": true, \"enab
|
|||||||
vllm serve "$MODEL_NAME" \
|
vllm serve "$MODEL_NAME" \
|
||||||
--tensor-parallel-size "$GPU_COUNT" \
|
--tensor-parallel-size "$GPU_COUNT" \
|
||||||
--enforce-eager \
|
--enforce-eager \
|
||||||
--override-pooler-config "$POOLER_CONFIG" \
|
--pooler-config "$POOLER_CONFIG" \
|
||||||
--served-model-name ${MODEL_CODE} \
|
--served-model-name ${MODEL_CODE} \
|
||||||
--api-key "$API_KEY" \
|
--api-key "$API_KEY" \
|
||||||
--trust-remote-code \
|
--trust-remote-code \
|
||||||
|
|||||||
@ -216,7 +216,7 @@ def server_with_chunked_processing():
|
|||||||
"--enforce-eager",
|
"--enforce-eager",
|
||||||
"--max-model-len",
|
"--max-model-len",
|
||||||
"512", # Set smaller max_model_len to trigger chunking mechanism
|
"512", # Set smaller max_model_len to trigger chunking mechanism
|
||||||
'--override-pooler-config',
|
'--pooler-config',
|
||||||
('{"pooling_type": "MEAN", "normalize": true, '
|
('{"pooling_type": "MEAN", "normalize": true, '
|
||||||
'"enable_chunked_processing": true, "max_embed_len": 10000}'),
|
'"enable_chunked_processing": true, "max_embed_len": 10000}'),
|
||||||
"--gpu-memory-utilization",
|
"--gpu-memory-utilization",
|
||||||
|
|||||||
@ -58,7 +58,7 @@ def test_models(
|
|||||||
|
|
||||||
vllm_extra_kwargs = {}
|
vllm_extra_kwargs = {}
|
||||||
if model == "ssmits/Qwen2-7B-Instruct-embed-base":
|
if model == "ssmits/Qwen2-7B-Instruct-embed-base":
|
||||||
vllm_extra_kwargs["override_pooler_config"] = \
|
vllm_extra_kwargs["pooler_config"] = \
|
||||||
PoolerConfig(pooling_type="MEAN", normalize=False)
|
PoolerConfig(pooling_type="MEAN", normalize=False)
|
||||||
|
|
||||||
max_model_len: Optional[int] = 512
|
max_model_len: Optional[int] = 512
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from vllm.config.pooler import PoolerConfig
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
@ -99,7 +100,7 @@ def test_gemma_multimodal(
|
|||||||
convert="classify",
|
convert="classify",
|
||||||
load_format="auto",
|
load_format="auto",
|
||||||
hf_overrides=update_config,
|
hf_overrides=update_config,
|
||||||
override_pooler_config={"pooling_type": "LAST"},
|
pooler_config=PoolerConfig(pooling_type="LAST"),
|
||||||
max_model_len=512,
|
max_model_len=512,
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
tensor_parallel_size=1,
|
tensor_parallel_size=1,
|
||||||
|
|||||||
@ -24,18 +24,18 @@ def test_classify_models_using_activation(
|
|||||||
dtype: str,
|
dtype: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
with vllm_runner(model,
|
with vllm_runner(
|
||||||
max_model_len=512,
|
model,
|
||||||
dtype=dtype,
|
max_model_len=512,
|
||||||
override_pooler_config=PoolerConfig(
|
dtype=dtype,
|
||||||
activation=False)) as vllm_model:
|
pooler_config=PoolerConfig(activation=False)) as vllm_model:
|
||||||
wo_activation_out = vllm_model.classify(example_prompts)
|
wo_activation_out = vllm_model.classify(example_prompts)
|
||||||
|
|
||||||
with vllm_runner(model,
|
with vllm_runner(
|
||||||
max_model_len=512,
|
model,
|
||||||
dtype=dtype,
|
max_model_len=512,
|
||||||
override_pooler_config=PoolerConfig(
|
dtype=dtype,
|
||||||
activation=True)) as vllm_model:
|
pooler_config=PoolerConfig(activation=True)) as vllm_model:
|
||||||
w_activation_out = vllm_model.classify(example_prompts)
|
w_activation_out = vllm_model.classify(example_prompts)
|
||||||
|
|
||||||
for wo_activation, w_activation in zip(wo_activation_out,
|
for wo_activation, w_activation in zip(wo_activation_out,
|
||||||
@ -43,9 +43,8 @@ def test_classify_models_using_activation(
|
|||||||
wo_activation = torch.tensor(wo_activation)
|
wo_activation = torch.tensor(wo_activation)
|
||||||
w_activation = torch.tensor(w_activation)
|
w_activation = torch.tensor(w_activation)
|
||||||
|
|
||||||
assert not torch.allclose(
|
assert not torch.allclose(wo_activation, w_activation,
|
||||||
wo_activation, w_activation,
|
atol=1e-2), "pooler_config is not working"
|
||||||
atol=1e-2), "override_pooler_config is not working"
|
|
||||||
assert torch.allclose(softmax(wo_activation), w_activation,
|
assert torch.allclose(softmax(wo_activation), w_activation,
|
||||||
1e-3 if dtype == "float" else 1e-2)
|
1e-3 if dtype == "float" else 1e-2)
|
||||||
|
|
||||||
@ -65,23 +64,22 @@ def test_embed_models_using_normalize(
|
|||||||
dtype: str,
|
dtype: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
with vllm_runner(model,
|
|
||||||
max_model_len=512,
|
|
||||||
dtype=dtype,
|
|
||||||
override_pooler_config=PoolerConfig(
|
|
||||||
normalize=False)) as vllm_model:
|
|
||||||
wo_normalize = torch.tensor(vllm_model.embed(example_prompts))
|
|
||||||
|
|
||||||
with vllm_runner(
|
with vllm_runner(
|
||||||
model,
|
model,
|
||||||
max_model_len=512,
|
max_model_len=512,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
override_pooler_config=PoolerConfig(normalize=True)) as vllm_model:
|
pooler_config=PoolerConfig(normalize=False)) as vllm_model:
|
||||||
|
wo_normalize = torch.tensor(vllm_model.embed(example_prompts))
|
||||||
|
|
||||||
|
with vllm_runner(model,
|
||||||
|
max_model_len=512,
|
||||||
|
dtype=dtype,
|
||||||
|
pooler_config=PoolerConfig(normalize=True)) as vllm_model:
|
||||||
w_normalize = torch.tensor(vllm_model.embed(example_prompts))
|
w_normalize = torch.tensor(vllm_model.embed(example_prompts))
|
||||||
|
|
||||||
assert not torch.allclose(
|
assert not torch.allclose(
|
||||||
wo_normalize, w_normalize,
|
wo_normalize, w_normalize,
|
||||||
atol=1e-2), "override_pooler_config normalize is not working"
|
atol=1e-2), "pooler_config normalize is not working"
|
||||||
assert torch.allclose(
|
assert torch.allclose(
|
||||||
F.normalize(wo_normalize, p=2, dim=-1), w_normalize,
|
F.normalize(wo_normalize, p=2, dim=-1), w_normalize,
|
||||||
atol=1e-2), "w_normal should be close to normal(wo_normal)."
|
atol=1e-2), "w_normal should be close to normal(wo_normal)."
|
||||||
@ -102,18 +100,16 @@ def test_reward_models_using_softmax(
|
|||||||
dtype: str,
|
dtype: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
with vllm_runner(
|
with vllm_runner(model,
|
||||||
model,
|
max_model_len=1024,
|
||||||
max_model_len=1024,
|
dtype=dtype,
|
||||||
dtype=dtype,
|
pooler_config=PoolerConfig(softmax=False)) as vllm_model:
|
||||||
override_pooler_config=PoolerConfig(softmax=False)) as vllm_model:
|
|
||||||
wo_softmax = vllm_model.encode(example_prompts)
|
wo_softmax = vllm_model.encode(example_prompts)
|
||||||
|
|
||||||
with vllm_runner(
|
with vllm_runner(model,
|
||||||
model,
|
max_model_len=1024,
|
||||||
max_model_len=1024,
|
dtype=dtype,
|
||||||
dtype=dtype,
|
pooler_config=PoolerConfig(softmax=True)) as vllm_model:
|
||||||
override_pooler_config=PoolerConfig(softmax=True)) as vllm_model:
|
|
||||||
w_softmax = vllm_model.encode(example_prompts)
|
w_softmax = vllm_model.encode(example_prompts)
|
||||||
|
|
||||||
for wo, w in zip(wo_softmax, w_softmax):
|
for wo, w in zip(wo_softmax, w_softmax):
|
||||||
@ -121,7 +117,7 @@ def test_reward_models_using_softmax(
|
|||||||
w = torch.tensor(w)
|
w = torch.tensor(w)
|
||||||
|
|
||||||
assert not torch.allclose(
|
assert not torch.allclose(
|
||||||
wo, w, atol=1e-2), "override_pooler_config softmax is not working"
|
wo, w, atol=1e-2), "pooler_config softmax is not working"
|
||||||
assert torch.allclose(
|
assert torch.allclose(
|
||||||
softmax(wo), w,
|
softmax(wo), w,
|
||||||
atol=1e-2), "w_softmax should be close to softmax(wo_softmax)."
|
atol=1e-2), "w_softmax should be close to softmax(wo_softmax)."
|
||||||
@ -207,25 +207,19 @@ def test_get_pooling_config():
|
|||||||
model_id = "sentence-transformers/all-MiniLM-L12-v2"
|
model_id = "sentence-transformers/all-MiniLM-L12-v2"
|
||||||
model_config = ModelConfig(model_id)
|
model_config = ModelConfig(model_id)
|
||||||
|
|
||||||
pooling_config = model_config._init_pooler_config()
|
assert model_config.pooler_config is not None
|
||||||
assert pooling_config is not None
|
assert model_config.pooler_config.normalize
|
||||||
|
assert model_config.pooler_config.pooling_type == PoolingType.MEAN.name
|
||||||
assert pooling_config.normalize
|
|
||||||
assert pooling_config.pooling_type == PoolingType.MEAN.name
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(current_platform.is_rocm(),
|
@pytest.mark.skipif(current_platform.is_rocm(),
|
||||||
reason="Xformers backend is not supported on ROCm.")
|
reason="Xformers backend is not supported on ROCm.")
|
||||||
def test_get_pooling_config_from_args():
|
def test_get_pooling_config_from_args():
|
||||||
model_id = "sentence-transformers/all-MiniLM-L12-v2"
|
model_id = "sentence-transformers/all-MiniLM-L12-v2"
|
||||||
model_config = ModelConfig(model_id)
|
pooler_config = PoolerConfig(pooling_type="CLS", normalize=True)
|
||||||
|
model_config = ModelConfig(model_id, pooler_config=pooler_config)
|
||||||
|
|
||||||
override_pooler_config = PoolerConfig(pooling_type='CLS', normalize=True)
|
assert asdict(model_config.pooler_config) == asdict(pooler_config)
|
||||||
model_config.override_pooler_config = override_pooler_config
|
|
||||||
|
|
||||||
pooling_config = model_config._init_pooler_config()
|
|
||||||
assert pooling_config is not None
|
|
||||||
assert asdict(pooling_config) == asdict(override_pooler_config)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|||||||
@ -40,6 +40,7 @@ from vllm.config.multimodal import (MMCacheType, MMEncoderTPMode,
|
|||||||
MultiModalConfig)
|
MultiModalConfig)
|
||||||
from vllm.config.parallel import (DistributedExecutorBackend, EPLBConfig,
|
from vllm.config.parallel import (DistributedExecutorBackend, EPLBConfig,
|
||||||
ParallelConfig)
|
ParallelConfig)
|
||||||
|
from vllm.config.pooler import PoolerConfig
|
||||||
from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy
|
from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy
|
||||||
from vllm.config.speculative import SpeculativeConfig
|
from vllm.config.speculative import SpeculativeConfig
|
||||||
from vllm.config.structured_outputs import StructuredOutputsConfig
|
from vllm.config.structured_outputs import StructuredOutputsConfig
|
||||||
@ -406,13 +407,6 @@ class ModelConfig:
|
|||||||
hf_overrides: HfOverrides = field(default_factory=dict)
|
hf_overrides: HfOverrides = field(default_factory=dict)
|
||||||
"""If a dictionary, contains arguments to be forwarded to the Hugging Face
|
"""If a dictionary, contains arguments to be forwarded to the Hugging Face
|
||||||
config. If a callable, it is called to update the HuggingFace config."""
|
config. If a callable, it is called to update the HuggingFace config."""
|
||||||
pooler_config: Optional["PoolerConfig"] = field(init=False)
|
|
||||||
"""Pooler config which controls the behaviour of output pooling in pooling
|
|
||||||
models."""
|
|
||||||
override_pooler_config: Optional[Union[dict, "PoolerConfig"]] = None
|
|
||||||
"""Initialize non-default pooling config or override default pooling config
|
|
||||||
for the pooling model. e.g. `{"pooling_type": "mean", "normalize": false}`.
|
|
||||||
"""
|
|
||||||
logits_processor_pattern: Optional[str] = None
|
logits_processor_pattern: Optional[str] = None
|
||||||
"""Optional regex pattern specifying valid logits processor qualified names
|
"""Optional regex pattern specifying valid logits processor qualified names
|
||||||
that can be passed with the `logits_processors` extra completion argument.
|
that can be passed with the `logits_processors` extra completion argument.
|
||||||
@ -448,6 +442,14 @@ class ModelConfig:
|
|||||||
io_processor_plugin: Optional[str] = None
|
io_processor_plugin: Optional[str] = None
|
||||||
"""IOProcessor plugin name to load at model startup"""
|
"""IOProcessor plugin name to load at model startup"""
|
||||||
|
|
||||||
|
# Pooler config
|
||||||
|
pooler_config: Optional[PoolerConfig] = None
|
||||||
|
"""Pooler config which controls the behaviour of output pooling in pooling
|
||||||
|
models."""
|
||||||
|
override_pooler_config: Optional[Union[dict, PoolerConfig]] = None
|
||||||
|
"""[DEPRECATED] Use `pooler_config` instead. This field will be removed in
|
||||||
|
v0.12.0 or v1.0.0, whichever is sooner."""
|
||||||
|
|
||||||
# Multimodal config and init vars
|
# Multimodal config and init vars
|
||||||
multimodal_config: Optional[MultiModalConfig] = None
|
multimodal_config: Optional[MultiModalConfig] = None
|
||||||
"""Configuration for multimodal model. If `None`, this will be inferred
|
"""Configuration for multimodal model. If `None`, this will be inferred
|
||||||
@ -709,7 +711,33 @@ class ModelConfig:
|
|||||||
self._architecture = arch
|
self._architecture = arch
|
||||||
logger.info("Resolved architecture: %s", arch)
|
logger.info("Resolved architecture: %s", arch)
|
||||||
|
|
||||||
self.pooler_config = self._init_pooler_config()
|
# Init pooler config if needed
|
||||||
|
if self.runner_type == "pooling":
|
||||||
|
if self.override_pooler_config is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
"`override_pooler_config` is deprecated and will be "
|
||||||
|
"removed in v0.12.0 or v1.0.0, whichever is sooner. "
|
||||||
|
"Please use `pooler_config` instead.")
|
||||||
|
|
||||||
|
if isinstance(self.override_pooler_config, dict):
|
||||||
|
self.pooler_config = PoolerConfig(
|
||||||
|
**self.override_pooler_config)
|
||||||
|
else:
|
||||||
|
self.pooler_config = self.override_pooler_config
|
||||||
|
|
||||||
|
if self.pooler_config is None:
|
||||||
|
self.pooler_config = PoolerConfig()
|
||||||
|
|
||||||
|
base_config = get_pooling_config(self.model, self.revision)
|
||||||
|
if base_config is not None:
|
||||||
|
# Only set values that are not overridden by the user
|
||||||
|
for k, v in base_config.items():
|
||||||
|
if getattr(self.pooler_config, k) is None:
|
||||||
|
setattr(self.pooler_config, k, v)
|
||||||
|
|
||||||
|
default_pooling_type = self._model_info.default_pooling_type
|
||||||
|
if self.pooler_config.pooling_type is None:
|
||||||
|
self.pooler_config.pooling_type = default_pooling_type
|
||||||
|
|
||||||
self.dtype: torch.dtype = _get_and_verify_dtype(
|
self.dtype: torch.dtype = _get_and_verify_dtype(
|
||||||
self.model,
|
self.model,
|
||||||
@ -869,29 +897,6 @@ class ModelConfig:
|
|||||||
return get_sentence_transformer_tokenizer_config(
|
return get_sentence_transformer_tokenizer_config(
|
||||||
self.model, self.revision)
|
self.model, self.revision)
|
||||||
|
|
||||||
def _init_pooler_config(self) -> Optional["PoolerConfig"]:
|
|
||||||
if self.runner_type == "pooling":
|
|
||||||
if isinstance(self.override_pooler_config, dict):
|
|
||||||
self.override_pooler_config = PoolerConfig(
|
|
||||||
**self.override_pooler_config)
|
|
||||||
|
|
||||||
pooler_config = self.override_pooler_config or PoolerConfig()
|
|
||||||
|
|
||||||
base_config = get_pooling_config(self.model, self.revision)
|
|
||||||
if base_config is not None:
|
|
||||||
# Only set values that are not overridden by the user
|
|
||||||
for k, v in base_config.items():
|
|
||||||
if getattr(pooler_config, k) is None:
|
|
||||||
setattr(pooler_config, k, v)
|
|
||||||
|
|
||||||
default_pooling_type = self._model_info.default_pooling_type
|
|
||||||
if pooler_config.pooling_type is None:
|
|
||||||
pooler_config.pooling_type = default_pooling_type
|
|
||||||
|
|
||||||
return pooler_config
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _verify_tokenizer_mode(self) -> None:
|
def _verify_tokenizer_mode(self) -> None:
|
||||||
tokenizer_mode = cast(TokenizerMode, self.tokenizer_mode.lower())
|
tokenizer_mode = cast(TokenizerMode, self.tokenizer_mode.lower())
|
||||||
if tokenizer_mode not in get_args(TokenizerMode):
|
if tokenizer_mode not in get_args(TokenizerMode):
|
||||||
@ -1833,94 +1838,6 @@ class DeviceConfig:
|
|||||||
self.device = torch.device(self.device_type)
|
self.device = torch.device(self.device_type)
|
||||||
|
|
||||||
|
|
||||||
@config
|
|
||||||
@dataclass
|
|
||||||
class PoolerConfig:
|
|
||||||
"""Controls the behavior of output pooling in pooling models."""
|
|
||||||
|
|
||||||
pooling_type: Optional[str] = None
|
|
||||||
"""
|
|
||||||
The pooling method of the pooling model. This should be a key in
|
|
||||||
[`vllm.model_executor.layers.pooler.PoolingType`][].
|
|
||||||
"""
|
|
||||||
|
|
||||||
## for embeddings models
|
|
||||||
normalize: Optional[bool] = None
|
|
||||||
"""
|
|
||||||
Whether to normalize the embeddings outputs. Defaults to True.
|
|
||||||
"""
|
|
||||||
dimensions: Optional[int] = None
|
|
||||||
"""
|
|
||||||
Reduce the dimensions of embeddings if model
|
|
||||||
support matryoshka representation. Defaults to None.
|
|
||||||
"""
|
|
||||||
enable_chunked_processing: Optional[bool] = None
|
|
||||||
"""
|
|
||||||
Whether to enable chunked processing for long inputs that exceed the model's
|
|
||||||
maximum position embeddings. When enabled, long inputs will be split into
|
|
||||||
chunks, processed separately, and then aggregated using weighted averaging.
|
|
||||||
This allows embedding models to handle arbitrarily long text without CUDA
|
|
||||||
errors. Defaults to False.
|
|
||||||
"""
|
|
||||||
max_embed_len: Optional[int] = None
|
|
||||||
"""
|
|
||||||
Maximum input length allowed for embedding generation. When set, allows
|
|
||||||
inputs longer than max_embed_len to be accepted for embedding models.
|
|
||||||
When an input exceeds max_embed_len, it will be handled according to
|
|
||||||
the original max_model_len validation logic.
|
|
||||||
Defaults to None (i.e. set to max_model_len).
|
|
||||||
"""
|
|
||||||
|
|
||||||
## for classification models
|
|
||||||
activation: Optional[bool] = None
|
|
||||||
"""
|
|
||||||
Whether to apply activation function to the classification outputs.
|
|
||||||
Defaults to True.
|
|
||||||
"""
|
|
||||||
logit_bias: Optional[float] = None
|
|
||||||
"""
|
|
||||||
If provided, apply classification logit biases. Defaults to None.
|
|
||||||
"""
|
|
||||||
|
|
||||||
## for reward models
|
|
||||||
softmax: Optional[bool] = None
|
|
||||||
"""
|
|
||||||
Whether to apply softmax to the reward outputs.
|
|
||||||
Defaults to True.
|
|
||||||
"""
|
|
||||||
step_tag_id: Optional[int] = None
|
|
||||||
"""
|
|
||||||
If set, only the score corresponding to the ``step_tag_id`` in the
|
|
||||||
generated sentence should be returned. Otherwise, the scores for all tokens
|
|
||||||
are returned.
|
|
||||||
"""
|
|
||||||
returned_token_ids: Optional[list[int]] = None
|
|
||||||
"""
|
|
||||||
A list of indices for the vocabulary dimensions to be extracted,
|
|
||||||
such as the token IDs of ``good_token`` and ``bad_token`` in the
|
|
||||||
``math-shepherd-mistral-7b-prm`` model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
|
||||||
"""
|
|
||||||
WARNING: Whenever a new field is added to this config,
|
|
||||||
ensure that it is included in the factors list if
|
|
||||||
it affects the computation graph.
|
|
||||||
|
|
||||||
Provide a hash that uniquely identifies all the configs
|
|
||||||
that affect the structure of the computation
|
|
||||||
graph from input ids/embeddings to the final hidden states,
|
|
||||||
excluding anything before input ids/embeddings and after
|
|
||||||
the final hidden states.
|
|
||||||
"""
|
|
||||||
# no factors to consider.
|
|
||||||
# this config will not affect the computation graph.
|
|
||||||
factors: list[Any] = []
|
|
||||||
hash_str = hashlib.md5(str(factors).encode(),
|
|
||||||
usedforsecurity=False).hexdigest()
|
|
||||||
return hash_str
|
|
||||||
|
|
||||||
|
|
||||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||||
"half": torch.float16,
|
"half": torch.float16,
|
||||||
"float16": torch.float16,
|
"float16": torch.float16,
|
||||||
|
|||||||
97
vllm/config/pooler.py
Normal file
97
vllm/config/pooler.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from pydantic.dataclasses import dataclass
|
||||||
|
|
||||||
|
from vllm.config.utils import config
|
||||||
|
|
||||||
|
|
||||||
|
@config
|
||||||
|
@dataclass
|
||||||
|
class PoolerConfig:
|
||||||
|
"""Controls the behavior of output pooling in pooling models."""
|
||||||
|
|
||||||
|
pooling_type: Optional[str] = None
|
||||||
|
"""
|
||||||
|
The pooling method of the pooling model. This should be a key in
|
||||||
|
[`vllm.model_executor.layers.pooler.PoolingType`][].
|
||||||
|
"""
|
||||||
|
|
||||||
|
## for embeddings models
|
||||||
|
normalize: Optional[bool] = None
|
||||||
|
"""
|
||||||
|
Whether to normalize the embeddings outputs. Defaults to True.
|
||||||
|
"""
|
||||||
|
dimensions: Optional[int] = None
|
||||||
|
"""
|
||||||
|
Reduce the dimensions of embeddings if model
|
||||||
|
support matryoshka representation. Defaults to None.
|
||||||
|
"""
|
||||||
|
enable_chunked_processing: Optional[bool] = None
|
||||||
|
"""
|
||||||
|
Whether to enable chunked processing for long inputs that exceed the model's
|
||||||
|
maximum position embeddings. When enabled, long inputs will be split into
|
||||||
|
chunks, processed separately, and then aggregated using weighted averaging.
|
||||||
|
This allows embedding models to handle arbitrarily long text without CUDA
|
||||||
|
errors. Defaults to False.
|
||||||
|
"""
|
||||||
|
max_embed_len: Optional[int] = None
|
||||||
|
"""
|
||||||
|
Maximum input length allowed for embedding generation. When set, allows
|
||||||
|
inputs longer than max_embed_len to be accepted for embedding models.
|
||||||
|
When an input exceeds max_embed_len, it will be handled according to
|
||||||
|
the original max_model_len validation logic.
|
||||||
|
Defaults to None (i.e. set to max_model_len).
|
||||||
|
"""
|
||||||
|
|
||||||
|
## for classification models
|
||||||
|
activation: Optional[bool] = None
|
||||||
|
"""
|
||||||
|
Whether to apply activation function to the classification outputs.
|
||||||
|
Defaults to True.
|
||||||
|
"""
|
||||||
|
logit_bias: Optional[float] = None
|
||||||
|
"""
|
||||||
|
If provided, apply classification logit biases. Defaults to None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
## for reward models
|
||||||
|
softmax: Optional[bool] = None
|
||||||
|
"""
|
||||||
|
Whether to apply softmax to the reward outputs.
|
||||||
|
Defaults to True.
|
||||||
|
"""
|
||||||
|
step_tag_id: Optional[int] = None
|
||||||
|
"""
|
||||||
|
If set, only the score corresponding to the ``step_tag_id`` in the
|
||||||
|
generated sentence should be returned. Otherwise, the scores for all tokens
|
||||||
|
are returned.
|
||||||
|
"""
|
||||||
|
returned_token_ids: Optional[list[int]] = None
|
||||||
|
"""
|
||||||
|
A list of indices for the vocabulary dimensions to be extracted,
|
||||||
|
such as the token IDs of ``good_token`` and ``bad_token`` in the
|
||||||
|
``math-shepherd-mistral-7b-prm`` model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def compute_hash(self) -> str:
|
||||||
|
"""
|
||||||
|
WARNING: Whenever a new field is added to this config,
|
||||||
|
ensure that it is included in the factors list if
|
||||||
|
it affects the computation graph.
|
||||||
|
|
||||||
|
Provide a hash that uniquely identifies all the configs
|
||||||
|
that affect the structure of the computation
|
||||||
|
graph from input ids/embeddings to the final hidden states,
|
||||||
|
excluding anything before input ids/embeddings and after
|
||||||
|
the final hidden states.
|
||||||
|
"""
|
||||||
|
# no factors to consider.
|
||||||
|
# this config will not affect the computation graph.
|
||||||
|
factors: list[Any] = []
|
||||||
|
hash_str = hashlib.md5(str(factors).encode(),
|
||||||
|
usedforsecurity=False).hexdigest()
|
||||||
|
return hash_str
|
||||||
@ -441,6 +441,7 @@ class EngineArgs:
|
|||||||
scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
|
scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
|
||||||
scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls
|
scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls
|
||||||
|
|
||||||
|
pooler_config: Optional[PoolerConfig] = ModelConfig.pooler_config
|
||||||
override_pooler_config: Optional[Union[dict, PoolerConfig]] = \
|
override_pooler_config: Optional[Union[dict, PoolerConfig]] = \
|
||||||
ModelConfig.override_pooler_config
|
ModelConfig.override_pooler_config
|
||||||
compilation_config: CompilationConfig = \
|
compilation_config: CompilationConfig = \
|
||||||
@ -579,8 +580,11 @@ class EngineArgs:
|
|||||||
help=model_kwargs["hf_token"]["help"])
|
help=model_kwargs["hf_token"]["help"])
|
||||||
model_group.add_argument("--hf-overrides",
|
model_group.add_argument("--hf-overrides",
|
||||||
**model_kwargs["hf_overrides"])
|
**model_kwargs["hf_overrides"])
|
||||||
|
model_group.add_argument("--pooler-config",
|
||||||
|
**model_kwargs["pooler_config"])
|
||||||
model_group.add_argument("--override-pooler-config",
|
model_group.add_argument("--override-pooler-config",
|
||||||
**model_kwargs["override_pooler_config"])
|
**model_kwargs["override_pooler_config"],
|
||||||
|
deprecated=True)
|
||||||
model_group.add_argument("--logits-processor-pattern",
|
model_group.add_argument("--logits-processor-pattern",
|
||||||
**model_kwargs["logits_processor_pattern"])
|
**model_kwargs["logits_processor_pattern"])
|
||||||
model_group.add_argument("--generation-config",
|
model_group.add_argument("--generation-config",
|
||||||
@ -1031,6 +1035,7 @@ class EngineArgs:
|
|||||||
mm_shm_cache_max_object_size_mb=self.
|
mm_shm_cache_max_object_size_mb=self.
|
||||||
mm_shm_cache_max_object_size_mb,
|
mm_shm_cache_max_object_size_mb,
|
||||||
mm_encoder_tp_mode=self.mm_encoder_tp_mode,
|
mm_encoder_tp_mode=self.mm_encoder_tp_mode,
|
||||||
|
pooler_config=self.pooler_config,
|
||||||
override_pooler_config=self.override_pooler_config,
|
override_pooler_config=self.override_pooler_config,
|
||||||
logits_processor_pattern=self.logits_processor_pattern,
|
logits_processor_pattern=self.logits_processor_pattern,
|
||||||
generation_config=self.generation_config,
|
generation_config=self.generation_config,
|
||||||
|
|||||||
@ -151,9 +151,11 @@ class LLM:
|
|||||||
multi-modal processor obtained from `AutoProcessor.from_pretrained`.
|
multi-modal processor obtained from `AutoProcessor.from_pretrained`.
|
||||||
The available overrides depend on the model that is being run.
|
The available overrides depend on the model that is being run.
|
||||||
For example, for Phi-3-Vision: `{"num_crops": 4}`.
|
For example, for Phi-3-Vision: `{"num_crops": 4}`.
|
||||||
override_pooler_config: Initialize non-default pooling config or
|
pooler_config: Initialize non-default pooling config for the pooling
|
||||||
override default pooling config for the pooling model.
|
model. e.g. `PoolerConfig(pooling_type="mean", normalize=False)`.
|
||||||
e.g. `PoolerConfig(pooling_type="mean", normalize=False)`.
|
override_pooler_config: [DEPRECATED] Use `pooler_config` instead. This
|
||||||
|
argument is deprecated and will be removed in v0.12.0 or v1.0.0,
|
||||||
|
whichever is sooner.
|
||||||
compilation_config: Either an integer or a dictionary. If it is an
|
compilation_config: Either an integer or a dictionary. If it is an
|
||||||
integer, it is used as the level of compilation optimization. If it
|
integer, it is used as the level of compilation optimization. If it
|
||||||
is a dictionary, it can specify the full compilation configuration.
|
is a dictionary, it can specify the full compilation configuration.
|
||||||
@ -191,6 +193,7 @@ class LLM:
|
|||||||
hf_token: Optional[Union[bool, str]] = None,
|
hf_token: Optional[Union[bool, str]] = None,
|
||||||
hf_overrides: Optional[HfOverrides] = None,
|
hf_overrides: Optional[HfOverrides] = None,
|
||||||
mm_processor_kwargs: Optional[dict[str, Any]] = None,
|
mm_processor_kwargs: Optional[dict[str, Any]] = None,
|
||||||
|
pooler_config: Optional[PoolerConfig] = None,
|
||||||
override_pooler_config: Optional[PoolerConfig] = None,
|
override_pooler_config: Optional[PoolerConfig] = None,
|
||||||
structured_outputs_config: Optional[Union[dict[
|
structured_outputs_config: Optional[Union[dict[
|
||||||
str, Any], StructuredOutputsConfig]] = None,
|
str, Any], StructuredOutputsConfig]] = None,
|
||||||
@ -288,6 +291,7 @@ class LLM:
|
|||||||
hf_token=hf_token,
|
hf_token=hf_token,
|
||||||
hf_overrides=hf_overrides,
|
hf_overrides=hf_overrides,
|
||||||
mm_processor_kwargs=mm_processor_kwargs,
|
mm_processor_kwargs=mm_processor_kwargs,
|
||||||
|
pooler_config=pooler_config,
|
||||||
override_pooler_config=override_pooler_config,
|
override_pooler_config=override_pooler_config,
|
||||||
structured_outputs_config=structured_outputs_instance,
|
structured_outputs_config=structured_outputs_instance,
|
||||||
compilation_config=compilation_config_instance,
|
compilation_config=compilation_config_instance,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user