mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-24 00:54:34 +08:00
[New Model]: nomic-embed-text-v2-moe (#17785)
This commit is contained in:
parent
06c0922a69
commit
e4b8713380
@ -622,7 +622,7 @@ Specified using `--task embed`.
|
||||
* [PP](#distributed-serving)
|
||||
- * `BertModel`
|
||||
* BERT-based
|
||||
* `BAAI/bge-base-en-v1.5`, etc.
|
||||
* `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc.
|
||||
*
|
||||
*
|
||||
- * `Gemma2Model`
|
||||
@ -635,6 +635,16 @@ Specified using `--task embed`.
|
||||
* `parasail-ai/GritLM-7B-vllm`.
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
- * `GteModel`
|
||||
* GteModel
|
||||
* `Snowflake/snowflake-arctic-embed-m-v2.0`.
|
||||
*
|
||||
* ︎
|
||||
- * `NomicBertModel`
|
||||
* NomicBertModel
|
||||
* `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc.
|
||||
* ︎
|
||||
* ︎
|
||||
- * `LlamaModel`, `LlamaForCausalLM`, `MistralModel`, etc.
|
||||
* Llama-based
|
||||
* `intfloat/e5-mistral-7b-instruct`, etc.
|
||||
@ -647,12 +657,12 @@ Specified using `--task embed`.
|
||||
* ✅︎
|
||||
- * `RobertaModel`, `RobertaForMaskedLM`
|
||||
* RoBERTa-based
|
||||
* `sentence-transformers/all-roberta-large-v1`, `sentence-transformers/all-roberta-large-v1`, etc.
|
||||
* `sentence-transformers/all-roberta-large-v1`, etc.
|
||||
*
|
||||
*
|
||||
- * `XLMRobertaModel`
|
||||
* XLM-RoBERTa-based
|
||||
* `intfloat/multilingual-e5-large`, `jinaai/jina-reranker-v2-base-multilingual`, etc.
|
||||
* `intfloat/multilingual-e5-large`, `jinaai/jina-reranker-v2-base-multilingual`, `Snowflake/snowflake-arctic-embed-l-v2.0`, `jinaai/jina-embeddings-v3`(see note), etc.
|
||||
*
|
||||
*
|
||||
:::
|
||||
@ -670,6 +680,10 @@ For both the 1.5B and 7B variants, you also need to enable `--trust-remote-code`
|
||||
See [relevant issue on HF Transformers](https://github.com/huggingface/transformers/issues/34882).
|
||||
:::
|
||||
|
||||
:::{note}
|
||||
`jinaai/jina-embeddings-v3` supports multiple tasks through lora, while vllm temporarily only supports text-matching tasks by merging lora weights.
|
||||
:::
|
||||
|
||||
If your model is not in the above list, we will try to automatically convert the model using
|
||||
{func}`~vllm.model_executor.models.adapters.as_embedding_model`. By default, the embeddings
|
||||
of the whole prompt are extracted from the normalized hidden state corresponding to the last token.
|
||||
|
||||
111
tests/models/language/pooling/mteb_utils.py
Normal file
111
tests/models/language/pooling/mteb_utils.py
Normal file
@ -0,0 +1,111 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import math
|
||||
from collections.abc import Sequence
|
||||
|
||||
import mteb
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from tests.models.utils import EmbedModelInfo
|
||||
|
||||
# Most models on the STS12 task (See #17175):
|
||||
# - Model implementation and minor changes in tensor dtype
|
||||
# results in differences less than 1e-4
|
||||
# - Different model results in differences more than 1e-3
|
||||
# 1e-4 is a good tolerance threshold
|
||||
MTEB_EMBED_TASKS = ["STS12"]
|
||||
MTEB_EMBED_TOL = 1e-4
|
||||
|
||||
|
||||
class VllmMtebEncoder(mteb.Encoder):
|
||||
|
||||
def __init__(self, vllm_model):
|
||||
super().__init__()
|
||||
self.model = vllm_model
|
||||
self.rng = np.random.default_rng(seed=42)
|
||||
|
||||
def encode(
|
||||
self,
|
||||
sentences: Sequence[str],
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
# Hoping to discover potential scheduling
|
||||
# issues by randomizing the order.
|
||||
r = self.rng.permutation(len(sentences))
|
||||
sentences = [sentences[i] for i in r]
|
||||
outputs = self.model.encode(sentences, use_tqdm=False)
|
||||
embeds = np.array(outputs)
|
||||
embeds = embeds[np.argsort(r)]
|
||||
return embeds
|
||||
|
||||
|
||||
class OpenAIClientMtebEncoder(mteb.Encoder):
|
||||
|
||||
def __init__(self, model_name: str, client):
|
||||
super().__init__()
|
||||
self.model_name = model_name
|
||||
self.client = client
|
||||
self.rng = np.random.default_rng(seed=42)
|
||||
|
||||
def encode(self, sentences: Sequence[str], *args, **kwargs) -> np.ndarray:
|
||||
# Hoping to discover potential scheduling
|
||||
# issues by randomizing the order.
|
||||
r = self.rng.permutation(len(sentences))
|
||||
sentences = [sentences[i] for i in r]
|
||||
|
||||
embeddings = self.client.embeddings.create(model=self.model_name,
|
||||
input=sentences)
|
||||
outputs = [d.embedding for d in embeddings.data]
|
||||
embeds = np.array(outputs)
|
||||
embeds = embeds[np.argsort(r)]
|
||||
return embeds
|
||||
|
||||
|
||||
def run_mteb_embed_task(encoder, tasks):
|
||||
tasks = mteb.get_tasks(tasks=tasks)
|
||||
evaluation = mteb.MTEB(tasks=tasks)
|
||||
results = evaluation.run(encoder, verbosity=0, output_folder=None)
|
||||
|
||||
main_score = results[0].scores["test"][0]["main_score"]
|
||||
return main_score
|
||||
|
||||
|
||||
def run_mteb_embed_task_st(model_name, tasks):
|
||||
from sentence_transformers import SentenceTransformer
|
||||
model = SentenceTransformer(model_name)
|
||||
return run_mteb_embed_task(model, tasks)
|
||||
|
||||
|
||||
def mteb_test_embed_models(hf_runner, vllm_runner, model_info: EmbedModelInfo):
|
||||
if not model_info.enable_test:
|
||||
# A model family has many models with the same architecture,
|
||||
# and we don't need to test each one.
|
||||
pytest.skip("Skipping test.")
|
||||
|
||||
with vllm_runner(model_info.name,
|
||||
task="embed",
|
||||
max_model_len=None,
|
||||
dtype=model_info.dtype) as vllm_model:
|
||||
|
||||
if model_info.architecture:
|
||||
assert (model_info.architecture
|
||||
in vllm_model.model.llm_engine.model_config.architectures)
|
||||
|
||||
vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model),
|
||||
MTEB_EMBED_TASKS)
|
||||
vllm_dtype = vllm_model.model.llm_engine.model_config.dtype
|
||||
model_dtype = getattr(
|
||||
vllm_model.model.llm_engine.model_config.hf_config, "torch_dtype",
|
||||
vllm_dtype)
|
||||
|
||||
with hf_runner(model_info.name,
|
||||
is_sentence_transformer=True,
|
||||
dtype=model_dtype) as hf_model:
|
||||
st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS)
|
||||
|
||||
print("VLLM:", vllm_dtype, vllm_main_score)
|
||||
print("SentenceTransformer:", model_dtype, st_main_score)
|
||||
print("Difference:", st_main_score - vllm_main_score)
|
||||
|
||||
assert math.isclose(st_main_score, vllm_main_score, rel_tol=MTEB_EMBED_TOL)
|
||||
47
tests/models/language/pooling/test_nomic.py
Normal file
47
tests/models/language/pooling/test_nomic.py
Normal file
@ -0,0 +1,47 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
|
||||
from ...utils import EmbedModelInfo, run_embedding_correctness_test
|
||||
|
||||
MODELS = [
|
||||
EmbedModelInfo("nomic-ai/nomic-embed-text-v1",
|
||||
architecture="NomicBertModel",
|
||||
dtype="float32",
|
||||
enable_test=True),
|
||||
EmbedModelInfo("nomic-ai/nomic-embed-text-v1.5",
|
||||
architecture="NomicBertModel",
|
||||
dtype="float32",
|
||||
enable_test=False),
|
||||
EmbedModelInfo("nomic-ai/nomic-embed-text-v2-moe",
|
||||
architecture="NomicBertModel",
|
||||
dtype="float32",
|
||||
enable_test=True)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_info", MODELS)
|
||||
def test_models_mteb(hf_runner, vllm_runner,
|
||||
model_info: EmbedModelInfo) -> None:
|
||||
from .mteb_utils import mteb_test_embed_models
|
||||
mteb_test_embed_models(hf_runner, vllm_runner, model_info)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_info", MODELS)
|
||||
def test_models_correctness(hf_runner, vllm_runner, model_info: EmbedModelInfo,
|
||||
example_prompts) -> None:
|
||||
if not model_info.enable_test:
|
||||
pytest.skip("Skipping test.")
|
||||
|
||||
with vllm_runner(model_info.name,
|
||||
task="embed",
|
||||
dtype=model_info.dtype,
|
||||
max_model_len=None) as vllm_model:
|
||||
vllm_outputs = vllm_model.encode(example_prompts)
|
||||
|
||||
with hf_runner(
|
||||
model_info.name,
|
||||
dtype=model_info.dtype,
|
||||
is_sentence_transformer=True,
|
||||
) as hf_model:
|
||||
run_embedding_correctness_test(hf_model, example_prompts, vllm_outputs)
|
||||
@ -1,12 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
|
||||
from ...utils import EmbedModelInfo, check_embeddings_close
|
||||
|
||||
EMBEDDING_PROMPTS = [
|
||||
'what is snowflake?', 'Where can I get the best tacos?', 'The Data Cloud!',
|
||||
'Mexico City of Course!'
|
||||
]
|
||||
from ...utils import EmbedModelInfo, run_embedding_correctness_test
|
||||
|
||||
MODELS = [
|
||||
EmbedModelInfo("Snowflake/snowflake-arctic-embed-xs",
|
||||
@ -45,51 +41,34 @@ MODELS = [
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_info", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
def test_models(
|
||||
def test_models_mteb(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model_info: EmbedModelInfo,
|
||||
dtype: str,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
from .mteb_utils import mteb_test_embed_models
|
||||
mteb_test_embed_models(hf_runner, vllm_runner, model_info)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_info", MODELS)
|
||||
def test_models_correctness(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
model_info: EmbedModelInfo,
|
||||
example_prompts,
|
||||
) -> None:
|
||||
if not model_info.enable_test:
|
||||
# A model family has many models with the same architecture,
|
||||
# and we don't need to test each one.
|
||||
pytest.skip("Skipping test.")
|
||||
|
||||
example_prompts = example_prompts + EMBEDDING_PROMPTS
|
||||
|
||||
vllm_extra_kwargs = {
|
||||
"hf_overrides": {
|
||||
"is_matryoshka": model_info.is_matryoshka
|
||||
}
|
||||
}
|
||||
|
||||
with hf_runner(model_info.name, dtype=dtype,
|
||||
is_sentence_transformer=True) as hf_model:
|
||||
hf_outputs = hf_model.encode(example_prompts)
|
||||
|
||||
with vllm_runner(model_info.name,
|
||||
task="embed",
|
||||
dtype=dtype,
|
||||
max_model_len=None,
|
||||
**vllm_extra_kwargs) as vllm_model:
|
||||
|
||||
assert (vllm_model.model.llm_engine.model_config.is_matryoshka ==
|
||||
model_info.is_matryoshka)
|
||||
|
||||
if model_info.architecture:
|
||||
assert (model_info.architecture
|
||||
in vllm_model.model.llm_engine.model_config.architectures)
|
||||
|
||||
dtype=model_info.dtype,
|
||||
max_model_len=None) as vllm_model:
|
||||
vllm_outputs = vllm_model.encode(example_prompts)
|
||||
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=hf_outputs,
|
||||
embeddings_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
tol=1e-2,
|
||||
)
|
||||
with hf_runner(
|
||||
model_info.name,
|
||||
dtype=model_info.dtype,
|
||||
is_sentence_transformer=True,
|
||||
) as hf_model:
|
||||
run_embedding_correctness_test(hf_model, example_prompts, vllm_outputs)
|
||||
|
||||
@ -332,9 +332,10 @@ def matryoshka_fy(tensor: torch.Tensor, dimensions: int):
|
||||
|
||||
class EmbedModelInfo(NamedTuple):
|
||||
name: str
|
||||
is_matryoshka: bool
|
||||
is_matryoshka: bool = False
|
||||
matryoshka_dimensions: Optional[list[int]] = None
|
||||
architecture: str = ""
|
||||
dtype: str = "auto"
|
||||
enable_test: bool = True
|
||||
|
||||
|
||||
|
||||
@ -11,16 +11,13 @@ from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.activation import (get_act_and_mul_fn,
|
||||
get_act_fn)
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.pooler import (CrossEncodingPooler, Pooler,
|
||||
PoolingType)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
@ -41,24 +38,19 @@ class BertEmbedding(nn.Module):
|
||||
self.size = config.hidden_size
|
||||
self.word_embeddings = VocabParallelEmbedding(config.vocab_size,
|
||||
config.hidden_size)
|
||||
|
||||
self.position_embeddings = VocabParallelEmbedding(
|
||||
config.max_position_embeddings, config.hidden_size)
|
||||
self.token_type_embeddings = VocabParallelEmbedding(
|
||||
config.type_vocab_size, config.hidden_size)
|
||||
self.LayerNorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.position_ids = nn.Parameter(
|
||||
torch.empty((1, config.max_position_embeddings)), )
|
||||
|
||||
self.position_embedding_type = config.position_embedding_type
|
||||
if self.position_embedding_type == "absolute":
|
||||
self.position_embeddings = VocabParallelEmbedding(
|
||||
config.max_position_embeddings, config.hidden_size)
|
||||
self.position_ids = nn.Parameter(
|
||||
torch.empty((1, config.max_position_embeddings)), )
|
||||
elif self.position_embedding_type == "rotary":
|
||||
self.position_embeddings = None
|
||||
self.position_ids = None
|
||||
else:
|
||||
raise ValueError("Only 'absolute' and 'rotary' " +
|
||||
"position_embedding_type is supported")
|
||||
if self.position_embedding_type != "absolute":
|
||||
raise ValueError("Only 'absolute' position_embedding_type" +
|
||||
" is supported")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -72,6 +64,9 @@ class BertEmbedding(nn.Module):
|
||||
# Input embeddings.
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
|
||||
# Position embeddings.
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros(input_shape,
|
||||
dtype=torch.long,
|
||||
@ -79,12 +74,7 @@ class BertEmbedding(nn.Module):
|
||||
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
|
||||
embeddings = inputs_embeds + token_type_embeddings
|
||||
|
||||
if self.position_embedding_type == "absolute":
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
embeddings += position_embeddings
|
||||
|
||||
embeddings = inputs_embeds + token_type_embeddings + position_embeddings
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
return embeddings
|
||||
|
||||
@ -108,11 +98,7 @@ class BertPooler(nn.Module):
|
||||
@support_torch_compile
|
||||
class BertEncoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
bias: bool = True,
|
||||
rotary_kwargs: Optional[dict] = None,
|
||||
prefix: str = ""):
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
@ -121,19 +107,16 @@ class BertEncoder(nn.Module):
|
||||
BertLayer(config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
bias=bias,
|
||||
rotary_kwargs=rotary_kwargs,
|
||||
prefix=f"{prefix}.layer.{layer_idx}")
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
for layer in self.layer:
|
||||
hidden_states = layer(positions, hidden_states)
|
||||
hidden_states = layer(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@ -143,8 +126,6 @@ class BertLayer(nn.Module):
|
||||
config: BertConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
bias: bool = True,
|
||||
rotary_kwargs: Optional[dict] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
@ -154,36 +135,23 @@ class BertLayer(nn.Module):
|
||||
layer_norm_eps=config.layer_norm_eps,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
bias=bias,
|
||||
rotary_kwargs=rotary_kwargs,
|
||||
prefix=f"{prefix}.attention")
|
||||
|
||||
if config.hidden_act in ["silu", "gelu_and_mul"]:
|
||||
self.intermediate = BertGatedIntermediate(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.intermediate")
|
||||
else:
|
||||
self.intermediate = BertIntermediate(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.intermediate")
|
||||
self.intermediate = BertIntermediate(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.intermediate")
|
||||
|
||||
self.output = BertOutput(hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
layer_norm_eps=config.layer_norm_eps,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.output")
|
||||
|
||||
def forward(self, positions: torch.Tensor, hidden_states: torch.Tensor):
|
||||
attn_output = self.attention(positions, hidden_states)
|
||||
def forward(self, hidden_states: torch.Tensor):
|
||||
attn_output = self.attention(hidden_states)
|
||||
intermediate_output = self.intermediate(attn_output)
|
||||
output = self.output(intermediate_output, attn_output)
|
||||
return output
|
||||
@ -198,8 +166,6 @@ class BertAttention(nn.Module):
|
||||
layer_norm_eps: float,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
bias: bool = True,
|
||||
rotary_kwargs: Optional[dict] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
@ -208,22 +174,18 @@ class BertAttention(nn.Module):
|
||||
num_attention_heads=num_attention_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
bias=bias,
|
||||
rotary_kwargs=rotary_kwargs,
|
||||
prefix=f"{prefix}.output")
|
||||
|
||||
self.output = BertSelfOutput(hidden_size=hidden_size,
|
||||
layer_norm_eps=layer_norm_eps,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.output")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
self_output = self.self(positions, hidden_states)
|
||||
self_output = self.self(hidden_states)
|
||||
return self.output(self_output, hidden_states)
|
||||
|
||||
|
||||
@ -235,8 +197,6 @@ class BertSelfAttention(nn.Module):
|
||||
num_attention_heads: int,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
bias: bool = True,
|
||||
rotary_kwargs: Optional[dict] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
@ -261,15 +221,10 @@ class BertSelfAttention(nn.Module):
|
||||
head_size=self.head_dim,
|
||||
total_num_heads=self.total_num_heads,
|
||||
total_num_kv_heads=self.total_num_kv_heads,
|
||||
bias=bias,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj")
|
||||
|
||||
if rotary_kwargs:
|
||||
self.rotary_emb = get_rope(**rotary_kwargs)
|
||||
else:
|
||||
self.rotary_emb = None
|
||||
|
||||
self.attn = Attention(num_heads=self.num_heads,
|
||||
head_size=self.head_dim,
|
||||
scale=self.scaling,
|
||||
@ -281,15 +236,10 @@ class BertSelfAttention(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
if self.rotary_emb:
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
|
||||
output = self.attn(q, k, v)
|
||||
return output
|
||||
|
||||
@ -299,13 +249,12 @@ class BertSelfOutput(nn.Module):
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
layer_norm_eps: float,
|
||||
bias: bool = True,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.dense = RowParallelLinear(input_size=hidden_size,
|
||||
output_size=hidden_size,
|
||||
bias=bias,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense")
|
||||
self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
@ -323,13 +272,12 @@ class BertIntermediate(nn.Module):
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
bias: bool = True,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.dense = ColumnParallelLinear(input_size=hidden_size,
|
||||
output_size=intermediate_size,
|
||||
bias=bias,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense")
|
||||
self.intermediate_act_fn = get_act_fn(hidden_act)
|
||||
@ -340,46 +288,19 @@ class BertIntermediate(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertGatedIntermediate(nn.Module):
|
||||
# for NomciBert and GteModel
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
bias: bool = True,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.act_fn = get_act_and_mul_fn(hidden_act)
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
gate_up, _ = self.gate_up_proj(hidden_states)
|
||||
hidden_states = self.act_fn(gate_up)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertOutput(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
layer_norm_eps: float,
|
||||
bias: bool = True,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
self.dense = RowParallelLinear(input_size=intermediate_size,
|
||||
output_size=hidden_size,
|
||||
bias=bias,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense")
|
||||
|
||||
@ -393,33 +314,18 @@ class BertOutput(nn.Module):
|
||||
|
||||
|
||||
class BertModel(nn.Module, SupportsQuant):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["query", "key", "value"],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]}
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
embedding_class: type = BertEmbedding,
|
||||
bias: bool = True,
|
||||
rotary_kwargs: Optional[dict] = None,
|
||||
add_pooling_layer: bool = False):
|
||||
super().__init__()
|
||||
"""
|
||||
For BertModel, all linear layers have bias.
|
||||
For NomicBertModel, all linear layers do not have bias.
|
||||
"""
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.embeddings = embedding_class(config)
|
||||
self.encoder = BertEncoder(vllm_config=vllm_config,
|
||||
bias=bias,
|
||||
rotary_kwargs=rotary_kwargs,
|
||||
prefix=f"{prefix}.encoder")
|
||||
self.pooler = BertPooler(config) if add_pooling_layer else None
|
||||
|
||||
@ -441,7 +347,7 @@ class BertModel(nn.Module, SupportsQuant):
|
||||
seq_lens=attn_metadata.seq_lens_tensor,
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids)
|
||||
return self.encoder(position_ids, hidden_states)
|
||||
return self.encoder(hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
@ -450,8 +356,6 @@ class BertModel(nn.Module, SupportsQuant):
|
||||
("qkv_proj", "query", "q"),
|
||||
("qkv_proj", "key", "k"),
|
||||
("qkv_proj", "value", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
@ -497,7 +401,6 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
self.config = vllm_config.model_config.hf_config
|
||||
self.model = self._build_model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self._pooler = self._build_pooler(pooler_config)
|
||||
@ -611,115 +514,3 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
token_type_ids=token_type_ids)
|
||||
|
||||
|
||||
class NomicBertEmbeddingModel(BertEmbeddingModel):
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_substr={
|
||||
"emb_ln": "embeddings.LayerNorm",
|
||||
"layers": "layer",
|
||||
"attn.Wqkv": "attention.self.qkv_proj",
|
||||
"attn.out_proj": "attention.output.dense",
|
||||
'norm1': "attention.output.LayerNorm",
|
||||
'mlp.fc11': "intermediate.up_proj",
|
||||
'mlp.fc12': "intermediate.gate_proj",
|
||||
'mlp.fc2': "output.dense",
|
||||
'norm2': "output.LayerNorm",
|
||||
})
|
||||
|
||||
def _build_model(self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "") -> BertModel:
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
assert config.__class__.__name__ == "NomicBertConfig"
|
||||
assert config.activation_function == "swiglu"
|
||||
|
||||
# Assume NomicBertModel all linear layers do not have bias
|
||||
assert not config.mlp_fc1_bias
|
||||
assert not config.mlp_fc2_bias
|
||||
assert not config.qkv_proj_bias
|
||||
|
||||
config.layer_norm_eps = config.layer_norm_epsilon
|
||||
config.position_embedding_type = "rotary"
|
||||
config.intermediate_size = config.n_inner
|
||||
config.hidden_act = "silu"
|
||||
config.hidden_size = config.n_embd
|
||||
config.num_hidden_layers = config.n_layer
|
||||
|
||||
head_dim = config.hidden_size // config.num_attention_heads
|
||||
rotary_kwargs = {
|
||||
"head_size": head_dim,
|
||||
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
|
||||
"max_position": config.max_trained_positions,
|
||||
"base": config.rotary_emb_base,
|
||||
"rope_scaling": {
|
||||
"rope_type": "dynamic",
|
||||
"factor": config.rotary_scaling_factor
|
||||
}
|
||||
}
|
||||
|
||||
return BertModel(vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
bias=False,
|
||||
rotary_kwargs=rotary_kwargs,
|
||||
embedding_class=BertEmbedding)
|
||||
|
||||
|
||||
class GteEmbeddingModel(BertEmbeddingModel):
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_substr={
|
||||
"attention.qkv_proj": "attention.self.qkv_proj",
|
||||
"attention.o_proj": "attention.output.dense",
|
||||
'attn_ln': "attention.output.LayerNorm",
|
||||
'mlp.down_proj': "output.dense",
|
||||
'mlp_ln': "output.LayerNorm",
|
||||
})
|
||||
|
||||
def _build_model(self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "") -> BertModel:
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
assert config.__class__.__name__ == "GteConfig"
|
||||
assert config.position_embedding_type == "rope"
|
||||
assert config.hidden_act == "gelu"
|
||||
|
||||
config.position_embedding_type = "rotary"
|
||||
config.hidden_act = "gelu_and_mul"
|
||||
|
||||
head_dim = config.hidden_size // config.num_attention_heads
|
||||
rotary_kwargs = {
|
||||
"head_size": head_dim,
|
||||
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
|
||||
"max_position": config.max_position_embeddings,
|
||||
"base": config.rope_theta,
|
||||
}
|
||||
|
||||
model = BertModel(vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
rotary_kwargs=rotary_kwargs,
|
||||
embedding_class=BertEmbedding)
|
||||
|
||||
# GteModel only gate_up_proj does not have bias.
|
||||
# Hack method learned from vllm/model_executor/models/glm.py
|
||||
for layer in model.encoder.layer:
|
||||
layer.intermediate.gate_up_proj.bias = None
|
||||
layer.intermediate.skip_bias_add = True
|
||||
return model
|
||||
|
||||
def split_up_gate_proj(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
n = "mlp.up_gate_proj"
|
||||
for name, weight in weights:
|
||||
if n in name:
|
||||
up, gate = weight.chunk(2, dim=0)
|
||||
yield name.replace(n, "intermediate.up_proj"), up
|
||||
yield name.replace(n, "intermediate.gate_proj"), gate
|
||||
else:
|
||||
yield name, weight
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
weights = self.hf_to_vllm_mapper.apply(weights)
|
||||
weights = self.split_up_gate_proj(weights)
|
||||
self.model.load_weights(weights)
|
||||
|
||||
652
vllm/model_executor/models/bert_with_rope.py
Normal file
652
vllm/model_executor/models/bert_with_rope.py
Normal file
@ -0,0 +1,652 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Iterable, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import (get_act_and_mul_fn,
|
||||
get_act_fn)
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models import SupportsV0Only
|
||||
from vllm.model_executor.models.interfaces import SupportsQuant
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
|
||||
class BertWithRopeEmbedding(nn.Module):
|
||||
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
|
||||
super().__init__()
|
||||
assert config.type_vocab_size > 0
|
||||
self.word_embeddings = VocabParallelEmbedding(config.vocab_size,
|
||||
config.hidden_size)
|
||||
self.token_type_embeddings = VocabParallelEmbedding(
|
||||
config.type_vocab_size, config.hidden_size)
|
||||
self.LayerNorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
input_shape = input_ids.size()
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros(input_shape,
|
||||
dtype=torch.long,
|
||||
device=inputs_embeds.device)
|
||||
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
embeddings = inputs_embeds + token_type_embeddings
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
return embeddings
|
||||
|
||||
|
||||
class BertWithRopeAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
bias: bool = True,
|
||||
rotary_kwargs: Optional[dict] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
self.total_num_heads = num_attention_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = self.total_num_heads
|
||||
self.head_dim = self.hidden_size // self.total_num_heads
|
||||
assert self.head_dim * self.total_num_heads == self.hidden_size
|
||||
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size=self.hidden_size,
|
||||
head_size=self.head_dim,
|
||||
total_num_heads=self.total_num_heads,
|
||||
total_num_kv_heads=self.total_num_kv_heads,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj")
|
||||
|
||||
self.rotary_emb = get_rope(**rotary_kwargs)
|
||||
|
||||
self.attn = Attention(num_heads=self.num_heads,
|
||||
head_size=self.head_dim,
|
||||
scale=self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
attn_type=AttentionType.ENCODER_ONLY)
|
||||
|
||||
self.out_proj = RowParallelLinear(input_size=hidden_size,
|
||||
output_size=hidden_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.out_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class BertWithRopeGatedMLP(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
bias: bool = True,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.act_fn = get_act_and_mul_fn(hidden_act)
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
self.down_proj = RowParallelLinear(input_size=intermediate_size,
|
||||
output_size=hidden_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj")
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
gate_up, _ = self.gate_up_proj(hidden_states)
|
||||
hidden_states = self.act_fn(gate_up)
|
||||
hidden_states, _ = self.down_proj(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertWithRopeMLP(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
bias: bool = True,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.act_fn = get_act_fn(hidden_act)
|
||||
self.up_proj = ColumnParallelLinear(input_size=hidden_size,
|
||||
output_size=intermediate_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.up_proj")
|
||||
self.down_proj = RowParallelLinear(input_size=intermediate_size,
|
||||
output_size=hidden_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj")
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, _ = self.up_proj(hidden_states)
|
||||
hidden_states = self.act_fn(hidden_states)
|
||||
hidden_states, _ = self.down_proj(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class NomicRouter(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size: int, moe_num_experts: int, moe_top_k: int):
|
||||
super().__init__()
|
||||
self.moe_top_k = moe_top_k
|
||||
self.layer = ReplicatedLinear(hidden_size, moe_num_experts, bias=False)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]:
|
||||
weights = self.layer(x.view(-1, x.shape[-1]))[0].softmax(
|
||||
dim=-1, dtype=torch.float32)
|
||||
top_weights, top_experts = torch.topk(weights, self.moe_top_k, dim=-1)
|
||||
weights = weights.to(x.dtype)
|
||||
top_weights = top_weights.to(x.dtype)
|
||||
return weights, top_weights, top_experts # type: ignore
|
||||
|
||||
|
||||
class NomicExpertMLP(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size: int, ffn_hidden_size: int,
|
||||
moe_num_experts: int, ffn_act_fn: str):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.ffn_hidden_size = ffn_hidden_size
|
||||
self.moe_num_experts = moe_num_experts
|
||||
|
||||
self.w1 = nn.Parameter(
|
||||
torch.empty(moe_num_experts * ffn_hidden_size, hidden_size))
|
||||
self.w2 = nn.Parameter(
|
||||
torch.empty(moe_num_experts * ffn_hidden_size, hidden_size))
|
||||
self.activation_fn = get_act_fn(ffn_act_fn)
|
||||
|
||||
def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor:
|
||||
expert_w1 = self.w1.view(self.moe_num_experts, self.ffn_hidden_size,
|
||||
self.hidden_size)[expert_idx]
|
||||
expert_w2 = self.w2.view(self.moe_num_experts, self.ffn_hidden_size,
|
||||
self.hidden_size)[expert_idx]
|
||||
|
||||
x1 = x.matmul(expert_w1.t())
|
||||
act_out = self.activation_fn(x1)
|
||||
x2 = act_out.matmul(expert_w2)
|
||||
return x2
|
||||
|
||||
|
||||
class NomicExperts(nn.Module):
|
||||
|
||||
def __init__(self, config, hidden_size: int, ffn_hidden_size: int,
|
||||
moe_num_experts: int):
|
||||
super().__init__()
|
||||
self.moe_num_experts = moe_num_experts
|
||||
|
||||
self.mlp = NomicExpertMLP(hidden_size=config.n_embd,
|
||||
ffn_hidden_size=config.n_inner,
|
||||
moe_num_experts=moe_num_experts,
|
||||
ffn_act_fn=config.hidden_act)
|
||||
self.bias = nn.Parameter(torch.zeros(config.n_embd))
|
||||
|
||||
def forward(self, x: torch.Tensor, weights: torch.Tensor,
|
||||
top_weights: torch.Tensor,
|
||||
top_experts: torch.LongTensor) -> torch.Tensor:
|
||||
q_len, hidden_size = x.shape
|
||||
x = x.view(-1, hidden_size)
|
||||
out = torch.zeros_like(x)
|
||||
|
||||
expert_mask = nn.functional.one_hot(
|
||||
top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0)
|
||||
for expert_idx in range(0, self.moe_num_experts):
|
||||
topk_idx, token_idx = torch.where(expert_mask[expert_idx])
|
||||
if token_idx.shape[0] == 0:
|
||||
continue
|
||||
|
||||
token_list = token_idx.tolist()
|
||||
topk_list = topk_idx.tolist()
|
||||
|
||||
expert_tokens = x[None, token_list].reshape(-1, hidden_size)
|
||||
expert_out = self.mlp(
|
||||
expert_tokens, expert_idx) * top_weights[token_list, topk_list,
|
||||
None]
|
||||
|
||||
out.index_add_(0, token_idx, expert_out)
|
||||
|
||||
out = out.reshape(q_len, hidden_size)
|
||||
return out + self.bias
|
||||
|
||||
|
||||
class NomicMoELayer(nn.Module):
|
||||
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__()
|
||||
|
||||
self.router = NomicRouter(
|
||||
config.n_embd,
|
||||
moe_num_experts=config.num_experts,
|
||||
moe_top_k=config.moe_top_k,
|
||||
)
|
||||
|
||||
self.experts = NomicExperts(
|
||||
config,
|
||||
hidden_size=config.n_embd,
|
||||
ffn_hidden_size=config.n_inner,
|
||||
moe_num_experts=config.num_experts,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
weights, top_weights, top_experts = self.router(x)
|
||||
out = self.experts(x, weights, top_weights, top_experts)
|
||||
return out
|
||||
|
||||
|
||||
class BertWithRopeBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
moe: bool = False,
|
||||
bias: bool = True,
|
||||
rotary_kwargs: Optional[dict] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.attn = BertWithRopeAttention(
|
||||
hidden_size=config.hidden_size,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
bias=bias,
|
||||
rotary_kwargs=rotary_kwargs,
|
||||
prefix=f"{prefix}.attention")
|
||||
|
||||
if moe:
|
||||
self.mlp = NomicMoELayer(config=config, )
|
||||
else:
|
||||
if config.hidden_act in ["silu", "gelu_and_mul"]:
|
||||
self.mlp = BertWithRopeGatedMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
else:
|
||||
self.mlp = BertWithRopeMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
|
||||
self.attn_ln = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.mlp_ln = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
def forward(self, positions: torch.Tensor, hidden_states: torch.Tensor):
|
||||
attn_output = self.attn(positions, hidden_states)
|
||||
hidden_states = self.attn_ln(hidden_states + attn_output)
|
||||
mlp_out = self.mlp(hidden_states)
|
||||
hidden_states = self.mlp_ln(hidden_states + mlp_out)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class BertWithRopeEncoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
bias: bool = True,
|
||||
rotary_kwargs: Optional[dict] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
every_n = getattr(config, "moe_every_n_layers", 0)
|
||||
self.layers = nn.ModuleList([
|
||||
BertWithRopeBlock(config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
bias=bias,
|
||||
moe=every_n > 0 and (layer_idx % every_n == 1),
|
||||
rotary_kwargs=rotary_kwargs,
|
||||
prefix=f"{prefix}.layer.{layer_idx}")
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
for layer in self.layers:
|
||||
hidden_states = layer(positions, hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant):
|
||||
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
self.config = self.config_verify(vllm_config)
|
||||
self.embeddings = BertWithRopeEmbedding(self.config)
|
||||
self.encoder = BertWithRopeEncoder(
|
||||
vllm_config=vllm_config,
|
||||
bias=getattr(self.config, "bias", True),
|
||||
rotary_kwargs=self.config.rotary_kwargs,
|
||||
prefix=f"{prefix}.encoder")
|
||||
|
||||
def config_verify(self, vllm_config):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.embeddings(input_ids=input_ids,
|
||||
token_type_ids=token_type_ids)
|
||||
return self.encoder(positions, hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
weights = self.hf_to_vllm_mapper.apply(weights)
|
||||
|
||||
if self.config.hidden_act in ["silu", "gelu_and_mul"]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
else:
|
||||
stacked_params_mapping = []
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "pooler" in name:
|
||||
continue
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class NomicBertModel(BertWithRope):
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_substr={
|
||||
"emb_ln": "embeddings.LayerNorm",
|
||||
"attn.Wqkv": "attn.qkv_proj",
|
||||
"norm1": "attn_ln",
|
||||
"mlp.fc1.": "mlp.up_proj.",
|
||||
"mlp.fc11": "mlp.up_proj",
|
||||
"mlp.fc12": "mlp.gate_proj",
|
||||
"mlp.fc2": "mlp.down_proj",
|
||||
"norm2": "mlp_ln",
|
||||
})
|
||||
|
||||
def config_verify(self, vllm_config):
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
assert config.__class__.__name__ == "NomicBertConfig"
|
||||
assert config.activation_function in ["swiglu", "gelu"]
|
||||
|
||||
if config.activation_function == "swiglu":
|
||||
config.hidden_act = "silu"
|
||||
else:
|
||||
config.hidden_act = config.activation_function
|
||||
|
||||
assert (config.mlp_fc1_bias == config.mlp_fc2_bias ==
|
||||
config.qkv_proj_bias)
|
||||
config.bias = config.qkv_proj_bias
|
||||
|
||||
assert config.rotary_emb_scale_base is None
|
||||
assert not config.rotary_emb_interleaved
|
||||
|
||||
config.layer_norm_eps = config.layer_norm_epsilon
|
||||
config.intermediate_size = config.n_inner
|
||||
config.hidden_size = config.n_embd
|
||||
config.num_hidden_layers = config.n_layer
|
||||
|
||||
head_dim = config.hidden_size // config.num_attention_heads
|
||||
rotary_emb_dim = head_dim * config.rotary_emb_fraction
|
||||
config.rotary_kwargs = {
|
||||
"head_size": head_dim,
|
||||
"rotary_dim": rotary_emb_dim,
|
||||
"max_position": config.max_trained_positions,
|
||||
"base": getattr(config, "rope_theta", config.rotary_emb_base),
|
||||
"rope_scaling": getattr(config, "rope_scaling", None)
|
||||
}
|
||||
|
||||
# we ignore config.rotary_scaling_factor so that for datasets shorter
|
||||
# than max_trained_positions 2048, the results are consistent
|
||||
# with SentenceTransformer.
|
||||
# The context extension uses vllm style rope_theta and rope_scaling.
|
||||
# See #17785
|
||||
|
||||
return config
|
||||
|
||||
|
||||
class GteModel(BertWithRope):
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_substr={
|
||||
"layer": 'layers',
|
||||
"attention.qkv_proj": "attn.qkv_proj",
|
||||
"attention.o_proj": "attn.out_proj",
|
||||
})
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
# GteModel only gate_up_proj does not have bias.
|
||||
# Hack method learned from vllm/model_executor/models/glm.py
|
||||
for layer in self.encoder.layers:
|
||||
layer.mlp.gate_up_proj.bias = None
|
||||
layer.mlp.gate_up_proj.skip_bias_add = True
|
||||
|
||||
def config_verify(self, vllm_config):
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
assert config.__class__.__name__ == "GteConfig"
|
||||
assert config.position_embedding_type == "rope"
|
||||
assert config.hidden_act == "gelu"
|
||||
|
||||
config.position_embedding_type = "rotary"
|
||||
config.hidden_act = "gelu_and_mul"
|
||||
|
||||
head_dim = config.hidden_size // config.num_attention_heads
|
||||
config.rotary_kwargs = {
|
||||
"head_size": head_dim,
|
||||
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
|
||||
"max_position": config.max_position_embeddings,
|
||||
"base": config.rope_theta,
|
||||
"rope_scaling": getattr(config, "rope_scaling", None)
|
||||
}
|
||||
return config
|
||||
|
||||
def split_up_gate_proj(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
n = "mlp.up_gate_proj"
|
||||
for name, weight in weights:
|
||||
if n in name:
|
||||
up, gate = weight.chunk(2, dim=0)
|
||||
yield name.replace(n, "mlp.up_proj"), up
|
||||
yield name.replace(n, "mlp.gate_proj"), gate
|
||||
else:
|
||||
yield name, weight
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
weights = self.split_up_gate_proj(weights)
|
||||
return super().load_weights(weights)
|
||||
|
||||
|
||||
class JinaRobertaModel(BertWithRope):
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_substr={
|
||||
"emb_ln": "embeddings.LayerNorm",
|
||||
"mixer.Wqkv": "attn.qkv_proj",
|
||||
"mixer.out_proj": "attn.out_proj",
|
||||
"norm1": "attn_ln",
|
||||
"mlp.fc1.": "mlp.up_proj.",
|
||||
"mlp.fc2": "mlp.down_proj",
|
||||
"norm2": "mlp_ln",
|
||||
})
|
||||
|
||||
def config_verify(self, vllm_config):
|
||||
config = vllm_config.model_config.hf_config
|
||||
head_dim = config.hidden_size // config.num_attention_heads
|
||||
config.rotary_kwargs = {
|
||||
"head_size": head_dim,
|
||||
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
|
||||
"max_position": config.max_position_embeddings,
|
||||
"base": getattr(config, "rope_theta", config.rotary_emb_base),
|
||||
"rope_scaling": getattr(config, "rope_scaling", None)
|
||||
}
|
||||
return config
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return super().forward(input_ids=input_ids,
|
||||
positions=position_ids,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
token_type_ids=token_type_ids)
|
||||
|
||||
@torch.inference_mode()
|
||||
def jina_merge_lora_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]):
|
||||
# use for jina-embeddings-v3
|
||||
# Merge Lora weights into a single weight tensor.
|
||||
# This is a temporary solution until we have a better way to handle
|
||||
|
||||
scaling = self.config.lora_alpha / self.config.lora_rank
|
||||
|
||||
weights = {name: weight for name, weight in weights}
|
||||
|
||||
o = ".original"
|
||||
a = ".0.lora_A"
|
||||
b = ".0.lora_B"
|
||||
|
||||
# text-matching
|
||||
i = -1
|
||||
|
||||
for name in list(weights.keys()):
|
||||
if o in name:
|
||||
dtype = weights[name].dtype
|
||||
shape = weights[name].shape
|
||||
weight_name = name[:-len(o)]
|
||||
|
||||
if "embeddings" in weight_name:
|
||||
B = weights[weight_name + a][i].cuda().float()
|
||||
A = weights[weight_name + b][i].cuda().float()
|
||||
else:
|
||||
B = weights[weight_name + b][i].cuda().float()
|
||||
A = weights[weight_name + a][i].cuda().float()
|
||||
|
||||
weight = (weights[weight_name + o].cuda() +
|
||||
torch.matmul(B, A).view(shape) * scaling)
|
||||
weight = weight.cpu().to(dtype)
|
||||
|
||||
weights[weight_name.replace(".parametrizations", "")] = weight
|
||||
|
||||
del weights[weight_name + o], weights[weight_name +
|
||||
a], weights[weight_name +
|
||||
b]
|
||||
|
||||
return [(name, weight) for name, weight in weights.items()]
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
weights = self.jina_merge_lora_weights(weights)
|
||||
return super().load_weights(weights)
|
||||
@ -126,7 +126,7 @@ _EMBEDDING_MODELS = {
|
||||
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
|
||||
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
|
||||
"GritLM": ("gritlm", "GritLM"),
|
||||
"GteModel": ("bert", "GteEmbeddingModel"),
|
||||
"GteModel": ("bert_with_rope", "GteModel"),
|
||||
"InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
|
||||
"JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"), # noqa: E501
|
||||
"LlamaModel": ("llama", "LlamaForCausalLM"),
|
||||
@ -136,7 +136,7 @@ _EMBEDDING_MODELS = {
|
||||
if arch == "LlamaForCausalLM"
|
||||
},
|
||||
"MistralModel": ("llama", "LlamaForCausalLM"),
|
||||
"NomicBertModel": ("bert", "NomicBertEmbeddingModel"),
|
||||
"NomicBertModel": ("bert_with_rope", "NomicBertModel"),
|
||||
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
|
||||
"Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"),
|
||||
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import itertools
|
||||
from typing import Iterable, Optional, Tuple
|
||||
from typing import Iterable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -19,6 +19,7 @@ from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
from vllm.transformers_utils.config import (
|
||||
get_cross_encoder_activation_function)
|
||||
|
||||
from .bert_with_rope import BertWithRope, JinaRobertaModel
|
||||
from .interfaces import SupportsCrossEncoding, SupportsV0Only
|
||||
|
||||
|
||||
@ -125,39 +126,20 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
|
||||
|
||||
def _build_model(self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "") -> BertModel:
|
||||
prefix: str = "") -> Union[BertModel, BertWithRope]:
|
||||
if (vllm_config.model_config.hf_config.position_embedding_type ==
|
||||
"rotary"):
|
||||
config = vllm_config.model_config.hf_config
|
||||
head_dim = config.hidden_size // config.num_attention_heads
|
||||
|
||||
rotary_kwargs = {
|
||||
"head_size": head_dim,
|
||||
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
|
||||
"max_position": config.max_position_embeddings,
|
||||
"base": config.rotary_emb_base,
|
||||
"rope_scaling": getattr(config, "rope_scaling", None)
|
||||
}
|
||||
|
||||
return BertModel(vllm_config=vllm_config,
|
||||
rotary_kwargs=rotary_kwargs,
|
||||
prefix=prefix)
|
||||
return JinaRobertaModel(vllm_config=vllm_config, prefix=prefix)
|
||||
else:
|
||||
return BertModel(vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
embedding_class=RobertaEmbedding)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
if getattr(self.config, "lora_rank", 0) > 0:
|
||||
scaling = self.config.lora_alpha / self.config.lora_rank
|
||||
weights = jina_merge_lora_weights(weights, scaling)
|
||||
|
||||
weights = self.hf_to_vllm_mapper.apply(weights)
|
||||
# Separate weights in "roberta"-prefixed and all else (not in memory).
|
||||
# For use with models like FacebookAI/roberta-base.
|
||||
bert_weights, task_weights = roberta_task_weights_filter(weights)
|
||||
bert_weights = jina_to_vllm_mapper.apply(bert_weights)
|
||||
|
||||
loaded = self.model.load_weights(bert_weights)
|
||||
if not len(loaded):
|
||||
# Fix for models like `sentence-transformers/stsb-roberta-base-v2`
|
||||
@ -178,6 +160,18 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
||||
_pooler: An instance of Pooler used for pooling operations.
|
||||
"""
|
||||
|
||||
jina_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_substr={
|
||||
'emb_ln': "embeddings.LayerNorm",
|
||||
'layers': "layer",
|
||||
'mixer.Wqkv': "attention.self.qkv_proj",
|
||||
'mixer.out_proj': "attention.output.dense",
|
||||
'norm1': "attention.output.LayerNorm",
|
||||
'mlp.fc1': "intermediate.dense",
|
||||
'mlp.fc2': "output.dense",
|
||||
'norm2': "output.LayerNorm",
|
||||
})
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
@ -195,7 +189,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
bert_weights, task_weights = roberta_task_weights_filter(weights)
|
||||
bert_weights = jina_to_vllm_mapper.apply(bert_weights)
|
||||
bert_weights = self.jina_to_vllm_mapper.apply(bert_weights)
|
||||
|
||||
self.roberta.load_weights(bert_weights)
|
||||
|
||||
@ -276,57 +270,3 @@ def roberta_task_weights_filter(
|
||||
|
||||
return encoder_decoder_weights(), ((n, w) for n, w in all_weights2
|
||||
if not n.startswith("roberta."))
|
||||
|
||||
|
||||
jina_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_substr={
|
||||
'emb_ln': "embeddings.LayerNorm",
|
||||
'layers': "layer",
|
||||
'mixer.Wqkv': "attention.self.qkv_proj",
|
||||
'mixer.out_proj': "attention.output.dense",
|
||||
'norm1': "attention.output.LayerNorm",
|
||||
'mlp.fc1': "intermediate.dense",
|
||||
'mlp.fc2': "output.dense",
|
||||
'norm2': "output.LayerNorm",
|
||||
})
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def jina_merge_lora_weights(weights: Iterable[Tuple[str, torch.Tensor]],
|
||||
scaling: float = 1.0):
|
||||
# use for jina-embeddings-v3
|
||||
# Merge Lora weights into a single weight tensor.
|
||||
# This is a temporary solution until we have a better way to handle
|
||||
|
||||
weights = {name: weight for name, weight in weights}
|
||||
|
||||
o = ".original"
|
||||
a = ".0.lora_A"
|
||||
b = ".0.lora_B"
|
||||
|
||||
# text-matching
|
||||
i = -1
|
||||
|
||||
for name in list(weights.keys()):
|
||||
if o in name:
|
||||
dtype = weights[name].dtype
|
||||
shape = weights[name].shape
|
||||
weight_name = name[:-len(o)]
|
||||
|
||||
if "embeddings" in weight_name:
|
||||
B = weights[weight_name + a][i].cuda().float()
|
||||
A = weights[weight_name + b][i].cuda().float()
|
||||
else:
|
||||
B = weights[weight_name + b][i].cuda().float()
|
||||
A = weights[weight_name + a][i].cuda().float()
|
||||
|
||||
weight = (weights[weight_name + o].cuda() +
|
||||
torch.matmul(B, A).view(shape) * scaling)
|
||||
weight = weight.cpu().to(dtype)
|
||||
|
||||
weights[weight_name.replace(".parametrizations", "")] = weight
|
||||
|
||||
del weights[weight_name + o], weights[weight_name +
|
||||
a], weights[weight_name + b]
|
||||
|
||||
return [(name, weight) for name, weight in weights.items()]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user