mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-21 23:37:52 +08:00
[Model] Improve Pooling Model (#25149)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
cc935fdd7e
commit
37970105fe
@ -12,8 +12,9 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.config import ModelConfig, PoolerConfig
|
from vllm.config import ModelConfig, PoolerConfig, get_current_vllm_config
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.models.adapters import _load_st_projector
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
|
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
|
||||||
from vllm.tasks import PoolingTask
|
from vllm.tasks import PoolingTask
|
||||||
@ -377,7 +378,6 @@ class PoolerClassify(PoolerActivation):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if static_num_labels:
|
if static_num_labels:
|
||||||
from vllm.config import get_current_vllm_config
|
|
||||||
vllm_config = get_current_vllm_config()
|
vllm_config = get_current_vllm_config()
|
||||||
self.num_labels = getattr(vllm_config.model_config.hf_config,
|
self.num_labels = getattr(vllm_config.model_config.hf_config,
|
||||||
"num_labels", 0)
|
"num_labels", 0)
|
||||||
@ -427,8 +427,6 @@ class EmbeddingPoolerHead(PoolerHead):
|
|||||||
super().__init__(activation=PoolerNormalize())
|
super().__init__(activation=PoolerNormalize())
|
||||||
|
|
||||||
# Load ST projector if available
|
# Load ST projector if available
|
||||||
from vllm.config import get_current_vllm_config
|
|
||||||
from vllm.model_executor.models.adapters import _load_st_projector
|
|
||||||
|
|
||||||
vllm_config = get_current_vllm_config()
|
vllm_config = get_current_vllm_config()
|
||||||
self.projector: Optional[nn.Module] = _load_st_projector(
|
self.projector: Optional[nn.Module] = _load_st_projector(
|
||||||
@ -489,7 +487,6 @@ class RewardPoolerHead(PoolerHead):
|
|||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__(activation=PoolerClassify(static_num_labels=False))
|
super().__init__(activation=PoolerClassify(static_num_labels=False))
|
||||||
|
|
||||||
from vllm.config import get_current_vllm_config
|
|
||||||
vllm_config = get_current_vllm_config()
|
vllm_config = get_current_vllm_config()
|
||||||
self.head_dtype = vllm_config.model_config.head_dtype
|
self.head_dtype = vllm_config.model_config.head_dtype
|
||||||
|
|
||||||
@ -638,7 +635,6 @@ class ClassifierPooler(Pooler):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
from vllm.config import get_current_vllm_config
|
|
||||||
vllm_config = get_current_vllm_config()
|
vllm_config = get_current_vllm_config()
|
||||||
|
|
||||||
self.pooling = pooling
|
self.pooling = pooling
|
||||||
@ -730,3 +726,7 @@ class DispatchPooler(Pooler):
|
|||||||
offset += num_items
|
offset += num_items
|
||||||
|
|
||||||
return PoolerOutput(outputs)
|
return PoolerOutput(outputs)
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
s = f"supported_task={self.get_supported_tasks()}"
|
||||||
|
return s
|
||||||
|
|||||||
@ -3151,6 +3151,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
model = cast(VllmModelForPooling, self.get_model())
|
model = cast(VllmModelForPooling, self.get_model())
|
||||||
dummy_pooling_params = PoolingParams(task=task)
|
dummy_pooling_params = PoolingParams(task=task)
|
||||||
|
dummy_pooling_params.verify(task=task, model_config=self.model_config)
|
||||||
to_update = model.pooler.get_pooling_updates(task)
|
to_update = model.pooler.get_pooling_updates(task)
|
||||||
to_update.apply(dummy_pooling_params)
|
to_update.apply(dummy_pooling_params)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user