mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 09:35:23 +08:00
[Model][1/N] Automatic conversion of CrossEncoding model (#20012)
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
parent
e110930680
commit
cd4cfee689
@ -43,7 +43,7 @@ class VllmMtebEncoder(mteb.Encoder):
|
|||||||
# issues by randomizing the order.
|
# issues by randomizing the order.
|
||||||
r = self.rng.permutation(len(sentences))
|
r = self.rng.permutation(len(sentences))
|
||||||
sentences = [sentences[i] for i in r]
|
sentences = [sentences[i] for i in r]
|
||||||
outputs = self.model.encode(sentences, use_tqdm=False)
|
outputs = self.model.embed(sentences, use_tqdm=False)
|
||||||
embeds = np.array(outputs)
|
embeds = np.array(outputs)
|
||||||
embeds = embeds[np.argsort(r)]
|
embeds = embeds[np.argsort(r)]
|
||||||
return embeds
|
return embeds
|
||||||
@ -250,16 +250,19 @@ def mteb_test_rerank_models(hf_runner,
|
|||||||
with vllm_runner(model_info.name,
|
with vllm_runner(model_info.name,
|
||||||
task="score",
|
task="score",
|
||||||
max_model_len=None,
|
max_model_len=None,
|
||||||
|
max_num_seqs=8,
|
||||||
**vllm_extra_kwargs) as vllm_model:
|
**vllm_extra_kwargs) as vllm_model:
|
||||||
|
|
||||||
|
model_config = vllm_model.model.llm_engine.model_config
|
||||||
|
|
||||||
if model_info.architecture:
|
if model_info.architecture:
|
||||||
assert (model_info.architecture
|
assert (model_info.architecture in model_config.architectures)
|
||||||
in vllm_model.model.llm_engine.model_config.architectures)
|
assert model_config.hf_config.num_labels == 1
|
||||||
|
|
||||||
vllm_main_score = run_mteb_rerank(VllmMtebEncoder(vllm_model),
|
vllm_main_score = run_mteb_rerank(VllmMtebEncoder(vllm_model),
|
||||||
tasks=MTEB_RERANK_TASKS,
|
tasks=MTEB_RERANK_TASKS,
|
||||||
languages=MTEB_RERANK_LANGS)
|
languages=MTEB_RERANK_LANGS)
|
||||||
vllm_dtype = vllm_model.model.llm_engine.model_config.dtype
|
vllm_dtype = model_config.dtype
|
||||||
|
|
||||||
with hf_runner(model_info.name, is_cross_encoder=True,
|
with hf_runner(model_info.name, is_cross_encoder=True,
|
||||||
dtype="float32") as hf_model:
|
dtype="float32") as hf_model:
|
||||||
|
|||||||
@ -569,6 +569,10 @@ class ModelConfig:
|
|||||||
else:
|
else:
|
||||||
self.truncation_side = "right"
|
self.truncation_side = "right"
|
||||||
|
|
||||||
|
model_info, arch = self.registry.inspect_model_cls(self.architectures)
|
||||||
|
self._model_info = model_info
|
||||||
|
self._architecture = arch
|
||||||
|
|
||||||
self.pooler_config = self._init_pooler_config()
|
self.pooler_config = self._init_pooler_config()
|
||||||
|
|
||||||
self.dtype = _get_and_verify_dtype(
|
self.dtype = _get_and_verify_dtype(
|
||||||
@ -660,8 +664,18 @@ class ModelConfig:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def architectures(self) -> list[str]:
|
def architectures(self) -> list[str]:
|
||||||
|
# architectures in the model config.
|
||||||
return getattr(self.hf_config, "architectures", [])
|
return getattr(self.hf_config, "architectures", [])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def architecture(self) -> str:
|
||||||
|
# The architecture vllm actually used.
|
||||||
|
return self._architecture
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_info(self) -> dict[str, Any]:
|
||||||
|
return self._model_info
|
||||||
|
|
||||||
def maybe_pull_model_tokenizer_for_s3(self, model: str,
|
def maybe_pull_model_tokenizer_for_s3(self, model: str,
|
||||||
tokenizer: str) -> None:
|
tokenizer: str) -> None:
|
||||||
"""Pull model/tokenizer from S3 to temporary directory when needed.
|
"""Pull model/tokenizer from S3 to temporary directory when needed.
|
||||||
@ -4450,6 +4464,9 @@ class VllmConfig:
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""Verify configs are valid & consistent with each other.
|
"""Verify configs are valid & consistent with each other.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
self.try_verify_and_update_config()
|
||||||
|
|
||||||
if self.model_config is not None:
|
if self.model_config is not None:
|
||||||
self.model_config.verify_async_output_proc(self.parallel_config,
|
self.model_config.verify_async_output_proc(self.parallel_config,
|
||||||
self.speculative_config,
|
self.speculative_config,
|
||||||
@ -4694,11 +4711,21 @@ class VllmConfig:
|
|||||||
batch_size_capture_list)
|
batch_size_capture_list)
|
||||||
|
|
||||||
def recalculate_max_model_len(self, max_model_len: int):
|
def recalculate_max_model_len(self, max_model_len: int):
|
||||||
|
# Can only be called in try_verify_and_update_config
|
||||||
model_config = self.model_config
|
model_config = self.model_config
|
||||||
max_model_len = model_config.get_and_verify_max_len(max_model_len)
|
max_model_len = model_config.get_and_verify_max_len(max_model_len)
|
||||||
self.model_config.max_model_len = max_model_len
|
self.model_config.max_model_len = max_model_len
|
||||||
self.scheduler_config.max_model_len = max_model_len
|
self.scheduler_config.max_model_len = max_model_len
|
||||||
self.compute_hash()
|
|
||||||
|
def try_verify_and_update_config(self):
|
||||||
|
architecture = getattr(self.model_config, "architecture", None)
|
||||||
|
if architecture is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
from vllm.model_executor.models.config import MODELS_CONFIG_MAP
|
||||||
|
cls = MODELS_CONFIG_MAP.get(architecture, None)
|
||||||
|
if cls is not None:
|
||||||
|
cls.verify_and_update_config(self)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return (
|
return (
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from copy import deepcopy
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -12,7 +11,6 @@ from vllm.attention import Attention, AttentionType
|
|||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.logger import init_logger
|
|
||||||
from vllm.model_executor.layers.activation import (get_act_and_mul_fn,
|
from vllm.model_executor.layers.activation import (get_act_and_mul_fn,
|
||||||
get_act_fn)
|
get_act_fn)
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
@ -30,8 +28,6 @@ from vllm.model_executor.models.interfaces import SupportsQuant
|
|||||||
from vllm.model_executor.models.utils import WeightsMapper
|
from vllm.model_executor.models.utils import WeightsMapper
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class BertWithRopeEmbedding(nn.Module):
|
class BertWithRopeEmbedding(nn.Module):
|
||||||
|
|
||||||
@ -408,7 +404,7 @@ class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant):
|
|||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
self.config = self.config_verify(vllm_config)
|
self.config = vllm_config.model_config.hf_config
|
||||||
self.embeddings = BertWithRopeEmbedding(self.config)
|
self.embeddings = BertWithRopeEmbedding(self.config)
|
||||||
self.encoder = BertWithRopeEncoder(
|
self.encoder = BertWithRopeEncoder(
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
@ -416,9 +412,6 @@ class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant):
|
|||||||
rotary_kwargs=self.config.rotary_kwargs,
|
rotary_kwargs=self.config.rotary_kwargs,
|
||||||
prefix=f"{prefix}.encoder")
|
prefix=f"{prefix}.encoder")
|
||||||
|
|
||||||
def config_verify(self, vllm_config):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.Tensor],
|
input_ids: Optional[torch.Tensor],
|
||||||
@ -490,95 +483,6 @@ class NomicBertModel(BertWithRope):
|
|||||||
"norm2": "mlp_ln",
|
"norm2": "mlp_ln",
|
||||||
})
|
})
|
||||||
|
|
||||||
def config_verify(self, vllm_config):
|
|
||||||
config = vllm_config.model_config.hf_config
|
|
||||||
|
|
||||||
assert config.__class__.__name__ == "NomicBertConfig"
|
|
||||||
assert config.activation_function in ["swiglu", "gelu"]
|
|
||||||
config.position_embedding_type = getattr(config,
|
|
||||||
"position_embedding_type",
|
|
||||||
"rope")
|
|
||||||
|
|
||||||
if config.activation_function == "swiglu":
|
|
||||||
config.hidden_act = "silu"
|
|
||||||
else:
|
|
||||||
config.hidden_act = config.activation_function
|
|
||||||
|
|
||||||
assert (config.mlp_fc1_bias == config.mlp_fc2_bias ==
|
|
||||||
config.qkv_proj_bias)
|
|
||||||
config.bias = config.qkv_proj_bias
|
|
||||||
|
|
||||||
assert config.rotary_emb_scale_base is None
|
|
||||||
assert not config.rotary_emb_interleaved
|
|
||||||
|
|
||||||
config.layer_norm_eps = config.layer_norm_epsilon
|
|
||||||
config.intermediate_size = config.n_inner
|
|
||||||
config.hidden_size = config.n_embd
|
|
||||||
config.num_hidden_layers = config.n_layer
|
|
||||||
|
|
||||||
head_dim = config.hidden_size // config.num_attention_heads
|
|
||||||
rotary_emb_dim = head_dim * config.rotary_emb_fraction
|
|
||||||
max_trained_positions = getattr(config, "max_trained_positions", 2048)
|
|
||||||
config.rotary_kwargs = {
|
|
||||||
"head_size": head_dim,
|
|
||||||
"rotary_dim": rotary_emb_dim,
|
|
||||||
"max_position": max_trained_positions,
|
|
||||||
"base": getattr(config, "rope_theta", config.rotary_emb_base),
|
|
||||||
"rope_scaling": getattr(config, "rope_scaling", None)
|
|
||||||
}
|
|
||||||
|
|
||||||
# we ignore config.rotary_scaling_factor so that for datasets shorter
|
|
||||||
# than max_trained_positions 2048, the results are consistent
|
|
||||||
# with SentenceTransformer.
|
|
||||||
# The context extension uses vllm style rope_theta and rope_scaling.
|
|
||||||
# See #17785 #18755
|
|
||||||
if (not vllm_config.model_config.hf_overrides
|
|
||||||
and vllm_config.model_config.original_max_model_len is None):
|
|
||||||
# Default
|
|
||||||
# Reset max_model_len to max_trained_positions.
|
|
||||||
# nomic-embed-text-v2-moe the length is set to 512
|
|
||||||
# by sentence_bert_config.json.
|
|
||||||
max_model_len_before = vllm_config.model_config.max_model_len
|
|
||||||
max_model_len = min(vllm_config.model_config.max_model_len,
|
|
||||||
max_trained_positions)
|
|
||||||
|
|
||||||
vllm_config.recalculate_max_model_len(max_model_len)
|
|
||||||
logger.warning(
|
|
||||||
"Nomic context extension is disabled. "
|
|
||||||
"Changing max_model_len from %s to %s. "
|
|
||||||
"To enable context extension, see: "
|
|
||||||
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.html",
|
|
||||||
max_model_len_before, vllm_config.model_config.max_model_len)
|
|
||||||
else:
|
|
||||||
# We need to re-verify max_model_len to avoid lengths
|
|
||||||
# greater than position_embedding.
|
|
||||||
model_config = vllm_config.model_config
|
|
||||||
hf_text_config = model_config.hf_text_config
|
|
||||||
|
|
||||||
if isinstance(model_config.hf_overrides, dict):
|
|
||||||
# hf_overrides_kw
|
|
||||||
max_model_len = model_config.hf_overrides.get(
|
|
||||||
"max_model_len", vllm_config.model_config.max_model_len)
|
|
||||||
else:
|
|
||||||
# hf_overrides_fn
|
|
||||||
# This might be overridden by sentence_bert_config.json.
|
|
||||||
max_model_len = vllm_config.model_config.max_model_len
|
|
||||||
|
|
||||||
# reset hf_text_config for recalculate_max_model_len.
|
|
||||||
if hasattr(hf_text_config, "max_model_len"):
|
|
||||||
delattr(hf_text_config, "max_model_len")
|
|
||||||
hf_text_config.max_position_embeddings = max_trained_positions
|
|
||||||
hf_text_config.rope_scaling = config.rotary_kwargs["rope_scaling"]
|
|
||||||
|
|
||||||
# The priority of sentence_bert_config.json is higher
|
|
||||||
# than max_position_embeddings
|
|
||||||
encoder_config = deepcopy(model_config.encoder_config)
|
|
||||||
encoder_config.pop("max_seq_length", None)
|
|
||||||
model_config.encoder_config = encoder_config
|
|
||||||
|
|
||||||
vllm_config.recalculate_max_model_len(max_model_len)
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
class GteNewModel(BertWithRope):
|
class GteNewModel(BertWithRope):
|
||||||
# for https://huggingface.co/Alibaba-NLP/new-impl
|
# for https://huggingface.co/Alibaba-NLP/new-impl
|
||||||
@ -600,24 +504,6 @@ class GteNewModel(BertWithRope):
|
|||||||
layer.mlp.gate_up_proj.bias = None
|
layer.mlp.gate_up_proj.bias = None
|
||||||
layer.mlp.gate_up_proj.skip_bias_add = True
|
layer.mlp.gate_up_proj.skip_bias_add = True
|
||||||
|
|
||||||
def config_verify(self, vllm_config):
|
|
||||||
config = vllm_config.model_config.hf_config
|
|
||||||
|
|
||||||
assert config.__class__.__name__ == "NewConfig"
|
|
||||||
assert config.hidden_act == "gelu"
|
|
||||||
|
|
||||||
config.hidden_act = "geglu"
|
|
||||||
|
|
||||||
head_dim = config.hidden_size // config.num_attention_heads
|
|
||||||
config.rotary_kwargs = {
|
|
||||||
"head_size": head_dim,
|
|
||||||
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
|
|
||||||
"max_position": config.max_position_embeddings,
|
|
||||||
"base": config.rope_theta,
|
|
||||||
"rope_scaling": getattr(config, "rope_scaling", None)
|
|
||||||
}
|
|
||||||
return config
|
|
||||||
|
|
||||||
def split_up_gate_proj(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
def split_up_gate_proj(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||||
n = "mlp.up_gate_proj"
|
n = "mlp.up_gate_proj"
|
||||||
for name, weight in weights:
|
for name, weight in weights:
|
||||||
@ -652,24 +538,6 @@ class SnowflakeGteNewModel(GteNewModel):
|
|||||||
"attention.o_proj": "attn.out_proj",
|
"attention.o_proj": "attn.out_proj",
|
||||||
})
|
})
|
||||||
|
|
||||||
def config_verify(self, vllm_config):
|
|
||||||
config = vllm_config.model_config.hf_config
|
|
||||||
|
|
||||||
assert config.__class__.__name__ == "GteConfig"
|
|
||||||
assert config.hidden_act == "gelu"
|
|
||||||
|
|
||||||
config.hidden_act = "geglu"
|
|
||||||
|
|
||||||
head_dim = config.hidden_size // config.num_attention_heads
|
|
||||||
config.rotary_kwargs = {
|
|
||||||
"head_size": head_dim,
|
|
||||||
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
|
|
||||||
"max_position": config.max_position_embeddings,
|
|
||||||
"base": config.rope_theta,
|
|
||||||
"rope_scaling": getattr(config, "rope_scaling", None)
|
|
||||||
}
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
class JinaRobertaModel(BertWithRope):
|
class JinaRobertaModel(BertWithRope):
|
||||||
# for https://huggingface.co/jinaai/jina-embeddings-v3
|
# for https://huggingface.co/jinaai/jina-embeddings-v3
|
||||||
@ -685,21 +553,6 @@ class JinaRobertaModel(BertWithRope):
|
|||||||
"norm2": "mlp_ln",
|
"norm2": "mlp_ln",
|
||||||
})
|
})
|
||||||
|
|
||||||
def config_verify(self, vllm_config):
|
|
||||||
config = vllm_config.model_config.hf_config
|
|
||||||
|
|
||||||
assert config.__class__.__name__ == "XLMRobertaFlashConfig"
|
|
||||||
|
|
||||||
head_dim = config.hidden_size // config.num_attention_heads
|
|
||||||
config.rotary_kwargs = {
|
|
||||||
"head_size": head_dim,
|
|
||||||
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
|
|
||||||
"max_position": config.max_position_embeddings,
|
|
||||||
"base": getattr(config, "rope_theta", config.rotary_emb_base),
|
|
||||||
"rope_scaling": getattr(config, "rope_scaling", None)
|
|
||||||
}
|
|
||||||
return config
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
|
|||||||
200
vllm/model_executor/models/config.py
Normal file
200
vllm/model_executor/models/config.py
Normal file
@ -0,0 +1,200 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class VerifyAndUpdateConfig:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class GteNewModelConfig(VerifyAndUpdateConfig):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
|
||||||
|
assert config.__class__.__name__ == "NewConfig"
|
||||||
|
assert config.hidden_act == "gelu"
|
||||||
|
|
||||||
|
config.hidden_act = "geglu"
|
||||||
|
|
||||||
|
head_dim = config.hidden_size // config.num_attention_heads
|
||||||
|
config.rotary_kwargs = {
|
||||||
|
"head_size": head_dim,
|
||||||
|
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
|
||||||
|
"max_position": config.max_position_embeddings,
|
||||||
|
"base": config.rope_theta,
|
||||||
|
"rope_scaling": getattr(config, "rope_scaling", None)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class JinaRobertaModelConfig(VerifyAndUpdateConfig):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
|
||||||
|
if config.position_embedding_type == "rotary":
|
||||||
|
assert config.__class__.__name__ == "XLMRobertaFlashConfig"
|
||||||
|
|
||||||
|
head_dim = config.hidden_size // config.num_attention_heads
|
||||||
|
config.rotary_kwargs = {
|
||||||
|
"head_size": head_dim,
|
||||||
|
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
|
||||||
|
"max_position": config.max_position_embeddings,
|
||||||
|
"base": getattr(config, "rope_theta", config.rotary_emb_base),
|
||||||
|
"rope_scaling": getattr(config, "rope_scaling", None)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class NomicBertModelConfig(VerifyAndUpdateConfig):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
|
||||||
|
assert config.__class__.__name__ == "NomicBertConfig"
|
||||||
|
assert config.activation_function in ["swiglu", "gelu"]
|
||||||
|
config.position_embedding_type = getattr(config,
|
||||||
|
"position_embedding_type",
|
||||||
|
"rope")
|
||||||
|
|
||||||
|
if config.activation_function == "swiglu":
|
||||||
|
config.hidden_act = "silu"
|
||||||
|
else:
|
||||||
|
config.hidden_act = config.activation_function
|
||||||
|
|
||||||
|
assert (config.mlp_fc1_bias == config.mlp_fc2_bias ==
|
||||||
|
config.qkv_proj_bias)
|
||||||
|
config.bias = config.qkv_proj_bias
|
||||||
|
|
||||||
|
assert config.rotary_emb_scale_base is None
|
||||||
|
assert not config.rotary_emb_interleaved
|
||||||
|
|
||||||
|
config.layer_norm_eps = config.layer_norm_epsilon
|
||||||
|
config.intermediate_size = config.n_inner
|
||||||
|
config.hidden_size = config.n_embd
|
||||||
|
config.num_hidden_layers = config.n_layer
|
||||||
|
|
||||||
|
head_dim = config.hidden_size // config.num_attention_heads
|
||||||
|
rotary_emb_dim = head_dim * config.rotary_emb_fraction
|
||||||
|
max_trained_positions = getattr(config, "max_trained_positions", 2048)
|
||||||
|
config.rotary_kwargs = {
|
||||||
|
"head_size": head_dim,
|
||||||
|
"rotary_dim": rotary_emb_dim,
|
||||||
|
"max_position": max_trained_positions,
|
||||||
|
"base": getattr(config, "rope_theta", config.rotary_emb_base),
|
||||||
|
"rope_scaling": getattr(config, "rope_scaling", None)
|
||||||
|
}
|
||||||
|
|
||||||
|
# we ignore config.rotary_scaling_factor so that for datasets shorter
|
||||||
|
# than max_trained_positions 2048, the results are consistent
|
||||||
|
# with SentenceTransformer.
|
||||||
|
# The context extension uses vllm style rope_theta and rope_scaling.
|
||||||
|
# See #17785 #18755
|
||||||
|
if (not vllm_config.model_config.hf_overrides
|
||||||
|
and vllm_config.model_config.original_max_model_len is None):
|
||||||
|
# Default
|
||||||
|
# Reset max_model_len to max_trained_positions.
|
||||||
|
# nomic-embed-text-v2-moe the length is set to 512
|
||||||
|
# by sentence_bert_config.json.
|
||||||
|
max_model_len_before = vllm_config.model_config.max_model_len
|
||||||
|
max_model_len = min(vllm_config.model_config.max_model_len,
|
||||||
|
max_trained_positions)
|
||||||
|
|
||||||
|
vllm_config.recalculate_max_model_len(max_model_len)
|
||||||
|
logger.warning(
|
||||||
|
"Nomic context extension is disabled. "
|
||||||
|
"Changing max_model_len from %s to %s. "
|
||||||
|
"To enable context extension, see: "
|
||||||
|
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.html",
|
||||||
|
max_model_len_before, vllm_config.model_config.max_model_len)
|
||||||
|
else:
|
||||||
|
# We need to re-verify max_model_len to avoid lengths
|
||||||
|
# greater than position_embedding.
|
||||||
|
model_config = vllm_config.model_config
|
||||||
|
hf_text_config = model_config.hf_text_config
|
||||||
|
|
||||||
|
if isinstance(model_config.hf_overrides, dict):
|
||||||
|
# hf_overrides_kw
|
||||||
|
max_model_len = model_config.hf_overrides.get(
|
||||||
|
"max_model_len", vllm_config.model_config.max_model_len)
|
||||||
|
else:
|
||||||
|
# hf_overrides_fn
|
||||||
|
# This might be overridden by sentence_bert_config.json.
|
||||||
|
max_model_len = vllm_config.model_config.max_model_len
|
||||||
|
|
||||||
|
# reset hf_text_config for recalculate_max_model_len.
|
||||||
|
if hasattr(hf_text_config, "max_model_len"):
|
||||||
|
delattr(hf_text_config, "max_model_len")
|
||||||
|
hf_text_config.max_position_embeddings = max_trained_positions
|
||||||
|
hf_text_config.rope_scaling = config.rotary_kwargs["rope_scaling"]
|
||||||
|
|
||||||
|
# The priority of sentence_bert_config.json is higher
|
||||||
|
# than max_position_embeddings
|
||||||
|
encoder_config = deepcopy(model_config.encoder_config)
|
||||||
|
encoder_config.pop("max_seq_length", None)
|
||||||
|
model_config.encoder_config = encoder_config
|
||||||
|
|
||||||
|
vllm_config.recalculate_max_model_len(max_model_len)
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3ForSequenceClassificationConfig(VerifyAndUpdateConfig):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
|
||||||
|
is_original_qwen3_reranker = getattr(config,
|
||||||
|
"is_original_qwen3_reranker",
|
||||||
|
False)
|
||||||
|
|
||||||
|
if not is_original_qwen3_reranker:
|
||||||
|
return
|
||||||
|
|
||||||
|
tokens = getattr(config, "classifier_from_token", None)
|
||||||
|
assert tokens is not None and len(tokens) == 2, \
|
||||||
|
("Try loading the original Qwen3 Reranker?, see: "
|
||||||
|
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py")
|
||||||
|
config.num_labels = 1
|
||||||
|
|
||||||
|
|
||||||
|
class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
|
||||||
|
assert config.__class__.__name__ == "GteConfig"
|
||||||
|
assert config.hidden_act == "gelu"
|
||||||
|
|
||||||
|
config.hidden_act = "geglu"
|
||||||
|
|
||||||
|
head_dim = config.hidden_size // config.num_attention_heads
|
||||||
|
config.rotary_kwargs = {
|
||||||
|
"head_size": head_dim,
|
||||||
|
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
|
||||||
|
"max_position": config.max_position_embeddings,
|
||||||
|
"base": config.rope_theta,
|
||||||
|
"rope_scaling": getattr(config, "rope_scaling", None)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
||||||
|
"GteModel": SnowflakeGteNewModelConfig,
|
||||||
|
"GteNewModel": GteNewModelConfig,
|
||||||
|
"NomicBertModel": NomicBertModelConfig,
|
||||||
|
"Qwen3ForSequenceClassification": Qwen3ForSequenceClassificationConfig,
|
||||||
|
"XLMRobertaModel": JinaRobertaModelConfig,
|
||||||
|
}
|
||||||
@ -400,22 +400,10 @@ class Qwen3ForSequenceClassification(nn.Module, SupportsLoRA,
|
|||||||
|
|
||||||
def load_weights_from_original_qwen3_reranker(
|
def load_weights_from_original_qwen3_reranker(
|
||||||
self, weights: Iterable[tuple[str, torch.Tensor]]):
|
self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||||
tokens = getattr(self.config, "classifier_from_token", None)
|
|
||||||
assert tokens is not None and len(tokens) == 2, \
|
|
||||||
("Try loading the original Qwen3 Reranker?, see: "
|
|
||||||
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py")
|
|
||||||
|
|
||||||
self.config.num_labels = 1
|
|
||||||
model_config = self.vllm_config.model_config
|
model_config = self.vllm_config.model_config
|
||||||
|
tokens = getattr(self.config, "classifier_from_token", None)
|
||||||
device = self.score.weight.device
|
device = self.score.weight.device
|
||||||
self.score = RowParallelLinear(self.config.hidden_size,
|
|
||||||
self.config.num_labels,
|
|
||||||
quant_config=self.quant_config,
|
|
||||||
input_is_parallel=False,
|
|
||||||
bias=False,
|
|
||||||
prefix=maybe_prefix(
|
|
||||||
self.prefix, "score")).to(device)
|
|
||||||
|
|
||||||
if self.config.tie_word_embeddings:
|
if self.config.tie_word_embeddings:
|
||||||
self.lm_head = self.model.embed_tokens
|
self.lm_head = self.model.embed_tokens
|
||||||
@ -443,5 +431,6 @@ class Qwen3ForSequenceClassification(nn.Module, SupportsLoRA,
|
|||||||
self.score.weight.data.copy_(weight)
|
self.score.weight.data.copy_(weight)
|
||||||
|
|
||||||
del self.lm_head
|
del self.lm_head
|
||||||
loaded_weights.add("classifier.weight")
|
loaded_weights.add("score.weight")
|
||||||
loaded_weights.discard("lm_head.weight")
|
loaded_weights.discard("lm_head.weight")
|
||||||
|
return loaded_weights
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user