[New Model]: nomic-embed-text-v2-moe (#17785)

This commit is contained in:
wang.yuqi 2025-05-11 15:59:43 +08:00 committed by GitHub
parent 06c0922a69
commit e4b8713380
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 899 additions and 364 deletions

View File

@ -622,7 +622,7 @@ Specified using `--task embed`.
* [PP](#distributed-serving) * [PP](#distributed-serving)
- * `BertModel` - * `BertModel`
* BERT-based * BERT-based
* `BAAI/bge-base-en-v1.5`, etc. * `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc.
* *
* *
- * `Gemma2Model` - * `Gemma2Model`
@ -635,6 +635,16 @@ Specified using `--task embed`.
* `parasail-ai/GritLM-7B-vllm`. * `parasail-ai/GritLM-7B-vllm`.
* ✅︎ * ✅︎
* ✅︎ * ✅︎
- * `GteModel`
* GteModel
* `Snowflake/snowflake-arctic-embed-m-v2.0`.
*
*
- * `NomicBertModel`
* NomicBertModel
* `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc.
*
*
- * `LlamaModel`, `LlamaForCausalLM`, `MistralModel`, etc. - * `LlamaModel`, `LlamaForCausalLM`, `MistralModel`, etc.
* Llama-based * Llama-based
* `intfloat/e5-mistral-7b-instruct`, etc. * `intfloat/e5-mistral-7b-instruct`, etc.
@ -647,12 +657,12 @@ Specified using `--task embed`.
* ✅︎ * ✅︎
- * `RobertaModel`, `RobertaForMaskedLM` - * `RobertaModel`, `RobertaForMaskedLM`
* RoBERTa-based * RoBERTa-based
* `sentence-transformers/all-roberta-large-v1`, `sentence-transformers/all-roberta-large-v1`, etc. * `sentence-transformers/all-roberta-large-v1`, etc.
* *
* *
- * `XLMRobertaModel` - * `XLMRobertaModel`
* XLM-RoBERTa-based * XLM-RoBERTa-based
* `intfloat/multilingual-e5-large`, `jinaai/jina-reranker-v2-base-multilingual`, etc. * `intfloat/multilingual-e5-large`, `jinaai/jina-reranker-v2-base-multilingual`, `Snowflake/snowflake-arctic-embed-l-v2.0`, `jinaai/jina-embeddings-v3`(see note), etc.
* *
* *
::: :::
@ -670,6 +680,10 @@ For both the 1.5B and 7B variants, you also need to enable `--trust-remote-code`
See [relevant issue on HF Transformers](https://github.com/huggingface/transformers/issues/34882). See [relevant issue on HF Transformers](https://github.com/huggingface/transformers/issues/34882).
::: :::
:::{note}
`jinaai/jina-embeddings-v3` supports multiple tasks through lora, while vllm temporarily only supports text-matching tasks by merging lora weights.
:::
If your model is not in the above list, we will try to automatically convert the model using If your model is not in the above list, we will try to automatically convert the model using
{func}`~vllm.model_executor.models.adapters.as_embedding_model`. By default, the embeddings {func}`~vllm.model_executor.models.adapters.as_embedding_model`. By default, the embeddings
of the whole prompt are extracted from the normalized hidden state corresponding to the last token. of the whole prompt are extracted from the normalized hidden state corresponding to the last token.

View File

@ -0,0 +1,111 @@
# SPDX-License-Identifier: Apache-2.0
import math
from collections.abc import Sequence
import mteb
import numpy as np
import pytest
from tests.models.utils import EmbedModelInfo
# Most models on the STS12 task (See #17175):
# - Model implementation and minor changes in tensor dtype
# results in differences less than 1e-4
# - Different model results in differences more than 1e-3
# 1e-4 is a good tolerance threshold
MTEB_EMBED_TASKS = ["STS12"]
MTEB_EMBED_TOL = 1e-4
class VllmMtebEncoder(mteb.Encoder):
def __init__(self, vllm_model):
super().__init__()
self.model = vllm_model
self.rng = np.random.default_rng(seed=42)
def encode(
self,
sentences: Sequence[str],
*args,
**kwargs,
) -> np.ndarray:
# Hoping to discover potential scheduling
# issues by randomizing the order.
r = self.rng.permutation(len(sentences))
sentences = [sentences[i] for i in r]
outputs = self.model.encode(sentences, use_tqdm=False)
embeds = np.array(outputs)
embeds = embeds[np.argsort(r)]
return embeds
class OpenAIClientMtebEncoder(mteb.Encoder):
def __init__(self, model_name: str, client):
super().__init__()
self.model_name = model_name
self.client = client
self.rng = np.random.default_rng(seed=42)
def encode(self, sentences: Sequence[str], *args, **kwargs) -> np.ndarray:
# Hoping to discover potential scheduling
# issues by randomizing the order.
r = self.rng.permutation(len(sentences))
sentences = [sentences[i] for i in r]
embeddings = self.client.embeddings.create(model=self.model_name,
input=sentences)
outputs = [d.embedding for d in embeddings.data]
embeds = np.array(outputs)
embeds = embeds[np.argsort(r)]
return embeds
def run_mteb_embed_task(encoder, tasks):
tasks = mteb.get_tasks(tasks=tasks)
evaluation = mteb.MTEB(tasks=tasks)
results = evaluation.run(encoder, verbosity=0, output_folder=None)
main_score = results[0].scores["test"][0]["main_score"]
return main_score
def run_mteb_embed_task_st(model_name, tasks):
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(model_name)
return run_mteb_embed_task(model, tasks)
def mteb_test_embed_models(hf_runner, vllm_runner, model_info: EmbedModelInfo):
if not model_info.enable_test:
# A model family has many models with the same architecture,
# and we don't need to test each one.
pytest.skip("Skipping test.")
with vllm_runner(model_info.name,
task="embed",
max_model_len=None,
dtype=model_info.dtype) as vllm_model:
if model_info.architecture:
assert (model_info.architecture
in vllm_model.model.llm_engine.model_config.architectures)
vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model),
MTEB_EMBED_TASKS)
vllm_dtype = vllm_model.model.llm_engine.model_config.dtype
model_dtype = getattr(
vllm_model.model.llm_engine.model_config.hf_config, "torch_dtype",
vllm_dtype)
with hf_runner(model_info.name,
is_sentence_transformer=True,
dtype=model_dtype) as hf_model:
st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS)
print("VLLM:", vllm_dtype, vllm_main_score)
print("SentenceTransformer:", model_dtype, st_main_score)
print("Difference:", st_main_score - vllm_main_score)
assert math.isclose(st_main_score, vllm_main_score, rel_tol=MTEB_EMBED_TOL)

View File

@ -0,0 +1,47 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
from ...utils import EmbedModelInfo, run_embedding_correctness_test
MODELS = [
EmbedModelInfo("nomic-ai/nomic-embed-text-v1",
architecture="NomicBertModel",
dtype="float32",
enable_test=True),
EmbedModelInfo("nomic-ai/nomic-embed-text-v1.5",
architecture="NomicBertModel",
dtype="float32",
enable_test=False),
EmbedModelInfo("nomic-ai/nomic-embed-text-v2-moe",
architecture="NomicBertModel",
dtype="float32",
enable_test=True)
]
@pytest.mark.parametrize("model_info", MODELS)
def test_models_mteb(hf_runner, vllm_runner,
model_info: EmbedModelInfo) -> None:
from .mteb_utils import mteb_test_embed_models
mteb_test_embed_models(hf_runner, vllm_runner, model_info)
@pytest.mark.parametrize("model_info", MODELS)
def test_models_correctness(hf_runner, vllm_runner, model_info: EmbedModelInfo,
example_prompts) -> None:
if not model_info.enable_test:
pytest.skip("Skipping test.")
with vllm_runner(model_info.name,
task="embed",
dtype=model_info.dtype,
max_model_len=None) as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts)
with hf_runner(
model_info.name,
dtype=model_info.dtype,
is_sentence_transformer=True,
) as hf_model:
run_embedding_correctness_test(hf_model, example_prompts, vllm_outputs)

View File

@ -1,12 +1,8 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import pytest import pytest
from ...utils import EmbedModelInfo, check_embeddings_close from ...utils import EmbedModelInfo, run_embedding_correctness_test
EMBEDDING_PROMPTS = [
'what is snowflake?', 'Where can I get the best tacos?', 'The Data Cloud!',
'Mexico City of Course!'
]
MODELS = [ MODELS = [
EmbedModelInfo("Snowflake/snowflake-arctic-embed-xs", EmbedModelInfo("Snowflake/snowflake-arctic-embed-xs",
@ -45,51 +41,34 @@ MODELS = [
@pytest.mark.parametrize("model_info", MODELS) @pytest.mark.parametrize("model_info", MODELS)
@pytest.mark.parametrize("dtype", ["half"]) def test_models_mteb(
def test_models(
hf_runner, hf_runner,
vllm_runner, vllm_runner,
example_prompts,
model_info: EmbedModelInfo, model_info: EmbedModelInfo,
dtype: str, ) -> None:
monkeypatch, from .mteb_utils import mteb_test_embed_models
mteb_test_embed_models(hf_runner, vllm_runner, model_info)
@pytest.mark.parametrize("model_info", MODELS)
def test_models_correctness(
hf_runner,
vllm_runner,
model_info: EmbedModelInfo,
example_prompts,
) -> None: ) -> None:
if not model_info.enable_test: if not model_info.enable_test:
# A model family has many models with the same architecture,
# and we don't need to test each one.
pytest.skip("Skipping test.") pytest.skip("Skipping test.")
example_prompts = example_prompts + EMBEDDING_PROMPTS
vllm_extra_kwargs = {
"hf_overrides": {
"is_matryoshka": model_info.is_matryoshka
}
}
with hf_runner(model_info.name, dtype=dtype,
is_sentence_transformer=True) as hf_model:
hf_outputs = hf_model.encode(example_prompts)
with vllm_runner(model_info.name, with vllm_runner(model_info.name,
task="embed", task="embed",
dtype=dtype, dtype=model_info.dtype,
max_model_len=None, max_model_len=None) as vllm_model:
**vllm_extra_kwargs) as vllm_model:
assert (vllm_model.model.llm_engine.model_config.is_matryoshka ==
model_info.is_matryoshka)
if model_info.architecture:
assert (model_info.architecture
in vllm_model.model.llm_engine.model_config.architectures)
vllm_outputs = vllm_model.encode(example_prompts) vllm_outputs = vllm_model.encode(example_prompts)
check_embeddings_close( with hf_runner(
embeddings_0_lst=hf_outputs, model_info.name,
embeddings_1_lst=vllm_outputs, dtype=model_info.dtype,
name_0="hf", is_sentence_transformer=True,
name_1="vllm", ) as hf_model:
tol=1e-2, run_embedding_correctness_test(hf_model, example_prompts, vllm_outputs)
)

View File

@ -332,9 +332,10 @@ def matryoshka_fy(tensor: torch.Tensor, dimensions: int):
class EmbedModelInfo(NamedTuple): class EmbedModelInfo(NamedTuple):
name: str name: str
is_matryoshka: bool is_matryoshka: bool = False
matryoshka_dimensions: Optional[list[int]] = None matryoshka_dimensions: Optional[list[int]] = None
architecture: str = "" architecture: str = ""
dtype: str = "auto"
enable_test: bool = True enable_test: bool = True

View File

@ -11,16 +11,13 @@ from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, PoolerConfig, VllmConfig from vllm.config import CacheConfig, PoolerConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import (get_act_and_mul_fn, from vllm.model_executor.layers.activation import get_act_fn
get_act_fn)
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.pooler import (CrossEncodingPooler, Pooler, from vllm.model_executor.layers.pooler import (CrossEncodingPooler, Pooler,
PoolingType) PoolingType)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@ -41,24 +38,19 @@ class BertEmbedding(nn.Module):
self.size = config.hidden_size self.size = config.hidden_size
self.word_embeddings = VocabParallelEmbedding(config.vocab_size, self.word_embeddings = VocabParallelEmbedding(config.vocab_size,
config.hidden_size) config.hidden_size)
self.position_embeddings = VocabParallelEmbedding(
config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = VocabParallelEmbedding( self.token_type_embeddings = VocabParallelEmbedding(
config.type_vocab_size, config.hidden_size) config.type_vocab_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, self.LayerNorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.position_ids = nn.Parameter(
torch.empty((1, config.max_position_embeddings)), )
self.position_embedding_type = config.position_embedding_type self.position_embedding_type = config.position_embedding_type
if self.position_embedding_type == "absolute": if self.position_embedding_type != "absolute":
self.position_embeddings = VocabParallelEmbedding( raise ValueError("Only 'absolute' position_embedding_type" +
config.max_position_embeddings, config.hidden_size) " is supported")
self.position_ids = nn.Parameter(
torch.empty((1, config.max_position_embeddings)), )
elif self.position_embedding_type == "rotary":
self.position_embeddings = None
self.position_ids = None
else:
raise ValueError("Only 'absolute' and 'rotary' " +
"position_embedding_type is supported")
def forward( def forward(
self, self,
@ -72,6 +64,9 @@ class BertEmbedding(nn.Module):
# Input embeddings. # Input embeddings.
inputs_embeds = self.word_embeddings(input_ids) inputs_embeds = self.word_embeddings(input_ids)
# Position embeddings.
position_embeddings = self.position_embeddings(position_ids)
if token_type_ids is None: if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, token_type_ids = torch.zeros(input_shape,
dtype=torch.long, dtype=torch.long,
@ -79,12 +74,7 @@ class BertEmbedding(nn.Module):
token_type_embeddings = self.token_type_embeddings(token_type_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + token_type_embeddings embeddings = inputs_embeds + token_type_embeddings + position_embeddings
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings
embeddings = self.LayerNorm(embeddings) embeddings = self.LayerNorm(embeddings)
return embeddings return embeddings
@ -108,11 +98,7 @@ class BertPooler(nn.Module):
@support_torch_compile @support_torch_compile
class BertEncoder(nn.Module): class BertEncoder(nn.Module):
def __init__(self, def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
vllm_config: VllmConfig,
bias: bool = True,
rotary_kwargs: Optional[dict] = None,
prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
@ -121,19 +107,16 @@ class BertEncoder(nn.Module):
BertLayer(config=config, BertLayer(config=config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
bias=bias,
rotary_kwargs=rotary_kwargs,
prefix=f"{prefix}.layer.{layer_idx}") prefix=f"{prefix}.layer.{layer_idx}")
for layer_idx in range(config.num_hidden_layers) for layer_idx in range(config.num_hidden_layers)
]) ])
def forward( def forward(
self, self,
positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
for layer in self.layer: for layer in self.layer:
hidden_states = layer(positions, hidden_states) hidden_states = layer(hidden_states)
return hidden_states return hidden_states
@ -143,8 +126,6 @@ class BertLayer(nn.Module):
config: BertConfig, config: BertConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
bias: bool = True,
rotary_kwargs: Optional[dict] = None,
prefix: str = ""): prefix: str = ""):
super().__init__() super().__init__()
@ -154,36 +135,23 @@ class BertLayer(nn.Module):
layer_norm_eps=config.layer_norm_eps, layer_norm_eps=config.layer_norm_eps,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
bias=bias,
rotary_kwargs=rotary_kwargs,
prefix=f"{prefix}.attention") prefix=f"{prefix}.attention")
if config.hidden_act in ["silu", "gelu_and_mul"]: self.intermediate = BertIntermediate(
self.intermediate = BertGatedIntermediate( hidden_size=config.hidden_size,
hidden_size=config.hidden_size, intermediate_size=config.intermediate_size,
intermediate_size=config.intermediate_size, hidden_act=config.hidden_act,
hidden_act=config.hidden_act, quant_config=quant_config,
bias=bias, prefix=f"{prefix}.intermediate")
quant_config=quant_config,
prefix=f"{prefix}.intermediate")
else:
self.intermediate = BertIntermediate(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.intermediate")
self.output = BertOutput(hidden_size=config.hidden_size, self.output = BertOutput(hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
layer_norm_eps=config.layer_norm_eps, layer_norm_eps=config.layer_norm_eps,
bias=bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.output") prefix=f"{prefix}.output")
def forward(self, positions: torch.Tensor, hidden_states: torch.Tensor): def forward(self, hidden_states: torch.Tensor):
attn_output = self.attention(positions, hidden_states) attn_output = self.attention(hidden_states)
intermediate_output = self.intermediate(attn_output) intermediate_output = self.intermediate(attn_output)
output = self.output(intermediate_output, attn_output) output = self.output(intermediate_output, attn_output)
return output return output
@ -198,8 +166,6 @@ class BertAttention(nn.Module):
layer_norm_eps: float, layer_norm_eps: float,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
bias: bool = True,
rotary_kwargs: Optional[dict] = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
@ -208,22 +174,18 @@ class BertAttention(nn.Module):
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
bias=bias,
rotary_kwargs=rotary_kwargs,
prefix=f"{prefix}.output") prefix=f"{prefix}.output")
self.output = BertSelfOutput(hidden_size=hidden_size, self.output = BertSelfOutput(hidden_size=hidden_size,
layer_norm_eps=layer_norm_eps, layer_norm_eps=layer_norm_eps,
bias=bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.output") prefix=f"{prefix}.output")
def forward( def forward(
self, self,
positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
self_output = self.self(positions, hidden_states) self_output = self.self(hidden_states)
return self.output(self_output, hidden_states) return self.output(self_output, hidden_states)
@ -235,8 +197,6 @@ class BertSelfAttention(nn.Module):
num_attention_heads: int, num_attention_heads: int,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
bias: bool = True,
rotary_kwargs: Optional[dict] = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
@ -261,15 +221,10 @@ class BertSelfAttention(nn.Module):
head_size=self.head_dim, head_size=self.head_dim,
total_num_heads=self.total_num_heads, total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_kv_heads, total_num_kv_heads=self.total_num_kv_heads,
bias=bias, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj") prefix=f"{prefix}.qkv_proj")
if rotary_kwargs:
self.rotary_emb = get_rope(**rotary_kwargs)
else:
self.rotary_emb = None
self.attn = Attention(num_heads=self.num_heads, self.attn = Attention(num_heads=self.num_heads,
head_size=self.head_dim, head_size=self.head_dim,
scale=self.scaling, scale=self.scaling,
@ -281,15 +236,10 @@ class BertSelfAttention(nn.Module):
def forward( def forward(
self, self,
positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.rotary_emb:
q, k = self.rotary_emb(positions, q, k)
output = self.attn(q, k, v) output = self.attn(q, k, v)
return output return output
@ -299,13 +249,12 @@ class BertSelfOutput(nn.Module):
def __init__(self, def __init__(self,
hidden_size: int, hidden_size: int,
layer_norm_eps: float, layer_norm_eps: float,
bias: bool = True,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""): prefix: str = ""):
super().__init__() super().__init__()
self.dense = RowParallelLinear(input_size=hidden_size, self.dense = RowParallelLinear(input_size=hidden_size,
output_size=hidden_size, output_size=hidden_size,
bias=bias, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.dense") prefix=f"{prefix}.dense")
self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
@ -323,13 +272,12 @@ class BertIntermediate(nn.Module):
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
hidden_act: str, hidden_act: str,
bias: bool = True,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""): prefix: str = ""):
super().__init__() super().__init__()
self.dense = ColumnParallelLinear(input_size=hidden_size, self.dense = ColumnParallelLinear(input_size=hidden_size,
output_size=intermediate_size, output_size=intermediate_size,
bias=bias, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.dense") prefix=f"{prefix}.dense")
self.intermediate_act_fn = get_act_fn(hidden_act) self.intermediate_act_fn = get_act_fn(hidden_act)
@ -340,46 +288,19 @@ class BertIntermediate(nn.Module):
return hidden_states return hidden_states
class BertGatedIntermediate(nn.Module):
# for NomciBert and GteModel
def __init__(self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
bias: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.act_fn = get_act_and_mul_fn(hidden_act)
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
gate_up, _ = self.gate_up_proj(hidden_states)
hidden_states = self.act_fn(gate_up)
return hidden_states
class BertOutput(nn.Module): class BertOutput(nn.Module):
def __init__(self, def __init__(self,
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
layer_norm_eps: float, layer_norm_eps: float,
bias: bool = True,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""): prefix: str = ""):
super().__init__() super().__init__()
self.dense = RowParallelLinear(input_size=intermediate_size, self.dense = RowParallelLinear(input_size=intermediate_size,
output_size=hidden_size, output_size=hidden_size,
bias=bias, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.dense") prefix=f"{prefix}.dense")
@ -393,33 +314,18 @@ class BertOutput(nn.Module):
class BertModel(nn.Module, SupportsQuant): class BertModel(nn.Module, SupportsQuant):
packed_modules_mapping = { packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]}
"qkv_proj": ["query", "key", "value"],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
def __init__(self, def __init__(self,
*, *,
vllm_config: VllmConfig, vllm_config: VllmConfig,
prefix: str = "", prefix: str = "",
embedding_class: type = BertEmbedding, embedding_class: type = BertEmbedding,
bias: bool = True,
rotary_kwargs: Optional[dict] = None,
add_pooling_layer: bool = False): add_pooling_layer: bool = False):
super().__init__() super().__init__()
"""
For BertModel, all linear layers have bias.
For NomicBertModel, all linear layers do not have bias.
"""
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
self.embeddings = embedding_class(config) self.embeddings = embedding_class(config)
self.encoder = BertEncoder(vllm_config=vllm_config, self.encoder = BertEncoder(vllm_config=vllm_config,
bias=bias,
rotary_kwargs=rotary_kwargs,
prefix=f"{prefix}.encoder") prefix=f"{prefix}.encoder")
self.pooler = BertPooler(config) if add_pooling_layer else None self.pooler = BertPooler(config) if add_pooling_layer else None
@ -441,7 +347,7 @@ class BertModel(nn.Module, SupportsQuant):
seq_lens=attn_metadata.seq_lens_tensor, seq_lens=attn_metadata.seq_lens_tensor,
position_ids=position_ids, position_ids=position_ids,
token_type_ids=token_type_ids) token_type_ids=token_type_ids)
return self.encoder(position_ids, hidden_states) return self.encoder(hidden_states)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
@ -450,8 +356,6 @@ class BertModel(nn.Module, SupportsQuant):
("qkv_proj", "query", "q"), ("qkv_proj", "query", "q"),
("qkv_proj", "key", "k"), ("qkv_proj", "key", "k"),
("qkv_proj", "value", "v"), ("qkv_proj", "value", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
@ -497,7 +401,6 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
self.config = vllm_config.model_config.hf_config
self.model = self._build_model(vllm_config=vllm_config, self.model = self._build_model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model")) prefix=maybe_prefix(prefix, "model"))
self._pooler = self._build_pooler(pooler_config) self._pooler = self._build_pooler(pooler_config)
@ -611,115 +514,3 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
token_type_ids=token_type_ids) token_type_ids=token_type_ids)
class NomicBertEmbeddingModel(BertEmbeddingModel):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
"emb_ln": "embeddings.LayerNorm",
"layers": "layer",
"attn.Wqkv": "attention.self.qkv_proj",
"attn.out_proj": "attention.output.dense",
'norm1': "attention.output.LayerNorm",
'mlp.fc11': "intermediate.up_proj",
'mlp.fc12': "intermediate.gate_proj",
'mlp.fc2': "output.dense",
'norm2': "output.LayerNorm",
})
def _build_model(self,
vllm_config: VllmConfig,
prefix: str = "") -> BertModel:
config = vllm_config.model_config.hf_config
assert config.__class__.__name__ == "NomicBertConfig"
assert config.activation_function == "swiglu"
# Assume NomicBertModel all linear layers do not have bias
assert not config.mlp_fc1_bias
assert not config.mlp_fc2_bias
assert not config.qkv_proj_bias
config.layer_norm_eps = config.layer_norm_epsilon
config.position_embedding_type = "rotary"
config.intermediate_size = config.n_inner
config.hidden_act = "silu"
config.hidden_size = config.n_embd
config.num_hidden_layers = config.n_layer
head_dim = config.hidden_size // config.num_attention_heads
rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_trained_positions,
"base": config.rotary_emb_base,
"rope_scaling": {
"rope_type": "dynamic",
"factor": config.rotary_scaling_factor
}
}
return BertModel(vllm_config=vllm_config,
prefix=prefix,
bias=False,
rotary_kwargs=rotary_kwargs,
embedding_class=BertEmbedding)
class GteEmbeddingModel(BertEmbeddingModel):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
"attention.qkv_proj": "attention.self.qkv_proj",
"attention.o_proj": "attention.output.dense",
'attn_ln': "attention.output.LayerNorm",
'mlp.down_proj': "output.dense",
'mlp_ln': "output.LayerNorm",
})
def _build_model(self,
vllm_config: VllmConfig,
prefix: str = "") -> BertModel:
config = vllm_config.model_config.hf_config
assert config.__class__.__name__ == "GteConfig"
assert config.position_embedding_type == "rope"
assert config.hidden_act == "gelu"
config.position_embedding_type = "rotary"
config.hidden_act = "gelu_and_mul"
head_dim = config.hidden_size // config.num_attention_heads
rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_position_embeddings,
"base": config.rope_theta,
}
model = BertModel(vllm_config=vllm_config,
prefix=prefix,
rotary_kwargs=rotary_kwargs,
embedding_class=BertEmbedding)
# GteModel only gate_up_proj does not have bias.
# Hack method learned from vllm/model_executor/models/glm.py
for layer in model.encoder.layer:
layer.intermediate.gate_up_proj.bias = None
layer.intermediate.skip_bias_add = True
return model
def split_up_gate_proj(self, weights: Iterable[Tuple[str, torch.Tensor]]):
n = "mlp.up_gate_proj"
for name, weight in weights:
if n in name:
up, gate = weight.chunk(2, dim=0)
yield name.replace(n, "intermediate.up_proj"), up
yield name.replace(n, "intermediate.gate_proj"), gate
else:
yield name, weight
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weights = self.hf_to_vllm_mapper.apply(weights)
weights = self.split_up_gate_proj(weights)
self.model.load_weights(weights)

View File

@ -0,0 +1,652 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Iterable, Optional, Set, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import (get_act_and_mul_fn,
get_act_fn)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models import SupportsV0Only
from vllm.model_executor.models.interfaces import SupportsQuant
from vllm.model_executor.models.utils import WeightsMapper
from vllm.sequence import IntermediateTensors
class BertWithRopeEmbedding(nn.Module):
def __init__(self, config: PretrainedConfig):
super().__init__()
assert config.type_vocab_size > 0
self.word_embeddings = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.token_type_embeddings = VocabParallelEmbedding(
config.type_vocab_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
input_shape = input_ids.size()
inputs_embeds = self.word_embeddings(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape,
dtype=torch.long,
device=inputs_embeds.device)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + token_type_embeddings
embeddings = self.LayerNorm(embeddings)
return embeddings
class BertWithRopeAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = True,
rotary_kwargs: Optional[dict] = None,
prefix: str = "",
):
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_attention_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = self.total_num_heads
self.head_dim = self.hidden_size // self.total_num_heads
assert self.head_dim * self.total_num_heads == self.hidden_size
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.qkv_proj = QKVParallelLinear(
hidden_size=self.hidden_size,
head_size=self.head_dim,
total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj")
self.rotary_emb = get_rope(**rotary_kwargs)
self.attn = Attention(num_heads=self.num_heads,
head_size=self.head_dim,
scale=self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
attn_type=AttentionType.ENCODER_ONLY)
self.out_proj = RowParallelLinear(input_size=hidden_size,
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.dense")
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.out_proj(attn_output)
return output
class BertWithRopeGatedMLP(nn.Module):
def __init__(self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
bias: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.act_fn = get_act_and_mul_fn(hidden_act)
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(input_size=intermediate_size,
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
gate_up, _ = self.gate_up_proj(hidden_states)
hidden_states = self.act_fn(gate_up)
hidden_states, _ = self.down_proj(hidden_states)
return hidden_states
class BertWithRopeMLP(nn.Module):
def __init__(self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
bias: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.act_fn = get_act_fn(hidden_act)
self.up_proj = ColumnParallelLinear(input_size=hidden_size,
output_size=intermediate_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.up_proj")
self.down_proj = RowParallelLinear(input_size=intermediate_size,
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.up_proj(hidden_states)
hidden_states = self.act_fn(hidden_states)
hidden_states, _ = self.down_proj(hidden_states)
return hidden_states
class NomicRouter(nn.Module):
def __init__(self, hidden_size: int, moe_num_experts: int, moe_top_k: int):
super().__init__()
self.moe_top_k = moe_top_k
self.layer = ReplicatedLinear(hidden_size, moe_num_experts, bias=False)
def forward(
self, x: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]:
weights = self.layer(x.view(-1, x.shape[-1]))[0].softmax(
dim=-1, dtype=torch.float32)
top_weights, top_experts = torch.topk(weights, self.moe_top_k, dim=-1)
weights = weights.to(x.dtype)
top_weights = top_weights.to(x.dtype)
return weights, top_weights, top_experts # type: ignore
class NomicExpertMLP(nn.Module):
def __init__(self, hidden_size: int, ffn_hidden_size: int,
moe_num_experts: int, ffn_act_fn: str):
super().__init__()
self.hidden_size = hidden_size
self.ffn_hidden_size = ffn_hidden_size
self.moe_num_experts = moe_num_experts
self.w1 = nn.Parameter(
torch.empty(moe_num_experts * ffn_hidden_size, hidden_size))
self.w2 = nn.Parameter(
torch.empty(moe_num_experts * ffn_hidden_size, hidden_size))
self.activation_fn = get_act_fn(ffn_act_fn)
def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor:
expert_w1 = self.w1.view(self.moe_num_experts, self.ffn_hidden_size,
self.hidden_size)[expert_idx]
expert_w2 = self.w2.view(self.moe_num_experts, self.ffn_hidden_size,
self.hidden_size)[expert_idx]
x1 = x.matmul(expert_w1.t())
act_out = self.activation_fn(x1)
x2 = act_out.matmul(expert_w2)
return x2
class NomicExperts(nn.Module):
def __init__(self, config, hidden_size: int, ffn_hidden_size: int,
moe_num_experts: int):
super().__init__()
self.moe_num_experts = moe_num_experts
self.mlp = NomicExpertMLP(hidden_size=config.n_embd,
ffn_hidden_size=config.n_inner,
moe_num_experts=moe_num_experts,
ffn_act_fn=config.hidden_act)
self.bias = nn.Parameter(torch.zeros(config.n_embd))
def forward(self, x: torch.Tensor, weights: torch.Tensor,
top_weights: torch.Tensor,
top_experts: torch.LongTensor) -> torch.Tensor:
q_len, hidden_size = x.shape
x = x.view(-1, hidden_size)
out = torch.zeros_like(x)
expert_mask = nn.functional.one_hot(
top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0)
for expert_idx in range(0, self.moe_num_experts):
topk_idx, token_idx = torch.where(expert_mask[expert_idx])
if token_idx.shape[0] == 0:
continue
token_list = token_idx.tolist()
topk_list = topk_idx.tolist()
expert_tokens = x[None, token_list].reshape(-1, hidden_size)
expert_out = self.mlp(
expert_tokens, expert_idx) * top_weights[token_list, topk_list,
None]
out.index_add_(0, token_idx, expert_out)
out = out.reshape(q_len, hidden_size)
return out + self.bias
class NomicMoELayer(nn.Module):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.router = NomicRouter(
config.n_embd,
moe_num_experts=config.num_experts,
moe_top_k=config.moe_top_k,
)
self.experts = NomicExperts(
config,
hidden_size=config.n_embd,
ffn_hidden_size=config.n_inner,
moe_num_experts=config.num_experts,
)
def forward(self, x: torch.Tensor):
weights, top_weights, top_experts = self.router(x)
out = self.experts(x, weights, top_weights, top_experts)
return out
class BertWithRopeBlock(nn.Module):
def __init__(self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
moe: bool = False,
bias: bool = True,
rotary_kwargs: Optional[dict] = None,
prefix: str = ""):
super().__init__()
self.attn = BertWithRopeAttention(
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
cache_config=cache_config,
quant_config=quant_config,
bias=bias,
rotary_kwargs=rotary_kwargs,
prefix=f"{prefix}.attention")
if moe:
self.mlp = NomicMoELayer(config=config, )
else:
if config.hidden_act in ["silu", "gelu_and_mul"]:
self.mlp = BertWithRopeGatedMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
else:
self.mlp = BertWithRopeMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.attn_ln = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.mlp_ln = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
def forward(self, positions: torch.Tensor, hidden_states: torch.Tensor):
attn_output = self.attn(positions, hidden_states)
hidden_states = self.attn_ln(hidden_states + attn_output)
mlp_out = self.mlp(hidden_states)
hidden_states = self.mlp_ln(hidden_states + mlp_out)
return hidden_states
@support_torch_compile
class BertWithRopeEncoder(nn.Module):
def __init__(self,
vllm_config: VllmConfig,
bias: bool = True,
rotary_kwargs: Optional[dict] = None,
prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
every_n = getattr(config, "moe_every_n_layers", 0)
self.layers = nn.ModuleList([
BertWithRopeBlock(config=config,
cache_config=cache_config,
quant_config=quant_config,
bias=bias,
moe=every_n > 0 and (layer_idx % every_n == 1),
rotary_kwargs=rotary_kwargs,
prefix=f"{prefix}.layer.{layer_idx}")
for layer_idx in range(config.num_hidden_layers)
])
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
for layer in self.layers:
hidden_states = layer(positions, hidden_states)
return hidden_states
class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant):
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.config = self.config_verify(vllm_config)
self.embeddings = BertWithRopeEmbedding(self.config)
self.encoder = BertWithRopeEncoder(
vllm_config=vllm_config,
bias=getattr(self.config, "bias", True),
rotary_kwargs=self.config.rotary_kwargs,
prefix=f"{prefix}.encoder")
def config_verify(self, vllm_config):
raise NotImplementedError
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embeddings(input_ids=input_ids,
token_type_ids=token_type_ids)
return self.encoder(positions, hidden_states)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
weights = self.hf_to_vllm_mapper.apply(weights)
if self.config.hidden_act in ["silu", "gelu_and_mul"]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
else:
stacked_params_mapping = []
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "pooler" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class NomicBertModel(BertWithRope):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
"emb_ln": "embeddings.LayerNorm",
"attn.Wqkv": "attn.qkv_proj",
"norm1": "attn_ln",
"mlp.fc1.": "mlp.up_proj.",
"mlp.fc11": "mlp.up_proj",
"mlp.fc12": "mlp.gate_proj",
"mlp.fc2": "mlp.down_proj",
"norm2": "mlp_ln",
})
def config_verify(self, vllm_config):
config = vllm_config.model_config.hf_config
assert config.__class__.__name__ == "NomicBertConfig"
assert config.activation_function in ["swiglu", "gelu"]
if config.activation_function == "swiglu":
config.hidden_act = "silu"
else:
config.hidden_act = config.activation_function
assert (config.mlp_fc1_bias == config.mlp_fc2_bias ==
config.qkv_proj_bias)
config.bias = config.qkv_proj_bias
assert config.rotary_emb_scale_base is None
assert not config.rotary_emb_interleaved
config.layer_norm_eps = config.layer_norm_epsilon
config.intermediate_size = config.n_inner
config.hidden_size = config.n_embd
config.num_hidden_layers = config.n_layer
head_dim = config.hidden_size // config.num_attention_heads
rotary_emb_dim = head_dim * config.rotary_emb_fraction
config.rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": rotary_emb_dim,
"max_position": config.max_trained_positions,
"base": getattr(config, "rope_theta", config.rotary_emb_base),
"rope_scaling": getattr(config, "rope_scaling", None)
}
# we ignore config.rotary_scaling_factor so that for datasets shorter
# than max_trained_positions 2048, the results are consistent
# with SentenceTransformer.
# The context extension uses vllm style rope_theta and rope_scaling.
# See #17785
return config
class GteModel(BertWithRope):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
"layer": 'layers',
"attention.qkv_proj": "attn.qkv_proj",
"attention.o_proj": "attn.out_proj",
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
# GteModel only gate_up_proj does not have bias.
# Hack method learned from vllm/model_executor/models/glm.py
for layer in self.encoder.layers:
layer.mlp.gate_up_proj.bias = None
layer.mlp.gate_up_proj.skip_bias_add = True
def config_verify(self, vllm_config):
config = vllm_config.model_config.hf_config
assert config.__class__.__name__ == "GteConfig"
assert config.position_embedding_type == "rope"
assert config.hidden_act == "gelu"
config.position_embedding_type = "rotary"
config.hidden_act = "gelu_and_mul"
head_dim = config.hidden_size // config.num_attention_heads
config.rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_position_embeddings,
"base": config.rope_theta,
"rope_scaling": getattr(config, "rope_scaling", None)
}
return config
def split_up_gate_proj(self, weights: Iterable[Tuple[str, torch.Tensor]]):
n = "mlp.up_gate_proj"
for name, weight in weights:
if n in name:
up, gate = weight.chunk(2, dim=0)
yield name.replace(n, "mlp.up_proj"), up
yield name.replace(n, "mlp.gate_proj"), gate
else:
yield name, weight
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
weights = self.split_up_gate_proj(weights)
return super().load_weights(weights)
class JinaRobertaModel(BertWithRope):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
"emb_ln": "embeddings.LayerNorm",
"mixer.Wqkv": "attn.qkv_proj",
"mixer.out_proj": "attn.out_proj",
"norm1": "attn_ln",
"mlp.fc1.": "mlp.up_proj.",
"mlp.fc2": "mlp.down_proj",
"norm2": "mlp_ln",
})
def config_verify(self, vllm_config):
config = vllm_config.model_config.hf_config
head_dim = config.hidden_size // config.num_attention_heads
config.rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_position_embeddings,
"base": getattr(config, "rope_theta", config.rotary_emb_base),
"rope_scaling": getattr(config, "rope_scaling", None)
}
return config
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return super().forward(input_ids=input_ids,
positions=position_ids,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
token_type_ids=token_type_ids)
@torch.inference_mode()
def jina_merge_lora_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]):
# use for jina-embeddings-v3
# Merge Lora weights into a single weight tensor.
# This is a temporary solution until we have a better way to handle
scaling = self.config.lora_alpha / self.config.lora_rank
weights = {name: weight for name, weight in weights}
o = ".original"
a = ".0.lora_A"
b = ".0.lora_B"
# text-matching
i = -1
for name in list(weights.keys()):
if o in name:
dtype = weights[name].dtype
shape = weights[name].shape
weight_name = name[:-len(o)]
if "embeddings" in weight_name:
B = weights[weight_name + a][i].cuda().float()
A = weights[weight_name + b][i].cuda().float()
else:
B = weights[weight_name + b][i].cuda().float()
A = weights[weight_name + a][i].cuda().float()
weight = (weights[weight_name + o].cuda() +
torch.matmul(B, A).view(shape) * scaling)
weight = weight.cpu().to(dtype)
weights[weight_name.replace(".parametrizations", "")] = weight
del weights[weight_name + o], weights[weight_name +
a], weights[weight_name +
b]
return [(name, weight) for name, weight in weights.items()]
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
weights = self.jina_merge_lora_weights(weights)
return super().load_weights(weights)

View File

@ -126,7 +126,7 @@ _EMBEDDING_MODELS = {
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"), "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
"GlmForCausalLM": ("glm", "GlmForCausalLM"), "GlmForCausalLM": ("glm", "GlmForCausalLM"),
"GritLM": ("gritlm", "GritLM"), "GritLM": ("gritlm", "GritLM"),
"GteModel": ("bert", "GteEmbeddingModel"), "GteModel": ("bert_with_rope", "GteModel"),
"InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"), "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
"JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"), # noqa: E501 "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"), # noqa: E501
"LlamaModel": ("llama", "LlamaForCausalLM"), "LlamaModel": ("llama", "LlamaForCausalLM"),
@ -136,7 +136,7 @@ _EMBEDDING_MODELS = {
if arch == "LlamaForCausalLM" if arch == "LlamaForCausalLM"
}, },
"MistralModel": ("llama", "LlamaForCausalLM"), "MistralModel": ("llama", "LlamaForCausalLM"),
"NomicBertModel": ("bert", "NomicBertEmbeddingModel"), "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
"Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"), "Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import itertools import itertools
from typing import Iterable, Optional, Tuple from typing import Iterable, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
@ -19,6 +19,7 @@ from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.transformers_utils.config import ( from vllm.transformers_utils.config import (
get_cross_encoder_activation_function) get_cross_encoder_activation_function)
from .bert_with_rope import BertWithRope, JinaRobertaModel
from .interfaces import SupportsCrossEncoding, SupportsV0Only from .interfaces import SupportsCrossEncoding, SupportsV0Only
@ -125,39 +126,20 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
def _build_model(self, def _build_model(self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
prefix: str = "") -> BertModel: prefix: str = "") -> Union[BertModel, BertWithRope]:
if (vllm_config.model_config.hf_config.position_embedding_type == if (vllm_config.model_config.hf_config.position_embedding_type ==
"rotary"): "rotary"):
config = vllm_config.model_config.hf_config return JinaRobertaModel(vllm_config=vllm_config, prefix=prefix)
head_dim = config.hidden_size // config.num_attention_heads
rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_position_embeddings,
"base": config.rotary_emb_base,
"rope_scaling": getattr(config, "rope_scaling", None)
}
return BertModel(vllm_config=vllm_config,
rotary_kwargs=rotary_kwargs,
prefix=prefix)
else: else:
return BertModel(vllm_config=vllm_config, return BertModel(vllm_config=vllm_config,
prefix=prefix, prefix=prefix,
embedding_class=RobertaEmbedding) embedding_class=RobertaEmbedding)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
if getattr(self.config, "lora_rank", 0) > 0:
scaling = self.config.lora_alpha / self.config.lora_rank
weights = jina_merge_lora_weights(weights, scaling)
weights = self.hf_to_vllm_mapper.apply(weights) weights = self.hf_to_vllm_mapper.apply(weights)
# Separate weights in "roberta"-prefixed and all else (not in memory). # Separate weights in "roberta"-prefixed and all else (not in memory).
# For use with models like FacebookAI/roberta-base. # For use with models like FacebookAI/roberta-base.
bert_weights, task_weights = roberta_task_weights_filter(weights) bert_weights, task_weights = roberta_task_weights_filter(weights)
bert_weights = jina_to_vllm_mapper.apply(bert_weights)
loaded = self.model.load_weights(bert_weights) loaded = self.model.load_weights(bert_weights)
if not len(loaded): if not len(loaded):
# Fix for models like `sentence-transformers/stsb-roberta-base-v2` # Fix for models like `sentence-transformers/stsb-roberta-base-v2`
@ -178,6 +160,18 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
_pooler: An instance of Pooler used for pooling operations. _pooler: An instance of Pooler used for pooling operations.
""" """
jina_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
'emb_ln': "embeddings.LayerNorm",
'layers': "layer",
'mixer.Wqkv': "attention.self.qkv_proj",
'mixer.out_proj': "attention.output.dense",
'norm1': "attention.output.LayerNorm",
'mlp.fc1': "intermediate.dense",
'mlp.fc2': "output.dense",
'norm2': "output.LayerNorm",
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
@ -195,7 +189,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
bert_weights, task_weights = roberta_task_weights_filter(weights) bert_weights, task_weights = roberta_task_weights_filter(weights)
bert_weights = jina_to_vllm_mapper.apply(bert_weights) bert_weights = self.jina_to_vllm_mapper.apply(bert_weights)
self.roberta.load_weights(bert_weights) self.roberta.load_weights(bert_weights)
@ -276,57 +270,3 @@ def roberta_task_weights_filter(
return encoder_decoder_weights(), ((n, w) for n, w in all_weights2 return encoder_decoder_weights(), ((n, w) for n, w in all_weights2
if not n.startswith("roberta.")) if not n.startswith("roberta."))
jina_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
'emb_ln': "embeddings.LayerNorm",
'layers': "layer",
'mixer.Wqkv': "attention.self.qkv_proj",
'mixer.out_proj': "attention.output.dense",
'norm1': "attention.output.LayerNorm",
'mlp.fc1': "intermediate.dense",
'mlp.fc2': "output.dense",
'norm2': "output.LayerNorm",
})
@torch.inference_mode()
def jina_merge_lora_weights(weights: Iterable[Tuple[str, torch.Tensor]],
scaling: float = 1.0):
# use for jina-embeddings-v3
# Merge Lora weights into a single weight tensor.
# This is a temporary solution until we have a better way to handle
weights = {name: weight for name, weight in weights}
o = ".original"
a = ".0.lora_A"
b = ".0.lora_B"
# text-matching
i = -1
for name in list(weights.keys()):
if o in name:
dtype = weights[name].dtype
shape = weights[name].shape
weight_name = name[:-len(o)]
if "embeddings" in weight_name:
B = weights[weight_name + a][i].cuda().float()
A = weights[weight_name + b][i].cuda().float()
else:
B = weights[weight_name + b][i].cuda().float()
A = weights[weight_name + a][i].cuda().float()
weight = (weights[weight_name + o].cuda() +
torch.matmul(B, A).view(shape) * scaling)
weight = weight.cpu().to(dtype)
weights[weight_name.replace(".parametrizations", "")] = weight
del weights[weight_name + o], weights[weight_name +
a], weights[weight_name + b]
return [(name, weight) for name, weight in weights.items()]