[New Model]: support GTE NewModel (#17986)

This commit is contained in:
wang.yuqi 2025-05-14 16:31:31 +08:00 committed by GitHub
parent e7ef61c1f0
commit 63ad622233
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 279 additions and 32 deletions

View File

@ -701,12 +701,22 @@ Specified using `--task embed`.
* ✅︎ * ✅︎
* ✅︎ * ✅︎
- * `GteModel` - * `GteModel`
* GteModel * Arctic-Embed-2.0-M
* `Snowflake/snowflake-arctic-embed-m-v2.0`. * `Snowflake/snowflake-arctic-embed-m-v2.0`.
* *
* *
- * `GteNewModel`
* mGTE-TRM (see note)
* `Alibaba-NLP/gte-multilingual-base`, etc.
*
*
- * `ModernBertModel`
* ModernBERT-based
* `Alibaba-NLP/gte-modernbert-base`, etc.
*
*
- * `NomicBertModel` - * `NomicBertModel`
* NomicBertModel * Nomic BERT
* `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc. * `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc.
* *
* *
@ -749,6 +759,10 @@ See [relevant issue on HF Transformers](https://github.com/huggingface/transform
`jinaai/jina-embeddings-v3` supports multiple tasks through lora, while vllm temporarily only supports text-matching tasks by merging lora weights. `jinaai/jina-embeddings-v3` supports multiple tasks through lora, while vllm temporarily only supports text-matching tasks by merging lora weights.
::: :::
:::{note}
The second-generation GTE model (mGTE-TRM) is named `NewModel`. The name `NewModel` is too generic, you should set `--hf-overrides '{"architectures": ["GteNewModel"]}'` to specify the use of the `GteNewModel` architecture.
:::
If your model is not in the above list, we will try to automatically convert the model using If your model is not in the above list, we will try to automatically convert the model using
{func}`~vllm.model_executor.models.adapters.as_embedding_model`. By default, the embeddings {func}`~vllm.model_executor.models.adapters.as_embedding_model`. By default, the embeddings
of the whole prompt are extracted from the normalized hidden state corresponding to the last token. of the whole prompt are extracted from the normalized hidden state corresponding to the last token.

View File

@ -7,6 +7,7 @@ import numpy as np
import pytest import pytest
from tests.models.utils import EmbedModelInfo from tests.models.utils import EmbedModelInfo
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
# Most models on the STS12 task (See #17175): # Most models on the STS12 task (See #17175):
# - Model implementation and minor changes in tensor dtype # - Model implementation and minor changes in tensor dtype
@ -77,16 +78,22 @@ def run_mteb_embed_task_st(model_name, tasks):
return run_mteb_embed_task(model, tasks) return run_mteb_embed_task(model, tasks)
def mteb_test_embed_models(hf_runner, vllm_runner, model_info: EmbedModelInfo): def mteb_test_embed_models(hf_runner,
vllm_runner,
model_info: EmbedModelInfo,
vllm_extra_kwargs=None):
if not model_info.enable_test: if not model_info.enable_test:
# A model family has many models with the same architecture, # A model family has many models with the same architecture,
# and we don't need to test each one. # and we don't need to test each one.
pytest.skip("Skipping test.") pytest.skip("Skipping test.")
vllm_extra_kwargs = vllm_extra_kwargs or {}
with vllm_runner(model_info.name, with vllm_runner(model_info.name,
task="embed", task="embed",
max_model_len=None, max_model_len=None,
dtype=model_info.dtype) as vllm_model: dtype=model_info.dtype,
**vllm_extra_kwargs) as vllm_model:
if model_info.architecture: if model_info.architecture:
assert (model_info.architecture assert (model_info.architecture
@ -99,9 +106,9 @@ def mteb_test_embed_models(hf_runner, vllm_runner, model_info: EmbedModelInfo):
vllm_model.model.llm_engine.model_config.hf_config, "torch_dtype", vllm_model.model.llm_engine.model_config.hf_config, "torch_dtype",
vllm_dtype) vllm_dtype)
with hf_runner(model_info.name, with set_default_torch_dtype(model_dtype) and hf_runner(
is_sentence_transformer=True, model_info.name, is_sentence_transformer=True,
dtype=model_dtype) as hf_model: dtype=model_dtype) as hf_model:
st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS) st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS)
print("VLLM:", vllm_dtype, vllm_main_score) print("VLLM:", vllm_dtype, vllm_main_score)

View File

@ -0,0 +1,104 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any
import pytest
from ...utils import EmbedModelInfo, run_embedding_correctness_test
MODELS = [
########## BertModel
EmbedModelInfo("thenlper/gte-large",
architecture="BertModel",
dtype="float32",
enable_test=True),
EmbedModelInfo("thenlper/gte-base",
architecture="BertModel",
dtype="float32",
enable_test=False),
EmbedModelInfo("thenlper/gte-small",
architecture="BertModel",
dtype="float32",
enable_test=False),
EmbedModelInfo("thenlper/gte-large-zh",
architecture="BertModel",
dtype="float32",
enable_test=False),
EmbedModelInfo("thenlper/gte-base-zh",
architecture="BertModel",
dtype="float32",
enable_test=False),
EmbedModelInfo("thenlper/gte-small-zh",
architecture="BertModel",
dtype="float32",
enable_test=False),
########### NewModel
EmbedModelInfo("Alibaba-NLP/gte-multilingual-base",
architecture="GteNewModel",
enable_test=True),
EmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5",
architecture="GteNewModel",
enable_test=True),
EmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5",
architecture="GteNewModel",
enable_test=True),
########### Qwen2ForCausalLM
EmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct",
architecture="Qwen2ForCausalLM",
enable_test=True),
EmbedModelInfo("Alibaba-NLP/gte-Qwen2-7B-instruct",
architecture="Qwen2ForCausalLM",
enable_test=False),
########## ModernBertModel
EmbedModelInfo("Alibaba-NLP/gte-modernbert-base",
architecture="ModernBertModel",
enable_test=True),
]
@pytest.mark.parametrize("model_info", MODELS)
def test_models_mteb(hf_runner, vllm_runner,
model_info: EmbedModelInfo) -> None:
pytest.skip("Skipping mteb test.")
from .mteb_utils import mteb_test_embed_models
vllm_extra_kwargs: dict[str, Any] = {}
if model_info.name == "Alibaba-NLP/gte-Qwen2-1.5B-instruct":
vllm_extra_kwargs["hf_overrides"] = {"is_causal": True}
if model_info.architecture == "GteNewModel":
vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]}
mteb_test_embed_models(hf_runner, vllm_runner, model_info,
vllm_extra_kwargs)
@pytest.mark.parametrize("model_info", MODELS)
def test_models_correctness(hf_runner, vllm_runner, model_info: EmbedModelInfo,
example_prompts) -> None:
if not model_info.enable_test:
pytest.skip("Skipping test.")
# ST will strip the input texts, see test_embedding.py
example_prompts = [str(s).strip() for s in example_prompts]
vllm_extra_kwargs: dict[str, Any] = {}
if model_info.name == "Alibaba-NLP/gte-Qwen2-1.5B-instruct":
vllm_extra_kwargs["hf_overrides"] = {"is_causal": True}
if model_info.architecture == "GteNewModel":
vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]}
with vllm_runner(model_info.name,
task="embed",
dtype=model_info.dtype,
max_model_len=None,
**vllm_extra_kwargs) as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts)
with hf_runner(
model_info.name,
dtype=model_info.dtype,
is_sentence_transformer=True,
) as hf_model:
run_embedding_correctness_test(hf_model, example_prompts, vllm_outputs)

View File

@ -23,6 +23,7 @@ MODELS = [
@pytest.mark.parametrize("model_info", MODELS) @pytest.mark.parametrize("model_info", MODELS)
def test_models_mteb(hf_runner, vllm_runner, def test_models_mteb(hf_runner, vllm_runner,
model_info: EmbedModelInfo) -> None: model_info: EmbedModelInfo) -> None:
pytest.skip("Skipping mteb test.")
from .mteb_utils import mteb_test_embed_models from .mteb_utils import mteb_test_embed_models
mteb_test_embed_models(hf_runner, vllm_runner, model_info) mteb_test_embed_models(hf_runner, vllm_runner, model_info)
@ -33,6 +34,9 @@ def test_models_correctness(hf_runner, vllm_runner, model_info: EmbedModelInfo,
if not model_info.enable_test: if not model_info.enable_test:
pytest.skip("Skipping test.") pytest.skip("Skipping test.")
# ST will strip the input texts, see test_embedding.py
example_prompts = [str(s).strip() for s in example_prompts]
with vllm_runner(model_info.name, with vllm_runner(model_info.name,
task="embed", task="embed",
dtype=model_info.dtype, dtype=model_info.dtype,

View File

@ -46,6 +46,7 @@ def test_models_mteb(
vllm_runner, vllm_runner,
model_info: EmbedModelInfo, model_info: EmbedModelInfo,
) -> None: ) -> None:
pytest.skip("Skipping mteb test.")
from .mteb_utils import mteb_test_embed_models from .mteb_utils import mteb_test_embed_models
mteb_test_embed_models(hf_runner, vllm_runner, model_info) mteb_test_embed_models(hf_runner, vllm_runner, model_info)
@ -60,6 +61,9 @@ def test_models_correctness(
if not model_info.enable_test: if not model_info.enable_test:
pytest.skip("Skipping test.") pytest.skip("Skipping test.")
# ST will strip the input texts, see test_embedding.py
example_prompts = [str(s).strip() for s in example_prompts]
with vllm_runner(model_info.name, with vllm_runner(model_info.name,
task="embed", task="embed",
dtype=model_info.dtype, dtype=model_info.dtype,

View File

@ -256,11 +256,17 @@ _EMBEDDING_EXAMPLE_MODELS = {
"GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"), "GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"),
"GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0", "GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0",
trust_remote_code=True), trust_remote_code=True),
"GteNewModel": _HfExamplesInfo("Alibaba-NLP/gte-base-en-v1.5",
trust_remote_code=True,
hf_overrides={"architectures":
["GteNewModel"]}),
"InternLM2ForRewardModel": _HfExamplesInfo("internlm/internlm2-1_8b-reward", "InternLM2ForRewardModel": _HfExamplesInfo("internlm/internlm2-1_8b-reward",
trust_remote_code=True), trust_remote_code=True),
"JambaForSequenceClassification": _HfExamplesInfo("ai21labs/Jamba-tiny-reward-dev"), # noqa: E501 "JambaForSequenceClassification": _HfExamplesInfo("ai21labs/Jamba-tiny-reward-dev"), # noqa: E501
"LlamaModel": _HfExamplesInfo("llama", is_available_online=False), "LlamaModel": _HfExamplesInfo("llama", is_available_online=False),
"MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"), "MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"),
"ModernBertModel": _HfExamplesInfo("Alibaba-NLP/gte-modernbert-base",
trust_remote_code=True),
"NomicBertModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-long", # noqa: E501 "NomicBertModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-long", # noqa: E501
trust_remote_code=True), trust_remote_code=True),
"Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"), "Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"),

View File

@ -354,7 +354,7 @@ def get_act_fn(act_fn_name: str) -> nn.Module:
_ACTIVATION_AND_MUL_REGISTRY = LazyDict({ _ACTIVATION_AND_MUL_REGISTRY = LazyDict({
"gelu": lambda: GeluAndMul(), "gelu": lambda: GeluAndMul(),
"silu": lambda: SiluAndMul(), "silu": lambda: SiluAndMul(),
"gelu_and_mul": lambda: GeluAndMul(), "geglu": lambda: GeluAndMul(),
}) })

View File

@ -456,6 +456,40 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
return self._scaling_factor_to_offset return self._scaling_factor_to_offset
class NTKScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with fixed and mixed NTK scaling.
https://kexue.fm/archives/9706 """
def __init__(self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
mixed_b: Optional[float] = None) -> None:
self.scaling_factor = scaling_factor
self.mixed_b = mixed_b
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
base = self.base * (self.scaling_factor if self.mixed_b is None else 1)
inv_freq = super()._compute_inv_freq(base)
if self.mixed_b is None:
inv_freq = inv_freq / self.scaling_factor**(2 / self.rotary_dim)
else:
a = torch.tensor(self.scaling_factor).log() / (self.rotary_dim /
2)**self.mixed_b
lambda_1_m = (a * torch.arange(
1, self.rotary_dim // 2 + 1).float()**self.mixed_b).exp()
inv_freq = inv_freq / lambda_1_m
return inv_freq
class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with Dynamic NTK scaling. """RotaryEmbedding extended with Dynamic NTK scaling.
@ -1765,6 +1799,14 @@ def get_rope(
max_position, base, max_position, base,
is_neox_style, is_neox_style,
scaling_factor, dtype) scaling_factor, dtype)
elif scaling_type == "ntk":
scaling_factor = rope_scaling["factor"]
mixed_b = rope_scaling.get('mixed_b', None)
rotary_emb = NTKScalingRotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style,
scaling_factor, dtype,
mixed_b)
elif scaling_type == "dynamic": elif scaling_type == "dynamic":
scaling_factor = rope_scaling["factor"] scaling_factor = rope_scaling["factor"]
rotary_emb = DynamicNTKScalingRotaryEmbedding( rotary_emb = DynamicNTKScalingRotaryEmbedding(

View File

@ -32,11 +32,18 @@ class BertWithRopeEmbedding(nn.Module):
def __init__(self, config: PretrainedConfig): def __init__(self, config: PretrainedConfig):
super().__init__() super().__init__()
assert config.type_vocab_size > 0 if config.position_embedding_type not in ["rope", "rotary"]:
raise ValueError("Only 'rotary'('rope') position_embedding_type" +
" is supported")
self.word_embeddings = VocabParallelEmbedding(config.vocab_size, self.word_embeddings = VocabParallelEmbedding(config.vocab_size,
config.hidden_size) config.hidden_size)
self.token_type_embeddings = VocabParallelEmbedding( if config.type_vocab_size > 0:
config.type_vocab_size, config.hidden_size) self.token_type_embeddings = VocabParallelEmbedding(
config.type_vocab_size, config.hidden_size)
else:
self.token_type_embeddings = None
self.LayerNorm = nn.LayerNorm(config.hidden_size, self.LayerNorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
@ -47,13 +54,17 @@ class BertWithRopeEmbedding(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
input_shape = input_ids.size() input_shape = input_ids.size()
inputs_embeds = self.word_embeddings(input_ids) inputs_embeds = self.word_embeddings(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape,
dtype=torch.long,
device=inputs_embeds.device)
token_type_embeddings = self.token_type_embeddings(token_type_ids) embeddings = inputs_embeds
embeddings = inputs_embeds + token_type_embeddings if self.token_type_embeddings is not None:
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape,
dtype=torch.long,
device=inputs_embeds.device)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings += token_type_embeddings
embeddings = self.LayerNorm(embeddings) embeddings = self.LayerNorm(embeddings)
return embeddings return embeddings
@ -321,7 +332,7 @@ class BertWithRopeBlock(nn.Module):
if moe: if moe:
self.mlp = NomicMoELayer(config=config, ) self.mlp = NomicMoELayer(config=config, )
else: else:
if config.hidden_act in ["silu", "gelu_and_mul"]: if config.hidden_act in ["silu", "geglu"]:
self.mlp = BertWithRopeGatedMLP( self.mlp = BertWithRopeGatedMLP(
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
@ -390,6 +401,7 @@ class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
self.vllm_config = vllm_config
self.config = self.config_verify(vllm_config) self.config = self.config_verify(vllm_config)
self.embeddings = BertWithRopeEmbedding(self.config) self.embeddings = BertWithRopeEmbedding(self.config)
self.encoder = BertWithRopeEncoder( self.encoder = BertWithRopeEncoder(
@ -420,7 +432,7 @@ class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant):
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
weights = self.hf_to_vllm_mapper.apply(weights) weights = self.hf_to_vllm_mapper.apply(weights)
if self.config.hidden_act in ["silu", "gelu_and_mul"]: if self.config.hidden_act in ["silu", "geglu"]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "gate_proj", 0),
@ -458,6 +470,8 @@ class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant):
class NomicBertModel(BertWithRope): class NomicBertModel(BertWithRope):
# for https://huggingface.co/nomic-ai/nomic-bert-2048
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={ orig_to_new_substr={
"emb_ln": "embeddings.LayerNorm", "emb_ln": "embeddings.LayerNorm",
@ -475,6 +489,9 @@ class NomicBertModel(BertWithRope):
assert config.__class__.__name__ == "NomicBertConfig" assert config.__class__.__name__ == "NomicBertConfig"
assert config.activation_function in ["swiglu", "gelu"] assert config.activation_function in ["swiglu", "gelu"]
config.position_embedding_type = getattr(config,
"position_embedding_type",
"rope")
if config.activation_function == "swiglu": if config.activation_function == "swiglu":
config.hidden_act = "silu" config.hidden_act = "silu"
@ -512,10 +529,13 @@ class NomicBertModel(BertWithRope):
return config return config
class GteModel(BertWithRope): class GteNewModel(BertWithRope):
# for https://huggingface.co/Alibaba-NLP/new-impl
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={ orig_to_new_substr={
"layer": 'layers', "new.": "",
"layer": "layers",
"attention.qkv_proj": "attn.qkv_proj", "attention.qkv_proj": "attn.qkv_proj",
"attention.o_proj": "attn.out_proj", "attention.o_proj": "attn.out_proj",
}) })
@ -523,7 +543,7 @@ class GteModel(BertWithRope):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix) super().__init__(vllm_config=vllm_config, prefix=prefix)
# GteModel only gate_up_proj does not have bias. # GteNewModel only gate_up_proj does not have bias.
# Hack method learned from vllm/model_executor/models/glm.py # Hack method learned from vllm/model_executor/models/glm.py
for layer in self.encoder.layers: for layer in self.encoder.layers:
layer.mlp.gate_up_proj.bias = None layer.mlp.gate_up_proj.bias = None
@ -532,12 +552,10 @@ class GteModel(BertWithRope):
def config_verify(self, vllm_config): def config_verify(self, vllm_config):
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
assert config.__class__.__name__ == "GteConfig" assert config.__class__.__name__ == "NewConfig"
assert config.position_embedding_type == "rope"
assert config.hidden_act == "gelu" assert config.hidden_act == "gelu"
config.position_embedding_type = "rotary" config.hidden_act = "geglu"
config.hidden_act = "gelu_and_mul"
head_dim = config.hidden_size // config.num_attention_heads head_dim = config.hidden_size // config.num_attention_heads
config.rotary_kwargs = { config.rotary_kwargs = {
@ -559,13 +577,52 @@ class GteModel(BertWithRope):
else: else:
yield name, weight yield name, weight
def ignore_unnecessary_layers(self,
weights: Iterable[Tuple[str, torch.Tensor]]):
for name, weight in weights:
if name.startswith("classifier"):
continue
yield name, weight
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
weights = self.ignore_unnecessary_layers(weights)
weights = self.split_up_gate_proj(weights) weights = self.split_up_gate_proj(weights)
return super().load_weights(weights) return super().load_weights(weights)
class SnowflakeGteNewModel(GteNewModel):
# for Snowflake/snowflake-arctic-embed-m-v2.0
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
"layer": "layers",
"attention.qkv_proj": "attn.qkv_proj",
"attention.o_proj": "attn.out_proj",
})
def config_verify(self, vllm_config):
config = vllm_config.model_config.hf_config
assert config.__class__.__name__ == "GteConfig"
assert config.hidden_act == "gelu"
config.hidden_act = "geglu"
head_dim = config.hidden_size // config.num_attention_heads
config.rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_position_embeddings,
"base": config.rope_theta,
"rope_scaling": getattr(config, "rope_scaling", None)
}
return config
class JinaRobertaModel(BertWithRope): class JinaRobertaModel(BertWithRope):
# for https://huggingface.co/jinaai/jina-embeddings-v3
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={ orig_to_new_substr={
"emb_ln": "embeddings.LayerNorm", "emb_ln": "embeddings.LayerNorm",
@ -579,6 +636,9 @@ class JinaRobertaModel(BertWithRope):
def config_verify(self, vllm_config): def config_verify(self, vllm_config):
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
assert config.__class__.__name__ == "XLMRobertaFlashConfig"
head_dim = config.hidden_size // config.num_attention_heads head_dim = config.hidden_size // config.num_attention_heads
config.rotary_kwargs = { config.rotary_kwargs = {
"head_size": head_dim, "head_size": head_dim,
@ -611,6 +671,7 @@ class JinaRobertaModel(BertWithRope):
# This is a temporary solution until we have a better way to handle # This is a temporary solution until we have a better way to handle
scaling = self.config.lora_alpha / self.config.lora_rank scaling = self.config.lora_alpha / self.config.lora_rank
device = self.vllm_config.device_config.device
weights = {name: weight for name, weight in weights} weights = {name: weight for name, weight in weights}
@ -628,13 +689,13 @@ class JinaRobertaModel(BertWithRope):
weight_name = name[:-len(o)] weight_name = name[:-len(o)]
if "embeddings" in weight_name: if "embeddings" in weight_name:
B = weights[weight_name + a][i].cuda().float() B = weights[weight_name + a][i].to(device).float()
A = weights[weight_name + b][i].cuda().float() A = weights[weight_name + b][i].to(device).float()
else: else:
B = weights[weight_name + b][i].cuda().float() B = weights[weight_name + b][i].to(device).float()
A = weights[weight_name + a][i].cuda().float() A = weights[weight_name + a][i].to(device).float()
weight = (weights[weight_name + o].cuda() + weight = (weights[weight_name + o].to(device) +
torch.matmul(B, A).view(shape) * scaling) torch.matmul(B, A).view(shape) * scaling)
weight = weight.cpu().to(dtype) weight = weight.cpu().to(dtype)

View File

@ -230,9 +230,12 @@ class ModernBertModel(nn.Module):
def forward( def forward(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
positions: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
position_ids = positions if positions is not None else position_ids
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:

View File

@ -127,7 +127,8 @@ _EMBEDDING_MODELS = {
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"), "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
"GlmForCausalLM": ("glm", "GlmForCausalLM"), "GlmForCausalLM": ("glm", "GlmForCausalLM"),
"GritLM": ("gritlm", "GritLM"), "GritLM": ("gritlm", "GritLM"),
"GteModel": ("bert_with_rope", "GteModel"), "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
"GteNewModel": ("bert_with_rope", "GteNewModel"),
"InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"), "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
"JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"), # noqa: E501 "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"), # noqa: E501
"LlamaModel": ("llama", "LlamaForCausalLM"), "LlamaModel": ("llama", "LlamaForCausalLM"),
@ -137,6 +138,7 @@ _EMBEDDING_MODELS = {
if arch == "LlamaForCausalLM" if arch == "LlamaForCausalLM"
}, },
"MistralModel": ("llama", "LlamaForCausalLM"), "MistralModel": ("llama", "LlamaForCausalLM"),
"ModernBertModel": ("modernbert", "ModernBertModel"),
"NomicBertModel": ("bert_with_rope", "NomicBertModel"), "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
"Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"), "Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"),