mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-24 02:54:28 +08:00
[Model] Automatic conversion of classification and reward models (#11469)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
409475a827
commit
3f3e92e1f2
@ -28,7 +28,7 @@ llm = LLM(model=..., task="generate") # Name or path of your model
|
||||
output = llm.generate("Hello, my name is")
|
||||
print(output)
|
||||
|
||||
# For pooling models (task={embed,classify,reward}) only
|
||||
# For pooling models (task={embed,classify,reward,score}) only
|
||||
llm = LLM(model=..., task="embed") # Name or path of your model
|
||||
output = llm.encode("Hello, my name is")
|
||||
print(output)
|
||||
@ -59,7 +59,7 @@ llm = LLM(model=..., revision=..., task=..., trust_remote_code=True)
|
||||
output = llm.generate("Hello, my name is")
|
||||
print(output)
|
||||
|
||||
# For pooling models (task={embed,classify,reward}) only
|
||||
# For pooling models (task={embed,classify,reward,score}) only
|
||||
output = llm.encode("Hello, my name is")
|
||||
print(output)
|
||||
```
|
||||
@ -369,14 +369,6 @@ you should explicitly specify the task type to ensure that the model is used in
|
||||
|
||||
#### Text Embedding (`--task embed`)
|
||||
|
||||
Any text generation model can be converted into an embedding model by passing {code}`--task embed`.
|
||||
|
||||
```{note}
|
||||
To get the best results, you should use pooling models that are specifically trained as such.
|
||||
```
|
||||
|
||||
The following table lists those that are tested in vLLM.
|
||||
|
||||
```{eval-rst}
|
||||
.. list-table::
|
||||
:widths: 25 25 50 5 5
|
||||
@ -437,6 +429,10 @@ On the other hand, its 1.5B variant ({code}`Alibaba-NLP/gte-Qwen2-1.5B-instruct`
|
||||
despite being described otherwise on its model card.
|
||||
```
|
||||
|
||||
If your model is not in the above list, we will try to automatically convert the model using
|
||||
:func:`vllm.model_executor.models.adapters.as_embedding_model`. By default, the embeddings
|
||||
of the whole prompt are extracted from the normalized hidden state corresponding to the last token.
|
||||
|
||||
#### Reward Modeling (`--task reward`)
|
||||
|
||||
```{eval-rst}
|
||||
@ -461,6 +457,9 @@ despite being described otherwise on its model card.
|
||||
- ✅︎
|
||||
```
|
||||
|
||||
If your model is not in the above list, we will try to automatically convert the model using
|
||||
:func:`vllm.model_executor.models.adapters.as_reward_model`. By default, we return the hidden states of each token directly.
|
||||
|
||||
```{important}
|
||||
For process-supervised reward models such as {code}`peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly,
|
||||
e.g.: {code}`--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`.
|
||||
@ -490,6 +489,9 @@ e.g.: {code}`--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 1
|
||||
- ✅︎
|
||||
```
|
||||
|
||||
If your model is not in the above list, we will try to automatically convert the model using
|
||||
:func:`vllm.model_executor.models.adapters.as_classification_model`. By default, the class probabilities are extracted from the softmaxed hidden state corresponding to the last token.
|
||||
|
||||
#### Sentence Pair Scoring (`--task score`)
|
||||
|
||||
```{eval-rst}
|
||||
|
||||
@ -1,7 +1,4 @@
|
||||
"""Compare the outputs of HF and vLLM when using greedy sampling.
|
||||
|
||||
This test only tests small models. Big models such as 7B should be tested from
|
||||
test_big_models.py because it could use a larger instance to run tests.
|
||||
"""Compare the classification outputs of HF and vLLM models.
|
||||
|
||||
Run `pytest tests/models/test_cls_models.py`.
|
||||
"""
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
"""Compare the embedding outputs of HF and vLLM models.
|
||||
"""Compare the scoring outputs of HF and vLLM models.
|
||||
|
||||
Run `pytest tests/models/embedding/language/test_embedding.py`.
|
||||
Run `pytest tests/models/embedding/language/test_scoring.py`.
|
||||
"""
|
||||
import math
|
||||
|
||||
|
||||
@ -6,7 +6,9 @@ import torch.cuda
|
||||
from vllm.model_executor.models import (is_pooling_model,
|
||||
is_text_generation_model,
|
||||
supports_multimodal)
|
||||
from vllm.model_executor.models.adapters import as_embedding_model
|
||||
from vllm.model_executor.models.adapters import (as_classification_model,
|
||||
as_embedding_model,
|
||||
as_reward_model)
|
||||
from vllm.model_executor.models.registry import (_MULTIMODAL_MODELS,
|
||||
_SPECULATIVE_DECODING_MODELS,
|
||||
_TEXT_GENERATION_MODELS,
|
||||
@ -29,9 +31,10 @@ def test_registry_imports(model_arch):
|
||||
or model_arch in _MULTIMODAL_MODELS):
|
||||
assert is_text_generation_model(model_cls)
|
||||
|
||||
# All vLLM models should be convertible to an embedding model
|
||||
embed_model = as_embedding_model(model_cls)
|
||||
assert is_pooling_model(embed_model)
|
||||
# All vLLM models should be convertible to a pooling model
|
||||
assert is_pooling_model(as_classification_model(model_cls))
|
||||
assert is_pooling_model(as_embedding_model(model_cls))
|
||||
assert is_pooling_model(as_reward_model(model_cls))
|
||||
|
||||
if model_arch in _MULTIMODAL_MODELS:
|
||||
assert supports_multimodal(model_cls)
|
||||
|
||||
@ -7,7 +7,9 @@ from torch import nn
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.model_executor.models.adapters import as_embedding_model
|
||||
from vllm.model_executor.models.adapters import (as_classification_model,
|
||||
as_embedding_model,
|
||||
as_reward_model)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
@ -35,8 +37,12 @@ def get_model_architecture(
|
||||
architectures = ["QuantMixtralForCausalLM"]
|
||||
|
||||
model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
|
||||
if model_config.runner_type == "pooling":
|
||||
if model_config.task == "embed":
|
||||
model_cls = as_embedding_model(model_cls)
|
||||
elif model_config.task == "classify":
|
||||
model_cls = as_classification_model(model_cls)
|
||||
elif model_config.task == "reward":
|
||||
model_cls = as_reward_model(model_cls)
|
||||
|
||||
return model_cls, arch
|
||||
|
||||
|
||||
@ -1,29 +1,48 @@
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypeVar
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .interfaces_base import VllmModelForPooling, is_pooling_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.pooler import PoolingType
|
||||
|
||||
_T = TypeVar("_T", bound=type[nn.Module])
|
||||
|
||||
_GENERATE_SUFFIXES = [
|
||||
"ForCausalLM",
|
||||
"ForConditionalGeneration",
|
||||
"ChatModel",
|
||||
"LMHeadModel",
|
||||
]
|
||||
|
||||
def as_embedding_model(cls: _T) -> _T:
|
||||
"""Subclass an existing vLLM model to support embeddings."""
|
||||
# Avoid modifying existing embedding models
|
||||
if is_pooling_model(cls):
|
||||
return cls
|
||||
|
||||
def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:
|
||||
model_name = orig_model_name
|
||||
|
||||
for generate_suffix in _GENERATE_SUFFIXES:
|
||||
model_name = model_name.removesuffix(generate_suffix)
|
||||
|
||||
return model_name + pooling_suffix
|
||||
|
||||
|
||||
def _create_pooling_model_cls(
|
||||
orig_cls: _T,
|
||||
*,
|
||||
default_pooling_type: "PoolingType",
|
||||
default_normalize: bool,
|
||||
default_softmax: bool,
|
||||
) -> _T:
|
||||
# Lazy import
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.pooler import (Pooler, PoolerOutput,
|
||||
PoolingType)
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolerOutput
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
|
||||
from .utils import AutoWeightsLoader, WeightsMapper
|
||||
|
||||
class ModelForEmbedding(cls, VllmModelForPooling):
|
||||
class ModelForPooling(orig_cls, VllmModelForPooling):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -34,7 +53,7 @@ def as_embedding_model(cls: _T) -> _T:
|
||||
) -> None:
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||
|
||||
# These are not used in embedding models
|
||||
# These are not used in pooling models
|
||||
for attr in ("lm_head", "logits_processor"):
|
||||
if hasattr(self, attr):
|
||||
delattr(self, attr)
|
||||
@ -46,9 +65,9 @@ def as_embedding_model(cls: _T) -> _T:
|
||||
if not getattr(self, "_pooler", None):
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.LAST,
|
||||
normalize=True,
|
||||
softmax=False,
|
||||
pooling_type=default_pooling_type,
|
||||
normalize=default_normalize,
|
||||
softmax=default_softmax,
|
||||
)
|
||||
|
||||
def pooler(
|
||||
@ -82,17 +101,148 @@ def as_embedding_model(cls: _T) -> _T:
|
||||
return
|
||||
|
||||
# For most other models
|
||||
if hasattr(cls, "load_weights"):
|
||||
cls.load_weights(self, weights) # type: ignore
|
||||
if hasattr(orig_cls, "load_weights"):
|
||||
orig_cls.load_weights(self, weights) # type: ignore
|
||||
# Fallback
|
||||
else:
|
||||
loader = AutoWeightsLoader(self)
|
||||
loader.load_weights(weights)
|
||||
|
||||
ModelForEmbedding.__name__ = cls.__name__ \
|
||||
.removesuffix("ForCausalLM") \
|
||||
.removesuffix("ForConditionalGeneration") \
|
||||
.removesuffix("ChatModel") \
|
||||
.removesuffix("LMHeadModel") + "ForEmbedding"
|
||||
return ModelForPooling # type: ignore
|
||||
|
||||
|
||||
def as_embedding_model(cls: _T) -> _T:
|
||||
"""
|
||||
Subclass an existing vLLM model to support embeddings.
|
||||
|
||||
By default, the embeddings of the whole prompt are extracted from the
|
||||
normalized hidden state corresponding to the last token.
|
||||
|
||||
Note:
|
||||
We assume that no extra layers are added to the original model;
|
||||
please implement your own model if this is not the case.
|
||||
"""
|
||||
# Avoid modifying existing embedding models
|
||||
if is_pooling_model(cls):
|
||||
return cls
|
||||
|
||||
# Lazy import
|
||||
from vllm.model_executor.layers.pooler import PoolingType
|
||||
|
||||
ModelForEmbedding = _create_pooling_model_cls(
|
||||
cls,
|
||||
default_pooling_type=PoolingType.LAST,
|
||||
default_normalize=True,
|
||||
default_softmax=False,
|
||||
)
|
||||
ModelForEmbedding.__name__ = \
|
||||
_get_pooling_model_name(cls.__name__, "ForEmbedding")
|
||||
|
||||
return ModelForEmbedding # type: ignore
|
||||
|
||||
|
||||
def as_classification_model(cls: _T) -> _T:
|
||||
"""
|
||||
Subclass an existing vLLM model to support classification.
|
||||
|
||||
By default, the class probabilities are extracted from the softmaxed
|
||||
hidden state corresponding to the last token.
|
||||
|
||||
Note:
|
||||
We assume that the classification head is a single linear layer
|
||||
stored as the attribute `score` of the top-level model;
|
||||
please implement your own model if this is not the case.
|
||||
"""
|
||||
# Avoid modifying existing classification models
|
||||
if is_pooling_model(cls):
|
||||
return cls
|
||||
|
||||
# Lazy import
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.linear import RowParallelLinear
|
||||
from vllm.model_executor.layers.pooler import PoolingType
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .utils import maybe_prefix
|
||||
|
||||
ModelForPooling = _create_pooling_model_cls(
|
||||
cls,
|
||||
default_pooling_type=PoolingType.LAST,
|
||||
default_normalize=False,
|
||||
default_softmax=True,
|
||||
)
|
||||
|
||||
class ModelForClassification(ModelForPooling):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: "VllmConfig",
|
||||
prefix: str = "",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.score = RowParallelLinear(config.hidden_size,
|
||||
config.num_labels,
|
||||
quant_config=quant_config,
|
||||
input_is_parallel=False,
|
||||
bias=False,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "score"))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: list[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = super().forward(input_ids, positions, kv_caches,
|
||||
attn_metadata,
|
||||
intermediate_tensors,
|
||||
inputs_embeds)
|
||||
logits, _ = self.score(hidden_states)
|
||||
return logits
|
||||
|
||||
|
||||
ModelForClassification.__name__ = \
|
||||
_get_pooling_model_name(cls.__name__, "ForClassification")
|
||||
|
||||
return ModelForClassification # type: ignore
|
||||
|
||||
|
||||
def as_reward_model(cls: _T) -> _T:
|
||||
"""
|
||||
Subclass an existing vLLM model to support reward modeling.
|
||||
|
||||
By default, we return the hidden states of each token directly.
|
||||
|
||||
Note:
|
||||
We assume that no extra layers are added to the original model;
|
||||
please implement your own model if this is not the case.
|
||||
"""
|
||||
# Avoid modifying existing reward models
|
||||
if is_pooling_model(cls):
|
||||
return cls
|
||||
|
||||
# Lazy import
|
||||
from vllm.model_executor.layers.pooler import PoolingType
|
||||
|
||||
ModelForReward = _create_pooling_model_cls(
|
||||
cls,
|
||||
default_pooling_type=PoolingType.ALL,
|
||||
default_normalize=False,
|
||||
default_softmax=False,
|
||||
)
|
||||
|
||||
ModelForReward.__name__ = \
|
||||
_get_pooling_model_name(cls.__name__, "ForReward")
|
||||
|
||||
return ModelForReward # type: ignore
|
||||
|
||||
@ -545,8 +545,8 @@ class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.model = Qwen2Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
# TODO: Replace this model class with for_embedding(Qwen2ForCausalLM),
|
||||
# after changing the default pooling method
|
||||
# TODO: Replace this model class with as_embedding_model(
|
||||
# Qwen2ForCausalLM) after changing the default pooling method
|
||||
if pooler_config.pooling_type is None:
|
||||
logger.warning(
|
||||
"This embedding model will default to last-token pooling in "
|
||||
|
||||
@ -1,104 +0,0 @@
|
||||
# Adapted from
|
||||
# https://huggingface.co/Qwen/Qwen2.5-Math-RM-72B/blob/main/modeling_qwen2_rm.py
|
||||
# Copyright 2024 Kakao Corp. (Kanana-X Team)
|
||||
# Copyright 2024 The Qwen team.
|
||||
# Copyright 2023 The vLLM team.
|
||||
"""Inference-only Qwen2-Classification model compatible with HF weights."""
|
||||
from typing import Iterable, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.linear import RowParallelLinear
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||
from vllm.model_executor.models.qwen2 import Qwen2Model
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import AutoWeightsLoader, maybe_prefix
|
||||
|
||||
|
||||
class Qwen2ForSequenceClassification(nn.Module, SupportsLoRA, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"qkv_proj",
|
||||
"o_proj",
|
||||
"gate_up_proj",
|
||||
"down_proj",
|
||||
]
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.model = Qwen2Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
# hidden_states from Qwen2Model has been reduced,
|
||||
# the input of score layer is not parallelized.
|
||||
self.score = RowParallelLinear(config.hidden_size,
|
||||
config.num_labels,
|
||||
quant_config=quant_config,
|
||||
input_is_parallel=False,
|
||||
bias=False,
|
||||
prefix=maybe_prefix(prefix, "score"))
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.LAST,
|
||||
normalize=False,
|
||||
softmax=True)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
logits, _ = self.score(hidden_states)
|
||||
return logits
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
loader = AutoWeightsLoader(self,
|
||||
ignore_unexpected_prefixes=["lm_head."])
|
||||
return loader.load_weights(weights)
|
||||
@ -20,11 +20,10 @@ import torch.nn as nn
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .adapters import as_embedding_model
|
||||
from .interfaces import (has_inner_state, is_attention_free, is_hybrid,
|
||||
supports_cross_encoding, supports_multimodal,
|
||||
supports_pp)
|
||||
from .interfaces_base import is_pooling_model, is_text_generation_model
|
||||
from .interfaces_base import is_text_generation_model
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -125,12 +124,13 @@ _EMBEDDING_MODELS = {
|
||||
"Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"),
|
||||
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
|
||||
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
|
||||
"Qwen2ForSequenceClassification": ("qwen2_cls", "Qwen2ForSequenceClassification"), # noqa: E501
|
||||
"TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
|
||||
# [Multimodal]
|
||||
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
|
||||
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
||||
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
|
||||
# [Auto-converted (see adapters.py)]
|
||||
"Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForCausalLM"),
|
||||
}
|
||||
|
||||
_CROSS_ENCODER_MODELS = {
|
||||
@ -226,19 +226,10 @@ class _ModelInfo:
|
||||
|
||||
@staticmethod
|
||||
def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
|
||||
is_pooling_model_ = is_pooling_model(model)
|
||||
if not is_pooling_model_:
|
||||
try:
|
||||
as_embedding_model(model)
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
is_pooling_model_ = True
|
||||
|
||||
return _ModelInfo(
|
||||
architecture=model.__name__,
|
||||
is_text_generation_model=is_text_generation_model(model),
|
||||
is_pooling_model=is_pooling_model_,
|
||||
is_pooling_model=True, # Can convert any model into a pooling model
|
||||
supports_cross_encoding=supports_cross_encoding(model),
|
||||
supports_multimodal=supports_multimodal(model),
|
||||
supports_pp=supports_pp(model),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user