From 5eeef1b90852917b300ed67b98e341eb846ba2e9 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 27 Aug 2025 21:24:09 +0800 Subject: [PATCH] [Model] Explicit `default_pooling_type` interface (#23736) Signed-off-by: DarkLight1337 --- vllm/model_executor/models/bert.py | 4 +-- vllm/model_executor/models/bert_with_rope.py | 5 ++-- vllm/model_executor/models/gritlm.py | 2 +- vllm/model_executor/models/interfaces.py | 19 +------------ vllm/model_executor/models/interfaces_base.py | 28 +++++++++++++++++++ vllm/model_executor/models/internlm2.py | 3 +- vllm/model_executor/models/modernbert.py | 3 +- .../models/prithvi_geospatial_mae.py | 7 +++-- vllm/model_executor/models/qwen2_rm.py | 3 +- vllm/model_executor/models/registry.py | 7 +++-- vllm/model_executor/models/roberta.py | 3 +- 11 files changed, 51 insertions(+), 33 deletions(-) diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 22b6c4401213..b34ca5cbe963 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -28,8 +28,8 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors from vllm.tasks import PoolingTask -from .interfaces import (SupportsCrossEncoding, SupportsQuant, - default_pooling_type) +from .interfaces import SupportsCrossEncoding, SupportsQuant +from .interfaces_base import default_pooling_type from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix diff --git a/vllm/model_executor/models/bert_with_rope.py b/vllm/model_executor/models/bert_with_rope.py index 129450927e56..dcb7e75456cd 100644 --- a/vllm/model_executor/models/bert_with_rope.py +++ b/vllm/model_executor/models/bert_with_rope.py @@ -27,13 +27,14 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import (SupportsQuant, - default_pooling_type) from vllm.model_executor.models.utils import WeightsMapper from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors +from .interfaces import SupportsQuant +from .interfaces_base import default_pooling_type + class BertWithRopeEmbedding(nn.Module): diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index 3f6790269ae6..1b3d541c65cf 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -20,7 +20,7 @@ from vllm.sequence import PoolerOutput from vllm.tasks import PoolingTask from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config -from .interfaces import default_pooling_type +from .interfaces_base import default_pooling_type logger = init_logger(__name__) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 9415e67924e7..22f005849e86 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -3,7 +3,7 @@ from collections.abc import Iterable, Mapping, MutableSequence from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol, - TypeVar, Union, overload, runtime_checkable) + Union, overload, runtime_checkable) import numpy as np import torch @@ -641,23 +641,6 @@ def supports_cross_encoding( return is_pooling_model(model) and _supports_cross_encoding(model) -_T = TypeVar("_T", bound=type[torch.nn.Module]) - - -def default_pooling_type(pooling_type: str): - """Set default_pooling_type decorator. """ - - def func(model: _T) -> _T: - model.default_pooling_type = pooling_type # type: ignore - return model - - return func - - -def get_default_pooling_type(model: Union[type[object], object]) -> str: - return getattr(model, "default_pooling_type", "LAST") - - class SupportsQuant: """The interface required for all models that support quantization.""" diff --git a/vllm/model_executor/models/interfaces_base.py b/vllm/model_executor/models/interfaces_base.py index 697fa020deb4..19a3ef1a3b80 100644 --- a/vllm/model_executor/models/interfaces_base.py +++ b/vllm/model_executor/models/interfaces_base.py @@ -144,6 +144,17 @@ class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]): MRO of your model class. """ + default_pooling_type: ClassVar[str] = "LAST" + """ + Indicates the + [vllm.model_executor.layers.pooler.PoolerConfig.pooling_type][] + to use by default. + + You can use the + [vllm.model_executor.models.interfaces_base.default_pooling_type][] + decorator to conveniently set this field. + """ + pooler: Pooler """The pooler is only called on TP rank 0.""" @@ -165,3 +176,20 @@ def is_pooling_model( return False return getattr(model, "is_pooling_model", False) + + +_T = TypeVar("_T", bound=type[nn.Module]) + + +def default_pooling_type(pooling_type: str): + """Decorator to set `VllmModelForPooling.default_pooling_type`.""" + + def func(model: _T) -> _T: + model.default_pooling_type = pooling_type # type: ignore + return model + + return func + + +def get_default_pooling_type(model: Union[type[object], object]) -> str: + return getattr(model, "default_pooling_type", "LAST") diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index d0c4bf5450d6..26bc48ffbd9b 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -31,7 +31,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA, SupportsPP, default_pooling_type +from .interfaces import SupportsLoRA, SupportsPP +from .interfaces_base import default_pooling_type from .utils import (is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index 72290bf2ee29..477855586128 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -26,7 +26,8 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors from vllm.tasks import PoolingTask -from .interfaces import SupportsCrossEncoding, default_pooling_type +from .interfaces import SupportsCrossEncoding +from .interfaces_base import default_pooling_type from .utils import WeightsMapper, maybe_prefix diff --git a/vllm/model_executor/models/prithvi_geospatial_mae.py b/vllm/model_executor/models/prithvi_geospatial_mae.py index 59e9f3e8a47b..f46d6375e1f6 100644 --- a/vllm/model_executor/models/prithvi_geospatial_mae.py +++ b/vllm/model_executor/models/prithvi_geospatial_mae.py @@ -27,9 +27,6 @@ from transformers import BatchFeature from vllm.config import VllmConfig from vllm.model_executor.layers.pooler import DispatchPooler, Pooler from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import ( - IsAttentionFree, MultiModalEmbeddings, SupportsMultiModalWithRawInput, - default_pooling_type) from vllm.model_executor.models.utils import AutoWeightsLoader from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (ImageItem, ModalityData, @@ -43,6 +40,10 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from .interfaces import (IsAttentionFree, MultiModalEmbeddings, + SupportsMultiModalWithRawInput) +from .interfaces_base import default_pooling_type + def _prithvi_field_config(hf_inputs: Mapping[str, torch.Tensor]): # This model receives in input a multi-dimensional tensor representing diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index e0a30e04c602..421b43563bad 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -18,7 +18,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.pooler import DispatchPooler, Pooler from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA, SupportsPP, default_pooling_type +from .interfaces import SupportsLoRA, SupportsPP +from .interfaces_base import default_pooling_type from .qwen2 import Qwen2Model from .utils import AutoWeightsLoader, maybe_prefix diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index c65c58d4a047..196b5f35e1e4 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -25,11 +25,12 @@ from vllm.logger import init_logger from vllm.transformers_utils.dynamic_module import ( try_get_class_from_dynamic_module) -from .interfaces import (get_default_pooling_type, has_inner_state, has_noops, - is_attention_free, is_hybrid, supports_cross_encoding, +from .interfaces import (has_inner_state, has_noops, is_attention_free, + is_hybrid, supports_cross_encoding, supports_multimodal, supports_multimodal_raw_input, supports_pp, supports_transcription, supports_v0_only) -from .interfaces_base import is_pooling_model, is_text_generation_model +from .interfaces_base import (get_default_pooling_type, is_pooling_model, + is_text_generation_model) logger = init_logger(__name__) diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 49a37342c67f..2bfa51162910 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -22,7 +22,8 @@ from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, from vllm.sequence import IntermediateTensors from .bert_with_rope import BertWithRope, JinaRobertaModel -from .interfaces import SupportsCrossEncoding, default_pooling_type +from .interfaces import SupportsCrossEncoding +from .interfaces_base import default_pooling_type class RobertaEmbedding(nn.Module):