[New Model]: support GTE NewModel (#17986)

This commit is contained in:
wang.yuqi 2025-05-14 16:31:31 +08:00 committed by GitHub
parent e7ef61c1f0
commit 63ad622233
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 279 additions and 32 deletions

View File

@ -701,12 +701,22 @@ Specified using `--task embed`.
* ✅︎
* ✅︎
- * `GteModel`
* GteModel
* Arctic-Embed-2.0-M
* `Snowflake/snowflake-arctic-embed-m-v2.0`.
*
*
- * `GteNewModel`
* mGTE-TRM (see note)
* `Alibaba-NLP/gte-multilingual-base`, etc.
*
*
- * `ModernBertModel`
* ModernBERT-based
* `Alibaba-NLP/gte-modernbert-base`, etc.
*
*
- * `NomicBertModel`
* NomicBertModel
* Nomic BERT
* `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc.
*
*
@ -749,6 +759,10 @@ See [relevant issue on HF Transformers](https://github.com/huggingface/transform
`jinaai/jina-embeddings-v3` supports multiple tasks through lora, while vllm temporarily only supports text-matching tasks by merging lora weights.
:::
:::{note}
The second-generation GTE model (mGTE-TRM) is named `NewModel`. The name `NewModel` is too generic, you should set `--hf-overrides '{"architectures": ["GteNewModel"]}'` to specify the use of the `GteNewModel` architecture.
:::
If your model is not in the above list, we will try to automatically convert the model using
{func}`~vllm.model_executor.models.adapters.as_embedding_model`. By default, the embeddings
of the whole prompt are extracted from the normalized hidden state corresponding to the last token.

View File

@ -7,6 +7,7 @@ import numpy as np
import pytest
from tests.models.utils import EmbedModelInfo
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
# Most models on the STS12 task (See #17175):
# - Model implementation and minor changes in tensor dtype
@ -77,16 +78,22 @@ def run_mteb_embed_task_st(model_name, tasks):
return run_mteb_embed_task(model, tasks)
def mteb_test_embed_models(hf_runner, vllm_runner, model_info: EmbedModelInfo):
def mteb_test_embed_models(hf_runner,
vllm_runner,
model_info: EmbedModelInfo,
vllm_extra_kwargs=None):
if not model_info.enable_test:
# A model family has many models with the same architecture,
# and we don't need to test each one.
pytest.skip("Skipping test.")
vllm_extra_kwargs = vllm_extra_kwargs or {}
with vllm_runner(model_info.name,
task="embed",
max_model_len=None,
dtype=model_info.dtype) as vllm_model:
dtype=model_info.dtype,
**vllm_extra_kwargs) as vllm_model:
if model_info.architecture:
assert (model_info.architecture
@ -99,9 +106,9 @@ def mteb_test_embed_models(hf_runner, vllm_runner, model_info: EmbedModelInfo):
vllm_model.model.llm_engine.model_config.hf_config, "torch_dtype",
vllm_dtype)
with hf_runner(model_info.name,
is_sentence_transformer=True,
dtype=model_dtype) as hf_model:
with set_default_torch_dtype(model_dtype) and hf_runner(
model_info.name, is_sentence_transformer=True,
dtype=model_dtype) as hf_model:
st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS)
print("VLLM:", vllm_dtype, vllm_main_score)

View File

@ -0,0 +1,104 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any
import pytest
from ...utils import EmbedModelInfo, run_embedding_correctness_test
MODELS = [
########## BertModel
EmbedModelInfo("thenlper/gte-large",
architecture="BertModel",
dtype="float32",
enable_test=True),
EmbedModelInfo("thenlper/gte-base",
architecture="BertModel",
dtype="float32",
enable_test=False),
EmbedModelInfo("thenlper/gte-small",
architecture="BertModel",
dtype="float32",
enable_test=False),
EmbedModelInfo("thenlper/gte-large-zh",
architecture="BertModel",
dtype="float32",
enable_test=False),
EmbedModelInfo("thenlper/gte-base-zh",
architecture="BertModel",
dtype="float32",
enable_test=False),
EmbedModelInfo("thenlper/gte-small-zh",
architecture="BertModel",
dtype="float32",
enable_test=False),
########### NewModel
EmbedModelInfo("Alibaba-NLP/gte-multilingual-base",
architecture="GteNewModel",
enable_test=True),
EmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5",
architecture="GteNewModel",
enable_test=True),
EmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5",
architecture="GteNewModel",
enable_test=True),
########### Qwen2ForCausalLM
EmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct",
architecture="Qwen2ForCausalLM",
enable_test=True),
EmbedModelInfo("Alibaba-NLP/gte-Qwen2-7B-instruct",
architecture="Qwen2ForCausalLM",
enable_test=False),
########## ModernBertModel
EmbedModelInfo("Alibaba-NLP/gte-modernbert-base",
architecture="ModernBertModel",
enable_test=True),
]
@pytest.mark.parametrize("model_info", MODELS)
def test_models_mteb(hf_runner, vllm_runner,
model_info: EmbedModelInfo) -> None:
pytest.skip("Skipping mteb test.")
from .mteb_utils import mteb_test_embed_models
vllm_extra_kwargs: dict[str, Any] = {}
if model_info.name == "Alibaba-NLP/gte-Qwen2-1.5B-instruct":
vllm_extra_kwargs["hf_overrides"] = {"is_causal": True}
if model_info.architecture == "GteNewModel":
vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]}
mteb_test_embed_models(hf_runner, vllm_runner, model_info,
vllm_extra_kwargs)
@pytest.mark.parametrize("model_info", MODELS)
def test_models_correctness(hf_runner, vllm_runner, model_info: EmbedModelInfo,
example_prompts) -> None:
if not model_info.enable_test:
pytest.skip("Skipping test.")
# ST will strip the input texts, see test_embedding.py
example_prompts = [str(s).strip() for s in example_prompts]
vllm_extra_kwargs: dict[str, Any] = {}
if model_info.name == "Alibaba-NLP/gte-Qwen2-1.5B-instruct":
vllm_extra_kwargs["hf_overrides"] = {"is_causal": True}
if model_info.architecture == "GteNewModel":
vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]}
with vllm_runner(model_info.name,
task="embed",
dtype=model_info.dtype,
max_model_len=None,
**vllm_extra_kwargs) as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts)
with hf_runner(
model_info.name,
dtype=model_info.dtype,
is_sentence_transformer=True,
) as hf_model:
run_embedding_correctness_test(hf_model, example_prompts, vllm_outputs)

View File

@ -23,6 +23,7 @@ MODELS = [
@pytest.mark.parametrize("model_info", MODELS)
def test_models_mteb(hf_runner, vllm_runner,
model_info: EmbedModelInfo) -> None:
pytest.skip("Skipping mteb test.")
from .mteb_utils import mteb_test_embed_models
mteb_test_embed_models(hf_runner, vllm_runner, model_info)
@ -33,6 +34,9 @@ def test_models_correctness(hf_runner, vllm_runner, model_info: EmbedModelInfo,
if not model_info.enable_test:
pytest.skip("Skipping test.")
# ST will strip the input texts, see test_embedding.py
example_prompts = [str(s).strip() for s in example_prompts]
with vllm_runner(model_info.name,
task="embed",
dtype=model_info.dtype,

View File

@ -46,6 +46,7 @@ def test_models_mteb(
vllm_runner,
model_info: EmbedModelInfo,
) -> None:
pytest.skip("Skipping mteb test.")
from .mteb_utils import mteb_test_embed_models
mteb_test_embed_models(hf_runner, vllm_runner, model_info)
@ -60,6 +61,9 @@ def test_models_correctness(
if not model_info.enable_test:
pytest.skip("Skipping test.")
# ST will strip the input texts, see test_embedding.py
example_prompts = [str(s).strip() for s in example_prompts]
with vllm_runner(model_info.name,
task="embed",
dtype=model_info.dtype,

View File

@ -256,11 +256,17 @@ _EMBEDDING_EXAMPLE_MODELS = {
"GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"),
"GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0",
trust_remote_code=True),
"GteNewModel": _HfExamplesInfo("Alibaba-NLP/gte-base-en-v1.5",
trust_remote_code=True,
hf_overrides={"architectures":
["GteNewModel"]}),
"InternLM2ForRewardModel": _HfExamplesInfo("internlm/internlm2-1_8b-reward",
trust_remote_code=True),
"JambaForSequenceClassification": _HfExamplesInfo("ai21labs/Jamba-tiny-reward-dev"), # noqa: E501
"LlamaModel": _HfExamplesInfo("llama", is_available_online=False),
"MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"),
"ModernBertModel": _HfExamplesInfo("Alibaba-NLP/gte-modernbert-base",
trust_remote_code=True),
"NomicBertModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-long", # noqa: E501
trust_remote_code=True),
"Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"),

View File

@ -354,7 +354,7 @@ def get_act_fn(act_fn_name: str) -> nn.Module:
_ACTIVATION_AND_MUL_REGISTRY = LazyDict({
"gelu": lambda: GeluAndMul(),
"silu": lambda: SiluAndMul(),
"gelu_and_mul": lambda: GeluAndMul(),
"geglu": lambda: GeluAndMul(),
})

View File

@ -456,6 +456,40 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
return self._scaling_factor_to_offset
class NTKScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with fixed and mixed NTK scaling.
https://kexue.fm/archives/9706 """
def __init__(self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
mixed_b: Optional[float] = None) -> None:
self.scaling_factor = scaling_factor
self.mixed_b = mixed_b
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
base = self.base * (self.scaling_factor if self.mixed_b is None else 1)
inv_freq = super()._compute_inv_freq(base)
if self.mixed_b is None:
inv_freq = inv_freq / self.scaling_factor**(2 / self.rotary_dim)
else:
a = torch.tensor(self.scaling_factor).log() / (self.rotary_dim /
2)**self.mixed_b
lambda_1_m = (a * torch.arange(
1, self.rotary_dim // 2 + 1).float()**self.mixed_b).exp()
inv_freq = inv_freq / lambda_1_m
return inv_freq
class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with Dynamic NTK scaling.
@ -1765,6 +1799,14 @@ def get_rope(
max_position, base,
is_neox_style,
scaling_factor, dtype)
elif scaling_type == "ntk":
scaling_factor = rope_scaling["factor"]
mixed_b = rope_scaling.get('mixed_b', None)
rotary_emb = NTKScalingRotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style,
scaling_factor, dtype,
mixed_b)
elif scaling_type == "dynamic":
scaling_factor = rope_scaling["factor"]
rotary_emb = DynamicNTKScalingRotaryEmbedding(

View File

@ -32,11 +32,18 @@ class BertWithRopeEmbedding(nn.Module):
def __init__(self, config: PretrainedConfig):
super().__init__()
assert config.type_vocab_size > 0
if config.position_embedding_type not in ["rope", "rotary"]:
raise ValueError("Only 'rotary'('rope') position_embedding_type" +
" is supported")
self.word_embeddings = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.token_type_embeddings = VocabParallelEmbedding(
config.type_vocab_size, config.hidden_size)
if config.type_vocab_size > 0:
self.token_type_embeddings = VocabParallelEmbedding(
config.type_vocab_size, config.hidden_size)
else:
self.token_type_embeddings = None
self.LayerNorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
@ -47,13 +54,17 @@ class BertWithRopeEmbedding(nn.Module):
) -> torch.Tensor:
input_shape = input_ids.size()
inputs_embeds = self.word_embeddings(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape,
dtype=torch.long,
device=inputs_embeds.device)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + token_type_embeddings
embeddings = inputs_embeds
if self.token_type_embeddings is not None:
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape,
dtype=torch.long,
device=inputs_embeds.device)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings += token_type_embeddings
embeddings = self.LayerNorm(embeddings)
return embeddings
@ -321,7 +332,7 @@ class BertWithRopeBlock(nn.Module):
if moe:
self.mlp = NomicMoELayer(config=config, )
else:
if config.hidden_act in ["silu", "gelu_and_mul"]:
if config.hidden_act in ["silu", "geglu"]:
self.mlp = BertWithRopeGatedMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
@ -390,6 +401,7 @@ class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.vllm_config = vllm_config
self.config = self.config_verify(vllm_config)
self.embeddings = BertWithRopeEmbedding(self.config)
self.encoder = BertWithRopeEncoder(
@ -420,7 +432,7 @@ class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant):
torch.Tensor]]) -> Set[str]:
weights = self.hf_to_vllm_mapper.apply(weights)
if self.config.hidden_act in ["silu", "gelu_and_mul"]:
if self.config.hidden_act in ["silu", "geglu"]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
@ -458,6 +470,8 @@ class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant):
class NomicBertModel(BertWithRope):
# for https://huggingface.co/nomic-ai/nomic-bert-2048
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
"emb_ln": "embeddings.LayerNorm",
@ -475,6 +489,9 @@ class NomicBertModel(BertWithRope):
assert config.__class__.__name__ == "NomicBertConfig"
assert config.activation_function in ["swiglu", "gelu"]
config.position_embedding_type = getattr(config,
"position_embedding_type",
"rope")
if config.activation_function == "swiglu":
config.hidden_act = "silu"
@ -512,10 +529,13 @@ class NomicBertModel(BertWithRope):
return config
class GteModel(BertWithRope):
class GteNewModel(BertWithRope):
# for https://huggingface.co/Alibaba-NLP/new-impl
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
"layer": 'layers',
"new.": "",
"layer": "layers",
"attention.qkv_proj": "attn.qkv_proj",
"attention.o_proj": "attn.out_proj",
})
@ -523,7 +543,7 @@ class GteModel(BertWithRope):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
# GteModel only gate_up_proj does not have bias.
# GteNewModel only gate_up_proj does not have bias.
# Hack method learned from vllm/model_executor/models/glm.py
for layer in self.encoder.layers:
layer.mlp.gate_up_proj.bias = None
@ -532,12 +552,10 @@ class GteModel(BertWithRope):
def config_verify(self, vllm_config):
config = vllm_config.model_config.hf_config
assert config.__class__.__name__ == "GteConfig"
assert config.position_embedding_type == "rope"
assert config.__class__.__name__ == "NewConfig"
assert config.hidden_act == "gelu"
config.position_embedding_type = "rotary"
config.hidden_act = "gelu_and_mul"
config.hidden_act = "geglu"
head_dim = config.hidden_size // config.num_attention_heads
config.rotary_kwargs = {
@ -559,13 +577,52 @@ class GteModel(BertWithRope):
else:
yield name, weight
def ignore_unnecessary_layers(self,
weights: Iterable[Tuple[str, torch.Tensor]]):
for name, weight in weights:
if name.startswith("classifier"):
continue
yield name, weight
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
weights = self.ignore_unnecessary_layers(weights)
weights = self.split_up_gate_proj(weights)
return super().load_weights(weights)
class SnowflakeGteNewModel(GteNewModel):
# for Snowflake/snowflake-arctic-embed-m-v2.0
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
"layer": "layers",
"attention.qkv_proj": "attn.qkv_proj",
"attention.o_proj": "attn.out_proj",
})
def config_verify(self, vllm_config):
config = vllm_config.model_config.hf_config
assert config.__class__.__name__ == "GteConfig"
assert config.hidden_act == "gelu"
config.hidden_act = "geglu"
head_dim = config.hidden_size // config.num_attention_heads
config.rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_position_embeddings,
"base": config.rope_theta,
"rope_scaling": getattr(config, "rope_scaling", None)
}
return config
class JinaRobertaModel(BertWithRope):
# for https://huggingface.co/jinaai/jina-embeddings-v3
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
"emb_ln": "embeddings.LayerNorm",
@ -579,6 +636,9 @@ class JinaRobertaModel(BertWithRope):
def config_verify(self, vllm_config):
config = vllm_config.model_config.hf_config
assert config.__class__.__name__ == "XLMRobertaFlashConfig"
head_dim = config.hidden_size // config.num_attention_heads
config.rotary_kwargs = {
"head_size": head_dim,
@ -611,6 +671,7 @@ class JinaRobertaModel(BertWithRope):
# This is a temporary solution until we have a better way to handle
scaling = self.config.lora_alpha / self.config.lora_rank
device = self.vllm_config.device_config.device
weights = {name: weight for name, weight in weights}
@ -628,13 +689,13 @@ class JinaRobertaModel(BertWithRope):
weight_name = name[:-len(o)]
if "embeddings" in weight_name:
B = weights[weight_name + a][i].cuda().float()
A = weights[weight_name + b][i].cuda().float()
B = weights[weight_name + a][i].to(device).float()
A = weights[weight_name + b][i].to(device).float()
else:
B = weights[weight_name + b][i].cuda().float()
A = weights[weight_name + a][i].cuda().float()
B = weights[weight_name + b][i].to(device).float()
A = weights[weight_name + a][i].to(device).float()
weight = (weights[weight_name + o].cuda() +
weight = (weights[weight_name + o].to(device) +
torch.matmul(B, A).view(shape) * scaling)
weight = weight.cpu().to(dtype)

View File

@ -230,9 +230,12 @@ class ModernBertModel(nn.Module):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
positions: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
position_ids = positions if positions is not None else position_ids
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:

View File

@ -127,7 +127,8 @@ _EMBEDDING_MODELS = {
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
"GritLM": ("gritlm", "GritLM"),
"GteModel": ("bert_with_rope", "GteModel"),
"GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
"GteNewModel": ("bert_with_rope", "GteNewModel"),
"InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
"JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"), # noqa: E501
"LlamaModel": ("llama", "LlamaForCausalLM"),
@ -137,6 +138,7 @@ _EMBEDDING_MODELS = {
if arch == "LlamaForCausalLM"
},
"MistralModel": ("llama", "LlamaForCausalLM"),
"ModernBertModel": ("modernbert", "ModernBertModel"),
"NomicBertModel": ("bert_with_rope", "NomicBertModel"),
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
"Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"),