mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 06:34:58 +08:00
Move PoolerConfig from config/__init__.py to config/pooler.py (#25181)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
1dfea5f4a9
commit
058525b997
@ -59,7 +59,7 @@ enabling the corresponding APIs:
|
||||
#### Predefined models
|
||||
|
||||
If the [Pooler][vllm.model_executor.layers.pooler.Pooler] defined by the model accepts `pooler_config`,
|
||||
you can override some of its attributes via the `--override-pooler-config` option.
|
||||
you can override some of its attributes via the `--pooler-config` option.
|
||||
|
||||
#### Converted models
|
||||
|
||||
@ -75,7 +75,7 @@ the pooler assigned to each task has the following attributes by default:
|
||||
When loading [Sentence Transformers](https://huggingface.co/sentence-transformers) models,
|
||||
its Sentence Transformers configuration file (`modules.json`) takes priority over the model's defaults.
|
||||
|
||||
You can further customize this via the `--override-pooler-config` option,
|
||||
You can further customize this via the `--pooler-config` option,
|
||||
which takes priority over both the model's and Sentence Transformers's defaults.
|
||||
|
||||
## Offline Inference
|
||||
|
||||
@ -457,7 +457,7 @@ These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) A
|
||||
|
||||
!!! note
|
||||
`ssmits/Qwen2-7B-Instruct-embed-base` has an improperly defined Sentence Transformers config.
|
||||
You need to manually set mean pooling by passing `--override-pooler-config '{"pooling_type": "MEAN"}'`.
|
||||
You need to manually set mean pooling by passing `--pooler-config '{"pooling_type": "MEAN"}'`.
|
||||
|
||||
!!! note
|
||||
For `Alibaba-NLP/gte-Qwen2-*`, you need to enable `--trust-remote-code` for the correct tokenizer to be loaded.
|
||||
@ -552,7 +552,7 @@ If your model is not in the above list, we will try to automatically convert the
|
||||
|
||||
!!! important
|
||||
For process-supervised reward models such as `peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly,
|
||||
e.g.: `--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`.
|
||||
e.g.: `--pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`.
|
||||
|
||||
#### Token Classification
|
||||
|
||||
|
||||
@ -42,7 +42,7 @@ python client.py
|
||||
|
||||
### Server Configuration
|
||||
|
||||
The key parameters for chunked processing are in the `--override-pooler-config`:
|
||||
The key parameters for chunked processing are in the `--pooler-config`:
|
||||
|
||||
```json
|
||||
{
|
||||
|
||||
@ -13,7 +13,7 @@ Prerequisites:
|
||||
|
||||
# MEAN pooling (processes all chunks, recommended for complete coverage)
|
||||
vllm serve intfloat/multilingual-e5-large \
|
||||
--override-pooler-config \
|
||||
--pooler-config \
|
||||
'{"pooling_type": "MEAN", "normalize": true, ' \
|
||||
'"enable_chunked_processing": true, "max_embed_len": 3072000}' \
|
||||
--served-model-name multilingual-e5-large \
|
||||
@ -23,7 +23,7 @@ Prerequisites:
|
||||
|
||||
# OR CLS pooling (native CLS within chunks, MEAN aggregation across chunks)
|
||||
vllm serve BAAI/bge-large-en-v1.5 \
|
||||
--override-pooler-config \
|
||||
--pooler-config \
|
||||
'{"pooling_type": "CLS", "normalize": true, ' \
|
||||
'"enable_chunked_processing": true, "max_embed_len": 1048576}' \
|
||||
--served-model-name bge-large-en-v1.5 \
|
||||
|
||||
@ -103,7 +103,7 @@ POOLER_CONFIG="{\"pooling_type\": \"$POOLING_TYPE\", \"normalize\": true, \"enab
|
||||
vllm serve "$MODEL_NAME" \
|
||||
--tensor-parallel-size "$GPU_COUNT" \
|
||||
--enforce-eager \
|
||||
--override-pooler-config "$POOLER_CONFIG" \
|
||||
--pooler-config "$POOLER_CONFIG" \
|
||||
--served-model-name ${MODEL_CODE} \
|
||||
--api-key "$API_KEY" \
|
||||
--trust-remote-code \
|
||||
|
||||
@ -216,7 +216,7 @@ def server_with_chunked_processing():
|
||||
"--enforce-eager",
|
||||
"--max-model-len",
|
||||
"512", # Set smaller max_model_len to trigger chunking mechanism
|
||||
'--override-pooler-config',
|
||||
'--pooler-config',
|
||||
('{"pooling_type": "MEAN", "normalize": true, '
|
||||
'"enable_chunked_processing": true, "max_embed_len": 10000}'),
|
||||
"--gpu-memory-utilization",
|
||||
|
||||
@ -58,7 +58,7 @@ def test_models(
|
||||
|
||||
vllm_extra_kwargs = {}
|
||||
if model == "ssmits/Qwen2-7B-Instruct-embed-base":
|
||||
vllm_extra_kwargs["override_pooler_config"] = \
|
||||
vllm_extra_kwargs["pooler_config"] = \
|
||||
PoolerConfig(pooling_type="MEAN", normalize=False)
|
||||
|
||||
max_model_len: Optional[int] = 512
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.config.pooler import PoolerConfig
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@ -99,7 +100,7 @@ def test_gemma_multimodal(
|
||||
convert="classify",
|
||||
load_format="auto",
|
||||
hf_overrides=update_config,
|
||||
override_pooler_config={"pooling_type": "LAST"},
|
||||
pooler_config=PoolerConfig(pooling_type="LAST"),
|
||||
max_model_len=512,
|
||||
enforce_eager=True,
|
||||
tensor_parallel_size=1,
|
||||
|
||||
@ -24,18 +24,18 @@ def test_classify_models_using_activation(
|
||||
dtype: str,
|
||||
) -> None:
|
||||
|
||||
with vllm_runner(model,
|
||||
max_model_len=512,
|
||||
dtype=dtype,
|
||||
override_pooler_config=PoolerConfig(
|
||||
activation=False)) as vllm_model:
|
||||
with vllm_runner(
|
||||
model,
|
||||
max_model_len=512,
|
||||
dtype=dtype,
|
||||
pooler_config=PoolerConfig(activation=False)) as vllm_model:
|
||||
wo_activation_out = vllm_model.classify(example_prompts)
|
||||
|
||||
with vllm_runner(model,
|
||||
max_model_len=512,
|
||||
dtype=dtype,
|
||||
override_pooler_config=PoolerConfig(
|
||||
activation=True)) as vllm_model:
|
||||
with vllm_runner(
|
||||
model,
|
||||
max_model_len=512,
|
||||
dtype=dtype,
|
||||
pooler_config=PoolerConfig(activation=True)) as vllm_model:
|
||||
w_activation_out = vllm_model.classify(example_prompts)
|
||||
|
||||
for wo_activation, w_activation in zip(wo_activation_out,
|
||||
@ -43,9 +43,8 @@ def test_classify_models_using_activation(
|
||||
wo_activation = torch.tensor(wo_activation)
|
||||
w_activation = torch.tensor(w_activation)
|
||||
|
||||
assert not torch.allclose(
|
||||
wo_activation, w_activation,
|
||||
atol=1e-2), "override_pooler_config is not working"
|
||||
assert not torch.allclose(wo_activation, w_activation,
|
||||
atol=1e-2), "pooler_config is not working"
|
||||
assert torch.allclose(softmax(wo_activation), w_activation,
|
||||
1e-3 if dtype == "float" else 1e-2)
|
||||
|
||||
@ -65,23 +64,22 @@ def test_embed_models_using_normalize(
|
||||
dtype: str,
|
||||
) -> None:
|
||||
|
||||
with vllm_runner(model,
|
||||
max_model_len=512,
|
||||
dtype=dtype,
|
||||
override_pooler_config=PoolerConfig(
|
||||
normalize=False)) as vllm_model:
|
||||
wo_normalize = torch.tensor(vllm_model.embed(example_prompts))
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
max_model_len=512,
|
||||
dtype=dtype,
|
||||
override_pooler_config=PoolerConfig(normalize=True)) as vllm_model:
|
||||
pooler_config=PoolerConfig(normalize=False)) as vllm_model:
|
||||
wo_normalize = torch.tensor(vllm_model.embed(example_prompts))
|
||||
|
||||
with vllm_runner(model,
|
||||
max_model_len=512,
|
||||
dtype=dtype,
|
||||
pooler_config=PoolerConfig(normalize=True)) as vllm_model:
|
||||
w_normalize = torch.tensor(vllm_model.embed(example_prompts))
|
||||
|
||||
assert not torch.allclose(
|
||||
wo_normalize, w_normalize,
|
||||
atol=1e-2), "override_pooler_config normalize is not working"
|
||||
atol=1e-2), "pooler_config normalize is not working"
|
||||
assert torch.allclose(
|
||||
F.normalize(wo_normalize, p=2, dim=-1), w_normalize,
|
||||
atol=1e-2), "w_normal should be close to normal(wo_normal)."
|
||||
@ -102,18 +100,16 @@ def test_reward_models_using_softmax(
|
||||
dtype: str,
|
||||
) -> None:
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
max_model_len=1024,
|
||||
dtype=dtype,
|
||||
override_pooler_config=PoolerConfig(softmax=False)) as vllm_model:
|
||||
with vllm_runner(model,
|
||||
max_model_len=1024,
|
||||
dtype=dtype,
|
||||
pooler_config=PoolerConfig(softmax=False)) as vllm_model:
|
||||
wo_softmax = vllm_model.encode(example_prompts)
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
max_model_len=1024,
|
||||
dtype=dtype,
|
||||
override_pooler_config=PoolerConfig(softmax=True)) as vllm_model:
|
||||
with vllm_runner(model,
|
||||
max_model_len=1024,
|
||||
dtype=dtype,
|
||||
pooler_config=PoolerConfig(softmax=True)) as vllm_model:
|
||||
w_softmax = vllm_model.encode(example_prompts)
|
||||
|
||||
for wo, w in zip(wo_softmax, w_softmax):
|
||||
@ -121,7 +117,7 @@ def test_reward_models_using_softmax(
|
||||
w = torch.tensor(w)
|
||||
|
||||
assert not torch.allclose(
|
||||
wo, w, atol=1e-2), "override_pooler_config softmax is not working"
|
||||
wo, w, atol=1e-2), "pooler_config softmax is not working"
|
||||
assert torch.allclose(
|
||||
softmax(wo), w,
|
||||
atol=1e-2), "w_softmax should be close to softmax(wo_softmax)."
|
||||
@ -207,25 +207,19 @@ def test_get_pooling_config():
|
||||
model_id = "sentence-transformers/all-MiniLM-L12-v2"
|
||||
model_config = ModelConfig(model_id)
|
||||
|
||||
pooling_config = model_config._init_pooler_config()
|
||||
assert pooling_config is not None
|
||||
|
||||
assert pooling_config.normalize
|
||||
assert pooling_config.pooling_type == PoolingType.MEAN.name
|
||||
assert model_config.pooler_config is not None
|
||||
assert model_config.pooler_config.normalize
|
||||
assert model_config.pooler_config.pooling_type == PoolingType.MEAN.name
|
||||
|
||||
|
||||
@pytest.mark.skipif(current_platform.is_rocm(),
|
||||
reason="Xformers backend is not supported on ROCm.")
|
||||
def test_get_pooling_config_from_args():
|
||||
model_id = "sentence-transformers/all-MiniLM-L12-v2"
|
||||
model_config = ModelConfig(model_id)
|
||||
pooler_config = PoolerConfig(pooling_type="CLS", normalize=True)
|
||||
model_config = ModelConfig(model_id, pooler_config=pooler_config)
|
||||
|
||||
override_pooler_config = PoolerConfig(pooling_type='CLS', normalize=True)
|
||||
model_config.override_pooler_config = override_pooler_config
|
||||
|
||||
pooling_config = model_config._init_pooler_config()
|
||||
assert pooling_config is not None
|
||||
assert asdict(pooling_config) == asdict(override_pooler_config)
|
||||
assert asdict(model_config.pooler_config) == asdict(pooler_config)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@ -40,6 +40,7 @@ from vllm.config.multimodal import (MMCacheType, MMEncoderTPMode,
|
||||
MultiModalConfig)
|
||||
from vllm.config.parallel import (DistributedExecutorBackend, EPLBConfig,
|
||||
ParallelConfig)
|
||||
from vllm.config.pooler import PoolerConfig
|
||||
from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy
|
||||
from vllm.config.speculative import SpeculativeConfig
|
||||
from vllm.config.structured_outputs import StructuredOutputsConfig
|
||||
@ -406,13 +407,6 @@ class ModelConfig:
|
||||
hf_overrides: HfOverrides = field(default_factory=dict)
|
||||
"""If a dictionary, contains arguments to be forwarded to the Hugging Face
|
||||
config. If a callable, it is called to update the HuggingFace config."""
|
||||
pooler_config: Optional["PoolerConfig"] = field(init=False)
|
||||
"""Pooler config which controls the behaviour of output pooling in pooling
|
||||
models."""
|
||||
override_pooler_config: Optional[Union[dict, "PoolerConfig"]] = None
|
||||
"""Initialize non-default pooling config or override default pooling config
|
||||
for the pooling model. e.g. `{"pooling_type": "mean", "normalize": false}`.
|
||||
"""
|
||||
logits_processor_pattern: Optional[str] = None
|
||||
"""Optional regex pattern specifying valid logits processor qualified names
|
||||
that can be passed with the `logits_processors` extra completion argument.
|
||||
@ -448,6 +442,14 @@ class ModelConfig:
|
||||
io_processor_plugin: Optional[str] = None
|
||||
"""IOProcessor plugin name to load at model startup"""
|
||||
|
||||
# Pooler config
|
||||
pooler_config: Optional[PoolerConfig] = None
|
||||
"""Pooler config which controls the behaviour of output pooling in pooling
|
||||
models."""
|
||||
override_pooler_config: Optional[Union[dict, PoolerConfig]] = None
|
||||
"""[DEPRECATED] Use `pooler_config` instead. This field will be removed in
|
||||
v0.12.0 or v1.0.0, whichever is sooner."""
|
||||
|
||||
# Multimodal config and init vars
|
||||
multimodal_config: Optional[MultiModalConfig] = None
|
||||
"""Configuration for multimodal model. If `None`, this will be inferred
|
||||
@ -709,7 +711,33 @@ class ModelConfig:
|
||||
self._architecture = arch
|
||||
logger.info("Resolved architecture: %s", arch)
|
||||
|
||||
self.pooler_config = self._init_pooler_config()
|
||||
# Init pooler config if needed
|
||||
if self.runner_type == "pooling":
|
||||
if self.override_pooler_config is not None:
|
||||
logger.warning_once(
|
||||
"`override_pooler_config` is deprecated and will be "
|
||||
"removed in v0.12.0 or v1.0.0, whichever is sooner. "
|
||||
"Please use `pooler_config` instead.")
|
||||
|
||||
if isinstance(self.override_pooler_config, dict):
|
||||
self.pooler_config = PoolerConfig(
|
||||
**self.override_pooler_config)
|
||||
else:
|
||||
self.pooler_config = self.override_pooler_config
|
||||
|
||||
if self.pooler_config is None:
|
||||
self.pooler_config = PoolerConfig()
|
||||
|
||||
base_config = get_pooling_config(self.model, self.revision)
|
||||
if base_config is not None:
|
||||
# Only set values that are not overridden by the user
|
||||
for k, v in base_config.items():
|
||||
if getattr(self.pooler_config, k) is None:
|
||||
setattr(self.pooler_config, k, v)
|
||||
|
||||
default_pooling_type = self._model_info.default_pooling_type
|
||||
if self.pooler_config.pooling_type is None:
|
||||
self.pooler_config.pooling_type = default_pooling_type
|
||||
|
||||
self.dtype: torch.dtype = _get_and_verify_dtype(
|
||||
self.model,
|
||||
@ -869,29 +897,6 @@ class ModelConfig:
|
||||
return get_sentence_transformer_tokenizer_config(
|
||||
self.model, self.revision)
|
||||
|
||||
def _init_pooler_config(self) -> Optional["PoolerConfig"]:
|
||||
if self.runner_type == "pooling":
|
||||
if isinstance(self.override_pooler_config, dict):
|
||||
self.override_pooler_config = PoolerConfig(
|
||||
**self.override_pooler_config)
|
||||
|
||||
pooler_config = self.override_pooler_config or PoolerConfig()
|
||||
|
||||
base_config = get_pooling_config(self.model, self.revision)
|
||||
if base_config is not None:
|
||||
# Only set values that are not overridden by the user
|
||||
for k, v in base_config.items():
|
||||
if getattr(pooler_config, k) is None:
|
||||
setattr(pooler_config, k, v)
|
||||
|
||||
default_pooling_type = self._model_info.default_pooling_type
|
||||
if pooler_config.pooling_type is None:
|
||||
pooler_config.pooling_type = default_pooling_type
|
||||
|
||||
return pooler_config
|
||||
|
||||
return None
|
||||
|
||||
def _verify_tokenizer_mode(self) -> None:
|
||||
tokenizer_mode = cast(TokenizerMode, self.tokenizer_mode.lower())
|
||||
if tokenizer_mode not in get_args(TokenizerMode):
|
||||
@ -1833,94 +1838,6 @@ class DeviceConfig:
|
||||
self.device = torch.device(self.device_type)
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class PoolerConfig:
|
||||
"""Controls the behavior of output pooling in pooling models."""
|
||||
|
||||
pooling_type: Optional[str] = None
|
||||
"""
|
||||
The pooling method of the pooling model. This should be a key in
|
||||
[`vllm.model_executor.layers.pooler.PoolingType`][].
|
||||
"""
|
||||
|
||||
## for embeddings models
|
||||
normalize: Optional[bool] = None
|
||||
"""
|
||||
Whether to normalize the embeddings outputs. Defaults to True.
|
||||
"""
|
||||
dimensions: Optional[int] = None
|
||||
"""
|
||||
Reduce the dimensions of embeddings if model
|
||||
support matryoshka representation. Defaults to None.
|
||||
"""
|
||||
enable_chunked_processing: Optional[bool] = None
|
||||
"""
|
||||
Whether to enable chunked processing for long inputs that exceed the model's
|
||||
maximum position embeddings. When enabled, long inputs will be split into
|
||||
chunks, processed separately, and then aggregated using weighted averaging.
|
||||
This allows embedding models to handle arbitrarily long text without CUDA
|
||||
errors. Defaults to False.
|
||||
"""
|
||||
max_embed_len: Optional[int] = None
|
||||
"""
|
||||
Maximum input length allowed for embedding generation. When set, allows
|
||||
inputs longer than max_embed_len to be accepted for embedding models.
|
||||
When an input exceeds max_embed_len, it will be handled according to
|
||||
the original max_model_len validation logic.
|
||||
Defaults to None (i.e. set to max_model_len).
|
||||
"""
|
||||
|
||||
## for classification models
|
||||
activation: Optional[bool] = None
|
||||
"""
|
||||
Whether to apply activation function to the classification outputs.
|
||||
Defaults to True.
|
||||
"""
|
||||
logit_bias: Optional[float] = None
|
||||
"""
|
||||
If provided, apply classification logit biases. Defaults to None.
|
||||
"""
|
||||
|
||||
## for reward models
|
||||
softmax: Optional[bool] = None
|
||||
"""
|
||||
Whether to apply softmax to the reward outputs.
|
||||
Defaults to True.
|
||||
"""
|
||||
step_tag_id: Optional[int] = None
|
||||
"""
|
||||
If set, only the score corresponding to the ``step_tag_id`` in the
|
||||
generated sentence should be returned. Otherwise, the scores for all tokens
|
||||
are returned.
|
||||
"""
|
||||
returned_token_ids: Optional[list[int]] = None
|
||||
"""
|
||||
A list of indices for the vocabulary dimensions to be extracted,
|
||||
such as the token IDs of ``good_token`` and ``bad_token`` in the
|
||||
``math-shepherd-mistral-7b-prm`` model.
|
||||
"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode(),
|
||||
usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
|
||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
"half": torch.float16,
|
||||
"float16": torch.float16,
|
||||
|
||||
97
vllm/config/pooler.py
Normal file
97
vllm/config/pooler.py
Normal file
@ -0,0 +1,97 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import hashlib
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.config.utils import config
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class PoolerConfig:
|
||||
"""Controls the behavior of output pooling in pooling models."""
|
||||
|
||||
pooling_type: Optional[str] = None
|
||||
"""
|
||||
The pooling method of the pooling model. This should be a key in
|
||||
[`vllm.model_executor.layers.pooler.PoolingType`][].
|
||||
"""
|
||||
|
||||
## for embeddings models
|
||||
normalize: Optional[bool] = None
|
||||
"""
|
||||
Whether to normalize the embeddings outputs. Defaults to True.
|
||||
"""
|
||||
dimensions: Optional[int] = None
|
||||
"""
|
||||
Reduce the dimensions of embeddings if model
|
||||
support matryoshka representation. Defaults to None.
|
||||
"""
|
||||
enable_chunked_processing: Optional[bool] = None
|
||||
"""
|
||||
Whether to enable chunked processing for long inputs that exceed the model's
|
||||
maximum position embeddings. When enabled, long inputs will be split into
|
||||
chunks, processed separately, and then aggregated using weighted averaging.
|
||||
This allows embedding models to handle arbitrarily long text without CUDA
|
||||
errors. Defaults to False.
|
||||
"""
|
||||
max_embed_len: Optional[int] = None
|
||||
"""
|
||||
Maximum input length allowed for embedding generation. When set, allows
|
||||
inputs longer than max_embed_len to be accepted for embedding models.
|
||||
When an input exceeds max_embed_len, it will be handled according to
|
||||
the original max_model_len validation logic.
|
||||
Defaults to None (i.e. set to max_model_len).
|
||||
"""
|
||||
|
||||
## for classification models
|
||||
activation: Optional[bool] = None
|
||||
"""
|
||||
Whether to apply activation function to the classification outputs.
|
||||
Defaults to True.
|
||||
"""
|
||||
logit_bias: Optional[float] = None
|
||||
"""
|
||||
If provided, apply classification logit biases. Defaults to None.
|
||||
"""
|
||||
|
||||
## for reward models
|
||||
softmax: Optional[bool] = None
|
||||
"""
|
||||
Whether to apply softmax to the reward outputs.
|
||||
Defaults to True.
|
||||
"""
|
||||
step_tag_id: Optional[int] = None
|
||||
"""
|
||||
If set, only the score corresponding to the ``step_tag_id`` in the
|
||||
generated sentence should be returned. Otherwise, the scores for all tokens
|
||||
are returned.
|
||||
"""
|
||||
returned_token_ids: Optional[list[int]] = None
|
||||
"""
|
||||
A list of indices for the vocabulary dimensions to be extracted,
|
||||
such as the token IDs of ``good_token`` and ``bad_token`` in the
|
||||
``math-shepherd-mistral-7b-prm`` model.
|
||||
"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode(),
|
||||
usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
@ -441,6 +441,7 @@ class EngineArgs:
|
||||
scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
|
||||
scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls
|
||||
|
||||
pooler_config: Optional[PoolerConfig] = ModelConfig.pooler_config
|
||||
override_pooler_config: Optional[Union[dict, PoolerConfig]] = \
|
||||
ModelConfig.override_pooler_config
|
||||
compilation_config: CompilationConfig = \
|
||||
@ -579,8 +580,11 @@ class EngineArgs:
|
||||
help=model_kwargs["hf_token"]["help"])
|
||||
model_group.add_argument("--hf-overrides",
|
||||
**model_kwargs["hf_overrides"])
|
||||
model_group.add_argument("--pooler-config",
|
||||
**model_kwargs["pooler_config"])
|
||||
model_group.add_argument("--override-pooler-config",
|
||||
**model_kwargs["override_pooler_config"])
|
||||
**model_kwargs["override_pooler_config"],
|
||||
deprecated=True)
|
||||
model_group.add_argument("--logits-processor-pattern",
|
||||
**model_kwargs["logits_processor_pattern"])
|
||||
model_group.add_argument("--generation-config",
|
||||
@ -1031,6 +1035,7 @@ class EngineArgs:
|
||||
mm_shm_cache_max_object_size_mb=self.
|
||||
mm_shm_cache_max_object_size_mb,
|
||||
mm_encoder_tp_mode=self.mm_encoder_tp_mode,
|
||||
pooler_config=self.pooler_config,
|
||||
override_pooler_config=self.override_pooler_config,
|
||||
logits_processor_pattern=self.logits_processor_pattern,
|
||||
generation_config=self.generation_config,
|
||||
|
||||
@ -151,9 +151,11 @@ class LLM:
|
||||
multi-modal processor obtained from `AutoProcessor.from_pretrained`.
|
||||
The available overrides depend on the model that is being run.
|
||||
For example, for Phi-3-Vision: `{"num_crops": 4}`.
|
||||
override_pooler_config: Initialize non-default pooling config or
|
||||
override default pooling config for the pooling model.
|
||||
e.g. `PoolerConfig(pooling_type="mean", normalize=False)`.
|
||||
pooler_config: Initialize non-default pooling config for the pooling
|
||||
model. e.g. `PoolerConfig(pooling_type="mean", normalize=False)`.
|
||||
override_pooler_config: [DEPRECATED] Use `pooler_config` instead. This
|
||||
argument is deprecated and will be removed in v0.12.0 or v1.0.0,
|
||||
whichever is sooner.
|
||||
compilation_config: Either an integer or a dictionary. If it is an
|
||||
integer, it is used as the level of compilation optimization. If it
|
||||
is a dictionary, it can specify the full compilation configuration.
|
||||
@ -191,6 +193,7 @@ class LLM:
|
||||
hf_token: Optional[Union[bool, str]] = None,
|
||||
hf_overrides: Optional[HfOverrides] = None,
|
||||
mm_processor_kwargs: Optional[dict[str, Any]] = None,
|
||||
pooler_config: Optional[PoolerConfig] = None,
|
||||
override_pooler_config: Optional[PoolerConfig] = None,
|
||||
structured_outputs_config: Optional[Union[dict[
|
||||
str, Any], StructuredOutputsConfig]] = None,
|
||||
@ -288,6 +291,7 @@ class LLM:
|
||||
hf_token=hf_token,
|
||||
hf_overrides=hf_overrides,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
pooler_config=pooler_config,
|
||||
override_pooler_config=override_pooler_config,
|
||||
structured_outputs_config=structured_outputs_instance,
|
||||
compilation_config=compilation_config_instance,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user