diff --git a/docs/models/pooling_models.md b/docs/models/pooling_models.md index 0521a22c0702..50982d3d0d0f 100644 --- a/docs/models/pooling_models.md +++ b/docs/models/pooling_models.md @@ -59,7 +59,7 @@ enabling the corresponding APIs: #### Predefined models If the [Pooler][vllm.model_executor.layers.pooler.Pooler] defined by the model accepts `pooler_config`, -you can override some of its attributes via the `--override-pooler-config` option. +you can override some of its attributes via the `--pooler-config` option. #### Converted models @@ -75,7 +75,7 @@ the pooler assigned to each task has the following attributes by default: When loading [Sentence Transformers](https://huggingface.co/sentence-transformers) models, its Sentence Transformers configuration file (`modules.json`) takes priority over the model's defaults. -You can further customize this via the `--override-pooler-config` option, +You can further customize this via the `--pooler-config` option, which takes priority over both the model's and Sentence Transformers's defaults. ## Offline Inference diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index b67ebcbe3c81..3a6738a27be0 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -457,7 +457,7 @@ These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) A !!! note `ssmits/Qwen2-7B-Instruct-embed-base` has an improperly defined Sentence Transformers config. - You need to manually set mean pooling by passing `--override-pooler-config '{"pooling_type": "MEAN"}'`. + You need to manually set mean pooling by passing `--pooler-config '{"pooling_type": "MEAN"}'`. !!! note For `Alibaba-NLP/gte-Qwen2-*`, you need to enable `--trust-remote-code` for the correct tokenizer to be loaded. @@ -552,7 +552,7 @@ If your model is not in the above list, we will try to automatically convert the !!! important For process-supervised reward models such as `peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly, - e.g.: `--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`. + e.g.: `--pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`. #### Token Classification diff --git a/examples/online_serving/openai_embedding_long_text/README.md b/examples/online_serving/openai_embedding_long_text/README.md index 04edc4680ea0..00d3ded3e41c 100644 --- a/examples/online_serving/openai_embedding_long_text/README.md +++ b/examples/online_serving/openai_embedding_long_text/README.md @@ -42,7 +42,7 @@ python client.py ### Server Configuration -The key parameters for chunked processing are in the `--override-pooler-config`: +The key parameters for chunked processing are in the `--pooler-config`: ```json { diff --git a/examples/online_serving/openai_embedding_long_text/client.py b/examples/online_serving/openai_embedding_long_text/client.py index 6e9838ac6d8d..4a3674bb3f2a 100644 --- a/examples/online_serving/openai_embedding_long_text/client.py +++ b/examples/online_serving/openai_embedding_long_text/client.py @@ -13,7 +13,7 @@ Prerequisites: # MEAN pooling (processes all chunks, recommended for complete coverage) vllm serve intfloat/multilingual-e5-large \ - --override-pooler-config \ + --pooler-config \ '{"pooling_type": "MEAN", "normalize": true, ' \ '"enable_chunked_processing": true, "max_embed_len": 3072000}' \ --served-model-name multilingual-e5-large \ @@ -23,7 +23,7 @@ Prerequisites: # OR CLS pooling (native CLS within chunks, MEAN aggregation across chunks) vllm serve BAAI/bge-large-en-v1.5 \ - --override-pooler-config \ + --pooler-config \ '{"pooling_type": "CLS", "normalize": true, ' \ '"enable_chunked_processing": true, "max_embed_len": 1048576}' \ --served-model-name bge-large-en-v1.5 \ diff --git a/examples/online_serving/openai_embedding_long_text/service.sh b/examples/online_serving/openai_embedding_long_text/service.sh index 56888c8aa0e4..1577de85f7ff 100644 --- a/examples/online_serving/openai_embedding_long_text/service.sh +++ b/examples/online_serving/openai_embedding_long_text/service.sh @@ -103,7 +103,7 @@ POOLER_CONFIG="{\"pooling_type\": \"$POOLING_TYPE\", \"normalize\": true, \"enab vllm serve "$MODEL_NAME" \ --tensor-parallel-size "$GPU_COUNT" \ --enforce-eager \ - --override-pooler-config "$POOLER_CONFIG" \ + --pooler-config "$POOLER_CONFIG" \ --served-model-name ${MODEL_CODE} \ --api-key "$API_KEY" \ --trust-remote-code \ diff --git a/tests/entrypoints/pooling/openai/test_embedding_long_text.py b/tests/entrypoints/pooling/openai/test_embedding_long_text.py index 2d3da238d245..ab5f765c28ed 100644 --- a/tests/entrypoints/pooling/openai/test_embedding_long_text.py +++ b/tests/entrypoints/pooling/openai/test_embedding_long_text.py @@ -216,7 +216,7 @@ def server_with_chunked_processing(): "--enforce-eager", "--max-model-len", "512", # Set smaller max_model_len to trigger chunking mechanism - '--override-pooler-config', + '--pooler-config', ('{"pooling_type": "MEAN", "normalize": true, ' '"enable_chunked_processing": true, "max_embed_len": 10000}'), "--gpu-memory-utilization", diff --git a/tests/models/language/pooling/test_embedding.py b/tests/models/language/pooling/test_embedding.py index d61ac08475e3..17513d1bb20d 100644 --- a/tests/models/language/pooling/test_embedding.py +++ b/tests/models/language/pooling/test_embedding.py @@ -58,7 +58,7 @@ def test_models( vllm_extra_kwargs = {} if model == "ssmits/Qwen2-7B-Instruct-embed-base": - vllm_extra_kwargs["override_pooler_config"] = \ + vllm_extra_kwargs["pooler_config"] = \ PoolerConfig(pooling_type="MEAN", normalize=False) max_model_len: Optional[int] = 512 diff --git a/tests/models/language/pooling/test_mm_classifier_conversion.py b/tests/models/language/pooling/test_mm_classifier_conversion.py index 166b953de43e..9814cad48a80 100644 --- a/tests/models/language/pooling/test_mm_classifier_conversion.py +++ b/tests/models/language/pooling/test_mm_classifier_conversion.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.config.pooler import PoolerConfig from vllm.platforms import current_platform @@ -99,7 +100,7 @@ def test_gemma_multimodal( convert="classify", load_format="auto", hf_overrides=update_config, - override_pooler_config={"pooling_type": "LAST"}, + pooler_config=PoolerConfig(pooling_type="LAST"), max_model_len=512, enforce_eager=True, tensor_parallel_size=1, diff --git a/tests/models/language/pooling/test_override_pooler_config.py b/tests/models/language/pooling/test_pooler_config_init_behaviour.py similarity index 74% rename from tests/models/language/pooling/test_override_pooler_config.py rename to tests/models/language/pooling/test_pooler_config_init_behaviour.py index 2b1c74652e76..9b3fbd6a6cd0 100644 --- a/tests/models/language/pooling/test_override_pooler_config.py +++ b/tests/models/language/pooling/test_pooler_config_init_behaviour.py @@ -24,18 +24,18 @@ def test_classify_models_using_activation( dtype: str, ) -> None: - with vllm_runner(model, - max_model_len=512, - dtype=dtype, - override_pooler_config=PoolerConfig( - activation=False)) as vllm_model: + with vllm_runner( + model, + max_model_len=512, + dtype=dtype, + pooler_config=PoolerConfig(activation=False)) as vllm_model: wo_activation_out = vllm_model.classify(example_prompts) - with vllm_runner(model, - max_model_len=512, - dtype=dtype, - override_pooler_config=PoolerConfig( - activation=True)) as vllm_model: + with vllm_runner( + model, + max_model_len=512, + dtype=dtype, + pooler_config=PoolerConfig(activation=True)) as vllm_model: w_activation_out = vllm_model.classify(example_prompts) for wo_activation, w_activation in zip(wo_activation_out, @@ -43,9 +43,8 @@ def test_classify_models_using_activation( wo_activation = torch.tensor(wo_activation) w_activation = torch.tensor(w_activation) - assert not torch.allclose( - wo_activation, w_activation, - atol=1e-2), "override_pooler_config is not working" + assert not torch.allclose(wo_activation, w_activation, + atol=1e-2), "pooler_config is not working" assert torch.allclose(softmax(wo_activation), w_activation, 1e-3 if dtype == "float" else 1e-2) @@ -65,23 +64,22 @@ def test_embed_models_using_normalize( dtype: str, ) -> None: - with vllm_runner(model, - max_model_len=512, - dtype=dtype, - override_pooler_config=PoolerConfig( - normalize=False)) as vllm_model: - wo_normalize = torch.tensor(vllm_model.embed(example_prompts)) - with vllm_runner( model, max_model_len=512, dtype=dtype, - override_pooler_config=PoolerConfig(normalize=True)) as vllm_model: + pooler_config=PoolerConfig(normalize=False)) as vllm_model: + wo_normalize = torch.tensor(vllm_model.embed(example_prompts)) + + with vllm_runner(model, + max_model_len=512, + dtype=dtype, + pooler_config=PoolerConfig(normalize=True)) as vllm_model: w_normalize = torch.tensor(vllm_model.embed(example_prompts)) assert not torch.allclose( wo_normalize, w_normalize, - atol=1e-2), "override_pooler_config normalize is not working" + atol=1e-2), "pooler_config normalize is not working" assert torch.allclose( F.normalize(wo_normalize, p=2, dim=-1), w_normalize, atol=1e-2), "w_normal should be close to normal(wo_normal)." @@ -102,18 +100,16 @@ def test_reward_models_using_softmax( dtype: str, ) -> None: - with vllm_runner( - model, - max_model_len=1024, - dtype=dtype, - override_pooler_config=PoolerConfig(softmax=False)) as vllm_model: + with vllm_runner(model, + max_model_len=1024, + dtype=dtype, + pooler_config=PoolerConfig(softmax=False)) as vllm_model: wo_softmax = vllm_model.encode(example_prompts) - with vllm_runner( - model, - max_model_len=1024, - dtype=dtype, - override_pooler_config=PoolerConfig(softmax=True)) as vllm_model: + with vllm_runner(model, + max_model_len=1024, + dtype=dtype, + pooler_config=PoolerConfig(softmax=True)) as vllm_model: w_softmax = vllm_model.encode(example_prompts) for wo, w in zip(wo_softmax, w_softmax): @@ -121,7 +117,7 @@ def test_reward_models_using_softmax( w = torch.tensor(w) assert not torch.allclose( - wo, w, atol=1e-2), "override_pooler_config softmax is not working" + wo, w, atol=1e-2), "pooler_config softmax is not working" assert torch.allclose( softmax(wo), w, atol=1e-2), "w_softmax should be close to softmax(wo_softmax)." diff --git a/tests/test_config.py b/tests/test_config.py index 6e37bdbee59e..0796447c079b 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -207,25 +207,19 @@ def test_get_pooling_config(): model_id = "sentence-transformers/all-MiniLM-L12-v2" model_config = ModelConfig(model_id) - pooling_config = model_config._init_pooler_config() - assert pooling_config is not None - - assert pooling_config.normalize - assert pooling_config.pooling_type == PoolingType.MEAN.name + assert model_config.pooler_config is not None + assert model_config.pooler_config.normalize + assert model_config.pooler_config.pooling_type == PoolingType.MEAN.name @pytest.mark.skipif(current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm.") def test_get_pooling_config_from_args(): model_id = "sentence-transformers/all-MiniLM-L12-v2" - model_config = ModelConfig(model_id) + pooler_config = PoolerConfig(pooling_type="CLS", normalize=True) + model_config = ModelConfig(model_id, pooler_config=pooler_config) - override_pooler_config = PoolerConfig(pooling_type='CLS', normalize=True) - model_config.override_pooler_config = override_pooler_config - - pooling_config = model_config._init_pooler_config() - assert pooling_config is not None - assert asdict(pooling_config) == asdict(override_pooler_config) + assert asdict(model_config.pooler_config) == asdict(pooler_config) @pytest.mark.parametrize( diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 25daca00c02d..45504e010d68 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -40,6 +40,7 @@ from vllm.config.multimodal import (MMCacheType, MMEncoderTPMode, MultiModalConfig) from vllm.config.parallel import (DistributedExecutorBackend, EPLBConfig, ParallelConfig) +from vllm.config.pooler import PoolerConfig from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy from vllm.config.speculative import SpeculativeConfig from vllm.config.structured_outputs import StructuredOutputsConfig @@ -406,13 +407,6 @@ class ModelConfig: hf_overrides: HfOverrides = field(default_factory=dict) """If a dictionary, contains arguments to be forwarded to the Hugging Face config. If a callable, it is called to update the HuggingFace config.""" - pooler_config: Optional["PoolerConfig"] = field(init=False) - """Pooler config which controls the behaviour of output pooling in pooling - models.""" - override_pooler_config: Optional[Union[dict, "PoolerConfig"]] = None - """Initialize non-default pooling config or override default pooling config - for the pooling model. e.g. `{"pooling_type": "mean", "normalize": false}`. - """ logits_processor_pattern: Optional[str] = None """Optional regex pattern specifying valid logits processor qualified names that can be passed with the `logits_processors` extra completion argument. @@ -448,6 +442,14 @@ class ModelConfig: io_processor_plugin: Optional[str] = None """IOProcessor plugin name to load at model startup""" + # Pooler config + pooler_config: Optional[PoolerConfig] = None + """Pooler config which controls the behaviour of output pooling in pooling + models.""" + override_pooler_config: Optional[Union[dict, PoolerConfig]] = None + """[DEPRECATED] Use `pooler_config` instead. This field will be removed in + v0.12.0 or v1.0.0, whichever is sooner.""" + # Multimodal config and init vars multimodal_config: Optional[MultiModalConfig] = None """Configuration for multimodal model. If `None`, this will be inferred @@ -709,7 +711,33 @@ class ModelConfig: self._architecture = arch logger.info("Resolved architecture: %s", arch) - self.pooler_config = self._init_pooler_config() + # Init pooler config if needed + if self.runner_type == "pooling": + if self.override_pooler_config is not None: + logger.warning_once( + "`override_pooler_config` is deprecated and will be " + "removed in v0.12.0 or v1.0.0, whichever is sooner. " + "Please use `pooler_config` instead.") + + if isinstance(self.override_pooler_config, dict): + self.pooler_config = PoolerConfig( + **self.override_pooler_config) + else: + self.pooler_config = self.override_pooler_config + + if self.pooler_config is None: + self.pooler_config = PoolerConfig() + + base_config = get_pooling_config(self.model, self.revision) + if base_config is not None: + # Only set values that are not overridden by the user + for k, v in base_config.items(): + if getattr(self.pooler_config, k) is None: + setattr(self.pooler_config, k, v) + + default_pooling_type = self._model_info.default_pooling_type + if self.pooler_config.pooling_type is None: + self.pooler_config.pooling_type = default_pooling_type self.dtype: torch.dtype = _get_and_verify_dtype( self.model, @@ -869,29 +897,6 @@ class ModelConfig: return get_sentence_transformer_tokenizer_config( self.model, self.revision) - def _init_pooler_config(self) -> Optional["PoolerConfig"]: - if self.runner_type == "pooling": - if isinstance(self.override_pooler_config, dict): - self.override_pooler_config = PoolerConfig( - **self.override_pooler_config) - - pooler_config = self.override_pooler_config or PoolerConfig() - - base_config = get_pooling_config(self.model, self.revision) - if base_config is not None: - # Only set values that are not overridden by the user - for k, v in base_config.items(): - if getattr(pooler_config, k) is None: - setattr(pooler_config, k, v) - - default_pooling_type = self._model_info.default_pooling_type - if pooler_config.pooling_type is None: - pooler_config.pooling_type = default_pooling_type - - return pooler_config - - return None - def _verify_tokenizer_mode(self) -> None: tokenizer_mode = cast(TokenizerMode, self.tokenizer_mode.lower()) if tokenizer_mode not in get_args(TokenizerMode): @@ -1833,94 +1838,6 @@ class DeviceConfig: self.device = torch.device(self.device_type) -@config -@dataclass -class PoolerConfig: - """Controls the behavior of output pooling in pooling models.""" - - pooling_type: Optional[str] = None - """ - The pooling method of the pooling model. This should be a key in - [`vllm.model_executor.layers.pooler.PoolingType`][]. - """ - - ## for embeddings models - normalize: Optional[bool] = None - """ - Whether to normalize the embeddings outputs. Defaults to True. - """ - dimensions: Optional[int] = None - """ - Reduce the dimensions of embeddings if model - support matryoshka representation. Defaults to None. - """ - enable_chunked_processing: Optional[bool] = None - """ - Whether to enable chunked processing for long inputs that exceed the model's - maximum position embeddings. When enabled, long inputs will be split into - chunks, processed separately, and then aggregated using weighted averaging. - This allows embedding models to handle arbitrarily long text without CUDA - errors. Defaults to False. - """ - max_embed_len: Optional[int] = None - """ - Maximum input length allowed for embedding generation. When set, allows - inputs longer than max_embed_len to be accepted for embedding models. - When an input exceeds max_embed_len, it will be handled according to - the original max_model_len validation logic. - Defaults to None (i.e. set to max_model_len). - """ - - ## for classification models - activation: Optional[bool] = None - """ - Whether to apply activation function to the classification outputs. - Defaults to True. - """ - logit_bias: Optional[float] = None - """ - If provided, apply classification logit biases. Defaults to None. - """ - - ## for reward models - softmax: Optional[bool] = None - """ - Whether to apply softmax to the reward outputs. - Defaults to True. - """ - step_tag_id: Optional[int] = None - """ - If set, only the score corresponding to the ``step_tag_id`` in the - generated sentence should be returned. Otherwise, the scores for all tokens - are returned. - """ - returned_token_ids: Optional[list[int]] = None - """ - A list of indices for the vocabulary dimensions to be extracted, - such as the token IDs of ``good_token`` and ``bad_token`` in the - ``math-shepherd-mistral-7b-prm`` model. - """ - - def compute_hash(self) -> str: - """ - WARNING: Whenever a new field is added to this config, - ensure that it is included in the factors list if - it affects the computation graph. - - Provide a hash that uniquely identifies all the configs - that affect the structure of the computation - graph from input ids/embeddings to the final hidden states, - excluding anything before input ids/embeddings and after - the final hidden states. - """ - # no factors to consider. - # this config will not affect the computation graph. - factors: list[Any] = [] - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() - return hash_str - - _STR_DTYPE_TO_TORCH_DTYPE = { "half": torch.float16, "float16": torch.float16, diff --git a/vllm/config/pooler.py b/vllm/config/pooler.py new file mode 100644 index 000000000000..85b5a1ace85f --- /dev/null +++ b/vllm/config/pooler.py @@ -0,0 +1,97 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import hashlib +from typing import Any, Optional + +from pydantic.dataclasses import dataclass + +from vllm.config.utils import config + + +@config +@dataclass +class PoolerConfig: + """Controls the behavior of output pooling in pooling models.""" + + pooling_type: Optional[str] = None + """ + The pooling method of the pooling model. This should be a key in + [`vllm.model_executor.layers.pooler.PoolingType`][]. + """ + + ## for embeddings models + normalize: Optional[bool] = None + """ + Whether to normalize the embeddings outputs. Defaults to True. + """ + dimensions: Optional[int] = None + """ + Reduce the dimensions of embeddings if model + support matryoshka representation. Defaults to None. + """ + enable_chunked_processing: Optional[bool] = None + """ + Whether to enable chunked processing for long inputs that exceed the model's + maximum position embeddings. When enabled, long inputs will be split into + chunks, processed separately, and then aggregated using weighted averaging. + This allows embedding models to handle arbitrarily long text without CUDA + errors. Defaults to False. + """ + max_embed_len: Optional[int] = None + """ + Maximum input length allowed for embedding generation. When set, allows + inputs longer than max_embed_len to be accepted for embedding models. + When an input exceeds max_embed_len, it will be handled according to + the original max_model_len validation logic. + Defaults to None (i.e. set to max_model_len). + """ + + ## for classification models + activation: Optional[bool] = None + """ + Whether to apply activation function to the classification outputs. + Defaults to True. + """ + logit_bias: Optional[float] = None + """ + If provided, apply classification logit biases. Defaults to None. + """ + + ## for reward models + softmax: Optional[bool] = None + """ + Whether to apply softmax to the reward outputs. + Defaults to True. + """ + step_tag_id: Optional[int] = None + """ + If set, only the score corresponding to the ``step_tag_id`` in the + generated sentence should be returned. Otherwise, the scores for all tokens + are returned. + """ + returned_token_ids: Optional[list[int]] = None + """ + A list of indices for the vocabulary dimensions to be extracted, + such as the token IDs of ``good_token`` and ``bad_token`` in the + ``math-shepherd-mistral-7b-prm`` model. + """ + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 63282c425350..27462b8fa0da 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -441,6 +441,7 @@ class EngineArgs: scheduling_policy: SchedulerPolicy = SchedulerConfig.policy scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls + pooler_config: Optional[PoolerConfig] = ModelConfig.pooler_config override_pooler_config: Optional[Union[dict, PoolerConfig]] = \ ModelConfig.override_pooler_config compilation_config: CompilationConfig = \ @@ -579,8 +580,11 @@ class EngineArgs: help=model_kwargs["hf_token"]["help"]) model_group.add_argument("--hf-overrides", **model_kwargs["hf_overrides"]) + model_group.add_argument("--pooler-config", + **model_kwargs["pooler_config"]) model_group.add_argument("--override-pooler-config", - **model_kwargs["override_pooler_config"]) + **model_kwargs["override_pooler_config"], + deprecated=True) model_group.add_argument("--logits-processor-pattern", **model_kwargs["logits_processor_pattern"]) model_group.add_argument("--generation-config", @@ -1031,6 +1035,7 @@ class EngineArgs: mm_shm_cache_max_object_size_mb=self. mm_shm_cache_max_object_size_mb, mm_encoder_tp_mode=self.mm_encoder_tp_mode, + pooler_config=self.pooler_config, override_pooler_config=self.override_pooler_config, logits_processor_pattern=self.logits_processor_pattern, generation_config=self.generation_config, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index df6b16c73d6e..e21bfce0ab08 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -151,9 +151,11 @@ class LLM: multi-modal processor obtained from `AutoProcessor.from_pretrained`. The available overrides depend on the model that is being run. For example, for Phi-3-Vision: `{"num_crops": 4}`. - override_pooler_config: Initialize non-default pooling config or - override default pooling config for the pooling model. - e.g. `PoolerConfig(pooling_type="mean", normalize=False)`. + pooler_config: Initialize non-default pooling config for the pooling + model. e.g. `PoolerConfig(pooling_type="mean", normalize=False)`. + override_pooler_config: [DEPRECATED] Use `pooler_config` instead. This + argument is deprecated and will be removed in v0.12.0 or v1.0.0, + whichever is sooner. compilation_config: Either an integer or a dictionary. If it is an integer, it is used as the level of compilation optimization. If it is a dictionary, it can specify the full compilation configuration. @@ -191,6 +193,7 @@ class LLM: hf_token: Optional[Union[bool, str]] = None, hf_overrides: Optional[HfOverrides] = None, mm_processor_kwargs: Optional[dict[str, Any]] = None, + pooler_config: Optional[PoolerConfig] = None, override_pooler_config: Optional[PoolerConfig] = None, structured_outputs_config: Optional[Union[dict[ str, Any], StructuredOutputsConfig]] = None, @@ -288,6 +291,7 @@ class LLM: hf_token=hf_token, hf_overrides=hf_overrides, mm_processor_kwargs=mm_processor_kwargs, + pooler_config=pooler_config, override_pooler_config=override_pooler_config, structured_outputs_config=structured_outputs_instance, compilation_config=compilation_config_instance,