diff --git a/vllm/model_executor/models/prithvi_geospatial_mae.py b/vllm/model_executor/models/prithvi_geospatial_mae.py index 304a9e987ee03..20f423cc7603d 100644 --- a/vllm/model_executor/models/prithvi_geospatial_mae.py +++ b/vllm/model_executor/models/prithvi_geospatial_mae.py @@ -25,11 +25,11 @@ import torch.nn as nn from transformers import BatchFeature from vllm.config import VllmConfig -from vllm.model_executor.layers.pooler import (AllPool, PoolerHead, - PoolerIdentity, SimplePooler) +from vllm.model_executor.layers.pooler import DispatchPooler, Pooler from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import ( - IsAttentionFree, MultiModalEmbeddings, SupportsMultiModalWithRawInput) + IsAttentionFree, MultiModalEmbeddings, SupportsMultiModalWithRawInput, + default_pooling_type) from vllm.model_executor.models.utils import AutoWeightsLoader from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, @@ -142,6 +142,7 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor): ) +@default_pooling_type("All") @MULTIMODAL_REGISTRY.register_processor( PrithviGeoSpatialMAEMultiModalProcessor, info=PrithviGeoSpatialMAEProcessingInfo, @@ -198,7 +199,11 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, "Only SemanticSegmentationTask is supported for now " "by PrithviGeospatialMAE.") - self.pooler = SimplePooler(AllPool(), PoolerHead(PoolerIdentity())) + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + self.pooler = DispatchPooler( + {"encode": Pooler.for_encode(pooler_config)}, ) def _parse_and_validate_multimodal_data( self, **kwargs) -> tuple[torch.Tensor, Optional[torch.Tensor]]: