wang.yuqi 110df74332
[Model][Last/4] Automatic conversion of CrossEncoding model (#19675)
Signed-off-by: wang.yuqi <noooop@126.com>
2025-07-07 14:46:04 +00:00

424 lines
15 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union, cast
import torch
import torch.nn as nn
from vllm.model_executor.models.config import VerifyAndUpdateConfig
from .interfaces_base import VllmModelForPooling, is_pooling_model
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import PoolingType
_T = TypeVar("_T", bound=type[nn.Module])
_GENERATE_SUFFIXES = [
"ForCausalLM",
"ForConditionalGeneration",
"ChatModel",
"LMHeadModel",
]
def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:
model_name = orig_model_name
for generate_suffix in _GENERATE_SUFFIXES:
model_name = model_name.removesuffix(generate_suffix)
return model_name + pooling_suffix
def _create_pooling_model_cls(
orig_cls: _T,
*,
default_pooling_type: "PoolingType",
default_normalize: bool,
default_softmax: bool,
) -> _T:
# Lazy import
from vllm.model_executor.layers.pooler import Pooler, PoolerOutput
from vllm.model_executor.pooling_metadata import PoolingMetadata
from .utils import AutoWeightsLoader, WeightsMapper
class ModelForPooling(orig_cls, VllmModelForPooling):
def __init__(
self,
*,
vllm_config: "VllmConfig",
prefix: str = "",
**kwargs: Any,
) -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
# These are not used in pooling models
for attr in ("lm_head", "logits_processor"):
if hasattr(self, attr):
delattr(self, attr)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
# If the model already defines a pooler instance, don't overwrite it
if not getattr(self, "_pooler", None):
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=default_pooling_type,
normalize=default_normalize,
softmax=default_softmax,
)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
# TODO: Support uninitialized params tracking
# We have deleted this attribute, so don't load it
weights = ((name, data) for name, data in weights
if not name.startswith("lm_head."))
# If `*ForCausalLM` defines `load_weights` on the inner model
# and there are no other inner modules with parameters,
# we support loading from both `*Model` and `*ForCausalLM`
if hasattr(self, "model") and hasattr(self.model, "load_weights"):
# Whether only `self.model` contains parameters
model_is_only_param = all(
name == "model" or next(child.parameters(), None) is None
for name, child in self.named_children())
if model_is_only_param:
mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
weights = mapper.apply(weights)
loaded_params = self.model.load_weights(weights)
loaded_params = {f"model.{name}" for name in loaded_params}
return loaded_params
# For most other models
if hasattr(orig_cls, "load_weights"):
return orig_cls.load_weights(self, weights) # type: ignore
# Fallback
else:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
return ModelForPooling # type: ignore
def as_embedding_model(cls: _T) -> _T:
"""
Subclass an existing vLLM model to support embeddings.
By default, the embeddings of the whole prompt are extracted from the
normalized hidden state corresponding to the last token.
Note:
We assume that no extra layers are added to the original model;
please implement your own model if this is not the case.
"""
# Avoid modifying existing embedding models
if is_pooling_model(cls):
return cls
# Lazy import
from vllm.model_executor.layers.pooler import PoolingType
ModelForEmbedding = _create_pooling_model_cls(
cls,
default_pooling_type=PoolingType.LAST,
default_normalize=True,
default_softmax=False,
)
ModelForEmbedding.__name__ = \
_get_pooling_model_name(cls.__name__, "ForEmbedding")
return ModelForEmbedding # type: ignore
def as_seq_cls_model(cls: _T) -> _T:
"""
Subclass an existing vLLM model to support classify and score tasks.
By default, the class probabilities are extracted from the softmaxed
hidden state corresponding to the last token.
Note:
We assume that the classification head is a single linear layer
stored as the attribute `score` of the top-level model;
please implement your own model if this is not the case.
"""
# Avoid modifying existing classification models
if is_pooling_model(cls):
return cls
# Lazy import
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.pooler import PoolerOutput, PoolingType
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors
from .utils import maybe_prefix
ModelForPooling = _create_pooling_model_cls(
cls,
default_pooling_type=PoolingType.LAST,
default_normalize=False,
default_softmax=True,
)
class ModelForSequenceClassification(ModelForPooling,
SupportsCrossEncoding):
def __init__(
self,
*,
vllm_config: "VllmConfig",
prefix: str = "",
**kwargs: Any,
) -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.vllm_config = vllm_config
self.task = vllm_config.model_config.task
self.pooling_type = (
vllm_config.model_config.pooler_config.pooling_type)
self.score = RowParallelLinear(config.hidden_size,
config.num_labels,
quant_config=quant_config,
input_is_parallel=False,
bias=False,
prefix=maybe_prefix(
prefix, "score"))
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return super().forward(input_ids, positions, intermediate_tensors,
inputs_embeds)
def pooler(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
def get_logits(hidden_states):
if isinstance(hidden_states, list):
logits = [self.score(state)[0] for state in hidden_states]
else:
logits, _ = self.score(hidden_states)
return logits
if self.pooling_type == PoolingType.ALL:
logits = get_logits(hidden_states)
return self._pooler(logits, pooling_metadata)
else:
hidden_states = self._pooler.extract_states(
hidden_states, pooling_metadata)
logits = get_logits(hidden_states)
pooled_data = self._pooler.head(logits, pooling_metadata)
pooled_outputs = [
self._pooler.build_output(data) for data in pooled_data
]
return PoolerOutput(outputs=pooled_outputs)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
tokens = getattr(self.config, "classifier_from_token", None)
method = getattr(self.config, "method", None)
if tokens is None and method is None:
return super().load_weights(weights)
else:
# Online convert ForCausalLM into
# ForSequenceClassification model.
return seq_cls_model_loader(self, weights)
ModelForSequenceClassification.__name__ = \
_get_pooling_model_name(cls.__name__, "ForSequenceClassification")
return ModelForSequenceClassification # type: ignore
def as_reward_model(cls: _T) -> _T:
"""
Subclass an existing vLLM model to support reward modeling.
By default, we return the hidden states of each token directly.
Note:
We assume that no extra layers are added to the original model;
please implement your own model if this is not the case.
"""
# Avoid modifying existing reward models
if is_pooling_model(cls):
return cls
# Lazy import
from vllm.model_executor.layers.pooler import PoolingType
ModelForReward = _create_pooling_model_cls(
cls,
default_pooling_type=PoolingType.ALL,
default_normalize=False,
default_softmax=False,
)
ModelForReward.__name__ = \
_get_pooling_model_name(cls.__name__, "ForReward")
return ModelForReward # type: ignore
class SequenceClassificationConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
config = vllm_config.model_config.hf_config
method = getattr(config, "method", None)
tokens = getattr(config, "classifier_from_token", None)
if method is None:
return
assert tokens is not None
assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported"
if method == "from_2_way_softmax":
assert len(tokens) == 2
config.num_labels = 1
else:
config.num_labels = len(tokens)
# `llm as reranker` defaults to not using pad_token
use_pad_token = getattr(config, "use_pad_token", False)
config.use_pad_token = use_pad_token
def load_weights_using_from_2_way_softmax(
model, weights: Iterable[tuple[str, torch.Tensor]]):
# refer to https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead)
from vllm.model_executor.models.utils import AutoWeightsLoader
model_config = model.vllm_config.model_config
tokens = getattr(model.config, "classifier_from_token", [])
tokens = cast(list[int], tokens)
assert len(tokens) == 2
device = model.score.weight.device
if model.config.tie_word_embeddings:
model.lm_head = model.model.embed_tokens
else:
model.lm_head = ParallelLMHead(model.config.vocab_size,
model.config.hidden_size,
quant_config=model.quant_config)
loader = AutoWeightsLoader(model)
loaded_weights = loader.load_weights(weights)
from vllm.transformers_utils.tokenizer import get_tokenizer
tokenizer = get_tokenizer(model_config.tokenizer,
revision=model_config.tokenizer_revision,
tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code)
false_id = tokenizer.convert_tokens_to_ids(tokens[0])
true_id = tokenizer.convert_tokens_to_ids(tokens[1])
weight = model.lm_head.weight.data[true_id].to(device).to(
torch.float32) - model.lm_head.weight.data[false_id].to(device).to(
torch.float32)
model.score.weight.data.copy_(weight)
del model.lm_head
loaded_weights.add("score.weight")
loaded_weights.discard("lm_head.weight")
return loaded_weights
def load_weights_no_post_processing(model,
weights: Iterable[tuple[str,
torch.Tensor]]):
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead)
from vllm.model_executor.models.utils import AutoWeightsLoader
model_config = model.vllm_config.model_config
tokens = getattr(model.config, "classifier_from_token", [])
tokens = cast(list[int], tokens)
assert len(tokens) > 0
device = model.score.weight.device
if model.config.tie_word_embeddings:
model.lm_head = model.model.embed_tokens
else:
model.lm_head = ParallelLMHead(model.config.vocab_size,
model.config.hidden_size,
quant_config=model.quant_config)
loader = AutoWeightsLoader(model)
loaded_weights = loader.load_weights(weights)
from vllm.transformers_utils.tokenizer import get_tokenizer
tokenizer = get_tokenizer(model_config.tokenizer,
revision=model_config.tokenizer_revision,
tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code)
token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens]
score_weight = model.lm_head.weight.data[token_ids].to(device)
model.score.weight.data.copy_(score_weight)
del model.lm_head
loaded_weights.add("score.weight")
loaded_weights.discard("lm_head.weight")
return loaded_weights
SEQ_CLS_LOAD_METHODS = {
"from_2_way_softmax": load_weights_using_from_2_way_softmax,
"no_post_processing": load_weights_no_post_processing,
}
def seq_cls_model_loader(model, weights: Iterable[tuple[str, torch.Tensor]]):
# Online convert ForCausalLM into ForSequenceClassification model.
# - from_2_way_softmax:
# - Qwen3ForCausalLM
# - Qwen3-Reranker
# - Qwen2ForCausalLM
# - mxbai-rerank-v2
# - no_post_processing:
# - GemmaForCausalLM
# - bge-reranker-v2-gemma
config = model.vllm_config.model_config.hf_config
method = getattr(config, "method", None)
assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported"
return SEQ_CLS_LOAD_METHODS[method](model, weights)