[Model] Update pooling model interface (#21058)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-07-18 00:05:40 +08:00 committed by GitHub
parent 9fb2d22032
commit 90bd2ab6e3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 247 additions and 345 deletions

View File

@ -11,11 +11,13 @@ from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.models.gemma2 import Gemma2Model
from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.sequence import IntermediateTensors
class MyGemma2Embedding(nn.Module):
is_pooling_model = True
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
@ -24,7 +26,7 @@ class MyGemma2Embedding(nn.Module):
self.model = Gemma2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self._pooler = Pooler.from_config_with_defaults(
self.pooler = Pooler.from_config_with_defaults(
vllm_config.model_config.pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
@ -54,13 +56,6 @@ class MyGemma2Embedding(nn.Module):
# Return all-zero embeddings
return torch.zeros_like(hidden_states)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
weights = self.hf_to_vllm_mapper.apply(weights)

View File

@ -1237,10 +1237,6 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
user: Optional[str] = None
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
# --8<-- [start:embedding-pooling-params]
additional_data: Optional[Any] = None
# --8<-- [end:embedding-pooling-params]
# --8<-- [start:embedding-extra-params]
add_special_tokens: bool = Field(
default=True,
@ -1259,8 +1255,7 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
# --8<-- [end:embedding-extra-params]
def to_pooling_params(self):
return PoolingParams(dimensions=self.dimensions,
additional_data=self.additional_data)
return PoolingParams(dimensions=self.dimensions)
class EmbeddingChatRequest(OpenAIBaseModel):
@ -1272,10 +1267,6 @@ class EmbeddingChatRequest(OpenAIBaseModel):
user: Optional[str] = None
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
# --8<-- [start:chat-embedding-pooling-params]
additional_data: Optional[Any] = None
# --8<-- [end:chat-embedding-pooling-params]
# --8<-- [start:chat-embedding-extra-params]
add_special_tokens: bool = Field(
default=False,
@ -1323,8 +1314,7 @@ class EmbeddingChatRequest(OpenAIBaseModel):
return data
def to_pooling_params(self):
return PoolingParams(dimensions=self.dimensions,
additional_data=self.additional_data)
return PoolingParams(dimensions=self.dimensions)
EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]
@ -1340,10 +1330,6 @@ class ScoreRequest(OpenAIBaseModel):
text_2: Union[list[str], str, ScoreMultiModalParam]
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
# --8<-- [start:score-pooling-params]
additional_data: Optional[Any] = None
# --8<-- [end:score-pooling-params]
# --8<-- [start:score-extra-params]
mm_processor_kwargs: Optional[dict[str, Any]] = Field(
@ -1362,8 +1348,7 @@ class ScoreRequest(OpenAIBaseModel):
# --8<-- [end:score-extra-params]
def to_pooling_params(self, *, use_cross_encoder: bool = False):
return PoolingParams(use_cross_encoder=use_cross_encoder,
additional_data=self.additional_data)
return PoolingParams(use_cross_encoder=use_cross_encoder)
class RerankRequest(OpenAIBaseModel):
@ -1373,10 +1358,6 @@ class RerankRequest(OpenAIBaseModel):
top_n: int = Field(default_factory=lambda: 0)
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
# --8<-- [start:rerank-pooling-params]
additional_data: Optional[Any] = None
# --8<-- [end:rerank-pooling-params]
# --8<-- [start:rerank-extra-params]
mm_processor_kwargs: Optional[dict[str, Any]] = Field(
@ -1395,8 +1376,7 @@ class RerankRequest(OpenAIBaseModel):
# --8<-- [end:rerank-extra-params]
def to_pooling_params(self, *, use_cross_encoder: bool = False):
return PoolingParams(use_cross_encoder=use_cross_encoder,
additional_data=self.additional_data)
return PoolingParams(use_cross_encoder=use_cross_encoder)
class RerankDocument(BaseModel):
@ -1534,10 +1514,6 @@ class ClassificationRequest(OpenAIBaseModel):
truncate_prompt_tokens: Optional[int] = None
user: Optional[str] = None
# --8<-- [start:classification-pooling-params]
additional_data: Optional[Any] = None
# --8<-- [end:classification-pooling-params]
# --8<-- [start:classification-extra-params]
priority: int = Field(
default=0,
@ -1550,7 +1526,7 @@ class ClassificationRequest(OpenAIBaseModel):
# --8<-- [end:classification-extra-params]
def to_pooling_params(self):
return PoolingParams(additional_data=self.additional_data)
return PoolingParams()
class ClassificationData(OpenAIBaseModel):

View File

@ -3,22 +3,25 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import IntEnum
from typing import Callable, Optional, TypeVar, Union
from typing import Callable, Literal, Optional, TypeVar, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from typing_extensions import assert_never
from vllm.config import ModelConfig, PoolerConfig
from vllm.model_executor.pooling_metadata import ( # noqa: E501
PoolingMetadata as V0PoolingMetadata)
from vllm.model_executor.pooling_metadata import PoolingTensors
from vllm.pooling_params import PoolingParams
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
from vllm.utils import resolve_obj_by_qualname
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata
PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata]
PoolingTask = Literal["encode", "embed", "classify", "score"]
class PoolingType(IntEnum):
@ -64,6 +67,48 @@ class ResolvedPoolingConfig:
)
class Pooler(nn.Module, ABC):
"""The interface required for all poolers used in pooling models in vLLM."""
@staticmethod
def from_config_with_defaults(
pooler_config: PoolerConfig,
pooling_type: PoolingType,
normalize: bool,
softmax: bool,
step_tag_id: Optional[int] = None,
returned_token_ids: Optional[list[int]] = None,
) -> "Pooler":
resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
pooler_config=pooler_config,
pooling_type=pooling_type,
normalize=normalize,
softmax=softmax,
step_tag_id=step_tag_id,
returned_token_ids=returned_token_ids,
)
if pooling_type == PoolingType.STEP:
return StepPooler.from_config(resolved_config)
return SimplePooler.from_config(resolved_config)
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
"""
Construct the pooling parameters to use for a task,
or `None` if the task is not supported.
"""
return None
@abstractmethod
def forward(
self,
hidden_states: Union[list[torch.Tensor], torch.Tensor],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
raise NotImplementedError
def get_prompt_lens(
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
@ -104,17 +149,6 @@ def build_output(all_data: torch.Tensor) -> PoolerOutput:
return PoolerOutput(outputs=all_outputs)
class BasePooler(nn.Module):
@abstractmethod
def forward(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
raise NotImplementedError
class PoolingMethod(nn.Module, ABC):
@staticmethod
@ -130,6 +164,10 @@ class PoolingMethod(nn.Module, ABC):
raise NotImplementedError(f"Unsupported method: {pooling_type}")
@abstractmethod
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
raise NotImplementedError
@abstractmethod
def forward_one(
self,
@ -168,6 +206,14 @@ class PoolingMethod(nn.Module, ABC):
class CLSPool(PoolingMethod):
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
# The equalities are split up to keep mypy happy
if (task == "encode" or task == "embed" or task == "classify"
or task == "score"):
return PoolingParams()
assert_never(task)
def forward_one(
self,
hidden_states: torch.Tensor,
@ -190,6 +236,14 @@ class CLSPool(PoolingMethod):
class LastPool(PoolingMethod):
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
# The equalities are split up to keep mypy happy
if (task == "encode" or task == "embed" or task == "classify"
or task == "score"):
return PoolingParams()
assert_never(task)
def forward_one(
self,
hidden_states: torch.Tensor,
@ -208,6 +262,16 @@ class LastPool(PoolingMethod):
class AllPool(PoolingMethod):
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
if task == "encode":
return PoolingParams()
# The equalities are split up to keep mypy happy
if task == "embed" or task == "classify" or task == "score":
return None
assert_never(task)
def forward_one(
self,
hidden_states: torch.Tensor,
@ -235,6 +299,14 @@ class AllPool(PoolingMethod):
class MeanPool(PoolingMethod):
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
# The equalities are split up to keep mypy happy
if (task == "encode" or task == "embed" or task == "classify"
or task == "score"):
return PoolingParams()
assert_never(task)
def forward_one(
self,
hidden_states: torch.Tensor,
@ -345,25 +417,6 @@ class LambdaPoolerActivation(PoolerActivation):
class PoolerHead(nn.Module):
@classmethod
def from_config_with_defaults(
cls,
pooler_config: PoolerConfig,
pooling_type: PoolingType,
normalize: bool,
softmax: bool,
) -> "PoolerHead":
resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
pooler_config=pooler_config,
pooling_type=pooling_type,
normalize=normalize,
softmax=softmax,
step_tag_id=None,
returned_token_ids=None,
)
return cls.from_config(resolved_config)
@classmethod
def from_config(cls, pooler_config: ResolvedPoolingConfig) -> "PoolerHead":
if pooler_config.normalize and pooler_config.softmax:
@ -424,21 +477,17 @@ class PoolerHead(nn.Module):
return self.activation(pooled_data)
class SimplePooler(BasePooler):
class SimplePooler(Pooler):
"""A layer that pools specific information from hidden states.
This layer does the following:
1. Extracts specific tokens or aggregates data based on pooling method.
2. Normalizes output if specified.
3. Returns structured results as `PoolerOutput`.
Attributes:
pooling_type: The type of pooling to use.
normalize: Whether to normalize the pooled data.
"""
@classmethod
def from_config_with_defaults(
def from_config_with_defaults( # type: ignore[override]
cls,
pooler_config: PoolerConfig,
pooling_type: PoolingType,
@ -471,6 +520,9 @@ class SimplePooler(BasePooler):
self.pooling = pooling
self.head = head
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
return self.pooling.get_pooling_params(task)
def forward(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
@ -481,7 +533,7 @@ class SimplePooler(BasePooler):
return build_output(pooled_data)
class StepPooler(BasePooler):
class StepPooler(Pooler):
@classmethod
def from_config(cls, pooler_config: ResolvedPoolingConfig) -> "StepPooler":
@ -543,6 +595,16 @@ class StepPooler(BasePooler):
return pooled_data
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
if task == "encode":
return PoolingParams(logits_processing_needs_token_ids=True)
# The equalities are split up to keep mypy happy
if task == "embed" or task == "classify" or task == "score":
return None
assert_never(task)
def forward(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
@ -553,32 +615,6 @@ class StepPooler(BasePooler):
return build_output(pooled_data)
class Pooler(nn.Module):
@staticmethod
def from_config_with_defaults(
pooler_config: PoolerConfig,
pooling_type: PoolingType,
normalize: bool,
softmax: bool,
step_tag_id: Optional[int] = None,
returned_token_ids: Optional[list[int]] = None,
) -> BasePooler:
resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
pooler_config=pooler_config,
pooling_type=pooling_type,
normalize=normalize,
softmax=softmax,
step_tag_id=step_tag_id,
returned_token_ids=returned_token_ids,
)
if pooling_type == PoolingType.STEP:
return StepPooler.from_config(resolved_config)
return SimplePooler.from_config(resolved_config)
PoolingFn = Callable[
[Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata],
Union[torch.Tensor, list[torch.Tensor]]]
@ -618,6 +654,18 @@ class ClassifierPooler(nn.Module):
return (self.cross_encoder_act_fn
if use_cross_encoder else self.classification_act_fn)
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
if task == "encode":
return PoolingParams()
if task == "embed":
return None
if task == "classify":
return PoolingParams()
if task == "score":
return PoolingParams(use_cross_encoder=True)
assert_never(task)
def forward(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union, cast
from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
import torch
import torch.nn as nn
@ -42,13 +42,14 @@ def _create_pooling_model_cls(
default_softmax: bool,
) -> _T:
# Lazy import
from vllm.model_executor.layers.pooler import Pooler, PoolerOutput
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.layers.pooler import Pooler
from .utils import AutoWeightsLoader, WeightsMapper
class ModelForPooling(orig_cls, VllmModelForPooling):
is_pooling_model = True
def __init__(
self,
*,
@ -66,27 +67,20 @@ def _create_pooling_model_cls(
delattr(self, attr)
# If the model already defines a pooler instance, don't overwrite it
if not getattr(self, "_pooler", None):
if not getattr(self, "pooler", None):
self._init_pooler(vllm_config, prefix=prefix)
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self._pooler = Pooler.from_config_with_defaults(
self.pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=default_pooling_type,
normalize=default_normalize,
softmax=default_softmax,
)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
# TODO: Support uninitialized params tracking
@ -171,10 +165,8 @@ def as_seq_cls_model(cls: _T) -> _T:
# Lazy import
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.pooler import (ClassifierPooler,
PoolerOutput, PoolingType,
SimplePooler)
PoolingType, SimplePooler)
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors
from .utils import maybe_prefix
@ -213,7 +205,7 @@ def as_seq_cls_model(cls: _T) -> _T:
softmax=True,
)
self._pooler = ClassifierPooler(
self.pooler = ClassifierPooler(
vllm_config.model_config,
pooling=pooler.pooling,
classifier=self._classifier,
@ -234,13 +226,6 @@ def as_seq_cls_model(cls: _T) -> _T:
return super().forward(input_ids, positions, intermediate_tensors,
inputs_embeds)
def pooler(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
tokens = getattr(self.config, "classifier_from_token", None)
method = getattr(self.config, "method", None)

View File

@ -18,12 +18,14 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler,
PoolingMethod, PoolingType)
PoolingMethod, PoolingTask,
PoolingType)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.pooling_params import PoolingParams
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
@ -80,7 +82,7 @@ class BertEmbedding(nn.Module):
return embeddings
class BertPooler(nn.Module):
class BertPooler(Pooler):
def __init__(self, config: BertConfig):
super().__init__()
@ -89,6 +91,9 @@ class BertPooler(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
return self.pooling.get_pooling_params(task)
def forward(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
@ -319,6 +324,9 @@ class BertOutput(nn.Module):
class BertModel(nn.Module, SupportsQuant):
is_pooling_model = True
packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]}
def __init__(self,
@ -403,12 +411,15 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
_pooler: An instance of Pooler used for pooling operations.
"""
is_pooling_model = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
pooler_config = vllm_config.model_config.pooler_config
self.model = self._build_model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self._pooler = self._build_pooler(pooler_config)
self.pooler = self._build_pooler(pooler_config)
def forward(
self,
@ -422,13 +433,6 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
weights_list = list(weights)
@ -466,6 +470,8 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
_pooler: An instance of Pooler used for pooling operations.
"""
is_pooling_model = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
@ -476,7 +482,7 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
embedding_class=BertEmbedding,
add_pooling_layer=True)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self._pooler = ClassifierPooler(
self.pooler = ClassifierPooler(
vllm_config.model_config,
pooling=self.bert.pooler,
classifier=self.classifier,
@ -487,13 +493,6 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
loaded_params = loader.load_weights(weights)
return loaded_params
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def forward(
self,
input_ids: Optional[torch.Tensor],

View File

@ -40,9 +40,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.sequence import IntermediateTensors
from ..layers.pooler import Pooler, PoolingType
from .interfaces import SupportsPP
@ -332,6 +331,8 @@ class GPT2ForSequenceClassification(nn.Module):
_pooler: An instance of Pooler used for pooling operations.
"""
is_pooling_model = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
@ -339,7 +340,7 @@ class GPT2ForSequenceClassification(nn.Module):
prefix=maybe_prefix(prefix, "gpt2"))
self.score = nn.Linear(config.n_embd, config.num_labels, bias=False)
pooler_config = vllm_config.model_config.pooler_config
self._pooler = Pooler.from_config_with_defaults(
self.pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=False,
@ -349,13 +350,6 @@ class GPT2ForSequenceClassification(nn.Module):
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def forward(
self,
input_ids: torch.Tensor,

View File

@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from array import array
from typing import Optional
import torch
import torch.nn as nn
@ -195,6 +194,8 @@ class GritLM(LlamaForCausalLM, SupportsV0Only):
- "<|user|>\nPROMPT\n<|assistant|>\n"
"""
is_pooling_model = True
def __init__(
self,
vllm_config: VllmConfig,
@ -214,11 +215,4 @@ class GritLM(LlamaForCausalLM, SupportsV0Only):
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
self._pooler = GritLMPooler(vllm_config.model_config)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
self.pooler = GritLMPooler(vllm_config.model_config)

View File

@ -119,13 +119,6 @@ class SupportsMultiModal(Protocol):
...
# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
class _SupportsMultiModalType(Protocol):
supports_multimodal: Literal[True]
@overload
def supports_multimodal(
model: type[object]) -> TypeIs[type[SupportsMultiModal]]:
@ -140,10 +133,7 @@ def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]:
def supports_multimodal(
model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsMultiModal]], TypeIs[SupportsMultiModal]]:
if isinstance(model, type):
return isinstance(model, _SupportsMultiModalType)
return isinstance(model, SupportsMultiModal)
return getattr(model, "supports_multimodal", False)
@runtime_checkable
@ -174,13 +164,6 @@ class SupportsScoreTemplate(Protocol):
...
# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
class _SupportsScoreTemplateType(Protocol):
supports_score_template: Literal[True]
@overload
def supports_score_template(
model: type[object]) -> TypeIs[type[SupportsScoreTemplate]]:
@ -195,11 +178,7 @@ def supports_score_template(model: object) -> TypeIs[SupportsScoreTemplate]:
def supports_score_template(
model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsScoreTemplate]], TypeIs[SupportsScoreTemplate]]:
if isinstance(model, type):
return isinstance(model, _SupportsScoreTemplateType)
return isinstance(model, SupportsScoreTemplate)
return getattr(model, "supports_score_template", False)
@runtime_checkable
@ -409,11 +388,6 @@ class HasInnerState(Protocol):
"""
@runtime_checkable
class _HasInnerStateType(Protocol):
has_inner_state: ClassVar[Literal[True]]
@overload
def has_inner_state(model: object) -> TypeIs[HasInnerState]:
...
@ -427,10 +401,7 @@ def has_inner_state(model: type[object]) -> TypeIs[type[HasInnerState]]:
def has_inner_state(
model: Union[type[object], object]
) -> Union[TypeIs[type[HasInnerState]], TypeIs[HasInnerState]]:
if isinstance(model, type):
return isinstance(model, _HasInnerStateType)
return isinstance(model, HasInnerState)
return getattr(model, "has_inner_state", False)
@runtime_checkable
@ -446,11 +417,6 @@ class IsAttentionFree(Protocol):
"""
@runtime_checkable
class _IsAttentionFreeType(Protocol):
is_attention_free: ClassVar[Literal[True]]
@overload
def is_attention_free(model: object) -> TypeIs[IsAttentionFree]:
...
@ -464,10 +430,7 @@ def is_attention_free(model: type[object]) -> TypeIs[type[IsAttentionFree]]:
def is_attention_free(
model: Union[type[object], object]
) -> Union[TypeIs[type[IsAttentionFree]], TypeIs[IsAttentionFree]]:
if isinstance(model, type):
return isinstance(model, _IsAttentionFreeType)
return isinstance(model, IsAttentionFree)
return getattr(model, "is_attention_free", False)
@runtime_checkable
@ -502,11 +465,6 @@ class IsHybrid(Protocol):
...
@runtime_checkable
class _IsHybridType(Protocol):
is_hybrid: ClassVar[Literal[True]]
@overload
def is_hybrid(model: object) -> TypeIs[IsHybrid]:
...
@ -520,10 +478,7 @@ def is_hybrid(model: type[object]) -> TypeIs[type[IsHybrid]]:
def is_hybrid(
model: Union[type[object], object]
) -> Union[TypeIs[type[IsHybrid]], TypeIs[IsHybrid]]:
if isinstance(model, type):
return isinstance(model, _IsHybridType)
return isinstance(model, IsHybrid)
return getattr(model, "is_hybrid", False)
@runtime_checkable
@ -598,11 +553,6 @@ class HasNoOps(Protocol):
has_noops: ClassVar[Literal[True]] = True
@runtime_checkable
class _HasNoOpsType(Protocol):
has_noops: ClassVar[Literal[True]]
@overload
def has_noops(model: object) -> TypeIs[HasNoOps]:
...
@ -616,10 +566,7 @@ def has_noops(model: type[object]) -> TypeIs[type[HasNoOps]]:
def has_noops(
model: Union[type[object], object]
) -> Union[TypeIs[type[HasNoOps]], TypeIs[HasNoOps]]:
if isinstance(model, type):
return isinstance(model, _HasNoOpsType)
return isinstance(model, HasNoOps)
return getattr(model, "has_noops", False)
@runtime_checkable
@ -643,11 +590,7 @@ def supports_cross_encoding(model: object) -> TypeIs[SupportsCrossEncoding]:
def _supports_cross_encoding(
model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]:
if isinstance(model, type):
return isinstance(model, SupportsCrossEncoding)
return isinstance(model, SupportsCrossEncoding)
return getattr(model, "supports_cross_encoding", False)
def supports_cross_encoding(
@ -658,8 +601,9 @@ def supports_cross_encoding(
def has_step_pooler(model: Union[type[object], object]) -> bool:
"""Check if the model uses step pooler."""
return is_pooling_model(model) and any(
type(module).__name__ == "StepPooler" for module in model.modules())
from vllm.model_executor.layers.pooler import StepPooler
return is_pooling_model(model) and isinstance(model.pooler, StepPooler)
class SupportsQuant:
@ -770,10 +714,7 @@ def supports_transcription(model: object) -> TypeIs[SupportsTranscription]:
def supports_transcription(
model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsTranscription]], TypeIs[SupportsTranscription]]:
if isinstance(model, type):
return isinstance(model, SupportsTranscription)
return isinstance(model, SupportsTranscription)
return getattr(model, "supports_transcription", False)
@runtime_checkable
@ -796,7 +737,4 @@ def supports_v0_only(model: object) -> TypeIs[SupportsV0Only]:
def supports_v0_only(
model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsV0Only]], TypeIs[SupportsV0Only]]:
if isinstance(model, type):
return isinstance(model, SupportsV0Only)
return isinstance(model, SupportsV0Only)
return getattr(model, "supports_v0_only", False)

View File

@ -1,8 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import (TYPE_CHECKING, Optional, Protocol, Union, overload,
runtime_checkable)
from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol,
Union, overload, runtime_checkable)
import torch
import torch.nn as nn
@ -13,8 +12,7 @@ from vllm.utils import supports_kw
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import PoolerOutput
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.layers.pooler import Pooler
from vllm.model_executor.sampling_metadata import SamplingMetadata
logger = init_logger(__name__)
@ -130,16 +128,20 @@ def is_text_generation_model(
@runtime_checkable
class VllmModelForPooling(VllmModel[T], Protocol[T]):
class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]):
"""The interface required for all pooling models in vLLM."""
def pooler(
self,
hidden_states: T,
pooling_metadata: "PoolingMetadata",
) -> "PoolerOutput":
"""Only called on TP rank 0."""
...
is_pooling_model: ClassVar[Literal[True]] = True
"""
A flag that indicates this model supports pooling.
Note:
There is no need to redefine this flag if this class is in the
MRO of your model class.
"""
pooler: "Pooler"
"""The pooler is only called on TP rank 0."""
@overload
@ -158,7 +160,4 @@ def is_pooling_model(
if not is_vllm_model(model):
return False
if isinstance(model, type):
return isinstance(model, VllmModelForPooling)
return isinstance(model, VllmModelForPooling)
return getattr(model, "is_pooling_model", False)

View File

@ -28,9 +28,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
@ -404,6 +403,8 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
class InternLM2ForRewardModel(InternLM2ForCausalLM):
is_pooling_model = True
def __init__(
self,
*,
@ -428,7 +429,7 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
)
pooler_config = vllm_config.model_config.pooler_config
self._pooler = Pooler.from_config_with_defaults(
self.pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.ALL,
normalize=False,
@ -446,10 +447,3 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
inputs_embeds)
logits, _ = self.v_head(hidden_states)
return logits
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)

View File

@ -27,9 +27,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
@ -563,6 +562,8 @@ def _is_moe_layer(name: str):
class JambaForSequenceClassification(JambaForCausalLM):
is_pooling_model = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
@ -590,16 +591,9 @@ class JambaForSequenceClassification(JambaForCausalLM):
softmax=False,
)
self._pooler = ClassifierPooler(
self.pooler = ClassifierPooler(
vllm_config.model_config,
pooling=pooler.pooling,
classifier=self.score,
act_fn=pooler.head.activation,
)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)

View File

@ -13,9 +13,8 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.sequence import IntermediateTensors
from .interfaces import (SupportsCrossEncoding, SupportsMultiModal,
SupportsScoreTemplate)
@ -72,6 +71,8 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration,
SupportsCrossEncoding,
SupportsMultiModal,
SupportsScoreTemplate):
is_pooling_model = True
weight_mapper = WeightsMapper(
orig_to_new_prefix={
"score.0.": "score.dense.",
@ -95,7 +96,7 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration,
self.score = JinaVLScorer(config)
self._pooler = Pooler.from_config_with_defaults(
self.pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=False,
@ -137,14 +138,6 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration,
logits = self.score(hidden_states) - self.LOGIT_BIAS
return logits
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.weight_mapper)

View File

@ -13,14 +13,16 @@ from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.pooler import (BasePooler, ClassifierPooler,
PoolingMethod, PoolingType)
from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler,
PoolingMethod, PoolingTask,
PoolingType)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.pooling_params import PoolingParams
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsCrossEncoding, SupportsV0Only
from .utils import WeightsMapper, maybe_prefix
@ -253,7 +255,7 @@ class ModernBertModel(nn.Module):
return norm_outputs
class ModernBertPooler(BasePooler):
class ModernBertPooler(Pooler):
def __init__(self, config: ModernBertConfig):
super().__init__()
@ -268,6 +270,9 @@ class ModernBertPooler(BasePooler):
eps=config.norm_eps,
bias=config.norm_bias)
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
return self.pooling.get_pooling_params(task)
def forward(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
@ -281,6 +286,8 @@ class ModernBertPooler(BasePooler):
class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
SupportsCrossEncoding):
is_pooling_model = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
@ -288,7 +295,7 @@ class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
self.model = ModernBertModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "modernbert"))
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self._pooler = ClassifierPooler(
self.pooler = ClassifierPooler(
vllm_config.model_config,
pooling=ModernBertPooler(config),
classifier=self.classifier,
@ -321,13 +328,6 @@ class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
default_weight_loader)
weight_loader(param, loaded_weight)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def forward(
self,
input_ids: Optional[torch.LongTensor],

View File

@ -24,12 +24,13 @@ import torch.nn as nn
from transformers import BatchFeature
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import (AllPool, PoolerHead,
PoolerIdentity, SimplePooler)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import (IsAttentionFree,
SupportsMultiModal,
SupportsV0Only)
from vllm.model_executor.models.utils import AutoWeightsLoader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalKwargs)
@ -37,8 +38,7 @@ from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import (IntermediateTensors, PoolerOutput,
PoolingSequenceGroupOutput)
from vllm.sequence import IntermediateTensors
class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo):
@ -116,7 +116,9 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
dummy_inputs=PrithviGeoSpatialMAEInputBuilder)
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
SupportsV0Only):
""" Prithvi Masked Autoencoder"""
"""Prithvi Masked Autoencoder"""
is_pooling_model = True
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
@ -162,6 +164,8 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
"Only SemanticSegmentationTask is supported for now "
"by PrithviGeospatialMAE.")
self.pooler = SimplePooler(AllPool(), PoolerHead(PoolerIdentity()))
def _parse_and_validate_multimodal_data(
self, **kwargs) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
@ -189,7 +193,6 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
):
pixel_values, location_coords = (
self._parse_and_validate_multimodal_data(**kwargs))
model_output = self.model(pixel_values,
@ -197,13 +200,6 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
return model_output.output
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return PoolerOutput([PoolingSequenceGroupOutput(hidden_states)])
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
params_list = []

View File

@ -16,8 +16,7 @@ from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.pooler import Pooler, PoolingType, SimplePooler
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .qwen2 import Qwen2Model
@ -25,6 +24,10 @@ from .utils import AutoWeightsLoader, maybe_prefix
class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
is_pooling_model = True
pooler: SimplePooler
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@ -61,7 +64,6 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
quant_config=quant_config,
return_bias=False),
)
self._pooler: SimplePooler
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
@ -80,13 +82,6 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
logits = self.score(hidden_states)
return logits
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self,
@ -96,11 +91,11 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
class Qwen2ForRewardModel(Qwen2RewardBaseModel):
def __init__(self, *, vllm_config, prefix=""):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
vllm_config.model_config.hf_config.num_labels = 1
super().__init__(vllm_config=vllm_config, prefix=prefix)
pooler_config = vllm_config.model_config.pooler_config
self._pooler = Pooler.from_config_with_defaults(
self.pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.ALL,
normalize=False,
@ -109,11 +104,11 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel):
class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
def __init__(self, *, vllm_config, prefix=""):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
vllm_config.model_config.hf_config.num_labels = 2
super().__init__(vllm_config=vllm_config, prefix=prefix)
pooler_config = vllm_config.model_config.pooler_config
self._pooler = Pooler.from_config_with_defaults(
self.pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.STEP,
normalize=False,

View File

@ -15,8 +15,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel
from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
maybe_prefix)
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.sequence import IntermediateTensors
from .bert_with_rope import BertWithRope, JinaRobertaModel
from .interfaces import SupportsCrossEncoding, SupportsV0Only
@ -165,6 +164,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
_pooler: An instance of Pooler used for pooling operations.
"""
is_pooling_model = True
jina_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
'emb_ln': "embeddings.LayerNorm",
@ -188,7 +188,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
add_pooling_layer=False)
self.classifier = RobertaClassificationHead(config)
self._pooler = ClassifierPooler(
self.pooler = ClassifierPooler(
vllm_config.model_config,
pooling=CLSPool(),
classifier=self.classifier,
@ -198,13 +198,6 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.jina_to_vllm_mapper)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def forward(
self,
input_ids: Optional[torch.Tensor],

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Optional
import msgspec
@ -15,24 +15,31 @@ class PoolingParams(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg]
"""API parameters for pooling models. This is currently a placeholder.
"""API parameters for pooling models. This
Attributes:
dimensions: Reduce the dimensions of embeddings
if model support matryoshka representation.
additional_data: Any additional data needed for pooling.
"""
dimensions: Optional[int] = None
use_cross_encoder: bool = False
additional_data: Optional[Any] = None
"""Internal use only."""
logits_processing_needs_token_ids: bool = False
"""Internal use only."""
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY
def clone(self) -> "PoolingParams":
"""Returns a deep copy of the PoolingParams instance."""
return PoolingParams(dimensions=self.dimensions,
use_cross_encoder=self.use_cross_encoder,
additional_data=self.additional_data)
return PoolingParams(
dimensions=self.dimensions,
use_cross_encoder=self.use_cross_encoder,
logits_processing_needs_token_ids=self.
logits_processing_needs_token_ids,
)
def verify(self, model_config: "ModelConfig") -> None:
if self.dimensions is not None:
@ -54,10 +61,12 @@ class PoolingParams(
raise ValueError("Dimensions must be greater than 0")
def __repr__(self) -> str:
return (f"PoolingParams("
f"dimensions={self.dimensions}, "
f"use_cross_encoder={self.use_cross_encoder}, "
f"additional_metadata={self.additional_data})")
return (
f"PoolingParams("
f"dimensions={self.dimensions}, "
f"use_cross_encoder={self.use_cross_encoder}, "
f"logits_processing_needs_token_ids={self.logits_processing_needs_token_ids})"
)
def __post_init__(self) -> None:
assert self.output_kind == RequestOutputKind.FINAL_ONLY,\