mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 03:44:58 +08:00
[Model] Update pooling model interface (#21058)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
9fb2d22032
commit
90bd2ab6e3
@ -11,11 +11,13 @@ from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||
from vllm.model_executor.models.gemma2 import Gemma2Model
|
||||
from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
|
||||
class MyGemma2Embedding(nn.Module):
|
||||
|
||||
is_pooling_model = True
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
@ -24,7 +26,7 @@ class MyGemma2Embedding(nn.Module):
|
||||
self.model = Gemma2Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
self.pooler = Pooler.from_config_with_defaults(
|
||||
vllm_config.model_config.pooler_config,
|
||||
pooling_type=PoolingType.LAST,
|
||||
normalize=True,
|
||||
@ -54,13 +56,6 @@ class MyGemma2Embedding(nn.Module):
|
||||
# Return all-zero embeddings
|
||||
return torch.zeros_like(hidden_states)
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
|
||||
weights = self.hf_to_vllm_mapper.apply(weights)
|
||||
|
||||
@ -1237,10 +1237,6 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
|
||||
user: Optional[str] = None
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
|
||||
|
||||
# --8<-- [start:embedding-pooling-params]
|
||||
additional_data: Optional[Any] = None
|
||||
# --8<-- [end:embedding-pooling-params]
|
||||
|
||||
# --8<-- [start:embedding-extra-params]
|
||||
add_special_tokens: bool = Field(
|
||||
default=True,
|
||||
@ -1259,8 +1255,7 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
|
||||
# --8<-- [end:embedding-extra-params]
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(dimensions=self.dimensions,
|
||||
additional_data=self.additional_data)
|
||||
return PoolingParams(dimensions=self.dimensions)
|
||||
|
||||
|
||||
class EmbeddingChatRequest(OpenAIBaseModel):
|
||||
@ -1272,10 +1267,6 @@ class EmbeddingChatRequest(OpenAIBaseModel):
|
||||
user: Optional[str] = None
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
|
||||
|
||||
# --8<-- [start:chat-embedding-pooling-params]
|
||||
additional_data: Optional[Any] = None
|
||||
# --8<-- [end:chat-embedding-pooling-params]
|
||||
|
||||
# --8<-- [start:chat-embedding-extra-params]
|
||||
add_special_tokens: bool = Field(
|
||||
default=False,
|
||||
@ -1323,8 +1314,7 @@ class EmbeddingChatRequest(OpenAIBaseModel):
|
||||
return data
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(dimensions=self.dimensions,
|
||||
additional_data=self.additional_data)
|
||||
return PoolingParams(dimensions=self.dimensions)
|
||||
|
||||
|
||||
EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]
|
||||
@ -1340,10 +1330,6 @@ class ScoreRequest(OpenAIBaseModel):
|
||||
text_2: Union[list[str], str, ScoreMultiModalParam]
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
|
||||
|
||||
# --8<-- [start:score-pooling-params]
|
||||
additional_data: Optional[Any] = None
|
||||
# --8<-- [end:score-pooling-params]
|
||||
|
||||
# --8<-- [start:score-extra-params]
|
||||
|
||||
mm_processor_kwargs: Optional[dict[str, Any]] = Field(
|
||||
@ -1362,8 +1348,7 @@ class ScoreRequest(OpenAIBaseModel):
|
||||
# --8<-- [end:score-extra-params]
|
||||
|
||||
def to_pooling_params(self, *, use_cross_encoder: bool = False):
|
||||
return PoolingParams(use_cross_encoder=use_cross_encoder,
|
||||
additional_data=self.additional_data)
|
||||
return PoolingParams(use_cross_encoder=use_cross_encoder)
|
||||
|
||||
|
||||
class RerankRequest(OpenAIBaseModel):
|
||||
@ -1373,10 +1358,6 @@ class RerankRequest(OpenAIBaseModel):
|
||||
top_n: int = Field(default_factory=lambda: 0)
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
|
||||
|
||||
# --8<-- [start:rerank-pooling-params]
|
||||
additional_data: Optional[Any] = None
|
||||
# --8<-- [end:rerank-pooling-params]
|
||||
|
||||
# --8<-- [start:rerank-extra-params]
|
||||
|
||||
mm_processor_kwargs: Optional[dict[str, Any]] = Field(
|
||||
@ -1395,8 +1376,7 @@ class RerankRequest(OpenAIBaseModel):
|
||||
# --8<-- [end:rerank-extra-params]
|
||||
|
||||
def to_pooling_params(self, *, use_cross_encoder: bool = False):
|
||||
return PoolingParams(use_cross_encoder=use_cross_encoder,
|
||||
additional_data=self.additional_data)
|
||||
return PoolingParams(use_cross_encoder=use_cross_encoder)
|
||||
|
||||
|
||||
class RerankDocument(BaseModel):
|
||||
@ -1534,10 +1514,6 @@ class ClassificationRequest(OpenAIBaseModel):
|
||||
truncate_prompt_tokens: Optional[int] = None
|
||||
user: Optional[str] = None
|
||||
|
||||
# --8<-- [start:classification-pooling-params]
|
||||
additional_data: Optional[Any] = None
|
||||
# --8<-- [end:classification-pooling-params]
|
||||
|
||||
# --8<-- [start:classification-extra-params]
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
@ -1550,7 +1526,7 @@ class ClassificationRequest(OpenAIBaseModel):
|
||||
# --8<-- [end:classification-extra-params]
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(additional_data=self.additional_data)
|
||||
return PoolingParams()
|
||||
|
||||
|
||||
class ClassificationData(OpenAIBaseModel):
|
||||
|
||||
@ -3,22 +3,25 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum
|
||||
from typing import Callable, Optional, TypeVar, Union
|
||||
from typing import Callable, Literal, Optional, TypeVar, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import PretrainedConfig
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.config import ModelConfig, PoolerConfig
|
||||
from vllm.model_executor.pooling_metadata import ( # noqa: E501
|
||||
PoolingMetadata as V0PoolingMetadata)
|
||||
from vllm.model_executor.pooling_metadata import PoolingTensors
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
|
||||
from vllm.utils import resolve_obj_by_qualname
|
||||
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata
|
||||
|
||||
PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata]
|
||||
PoolingTask = Literal["encode", "embed", "classify", "score"]
|
||||
|
||||
|
||||
class PoolingType(IntEnum):
|
||||
@ -64,6 +67,48 @@ class ResolvedPoolingConfig:
|
||||
)
|
||||
|
||||
|
||||
class Pooler(nn.Module, ABC):
|
||||
"""The interface required for all poolers used in pooling models in vLLM."""
|
||||
|
||||
@staticmethod
|
||||
def from_config_with_defaults(
|
||||
pooler_config: PoolerConfig,
|
||||
pooling_type: PoolingType,
|
||||
normalize: bool,
|
||||
softmax: bool,
|
||||
step_tag_id: Optional[int] = None,
|
||||
returned_token_ids: Optional[list[int]] = None,
|
||||
) -> "Pooler":
|
||||
resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
|
||||
pooler_config=pooler_config,
|
||||
pooling_type=pooling_type,
|
||||
normalize=normalize,
|
||||
softmax=softmax,
|
||||
step_tag_id=step_tag_id,
|
||||
returned_token_ids=returned_token_ids,
|
||||
)
|
||||
|
||||
if pooling_type == PoolingType.STEP:
|
||||
return StepPooler.from_config(resolved_config)
|
||||
|
||||
return SimplePooler.from_config(resolved_config)
|
||||
|
||||
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
|
||||
"""
|
||||
Construct the pooling parameters to use for a task,
|
||||
or `None` if the task is not supported.
|
||||
"""
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Union[list[torch.Tensor], torch.Tensor],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def get_prompt_lens(
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
@ -104,17 +149,6 @@ def build_output(all_data: torch.Tensor) -> PoolerOutput:
|
||||
return PoolerOutput(outputs=all_outputs)
|
||||
|
||||
|
||||
class BasePooler(nn.Module):
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class PoolingMethod(nn.Module, ABC):
|
||||
|
||||
@staticmethod
|
||||
@ -130,6 +164,10 @@ class PoolingMethod(nn.Module, ABC):
|
||||
|
||||
raise NotImplementedError(f"Unsupported method: {pooling_type}")
|
||||
|
||||
@abstractmethod
|
||||
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def forward_one(
|
||||
self,
|
||||
@ -168,6 +206,14 @@ class PoolingMethod(nn.Module, ABC):
|
||||
|
||||
class CLSPool(PoolingMethod):
|
||||
|
||||
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
|
||||
# The equalities are split up to keep mypy happy
|
||||
if (task == "encode" or task == "embed" or task == "classify"
|
||||
or task == "score"):
|
||||
return PoolingParams()
|
||||
|
||||
assert_never(task)
|
||||
|
||||
def forward_one(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -190,6 +236,14 @@ class CLSPool(PoolingMethod):
|
||||
|
||||
class LastPool(PoolingMethod):
|
||||
|
||||
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
|
||||
# The equalities are split up to keep mypy happy
|
||||
if (task == "encode" or task == "embed" or task == "classify"
|
||||
or task == "score"):
|
||||
return PoolingParams()
|
||||
|
||||
assert_never(task)
|
||||
|
||||
def forward_one(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -208,6 +262,16 @@ class LastPool(PoolingMethod):
|
||||
|
||||
class AllPool(PoolingMethod):
|
||||
|
||||
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
|
||||
if task == "encode":
|
||||
return PoolingParams()
|
||||
|
||||
# The equalities are split up to keep mypy happy
|
||||
if task == "embed" or task == "classify" or task == "score":
|
||||
return None
|
||||
|
||||
assert_never(task)
|
||||
|
||||
def forward_one(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -235,6 +299,14 @@ class AllPool(PoolingMethod):
|
||||
|
||||
class MeanPool(PoolingMethod):
|
||||
|
||||
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
|
||||
# The equalities are split up to keep mypy happy
|
||||
if (task == "encode" or task == "embed" or task == "classify"
|
||||
or task == "score"):
|
||||
return PoolingParams()
|
||||
|
||||
assert_never(task)
|
||||
|
||||
def forward_one(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -345,25 +417,6 @@ class LambdaPoolerActivation(PoolerActivation):
|
||||
|
||||
class PoolerHead(nn.Module):
|
||||
|
||||
@classmethod
|
||||
def from_config_with_defaults(
|
||||
cls,
|
||||
pooler_config: PoolerConfig,
|
||||
pooling_type: PoolingType,
|
||||
normalize: bool,
|
||||
softmax: bool,
|
||||
) -> "PoolerHead":
|
||||
resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
|
||||
pooler_config=pooler_config,
|
||||
pooling_type=pooling_type,
|
||||
normalize=normalize,
|
||||
softmax=softmax,
|
||||
step_tag_id=None,
|
||||
returned_token_ids=None,
|
||||
)
|
||||
|
||||
return cls.from_config(resolved_config)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, pooler_config: ResolvedPoolingConfig) -> "PoolerHead":
|
||||
if pooler_config.normalize and pooler_config.softmax:
|
||||
@ -424,21 +477,17 @@ class PoolerHead(nn.Module):
|
||||
return self.activation(pooled_data)
|
||||
|
||||
|
||||
class SimplePooler(BasePooler):
|
||||
class SimplePooler(Pooler):
|
||||
"""A layer that pools specific information from hidden states.
|
||||
|
||||
This layer does the following:
|
||||
1. Extracts specific tokens or aggregates data based on pooling method.
|
||||
2. Normalizes output if specified.
|
||||
3. Returns structured results as `PoolerOutput`.
|
||||
|
||||
Attributes:
|
||||
pooling_type: The type of pooling to use.
|
||||
normalize: Whether to normalize the pooled data.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_config_with_defaults(
|
||||
def from_config_with_defaults( # type: ignore[override]
|
||||
cls,
|
||||
pooler_config: PoolerConfig,
|
||||
pooling_type: PoolingType,
|
||||
@ -471,6 +520,9 @@ class SimplePooler(BasePooler):
|
||||
self.pooling = pooling
|
||||
self.head = head
|
||||
|
||||
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
|
||||
return self.pooling.get_pooling_params(task)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
@ -481,7 +533,7 @@ class SimplePooler(BasePooler):
|
||||
return build_output(pooled_data)
|
||||
|
||||
|
||||
class StepPooler(BasePooler):
|
||||
class StepPooler(Pooler):
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, pooler_config: ResolvedPoolingConfig) -> "StepPooler":
|
||||
@ -543,6 +595,16 @@ class StepPooler(BasePooler):
|
||||
|
||||
return pooled_data
|
||||
|
||||
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
|
||||
if task == "encode":
|
||||
return PoolingParams(logits_processing_needs_token_ids=True)
|
||||
|
||||
# The equalities are split up to keep mypy happy
|
||||
if task == "embed" or task == "classify" or task == "score":
|
||||
return None
|
||||
|
||||
assert_never(task)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
@ -553,32 +615,6 @@ class StepPooler(BasePooler):
|
||||
return build_output(pooled_data)
|
||||
|
||||
|
||||
class Pooler(nn.Module):
|
||||
|
||||
@staticmethod
|
||||
def from_config_with_defaults(
|
||||
pooler_config: PoolerConfig,
|
||||
pooling_type: PoolingType,
|
||||
normalize: bool,
|
||||
softmax: bool,
|
||||
step_tag_id: Optional[int] = None,
|
||||
returned_token_ids: Optional[list[int]] = None,
|
||||
) -> BasePooler:
|
||||
resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
|
||||
pooler_config=pooler_config,
|
||||
pooling_type=pooling_type,
|
||||
normalize=normalize,
|
||||
softmax=softmax,
|
||||
step_tag_id=step_tag_id,
|
||||
returned_token_ids=returned_token_ids,
|
||||
)
|
||||
|
||||
if pooling_type == PoolingType.STEP:
|
||||
return StepPooler.from_config(resolved_config)
|
||||
|
||||
return SimplePooler.from_config(resolved_config)
|
||||
|
||||
|
||||
PoolingFn = Callable[
|
||||
[Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata],
|
||||
Union[torch.Tensor, list[torch.Tensor]]]
|
||||
@ -618,6 +654,18 @@ class ClassifierPooler(nn.Module):
|
||||
return (self.cross_encoder_act_fn
|
||||
if use_cross_encoder else self.classification_act_fn)
|
||||
|
||||
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
|
||||
if task == "encode":
|
||||
return PoolingParams()
|
||||
if task == "embed":
|
||||
return None
|
||||
if task == "classify":
|
||||
return PoolingParams()
|
||||
if task == "score":
|
||||
return PoolingParams(use_cross_encoder=True)
|
||||
|
||||
assert_never(task)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -42,13 +42,14 @@ def _create_pooling_model_cls(
|
||||
default_softmax: bool,
|
||||
) -> _T:
|
||||
# Lazy import
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolerOutput
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.model_executor.layers.pooler import Pooler
|
||||
|
||||
from .utils import AutoWeightsLoader, WeightsMapper
|
||||
|
||||
class ModelForPooling(orig_cls, VllmModelForPooling):
|
||||
|
||||
is_pooling_model = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@ -66,27 +67,20 @@ def _create_pooling_model_cls(
|
||||
delattr(self, attr)
|
||||
|
||||
# If the model already defines a pooler instance, don't overwrite it
|
||||
if not getattr(self, "_pooler", None):
|
||||
if not getattr(self, "pooler", None):
|
||||
self._init_pooler(vllm_config, prefix=prefix)
|
||||
|
||||
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
self.pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=default_pooling_type,
|
||||
normalize=default_normalize,
|
||||
softmax=default_softmax,
|
||||
)
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
# TODO: Support uninitialized params tracking
|
||||
|
||||
@ -171,10 +165,8 @@ def as_seq_cls_model(cls: _T) -> _T:
|
||||
# Lazy import
|
||||
from vllm.model_executor.layers.linear import RowParallelLinear
|
||||
from vllm.model_executor.layers.pooler import (ClassifierPooler,
|
||||
PoolerOutput, PoolingType,
|
||||
SimplePooler)
|
||||
PoolingType, SimplePooler)
|
||||
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .utils import maybe_prefix
|
||||
@ -213,7 +205,7 @@ def as_seq_cls_model(cls: _T) -> _T:
|
||||
softmax=True,
|
||||
)
|
||||
|
||||
self._pooler = ClassifierPooler(
|
||||
self.pooler = ClassifierPooler(
|
||||
vllm_config.model_config,
|
||||
pooling=pooler.pooling,
|
||||
classifier=self._classifier,
|
||||
@ -234,13 +226,6 @@ def as_seq_cls_model(cls: _T) -> _T:
|
||||
return super().forward(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
tokens = getattr(self.config, "classifier_from_token", None)
|
||||
method = getattr(self.config, "method", None)
|
||||
|
||||
@ -18,12 +18,14 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler,
|
||||
PoolingMethod, PoolingType)
|
||||
PoolingMethod, PoolingTask,
|
||||
PoolingType)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only
|
||||
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
|
||||
@ -80,7 +82,7 @@ class BertEmbedding(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
class BertPooler(nn.Module):
|
||||
class BertPooler(Pooler):
|
||||
|
||||
def __init__(self, config: BertConfig):
|
||||
super().__init__()
|
||||
@ -89,6 +91,9 @@ class BertPooler(nn.Module):
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.activation = nn.Tanh()
|
||||
|
||||
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
|
||||
return self.pooling.get_pooling_params(task)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
@ -319,6 +324,9 @@ class BertOutput(nn.Module):
|
||||
|
||||
|
||||
class BertModel(nn.Module, SupportsQuant):
|
||||
|
||||
is_pooling_model = True
|
||||
|
||||
packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]}
|
||||
|
||||
def __init__(self,
|
||||
@ -403,12 +411,15 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
|
||||
_pooler: An instance of Pooler used for pooling operations.
|
||||
"""
|
||||
|
||||
is_pooling_model = True
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
self.model = self._build_model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self._pooler = self._build_pooler(pooler_config)
|
||||
self.pooler = self._build_pooler(pooler_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -422,13 +433,6 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors)
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
weights_list = list(weights)
|
||||
|
||||
@ -466,6 +470,8 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
|
||||
_pooler: An instance of Pooler used for pooling operations.
|
||||
"""
|
||||
|
||||
is_pooling_model = True
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
@ -476,7 +482,7 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
|
||||
embedding_class=BertEmbedding,
|
||||
add_pooling_layer=True)
|
||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||
self._pooler = ClassifierPooler(
|
||||
self.pooler = ClassifierPooler(
|
||||
vllm_config.model_config,
|
||||
pooling=self.bert.pooler,
|
||||
classifier=self.classifier,
|
||||
@ -487,13 +493,6 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
|
||||
loaded_params = loader.load_weights(weights)
|
||||
return loaded_params
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
|
||||
@ -40,9 +40,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from ..layers.pooler import Pooler, PoolingType
|
||||
from .interfaces import SupportsPP
|
||||
@ -332,6 +331,8 @@ class GPT2ForSequenceClassification(nn.Module):
|
||||
_pooler: An instance of Pooler used for pooling operations.
|
||||
"""
|
||||
|
||||
is_pooling_model = True
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
@ -339,7 +340,7 @@ class GPT2ForSequenceClassification(nn.Module):
|
||||
prefix=maybe_prefix(prefix, "gpt2"))
|
||||
self.score = nn.Linear(config.n_embd, config.num_labels, bias=False)
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
self.pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.LAST,
|
||||
normalize=False,
|
||||
@ -349,13 +350,6 @@ class GPT2ForSequenceClassification(nn.Module):
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from array import array
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -195,6 +194,8 @@ class GritLM(LlamaForCausalLM, SupportsV0Only):
|
||||
- "<|user|>\nPROMPT\n<|assistant|>\n"
|
||||
"""
|
||||
|
||||
is_pooling_model = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
@ -214,11 +215,4 @@ class GritLM(LlamaForCausalLM, SupportsV0Only):
|
||||
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||
|
||||
self._pooler = GritLMPooler(vllm_config.model_config)
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
self.pooler = GritLMPooler(vllm_config.model_config)
|
||||
|
||||
@ -119,13 +119,6 @@ class SupportsMultiModal(Protocol):
|
||||
...
|
||||
|
||||
|
||||
# We can't use runtime_checkable with ClassVar for issubclass checks
|
||||
# so we need to treat the class as an instance and use isinstance instead
|
||||
@runtime_checkable
|
||||
class _SupportsMultiModalType(Protocol):
|
||||
supports_multimodal: Literal[True]
|
||||
|
||||
|
||||
@overload
|
||||
def supports_multimodal(
|
||||
model: type[object]) -> TypeIs[type[SupportsMultiModal]]:
|
||||
@ -140,10 +133,7 @@ def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]:
|
||||
def supports_multimodal(
|
||||
model: Union[type[object], object],
|
||||
) -> Union[TypeIs[type[SupportsMultiModal]], TypeIs[SupportsMultiModal]]:
|
||||
if isinstance(model, type):
|
||||
return isinstance(model, _SupportsMultiModalType)
|
||||
|
||||
return isinstance(model, SupportsMultiModal)
|
||||
return getattr(model, "supports_multimodal", False)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@ -174,13 +164,6 @@ class SupportsScoreTemplate(Protocol):
|
||||
...
|
||||
|
||||
|
||||
# We can't use runtime_checkable with ClassVar for issubclass checks
|
||||
# so we need to treat the class as an instance and use isinstance instead
|
||||
@runtime_checkable
|
||||
class _SupportsScoreTemplateType(Protocol):
|
||||
supports_score_template: Literal[True]
|
||||
|
||||
|
||||
@overload
|
||||
def supports_score_template(
|
||||
model: type[object]) -> TypeIs[type[SupportsScoreTemplate]]:
|
||||
@ -195,11 +178,7 @@ def supports_score_template(model: object) -> TypeIs[SupportsScoreTemplate]:
|
||||
def supports_score_template(
|
||||
model: Union[type[object], object],
|
||||
) -> Union[TypeIs[type[SupportsScoreTemplate]], TypeIs[SupportsScoreTemplate]]:
|
||||
|
||||
if isinstance(model, type):
|
||||
return isinstance(model, _SupportsScoreTemplateType)
|
||||
|
||||
return isinstance(model, SupportsScoreTemplate)
|
||||
return getattr(model, "supports_score_template", False)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@ -409,11 +388,6 @@ class HasInnerState(Protocol):
|
||||
"""
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class _HasInnerStateType(Protocol):
|
||||
has_inner_state: ClassVar[Literal[True]]
|
||||
|
||||
|
||||
@overload
|
||||
def has_inner_state(model: object) -> TypeIs[HasInnerState]:
|
||||
...
|
||||
@ -427,10 +401,7 @@ def has_inner_state(model: type[object]) -> TypeIs[type[HasInnerState]]:
|
||||
def has_inner_state(
|
||||
model: Union[type[object], object]
|
||||
) -> Union[TypeIs[type[HasInnerState]], TypeIs[HasInnerState]]:
|
||||
if isinstance(model, type):
|
||||
return isinstance(model, _HasInnerStateType)
|
||||
|
||||
return isinstance(model, HasInnerState)
|
||||
return getattr(model, "has_inner_state", False)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@ -446,11 +417,6 @@ class IsAttentionFree(Protocol):
|
||||
"""
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class _IsAttentionFreeType(Protocol):
|
||||
is_attention_free: ClassVar[Literal[True]]
|
||||
|
||||
|
||||
@overload
|
||||
def is_attention_free(model: object) -> TypeIs[IsAttentionFree]:
|
||||
...
|
||||
@ -464,10 +430,7 @@ def is_attention_free(model: type[object]) -> TypeIs[type[IsAttentionFree]]:
|
||||
def is_attention_free(
|
||||
model: Union[type[object], object]
|
||||
) -> Union[TypeIs[type[IsAttentionFree]], TypeIs[IsAttentionFree]]:
|
||||
if isinstance(model, type):
|
||||
return isinstance(model, _IsAttentionFreeType)
|
||||
|
||||
return isinstance(model, IsAttentionFree)
|
||||
return getattr(model, "is_attention_free", False)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@ -502,11 +465,6 @@ class IsHybrid(Protocol):
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class _IsHybridType(Protocol):
|
||||
is_hybrid: ClassVar[Literal[True]]
|
||||
|
||||
|
||||
@overload
|
||||
def is_hybrid(model: object) -> TypeIs[IsHybrid]:
|
||||
...
|
||||
@ -520,10 +478,7 @@ def is_hybrid(model: type[object]) -> TypeIs[type[IsHybrid]]:
|
||||
def is_hybrid(
|
||||
model: Union[type[object], object]
|
||||
) -> Union[TypeIs[type[IsHybrid]], TypeIs[IsHybrid]]:
|
||||
if isinstance(model, type):
|
||||
return isinstance(model, _IsHybridType)
|
||||
|
||||
return isinstance(model, IsHybrid)
|
||||
return getattr(model, "is_hybrid", False)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@ -598,11 +553,6 @@ class HasNoOps(Protocol):
|
||||
has_noops: ClassVar[Literal[True]] = True
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class _HasNoOpsType(Protocol):
|
||||
has_noops: ClassVar[Literal[True]]
|
||||
|
||||
|
||||
@overload
|
||||
def has_noops(model: object) -> TypeIs[HasNoOps]:
|
||||
...
|
||||
@ -616,10 +566,7 @@ def has_noops(model: type[object]) -> TypeIs[type[HasNoOps]]:
|
||||
def has_noops(
|
||||
model: Union[type[object], object]
|
||||
) -> Union[TypeIs[type[HasNoOps]], TypeIs[HasNoOps]]:
|
||||
if isinstance(model, type):
|
||||
return isinstance(model, _HasNoOpsType)
|
||||
|
||||
return isinstance(model, HasNoOps)
|
||||
return getattr(model, "has_noops", False)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@ -643,11 +590,7 @@ def supports_cross_encoding(model: object) -> TypeIs[SupportsCrossEncoding]:
|
||||
def _supports_cross_encoding(
|
||||
model: Union[type[object], object],
|
||||
) -> Union[TypeIs[type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]:
|
||||
|
||||
if isinstance(model, type):
|
||||
return isinstance(model, SupportsCrossEncoding)
|
||||
|
||||
return isinstance(model, SupportsCrossEncoding)
|
||||
return getattr(model, "supports_cross_encoding", False)
|
||||
|
||||
|
||||
def supports_cross_encoding(
|
||||
@ -658,8 +601,9 @@ def supports_cross_encoding(
|
||||
|
||||
def has_step_pooler(model: Union[type[object], object]) -> bool:
|
||||
"""Check if the model uses step pooler."""
|
||||
return is_pooling_model(model) and any(
|
||||
type(module).__name__ == "StepPooler" for module in model.modules())
|
||||
from vllm.model_executor.layers.pooler import StepPooler
|
||||
|
||||
return is_pooling_model(model) and isinstance(model.pooler, StepPooler)
|
||||
|
||||
|
||||
class SupportsQuant:
|
||||
@ -770,10 +714,7 @@ def supports_transcription(model: object) -> TypeIs[SupportsTranscription]:
|
||||
def supports_transcription(
|
||||
model: Union[type[object], object],
|
||||
) -> Union[TypeIs[type[SupportsTranscription]], TypeIs[SupportsTranscription]]:
|
||||
if isinstance(model, type):
|
||||
return isinstance(model, SupportsTranscription)
|
||||
|
||||
return isinstance(model, SupportsTranscription)
|
||||
return getattr(model, "supports_transcription", False)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@ -796,7 +737,4 @@ def supports_v0_only(model: object) -> TypeIs[SupportsV0Only]:
|
||||
def supports_v0_only(
|
||||
model: Union[type[object], object],
|
||||
) -> Union[TypeIs[type[SupportsV0Only]], TypeIs[SupportsV0Only]]:
|
||||
if isinstance(model, type):
|
||||
return isinstance(model, SupportsV0Only)
|
||||
|
||||
return isinstance(model, SupportsV0Only)
|
||||
return getattr(model, "supports_v0_only", False)
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import (TYPE_CHECKING, Optional, Protocol, Union, overload,
|
||||
runtime_checkable)
|
||||
from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol,
|
||||
Union, overload, runtime_checkable)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -13,8 +12,7 @@ from vllm.utils import supports_kw
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.pooler import PoolerOutput
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.model_executor.layers.pooler import Pooler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -130,16 +128,20 @@ def is_text_generation_model(
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class VllmModelForPooling(VllmModel[T], Protocol[T]):
|
||||
class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]):
|
||||
"""The interface required for all pooling models in vLLM."""
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: T,
|
||||
pooling_metadata: "PoolingMetadata",
|
||||
) -> "PoolerOutput":
|
||||
"""Only called on TP rank 0."""
|
||||
...
|
||||
is_pooling_model: ClassVar[Literal[True]] = True
|
||||
"""
|
||||
A flag that indicates this model supports pooling.
|
||||
|
||||
Note:
|
||||
There is no need to redefine this flag if this class is in the
|
||||
MRO of your model class.
|
||||
"""
|
||||
|
||||
pooler: "Pooler"
|
||||
"""The pooler is only called on TP rank 0."""
|
||||
|
||||
|
||||
@overload
|
||||
@ -158,7 +160,4 @@ def is_pooling_model(
|
||||
if not is_vllm_model(model):
|
||||
return False
|
||||
|
||||
if isinstance(model, type):
|
||||
return isinstance(model, VllmModelForPooling)
|
||||
|
||||
return isinstance(model, VllmModelForPooling)
|
||||
return getattr(model, "is_pooling_model", False)
|
||||
|
||||
@ -28,9 +28,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
@ -404,6 +403,8 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
|
||||
|
||||
class InternLM2ForRewardModel(InternLM2ForCausalLM):
|
||||
|
||||
is_pooling_model = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@ -428,7 +429,7 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
|
||||
)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
self.pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.ALL,
|
||||
normalize=False,
|
||||
@ -446,10 +447,3 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
|
||||
inputs_embeds)
|
||||
logits, _ = self.v_head(hidden_states)
|
||||
return logits
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
@ -27,9 +27,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
||||
MambaCacheParams)
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import LayerBlockType
|
||||
|
||||
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
|
||||
@ -563,6 +562,8 @@ def _is_moe_layer(name: str):
|
||||
|
||||
class JambaForSequenceClassification(JambaForCausalLM):
|
||||
|
||||
is_pooling_model = True
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
@ -590,16 +591,9 @@ class JambaForSequenceClassification(JambaForCausalLM):
|
||||
softmax=False,
|
||||
)
|
||||
|
||||
self._pooler = ClassifierPooler(
|
||||
self.pooler = ClassifierPooler(
|
||||
vllm_config.model_config,
|
||||
pooling=pooler.pooling,
|
||||
classifier=self.score,
|
||||
act_fn=pooler.head.activation,
|
||||
)
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
@ -13,9 +13,8 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import (SupportsCrossEncoding, SupportsMultiModal,
|
||||
SupportsScoreTemplate)
|
||||
@ -72,6 +71,8 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration,
|
||||
SupportsCrossEncoding,
|
||||
SupportsMultiModal,
|
||||
SupportsScoreTemplate):
|
||||
|
||||
is_pooling_model = True
|
||||
weight_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
"score.0.": "score.dense.",
|
||||
@ -95,7 +96,7 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration,
|
||||
|
||||
self.score = JinaVLScorer(config)
|
||||
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
self.pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.LAST,
|
||||
normalize=False,
|
||||
@ -137,14 +138,6 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration,
|
||||
logits = self.score(hidden_states) - self.LOGIT_BIAS
|
||||
return logits
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights, mapper=self.weight_mapper)
|
||||
|
||||
@ -13,14 +13,16 @@ from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.pooler import (BasePooler, ClassifierPooler,
|
||||
PoolingMethod, PoolingType)
|
||||
from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler,
|
||||
PoolingMethod, PoolingTask,
|
||||
PoolingType)
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsCrossEncoding, SupportsV0Only
|
||||
from .utils import WeightsMapper, maybe_prefix
|
||||
@ -253,7 +255,7 @@ class ModernBertModel(nn.Module):
|
||||
return norm_outputs
|
||||
|
||||
|
||||
class ModernBertPooler(BasePooler):
|
||||
class ModernBertPooler(Pooler):
|
||||
|
||||
def __init__(self, config: ModernBertConfig):
|
||||
super().__init__()
|
||||
@ -268,6 +270,9 @@ class ModernBertPooler(BasePooler):
|
||||
eps=config.norm_eps,
|
||||
bias=config.norm_bias)
|
||||
|
||||
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
|
||||
return self.pooling.get_pooling_params(task)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
@ -281,6 +286,8 @@ class ModernBertPooler(BasePooler):
|
||||
class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
|
||||
SupportsCrossEncoding):
|
||||
|
||||
is_pooling_model = True
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
@ -288,7 +295,7 @@ class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
|
||||
self.model = ModernBertModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "modernbert"))
|
||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||
self._pooler = ClassifierPooler(
|
||||
self.pooler = ClassifierPooler(
|
||||
vllm_config.model_config,
|
||||
pooling=ModernBertPooler(config),
|
||||
classifier=self.classifier,
|
||||
@ -321,13 +328,6 @@ class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor],
|
||||
|
||||
@ -24,12 +24,13 @@ import torch.nn as nn
|
||||
from transformers import BatchFeature
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.pooler import (AllPool, PoolerHead,
|
||||
PoolerIdentity, SimplePooler)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import (IsAttentionFree,
|
||||
SupportsMultiModal,
|
||||
SupportsV0Only)
|
||||
from vllm.model_executor.models.utils import AutoWeightsLoader
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputs, MultiModalKwargs)
|
||||
@ -37,8 +38,7 @@ from vllm.multimodal.parse import MultiModalDataItems
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import (IntermediateTensors, PoolerOutput,
|
||||
PoolingSequenceGroupOutput)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
|
||||
class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo):
|
||||
@ -116,7 +116,9 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
|
||||
dummy_inputs=PrithviGeoSpatialMAEInputBuilder)
|
||||
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
|
||||
SupportsV0Only):
|
||||
""" Prithvi Masked Autoencoder"""
|
||||
"""Prithvi Masked Autoencoder"""
|
||||
|
||||
is_pooling_model = True
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
@ -162,6 +164,8 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
|
||||
"Only SemanticSegmentationTask is supported for now "
|
||||
"by PrithviGeospatialMAE.")
|
||||
|
||||
self.pooler = SimplePooler(AllPool(), PoolerHead(PoolerIdentity()))
|
||||
|
||||
def _parse_and_validate_multimodal_data(
|
||||
self, **kwargs) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
|
||||
@ -189,7 +193,6 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
):
|
||||
|
||||
pixel_values, location_coords = (
|
||||
self._parse_and_validate_multimodal_data(**kwargs))
|
||||
model_output = self.model(pixel_values,
|
||||
@ -197,13 +200,6 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
|
||||
|
||||
return model_output.output
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
return PoolerOutput([PoolingSequenceGroupOutput(hidden_states)])
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
params_list = []
|
||||
|
||||
@ -16,8 +16,7 @@ from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType, SimplePooler
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .qwen2 import Qwen2Model
|
||||
@ -25,6 +24,10 @@ from .utils import AutoWeightsLoader, maybe_prefix
|
||||
|
||||
|
||||
class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||
|
||||
is_pooling_model = True
|
||||
pooler: SimplePooler
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@ -61,7 +64,6 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||
quant_config=quant_config,
|
||||
return_bias=False),
|
||||
)
|
||||
self._pooler: SimplePooler
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
@ -80,13 +82,6 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||
logits = self.score(hidden_states)
|
||||
return logits
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self,
|
||||
@ -96,11 +91,11 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||
|
||||
class Qwen2ForRewardModel(Qwen2RewardBaseModel):
|
||||
|
||||
def __init__(self, *, vllm_config, prefix=""):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
vllm_config.model_config.hf_config.num_labels = 1
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
self.pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.ALL,
|
||||
normalize=False,
|
||||
@ -109,11 +104,11 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel):
|
||||
|
||||
class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
|
||||
|
||||
def __init__(self, *, vllm_config, prefix=""):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
vllm_config.model_config.hf_config.num_labels = 2
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
self.pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.STEP,
|
||||
normalize=False,
|
||||
|
||||
@ -15,8 +15,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel
|
||||
from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
|
||||
maybe_prefix)
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .bert_with_rope import BertWithRope, JinaRobertaModel
|
||||
from .interfaces import SupportsCrossEncoding, SupportsV0Only
|
||||
@ -165,6 +164,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
||||
_pooler: An instance of Pooler used for pooling operations.
|
||||
"""
|
||||
|
||||
is_pooling_model = True
|
||||
jina_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_substr={
|
||||
'emb_ln': "embeddings.LayerNorm",
|
||||
@ -188,7 +188,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
||||
add_pooling_layer=False)
|
||||
self.classifier = RobertaClassificationHead(config)
|
||||
|
||||
self._pooler = ClassifierPooler(
|
||||
self.pooler = ClassifierPooler(
|
||||
vllm_config.model_config,
|
||||
pooling=CLSPool(),
|
||||
classifier=self.classifier,
|
||||
@ -198,13 +198,6 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights, mapper=self.jina_to_vllm_mapper)
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import msgspec
|
||||
|
||||
@ -15,24 +15,31 @@ class PoolingParams(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
array_like=True): # type: ignore[call-arg]
|
||||
"""API parameters for pooling models. This is currently a placeholder.
|
||||
"""API parameters for pooling models. This
|
||||
|
||||
Attributes:
|
||||
dimensions: Reduce the dimensions of embeddings
|
||||
if model support matryoshka representation.
|
||||
additional_data: Any additional data needed for pooling.
|
||||
"""
|
||||
|
||||
dimensions: Optional[int] = None
|
||||
|
||||
use_cross_encoder: bool = False
|
||||
additional_data: Optional[Any] = None
|
||||
"""Internal use only."""
|
||||
|
||||
logits_processing_needs_token_ids: bool = False
|
||||
"""Internal use only."""
|
||||
|
||||
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY
|
||||
|
||||
def clone(self) -> "PoolingParams":
|
||||
"""Returns a deep copy of the PoolingParams instance."""
|
||||
return PoolingParams(dimensions=self.dimensions,
|
||||
use_cross_encoder=self.use_cross_encoder,
|
||||
additional_data=self.additional_data)
|
||||
return PoolingParams(
|
||||
dimensions=self.dimensions,
|
||||
use_cross_encoder=self.use_cross_encoder,
|
||||
logits_processing_needs_token_ids=self.
|
||||
logits_processing_needs_token_ids,
|
||||
)
|
||||
|
||||
def verify(self, model_config: "ModelConfig") -> None:
|
||||
if self.dimensions is not None:
|
||||
@ -54,10 +61,12 @@ class PoolingParams(
|
||||
raise ValueError("Dimensions must be greater than 0")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"PoolingParams("
|
||||
f"dimensions={self.dimensions}, "
|
||||
f"use_cross_encoder={self.use_cross_encoder}, "
|
||||
f"additional_metadata={self.additional_data})")
|
||||
return (
|
||||
f"PoolingParams("
|
||||
f"dimensions={self.dimensions}, "
|
||||
f"use_cross_encoder={self.use_cross_encoder}, "
|
||||
f"logits_processing_needs_token_ids={self.logits_processing_needs_token_ids})"
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
assert self.output_kind == RequestOutputKind.FINAL_ONLY,\
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user