[Model] Automatic conversion of classification and reward models (#11469)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2024-12-25 02:22:22 +08:00 committed by GitHub
parent 409475a827
commit 3f3e92e1f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 206 additions and 161 deletions

View File

@ -28,7 +28,7 @@ llm = LLM(model=..., task="generate") # Name or path of your model
output = llm.generate("Hello, my name is")
print(output)
# For pooling models (task={embed,classify,reward}) only
# For pooling models (task={embed,classify,reward,score}) only
llm = LLM(model=..., task="embed") # Name or path of your model
output = llm.encode("Hello, my name is")
print(output)
@ -59,7 +59,7 @@ llm = LLM(model=..., revision=..., task=..., trust_remote_code=True)
output = llm.generate("Hello, my name is")
print(output)
# For pooling models (task={embed,classify,reward}) only
# For pooling models (task={embed,classify,reward,score}) only
output = llm.encode("Hello, my name is")
print(output)
```
@ -369,14 +369,6 @@ you should explicitly specify the task type to ensure that the model is used in
#### Text Embedding (`--task embed`)
Any text generation model can be converted into an embedding model by passing {code}`--task embed`.
```{note}
To get the best results, you should use pooling models that are specifically trained as such.
```
The following table lists those that are tested in vLLM.
```{eval-rst}
.. list-table::
:widths: 25 25 50 5 5
@ -437,6 +429,10 @@ On the other hand, its 1.5B variant ({code}`Alibaba-NLP/gte-Qwen2-1.5B-instruct`
despite being described otherwise on its model card.
```
If your model is not in the above list, we will try to automatically convert the model using
:func:`vllm.model_executor.models.adapters.as_embedding_model`. By default, the embeddings
of the whole prompt are extracted from the normalized hidden state corresponding to the last token.
#### Reward Modeling (`--task reward`)
```{eval-rst}
@ -461,6 +457,9 @@ despite being described otherwise on its model card.
- ✅︎
```
If your model is not in the above list, we will try to automatically convert the model using
:func:`vllm.model_executor.models.adapters.as_reward_model`. By default, we return the hidden states of each token directly.
```{important}
For process-supervised reward models such as {code}`peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly,
e.g.: {code}`--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`.
@ -490,6 +489,9 @@ e.g.: {code}`--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 1
- ✅︎
```
If your model is not in the above list, we will try to automatically convert the model using
:func:`vllm.model_executor.models.adapters.as_classification_model`. By default, the class probabilities are extracted from the softmaxed hidden state corresponding to the last token.
#### Sentence Pair Scoring (`--task score`)
```{eval-rst}

View File

@ -1,7 +1,4 @@
"""Compare the outputs of HF and vLLM when using greedy sampling.
This test only tests small models. Big models such as 7B should be tested from
test_big_models.py because it could use a larger instance to run tests.
"""Compare the classification outputs of HF and vLLM models.
Run `pytest tests/models/test_cls_models.py`.
"""

View File

@ -1,6 +1,6 @@
"""Compare the embedding outputs of HF and vLLM models.
"""Compare the scoring outputs of HF and vLLM models.
Run `pytest tests/models/embedding/language/test_embedding.py`.
Run `pytest tests/models/embedding/language/test_scoring.py`.
"""
import math

View File

@ -6,7 +6,9 @@ import torch.cuda
from vllm.model_executor.models import (is_pooling_model,
is_text_generation_model,
supports_multimodal)
from vllm.model_executor.models.adapters import as_embedding_model
from vllm.model_executor.models.adapters import (as_classification_model,
as_embedding_model,
as_reward_model)
from vllm.model_executor.models.registry import (_MULTIMODAL_MODELS,
_SPECULATIVE_DECODING_MODELS,
_TEXT_GENERATION_MODELS,
@ -29,9 +31,10 @@ def test_registry_imports(model_arch):
or model_arch in _MULTIMODAL_MODELS):
assert is_text_generation_model(model_cls)
# All vLLM models should be convertible to an embedding model
embed_model = as_embedding_model(model_cls)
assert is_pooling_model(embed_model)
# All vLLM models should be convertible to a pooling model
assert is_pooling_model(as_classification_model(model_cls))
assert is_pooling_model(as_embedding_model(model_cls))
assert is_pooling_model(as_reward_model(model_cls))
if model_arch in _MULTIMODAL_MODELS:
assert supports_multimodal(model_cls)

View File

@ -7,7 +7,9 @@ from torch import nn
from vllm.config import ModelConfig
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.adapters import as_embedding_model
from vllm.model_executor.models.adapters import (as_classification_model,
as_embedding_model,
as_reward_model)
@contextlib.contextmanager
@ -35,8 +37,12 @@ def get_model_architecture(
architectures = ["QuantMixtralForCausalLM"]
model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
if model_config.runner_type == "pooling":
if model_config.task == "embed":
model_cls = as_embedding_model(model_cls)
elif model_config.task == "classify":
model_cls = as_classification_model(model_cls)
elif model_config.task == "reward":
model_cls = as_reward_model(model_cls)
return model_cls, arch

View File

@ -1,29 +1,48 @@
from collections.abc import Iterable
from typing import Any, TypeVar
from typing import TYPE_CHECKING, Any, Optional, TypeVar
import torch
import torch.nn as nn
from .interfaces_base import VllmModelForPooling, is_pooling_model
if TYPE_CHECKING:
from vllm.model_executor.layers.pooler import PoolingType
_T = TypeVar("_T", bound=type[nn.Module])
_GENERATE_SUFFIXES = [
"ForCausalLM",
"ForConditionalGeneration",
"ChatModel",
"LMHeadModel",
]
def as_embedding_model(cls: _T) -> _T:
"""Subclass an existing vLLM model to support embeddings."""
# Avoid modifying existing embedding models
if is_pooling_model(cls):
return cls
def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:
model_name = orig_model_name
for generate_suffix in _GENERATE_SUFFIXES:
model_name = model_name.removesuffix(generate_suffix)
return model_name + pooling_suffix
def _create_pooling_model_cls(
orig_cls: _T,
*,
default_pooling_type: "PoolingType",
default_normalize: bool,
default_softmax: bool,
) -> _T:
# Lazy import
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import (Pooler, PoolerOutput,
PoolingType)
from vllm.model_executor.layers.pooler import Pooler, PoolerOutput
from vllm.model_executor.pooling_metadata import PoolingMetadata
from .utils import AutoWeightsLoader, WeightsMapper
class ModelForEmbedding(cls, VllmModelForPooling):
class ModelForPooling(orig_cls, VllmModelForPooling):
def __init__(
self,
@ -34,7 +53,7 @@ def as_embedding_model(cls: _T) -> _T:
) -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
# These are not used in embedding models
# These are not used in pooling models
for attr in ("lm_head", "logits_processor"):
if hasattr(self, attr):
delattr(self, attr)
@ -46,9 +65,9 @@ def as_embedding_model(cls: _T) -> _T:
if not getattr(self, "_pooler", None):
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False,
pooling_type=default_pooling_type,
normalize=default_normalize,
softmax=default_softmax,
)
def pooler(
@ -82,17 +101,148 @@ def as_embedding_model(cls: _T) -> _T:
return
# For most other models
if hasattr(cls, "load_weights"):
cls.load_weights(self, weights) # type: ignore
if hasattr(orig_cls, "load_weights"):
orig_cls.load_weights(self, weights) # type: ignore
# Fallback
else:
loader = AutoWeightsLoader(self)
loader.load_weights(weights)
ModelForEmbedding.__name__ = cls.__name__ \
.removesuffix("ForCausalLM") \
.removesuffix("ForConditionalGeneration") \
.removesuffix("ChatModel") \
.removesuffix("LMHeadModel") + "ForEmbedding"
return ModelForPooling # type: ignore
def as_embedding_model(cls: _T) -> _T:
"""
Subclass an existing vLLM model to support embeddings.
By default, the embeddings of the whole prompt are extracted from the
normalized hidden state corresponding to the last token.
Note:
We assume that no extra layers are added to the original model;
please implement your own model if this is not the case.
"""
# Avoid modifying existing embedding models
if is_pooling_model(cls):
return cls
# Lazy import
from vllm.model_executor.layers.pooler import PoolingType
ModelForEmbedding = _create_pooling_model_cls(
cls,
default_pooling_type=PoolingType.LAST,
default_normalize=True,
default_softmax=False,
)
ModelForEmbedding.__name__ = \
_get_pooling_model_name(cls.__name__, "ForEmbedding")
return ModelForEmbedding # type: ignore
def as_classification_model(cls: _T) -> _T:
"""
Subclass an existing vLLM model to support classification.
By default, the class probabilities are extracted from the softmaxed
hidden state corresponding to the last token.
Note:
We assume that the classification head is a single linear layer
stored as the attribute `score` of the top-level model;
please implement your own model if this is not the case.
"""
# Avoid modifying existing classification models
if is_pooling_model(cls):
return cls
# Lazy import
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.pooler import PoolingType
from vllm.sequence import IntermediateTensors
from .utils import maybe_prefix
ModelForPooling = _create_pooling_model_cls(
cls,
default_pooling_type=PoolingType.LAST,
default_normalize=False,
default_softmax=True,
)
class ModelForClassification(ModelForPooling):
def __init__(
self,
*,
vllm_config: "VllmConfig",
prefix: str = "",
**kwargs: Any,
) -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.score = RowParallelLinear(config.hidden_size,
config.num_labels,
quant_config=quant_config,
input_is_parallel=False,
bias=False,
prefix=maybe_prefix(
prefix, "score"))
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: list[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = super().forward(input_ids, positions, kv_caches,
attn_metadata,
intermediate_tensors,
inputs_embeds)
logits, _ = self.score(hidden_states)
return logits
ModelForClassification.__name__ = \
_get_pooling_model_name(cls.__name__, "ForClassification")
return ModelForClassification # type: ignore
def as_reward_model(cls: _T) -> _T:
"""
Subclass an existing vLLM model to support reward modeling.
By default, we return the hidden states of each token directly.
Note:
We assume that no extra layers are added to the original model;
please implement your own model if this is not the case.
"""
# Avoid modifying existing reward models
if is_pooling_model(cls):
return cls
# Lazy import
from vllm.model_executor.layers.pooler import PoolingType
ModelForReward = _create_pooling_model_cls(
cls,
default_pooling_type=PoolingType.ALL,
default_normalize=False,
default_softmax=False,
)
ModelForReward.__name__ = \
_get_pooling_model_name(cls.__name__, "ForReward")
return ModelForReward # type: ignore

View File

@ -545,8 +545,8 @@ class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
self.model = Qwen2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
# TODO: Replace this model class with for_embedding(Qwen2ForCausalLM),
# after changing the default pooling method
# TODO: Replace this model class with as_embedding_model(
# Qwen2ForCausalLM) after changing the default pooling method
if pooler_config.pooling_type is None:
logger.warning(
"This embedding model will default to last-token pooling in "

View File

@ -1,104 +0,0 @@
# Adapted from
# https://huggingface.co/Qwen/Qwen2.5-Math-RM-72B/blob/main/modeling_qwen2_rm.py
# Copyright 2024 Kakao Corp. (Kanana-X Team)
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
"""Inference-only Qwen2-Classification model compatible with HF weights."""
from typing import Iterable, List, Optional, Set, Tuple
import torch
from torch import nn
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from .interfaces import SupportsLoRA, SupportsPP
from .utils import AutoWeightsLoader, maybe_prefix
class Qwen2ForSequenceClassification(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
]
embedding_modules = {}
embedding_padding_modules = []
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
pooler_config = vllm_config.model_config.pooler_config
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = Qwen2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
# hidden_states from Qwen2Model has been reduced,
# the input of score layer is not parallelized.
self.score = RowParallelLinear(config.hidden_size,
config.num_labels,
quant_config=quant_config,
input_is_parallel=False,
bias=False,
prefix=maybe_prefix(prefix, "score"))
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=False,
softmax=True)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
inputs_embeds)
logits, _ = self.score(hidden_states)
return logits
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self,
ignore_unexpected_prefixes=["lm_head."])
return loader.load_weights(weights)

View File

@ -20,11 +20,10 @@ import torch.nn as nn
from vllm.logger import init_logger
from vllm.platforms import current_platform
from .adapters import as_embedding_model
from .interfaces import (has_inner_state, is_attention_free, is_hybrid,
supports_cross_encoding, supports_multimodal,
supports_pp)
from .interfaces_base import is_pooling_model, is_text_generation_model
from .interfaces_base import is_text_generation_model
logger = init_logger(__name__)
@ -125,12 +124,13 @@ _EMBEDDING_MODELS = {
"Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
"Qwen2ForSequenceClassification": ("qwen2_cls", "Qwen2ForSequenceClassification"), # noqa: E501
"TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
# [Multimodal]
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
# [Auto-converted (see adapters.py)]
"Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForCausalLM"),
}
_CROSS_ENCODER_MODELS = {
@ -226,19 +226,10 @@ class _ModelInfo:
@staticmethod
def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
is_pooling_model_ = is_pooling_model(model)
if not is_pooling_model_:
try:
as_embedding_model(model)
except Exception:
pass
else:
is_pooling_model_ = True
return _ModelInfo(
architecture=model.__name__,
is_text_generation_model=is_text_generation_model(model),
is_pooling_model=is_pooling_model_,
is_pooling_model=True, # Can convert any model into a pooling model
supports_cross_encoding=supports_cross_encoding(model),
supports_multimodal=supports_multimodal(model),
supports_pp=supports_pp(model),