[Model][1/N] Automatic conversion of CrossEncoding model (#20012)

Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
wang.yuqi 2025-06-27 12:10:04 +08:00 committed by GitHub
parent e110930680
commit cd4cfee689
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 239 additions and 167 deletions

View File

@ -43,7 +43,7 @@ class VllmMtebEncoder(mteb.Encoder):
# issues by randomizing the order.
r = self.rng.permutation(len(sentences))
sentences = [sentences[i] for i in r]
outputs = self.model.encode(sentences, use_tqdm=False)
outputs = self.model.embed(sentences, use_tqdm=False)
embeds = np.array(outputs)
embeds = embeds[np.argsort(r)]
return embeds
@ -250,16 +250,19 @@ def mteb_test_rerank_models(hf_runner,
with vllm_runner(model_info.name,
task="score",
max_model_len=None,
max_num_seqs=8,
**vllm_extra_kwargs) as vllm_model:
model_config = vllm_model.model.llm_engine.model_config
if model_info.architecture:
assert (model_info.architecture
in vllm_model.model.llm_engine.model_config.architectures)
assert (model_info.architecture in model_config.architectures)
assert model_config.hf_config.num_labels == 1
vllm_main_score = run_mteb_rerank(VllmMtebEncoder(vllm_model),
tasks=MTEB_RERANK_TASKS,
languages=MTEB_RERANK_LANGS)
vllm_dtype = vllm_model.model.llm_engine.model_config.dtype
vllm_dtype = model_config.dtype
with hf_runner(model_info.name, is_cross_encoder=True,
dtype="float32") as hf_model:

View File

@ -569,6 +569,10 @@ class ModelConfig:
else:
self.truncation_side = "right"
model_info, arch = self.registry.inspect_model_cls(self.architectures)
self._model_info = model_info
self._architecture = arch
self.pooler_config = self._init_pooler_config()
self.dtype = _get_and_verify_dtype(
@ -660,8 +664,18 @@ class ModelConfig:
@property
def architectures(self) -> list[str]:
# architectures in the model config.
return getattr(self.hf_config, "architectures", [])
@property
def architecture(self) -> str:
# The architecture vllm actually used.
return self._architecture
@property
def model_info(self) -> dict[str, Any]:
return self._model_info
def maybe_pull_model_tokenizer_for_s3(self, model: str,
tokenizer: str) -> None:
"""Pull model/tokenizer from S3 to temporary directory when needed.
@ -4450,6 +4464,9 @@ class VllmConfig:
def __post_init__(self):
"""Verify configs are valid & consistent with each other.
"""
self.try_verify_and_update_config()
if self.model_config is not None:
self.model_config.verify_async_output_proc(self.parallel_config,
self.speculative_config,
@ -4694,11 +4711,21 @@ class VllmConfig:
batch_size_capture_list)
def recalculate_max_model_len(self, max_model_len: int):
# Can only be called in try_verify_and_update_config
model_config = self.model_config
max_model_len = model_config.get_and_verify_max_len(max_model_len)
self.model_config.max_model_len = max_model_len
self.scheduler_config.max_model_len = max_model_len
self.compute_hash()
def try_verify_and_update_config(self):
architecture = getattr(self.model_config, "architecture", None)
if architecture is None:
return
from vllm.model_executor.models.config import MODELS_CONFIG_MAP
cls = MODELS_CONFIG_MAP.get(architecture, None)
if cls is not None:
cls.verify_and_update_config(self)
def __str__(self):
return (

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from copy import deepcopy
from typing import Optional
import torch
@ -12,7 +11,6 @@ from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import (get_act_and_mul_fn,
get_act_fn)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -30,8 +28,6 @@ from vllm.model_executor.models.interfaces import SupportsQuant
from vllm.model_executor.models.utils import WeightsMapper
from vllm.sequence import IntermediateTensors
logger = init_logger(__name__)
class BertWithRopeEmbedding(nn.Module):
@ -408,7 +404,7 @@ class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.vllm_config = vllm_config
self.config = self.config_verify(vllm_config)
self.config = vllm_config.model_config.hf_config
self.embeddings = BertWithRopeEmbedding(self.config)
self.encoder = BertWithRopeEncoder(
vllm_config=vllm_config,
@ -416,9 +412,6 @@ class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant):
rotary_kwargs=self.config.rotary_kwargs,
prefix=f"{prefix}.encoder")
def config_verify(self, vllm_config):
raise NotImplementedError
def forward(
self,
input_ids: Optional[torch.Tensor],
@ -490,95 +483,6 @@ class NomicBertModel(BertWithRope):
"norm2": "mlp_ln",
})
def config_verify(self, vllm_config):
config = vllm_config.model_config.hf_config
assert config.__class__.__name__ == "NomicBertConfig"
assert config.activation_function in ["swiglu", "gelu"]
config.position_embedding_type = getattr(config,
"position_embedding_type",
"rope")
if config.activation_function == "swiglu":
config.hidden_act = "silu"
else:
config.hidden_act = config.activation_function
assert (config.mlp_fc1_bias == config.mlp_fc2_bias ==
config.qkv_proj_bias)
config.bias = config.qkv_proj_bias
assert config.rotary_emb_scale_base is None
assert not config.rotary_emb_interleaved
config.layer_norm_eps = config.layer_norm_epsilon
config.intermediate_size = config.n_inner
config.hidden_size = config.n_embd
config.num_hidden_layers = config.n_layer
head_dim = config.hidden_size // config.num_attention_heads
rotary_emb_dim = head_dim * config.rotary_emb_fraction
max_trained_positions = getattr(config, "max_trained_positions", 2048)
config.rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": rotary_emb_dim,
"max_position": max_trained_positions,
"base": getattr(config, "rope_theta", config.rotary_emb_base),
"rope_scaling": getattr(config, "rope_scaling", None)
}
# we ignore config.rotary_scaling_factor so that for datasets shorter
# than max_trained_positions 2048, the results are consistent
# with SentenceTransformer.
# The context extension uses vllm style rope_theta and rope_scaling.
# See #17785 #18755
if (not vllm_config.model_config.hf_overrides
and vllm_config.model_config.original_max_model_len is None):
# Default
# Reset max_model_len to max_trained_positions.
# nomic-embed-text-v2-moe the length is set to 512
# by sentence_bert_config.json.
max_model_len_before = vllm_config.model_config.max_model_len
max_model_len = min(vllm_config.model_config.max_model_len,
max_trained_positions)
vllm_config.recalculate_max_model_len(max_model_len)
logger.warning(
"Nomic context extension is disabled. "
"Changing max_model_len from %s to %s. "
"To enable context extension, see: "
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.html",
max_model_len_before, vllm_config.model_config.max_model_len)
else:
# We need to re-verify max_model_len to avoid lengths
# greater than position_embedding.
model_config = vllm_config.model_config
hf_text_config = model_config.hf_text_config
if isinstance(model_config.hf_overrides, dict):
# hf_overrides_kw
max_model_len = model_config.hf_overrides.get(
"max_model_len", vllm_config.model_config.max_model_len)
else:
# hf_overrides_fn
# This might be overridden by sentence_bert_config.json.
max_model_len = vllm_config.model_config.max_model_len
# reset hf_text_config for recalculate_max_model_len.
if hasattr(hf_text_config, "max_model_len"):
delattr(hf_text_config, "max_model_len")
hf_text_config.max_position_embeddings = max_trained_positions
hf_text_config.rope_scaling = config.rotary_kwargs["rope_scaling"]
# The priority of sentence_bert_config.json is higher
# than max_position_embeddings
encoder_config = deepcopy(model_config.encoder_config)
encoder_config.pop("max_seq_length", None)
model_config.encoder_config = encoder_config
vllm_config.recalculate_max_model_len(max_model_len)
return config
class GteNewModel(BertWithRope):
# for https://huggingface.co/Alibaba-NLP/new-impl
@ -600,24 +504,6 @@ class GteNewModel(BertWithRope):
layer.mlp.gate_up_proj.bias = None
layer.mlp.gate_up_proj.skip_bias_add = True
def config_verify(self, vllm_config):
config = vllm_config.model_config.hf_config
assert config.__class__.__name__ == "NewConfig"
assert config.hidden_act == "gelu"
config.hidden_act = "geglu"
head_dim = config.hidden_size // config.num_attention_heads
config.rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_position_embeddings,
"base": config.rope_theta,
"rope_scaling": getattr(config, "rope_scaling", None)
}
return config
def split_up_gate_proj(self, weights: Iterable[tuple[str, torch.Tensor]]):
n = "mlp.up_gate_proj"
for name, weight in weights:
@ -652,24 +538,6 @@ class SnowflakeGteNewModel(GteNewModel):
"attention.o_proj": "attn.out_proj",
})
def config_verify(self, vllm_config):
config = vllm_config.model_config.hf_config
assert config.__class__.__name__ == "GteConfig"
assert config.hidden_act == "gelu"
config.hidden_act = "geglu"
head_dim = config.hidden_size // config.num_attention_heads
config.rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_position_embeddings,
"base": config.rope_theta,
"rope_scaling": getattr(config, "rope_scaling", None)
}
return config
class JinaRobertaModel(BertWithRope):
# for https://huggingface.co/jinaai/jina-embeddings-v3
@ -685,21 +553,6 @@ class JinaRobertaModel(BertWithRope):
"norm2": "mlp_ln",
})
def config_verify(self, vllm_config):
config = vllm_config.model_config.hf_config
assert config.__class__.__name__ == "XLMRobertaFlashConfig"
head_dim = config.hidden_size // config.num_attention_heads
config.rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_position_embeddings,
"base": getattr(config, "rope_theta", config.rotary_emb_base),
"rope_scaling": getattr(config, "rope_scaling", None)
}
return config
def forward(
self,
input_ids: torch.Tensor,

View File

@ -0,0 +1,200 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from copy import deepcopy
from typing import TYPE_CHECKING
from vllm.logger import init_logger
if TYPE_CHECKING:
from vllm.config import VllmConfig
logger = init_logger(__name__)
class VerifyAndUpdateConfig:
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
raise NotImplementedError
class GteNewModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
config = vllm_config.model_config.hf_config
assert config.__class__.__name__ == "NewConfig"
assert config.hidden_act == "gelu"
config.hidden_act = "geglu"
head_dim = config.hidden_size // config.num_attention_heads
config.rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_position_embeddings,
"base": config.rope_theta,
"rope_scaling": getattr(config, "rope_scaling", None)
}
class JinaRobertaModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
config = vllm_config.model_config.hf_config
if config.position_embedding_type == "rotary":
assert config.__class__.__name__ == "XLMRobertaFlashConfig"
head_dim = config.hidden_size // config.num_attention_heads
config.rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_position_embeddings,
"base": getattr(config, "rope_theta", config.rotary_emb_base),
"rope_scaling": getattr(config, "rope_scaling", None)
}
class NomicBertModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
config = vllm_config.model_config.hf_config
assert config.__class__.__name__ == "NomicBertConfig"
assert config.activation_function in ["swiglu", "gelu"]
config.position_embedding_type = getattr(config,
"position_embedding_type",
"rope")
if config.activation_function == "swiglu":
config.hidden_act = "silu"
else:
config.hidden_act = config.activation_function
assert (config.mlp_fc1_bias == config.mlp_fc2_bias ==
config.qkv_proj_bias)
config.bias = config.qkv_proj_bias
assert config.rotary_emb_scale_base is None
assert not config.rotary_emb_interleaved
config.layer_norm_eps = config.layer_norm_epsilon
config.intermediate_size = config.n_inner
config.hidden_size = config.n_embd
config.num_hidden_layers = config.n_layer
head_dim = config.hidden_size // config.num_attention_heads
rotary_emb_dim = head_dim * config.rotary_emb_fraction
max_trained_positions = getattr(config, "max_trained_positions", 2048)
config.rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": rotary_emb_dim,
"max_position": max_trained_positions,
"base": getattr(config, "rope_theta", config.rotary_emb_base),
"rope_scaling": getattr(config, "rope_scaling", None)
}
# we ignore config.rotary_scaling_factor so that for datasets shorter
# than max_trained_positions 2048, the results are consistent
# with SentenceTransformer.
# The context extension uses vllm style rope_theta and rope_scaling.
# See #17785 #18755
if (not vllm_config.model_config.hf_overrides
and vllm_config.model_config.original_max_model_len is None):
# Default
# Reset max_model_len to max_trained_positions.
# nomic-embed-text-v2-moe the length is set to 512
# by sentence_bert_config.json.
max_model_len_before = vllm_config.model_config.max_model_len
max_model_len = min(vllm_config.model_config.max_model_len,
max_trained_positions)
vllm_config.recalculate_max_model_len(max_model_len)
logger.warning(
"Nomic context extension is disabled. "
"Changing max_model_len from %s to %s. "
"To enable context extension, see: "
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.html",
max_model_len_before, vllm_config.model_config.max_model_len)
else:
# We need to re-verify max_model_len to avoid lengths
# greater than position_embedding.
model_config = vllm_config.model_config
hf_text_config = model_config.hf_text_config
if isinstance(model_config.hf_overrides, dict):
# hf_overrides_kw
max_model_len = model_config.hf_overrides.get(
"max_model_len", vllm_config.model_config.max_model_len)
else:
# hf_overrides_fn
# This might be overridden by sentence_bert_config.json.
max_model_len = vllm_config.model_config.max_model_len
# reset hf_text_config for recalculate_max_model_len.
if hasattr(hf_text_config, "max_model_len"):
delattr(hf_text_config, "max_model_len")
hf_text_config.max_position_embeddings = max_trained_positions
hf_text_config.rope_scaling = config.rotary_kwargs["rope_scaling"]
# The priority of sentence_bert_config.json is higher
# than max_position_embeddings
encoder_config = deepcopy(model_config.encoder_config)
encoder_config.pop("max_seq_length", None)
model_config.encoder_config = encoder_config
vllm_config.recalculate_max_model_len(max_model_len)
class Qwen3ForSequenceClassificationConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
config = vllm_config.model_config.hf_config
is_original_qwen3_reranker = getattr(config,
"is_original_qwen3_reranker",
False)
if not is_original_qwen3_reranker:
return
tokens = getattr(config, "classifier_from_token", None)
assert tokens is not None and len(tokens) == 2, \
("Try loading the original Qwen3 Reranker?, see: "
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py")
config.num_labels = 1
class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
config = vllm_config.model_config.hf_config
assert config.__class__.__name__ == "GteConfig"
assert config.hidden_act == "gelu"
config.hidden_act = "geglu"
head_dim = config.hidden_size // config.num_attention_heads
config.rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_position_embeddings,
"base": config.rope_theta,
"rope_scaling": getattr(config, "rope_scaling", None)
}
MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"GteModel": SnowflakeGteNewModelConfig,
"GteNewModel": GteNewModelConfig,
"NomicBertModel": NomicBertModelConfig,
"Qwen3ForSequenceClassification": Qwen3ForSequenceClassificationConfig,
"XLMRobertaModel": JinaRobertaModelConfig,
}

View File

@ -400,22 +400,10 @@ class Qwen3ForSequenceClassification(nn.Module, SupportsLoRA,
def load_weights_from_original_qwen3_reranker(
self, weights: Iterable[tuple[str, torch.Tensor]]):
tokens = getattr(self.config, "classifier_from_token", None)
assert tokens is not None and len(tokens) == 2, \
("Try loading the original Qwen3 Reranker?, see: "
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py")
self.config.num_labels = 1
model_config = self.vllm_config.model_config
tokens = getattr(self.config, "classifier_from_token", None)
device = self.score.weight.device
self.score = RowParallelLinear(self.config.hidden_size,
self.config.num_labels,
quant_config=self.quant_config,
input_is_parallel=False,
bias=False,
prefix=maybe_prefix(
self.prefix, "score")).to(device)
if self.config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
@ -443,5 +431,6 @@ class Qwen3ForSequenceClassification(nn.Module, SupportsLoRA,
self.score.weight.data.copy_(weight)
del self.lm_head
loaded_weights.add("classifier.weight")
loaded_weights.add("score.weight")
loaded_weights.discard("lm_head.weight")
return loaded_weights