diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index b571a8f866990..4a97438b1bb2c 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -12,8 +12,9 @@ import torch.nn as nn import torch.nn.functional as F from transformers import PretrainedConfig -from vllm.config import ModelConfig, PoolerConfig +from vllm.config import ModelConfig, PoolerConfig, get_current_vllm_config from vllm.logger import init_logger +from vllm.model_executor.models.adapters import _load_st_projector from vllm.pooling_params import PoolingParams from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput from vllm.tasks import PoolingTask @@ -377,7 +378,6 @@ class PoolerClassify(PoolerActivation): super().__init__() if static_num_labels: - from vllm.config import get_current_vllm_config vllm_config = get_current_vllm_config() self.num_labels = getattr(vllm_config.model_config.hf_config, "num_labels", 0) @@ -427,8 +427,6 @@ class EmbeddingPoolerHead(PoolerHead): super().__init__(activation=PoolerNormalize()) # Load ST projector if available - from vllm.config import get_current_vllm_config - from vllm.model_executor.models.adapters import _load_st_projector vllm_config = get_current_vllm_config() self.projector: Optional[nn.Module] = _load_st_projector( @@ -489,7 +487,6 @@ class RewardPoolerHead(PoolerHead): def __init__(self) -> None: super().__init__(activation=PoolerClassify(static_num_labels=False)) - from vllm.config import get_current_vllm_config vllm_config = get_current_vllm_config() self.head_dtype = vllm_config.model_config.head_dtype @@ -638,7 +635,6 @@ class ClassifierPooler(Pooler): ) -> None: super().__init__() - from vllm.config import get_current_vllm_config vllm_config = get_current_vllm_config() self.pooling = pooling @@ -730,3 +726,7 @@ class DispatchPooler(Pooler): offset += num_items return PoolerOutput(outputs) + + def extra_repr(self) -> str: + s = f"supported_task={self.get_supported_tasks()}" + return s diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e8ad9c2fca07c..2e67984cb4327 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3151,6 +3151,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): model = cast(VllmModelForPooling, self.get_model()) dummy_pooling_params = PoolingParams(task=task) + dummy_pooling_params.verify(task=task, model_config=self.model_config) to_update = model.pooler.get_pooling_updates(task) to_update.apply(dummy_pooling_params)