mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:16:06 +08:00
[New Model]: jinaai/jina-embeddings-v3 (#16120)
This commit is contained in:
parent
90cb44eb02
commit
1f5d13ab9f
50
examples/offline_inference/embed_jina_embeddings_v3.py
Normal file
50
examples/offline_inference/embed_jina_embeddings_v3.py
Normal file
@ -0,0 +1,50 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from argparse import Namespace
|
||||
|
||||
from vllm import LLM, EngineArgs
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Follow the white rabbit.", # English
|
||||
"Sigue al conejo blanco.", # Spanish
|
||||
"Suis le lapin blanc.", # French
|
||||
"跟着白兔走。", # Chinese
|
||||
"اتبع الأرنب الأبيض.", # Arabic
|
||||
"Folge dem weißen Kaninchen.", # German
|
||||
]
|
||||
|
||||
# Create an LLM.
|
||||
# You should pass task="embed" for embedding models
|
||||
model = LLM(**vars(args))
|
||||
|
||||
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
|
||||
# Only text matching task is supported for now. See #16120
|
||||
outputs = model.embed(prompts)
|
||||
|
||||
# Print the outputs.
|
||||
print("\nGenerated Outputs:")
|
||||
print("Only text matching task is supported for now. See #16120")
|
||||
print("-" * 60)
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
embeds = output.outputs.embedding
|
||||
embeds_trimmed = ((str(embeds[:16])[:-1] +
|
||||
", ...]") if len(embeds) > 16 else embeds)
|
||||
print(f"Prompt: {prompt!r} \n"
|
||||
f"Embeddings for text matching: {embeds_trimmed} "
|
||||
f"(size={len(embeds)})")
|
||||
print("-" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser()
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# Set example specific arguments
|
||||
parser.set_defaults(model="jinaai/jina-embeddings-v3",
|
||||
task="embed",
|
||||
trust_remote_code=True)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@ -671,8 +671,9 @@ class HfRunner:
|
||||
return [(output_ids, output_str, output_logprobs)
|
||||
for output_ids, output_str, output_logprobs in outputs]
|
||||
|
||||
def encode(self, prompts: list[str]) -> list[list[torch.Tensor]]:
|
||||
return self.model.encode(prompts)
|
||||
def encode(self, prompts: list[str], *args,
|
||||
**kwargs) -> list[list[torch.Tensor]]:
|
||||
return self.model.encode(prompts, *args, **kwargs)
|
||||
|
||||
def predict(self, prompts: list[list[str]]) -> torch.Tensor:
|
||||
return self.model.predict(prompts, convert_to_tensor=True)
|
||||
|
||||
@ -2,13 +2,15 @@
|
||||
# ruff: noqa: E501
|
||||
"""Compare the scoring outputs of HF and vLLM models.
|
||||
|
||||
Run `pytest tests/models/embedding/language/test_jina_reranker_v2.py`.
|
||||
Run `pytest tests/models/embedding/language/test_jina.py`.
|
||||
"""
|
||||
import math
|
||||
|
||||
import pytest
|
||||
|
||||
MODELS = [
|
||||
from tests.models.embedding.utils import check_embeddings_close
|
||||
|
||||
SCORING_MODELS = [
|
||||
"jinaai/jina-reranker-v2-base-multilingual", # Roberta
|
||||
]
|
||||
|
||||
@ -27,8 +29,21 @@ TEXTS_2 = [
|
||||
"新しいメイクのトレンドは鮮やかな色と革新的な技術に焦点を当てています",
|
||||
]
|
||||
|
||||
EMBEDDING_MODELS = [
|
||||
"jinaai/jina-embeddings-v3",
|
||||
]
|
||||
|
||||
@pytest.fixture(scope="module", params=MODELS)
|
||||
EMBEDDING_PROMPTS = [
|
||||
"Follow the white rabbit.", # English
|
||||
"Sigue al conejo blanco.", # Spanish
|
||||
"Suis le lapin blanc.", # French
|
||||
"跟着白兔走。", # Chinese
|
||||
"اتبع الأرنب الأبيض.", # Arabic
|
||||
"Folge dem weißen Kaninchen.", # German
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=SCORING_MODELS)
|
||||
def model_name(request):
|
||||
yield request.param
|
||||
|
||||
@ -68,3 +83,46 @@ def test_llm_1_to_N(vllm_runner, hf_runner, model_name, dtype: str):
|
||||
|
||||
assert math.isclose(hf_outputs[0], vllm_outputs[0], rel_tol=0.01)
|
||||
assert math.isclose(hf_outputs[1], vllm_outputs[1], rel_tol=0.01)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=EMBEDDING_MODELS)
|
||||
def emb_model_name(request):
|
||||
yield request.param
|
||||
|
||||
|
||||
def test_is_matryoshka(vllm_runner, emb_model_name):
|
||||
with vllm_runner(emb_model_name, task="embed",
|
||||
max_model_len=None) as vllm_model:
|
||||
assert vllm_model.model.llm_engine.model_config.is_matryoshka
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", EMBEDDING_MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
def test_embeddings(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
model,
|
||||
dtype: str,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
|
||||
example_prompts = EMBEDDING_PROMPTS
|
||||
|
||||
with hf_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
is_sentence_transformer=True,
|
||||
) as hf_model:
|
||||
hf_outputs = hf_model.encode(example_prompts, task="text-matching")
|
||||
|
||||
with vllm_runner(model, task="embed", dtype=dtype,
|
||||
max_model_len=None) as vllm_model:
|
||||
vllm_outputs = vllm_model.encode(example_prompts)
|
||||
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=hf_outputs,
|
||||
embeddings_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
tol=1e-2,
|
||||
)
|
||||
@ -1130,6 +1130,11 @@ class ModelConfig:
|
||||
architectures = getattr(self.hf_config, "architectures", [])
|
||||
return ModelRegistry.is_v1_compatible(architectures)
|
||||
|
||||
@property
|
||||
def is_matryoshka(self) -> bool:
|
||||
return (hasattr(self.hf_config, "matryoshka_dimensions")
|
||||
or getattr(self.hf_config, "is_matryoshka", False))
|
||||
|
||||
|
||||
class CacheConfig:
|
||||
"""Configuration for the KV cache.
|
||||
|
||||
@ -18,6 +18,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
from vllm.model_executor.layers.pooler import (CrossEncodingPooler, Pooler,
|
||||
PoolingType)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
@ -38,19 +39,24 @@ class BertEmbedding(nn.Module):
|
||||
self.size = config.hidden_size
|
||||
self.word_embeddings = VocabParallelEmbedding(config.vocab_size,
|
||||
config.hidden_size)
|
||||
self.position_embeddings = VocabParallelEmbedding(
|
||||
config.max_position_embeddings, config.hidden_size)
|
||||
|
||||
self.token_type_embeddings = VocabParallelEmbedding(
|
||||
config.type_vocab_size, config.hidden_size)
|
||||
self.LayerNorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.position_ids = nn.Parameter(
|
||||
torch.empty((1, config.max_position_embeddings)), )
|
||||
|
||||
self.position_embedding_type = config.position_embedding_type
|
||||
if self.position_embedding_type != "absolute":
|
||||
raise ValueError("Only 'absolute' position_embedding_type" +
|
||||
" is supported")
|
||||
if self.position_embedding_type == "absolute":
|
||||
self.position_embeddings = VocabParallelEmbedding(
|
||||
config.max_position_embeddings, config.hidden_size)
|
||||
self.position_ids = nn.Parameter(
|
||||
torch.empty((1, config.max_position_embeddings)), )
|
||||
elif self.position_embedding_type == "rotary":
|
||||
self.position_embeddings = None
|
||||
self.position_ids = None
|
||||
else:
|
||||
raise ValueError("Only 'absolute' and 'rotary' " +
|
||||
"position_embedding_type is supported")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -64,9 +70,6 @@ class BertEmbedding(nn.Module):
|
||||
# Input embeddings.
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
|
||||
# Position embeddings.
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros(input_shape,
|
||||
dtype=torch.long,
|
||||
@ -74,7 +77,12 @@ class BertEmbedding(nn.Module):
|
||||
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
|
||||
embeddings = inputs_embeds + token_type_embeddings + position_embeddings
|
||||
embeddings = inputs_embeds + token_type_embeddings
|
||||
|
||||
if self.position_embedding_type == "absolute":
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
embeddings += position_embeddings
|
||||
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
return embeddings
|
||||
|
||||
@ -98,7 +106,10 @@ class BertPooler(nn.Module):
|
||||
@support_torch_compile
|
||||
class BertEncoder(nn.Module):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
rotary_kwargs: Optional[dict] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
@ -107,16 +118,18 @@ class BertEncoder(nn.Module):
|
||||
BertLayer(config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
rotary_kwargs=rotary_kwargs,
|
||||
prefix=f"{prefix}.layer.{layer_idx}")
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
for layer in self.layer:
|
||||
hidden_states = layer(hidden_states)
|
||||
hidden_states = layer(positions, hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@ -126,6 +139,7 @@ class BertLayer(nn.Module):
|
||||
config: BertConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
rotary_kwargs: Optional[dict] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
@ -135,6 +149,7 @@ class BertLayer(nn.Module):
|
||||
layer_norm_eps=config.layer_norm_eps,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
rotary_kwargs=rotary_kwargs,
|
||||
prefix=f"{prefix}.attention")
|
||||
|
||||
self.intermediate = BertIntermediate(
|
||||
@ -150,8 +165,8 @@ class BertLayer(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.output")
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor):
|
||||
attn_output = self.attention(hidden_states)
|
||||
def forward(self, positions: torch.Tensor, hidden_states: torch.Tensor):
|
||||
attn_output = self.attention(positions, hidden_states)
|
||||
intermediate_output = self.intermediate(attn_output)
|
||||
output = self.output(intermediate_output, attn_output)
|
||||
return output
|
||||
@ -166,6 +181,7 @@ class BertAttention(nn.Module):
|
||||
layer_norm_eps: float,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
rotary_kwargs: Optional[dict] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
@ -174,6 +190,7 @@ class BertAttention(nn.Module):
|
||||
num_attention_heads=num_attention_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
rotary_kwargs=rotary_kwargs,
|
||||
prefix=f"{prefix}.output")
|
||||
|
||||
self.output = BertSelfOutput(hidden_size=hidden_size,
|
||||
@ -183,9 +200,10 @@ class BertAttention(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
self_output = self.self(hidden_states)
|
||||
self_output = self.self(positions, hidden_states)
|
||||
return self.output(self_output, hidden_states)
|
||||
|
||||
|
||||
@ -197,6 +215,7 @@ class BertSelfAttention(nn.Module):
|
||||
num_attention_heads: int,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
rotary_kwargs: Optional[dict] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
@ -225,6 +244,11 @@ class BertSelfAttention(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj")
|
||||
|
||||
if rotary_kwargs:
|
||||
self.rotary_emb = get_rope(**rotary_kwargs)
|
||||
else:
|
||||
self.rotary_emb = None
|
||||
|
||||
self.attn = Attention(num_heads=self.num_heads,
|
||||
head_size=self.head_dim,
|
||||
scale=self.scaling,
|
||||
@ -236,10 +260,15 @@ class BertSelfAttention(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
if self.rotary_emb:
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
|
||||
output = self.attn(q, k, v)
|
||||
return output
|
||||
|
||||
@ -321,11 +350,13 @@ class BertModel(nn.Module, SupportsQuant):
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
embedding_class: type = BertEmbedding,
|
||||
rotary_kwargs: Optional[dict] = None,
|
||||
add_pooling_layer: bool = False):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.embeddings = embedding_class(config)
|
||||
self.encoder = BertEncoder(vllm_config=vllm_config,
|
||||
rotary_kwargs=rotary_kwargs,
|
||||
prefix=f"{prefix}.encoder")
|
||||
self.pooler = BertPooler(config) if add_pooling_layer else None
|
||||
|
||||
@ -347,7 +378,7 @@ class BertModel(nn.Module, SupportsQuant):
|
||||
seq_lens=attn_metadata.seq_lens_tensor,
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids)
|
||||
return self.encoder(hidden_states)
|
||||
return self.encoder(position_ids, hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
@ -401,6 +432,7 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
self.config = vllm_config.model_config.hf_config
|
||||
self.model = self._build_model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self._pooler = self._build_pooler(pooler_config)
|
||||
|
||||
@ -22,30 +22,6 @@ from vllm.transformers_utils.config import (
|
||||
from .interfaces import SupportsCrossEncoding, SupportsV0Only
|
||||
|
||||
|
||||
def roberta_task_weights_filter(
|
||||
all_weights: Iterable[Tuple[str, torch.Tensor]]
|
||||
) -> Tuple[Iterable[Tuple[str, torch.Tensor]], Iterable[Tuple[str,
|
||||
torch.Tensor]]]:
|
||||
"""
|
||||
Separate task-specific weights that are applied on top
|
||||
of the encoder-decoder bert base.
|
||||
To do so, return two generators over the original iterator.
|
||||
Also, remove the "roberta." prefix to make it loadable
|
||||
from vanilla BertModel.
|
||||
"""
|
||||
# Copy of a lazy iterator without in-memory overhead so both
|
||||
# iterators can be iterated upon independently.
|
||||
all_weights1, all_weights2 = itertools.tee(all_weights)
|
||||
|
||||
def encoder_decoder_weights():
|
||||
for name, weight in all_weights1:
|
||||
if name.startswith("roberta."):
|
||||
yield (name[len("roberta."):], weight)
|
||||
|
||||
return encoder_decoder_weights(), ((n, w) for n, w in all_weights2
|
||||
if not n.startswith("roberta."))
|
||||
|
||||
|
||||
class RobertaEmbedding(nn.Module):
|
||||
|
||||
def __init__(self, config: RobertaConfig):
|
||||
@ -119,30 +95,6 @@ class RobertaEmbedding(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
# Adapted from transformers
|
||||
def create_position_ids_from_input_ids(input_ids,
|
||||
padding_idx,
|
||||
past_key_values_length=0):
|
||||
"""
|
||||
Replace non-padding symbols with their position numbers.
|
||||
Position numbers begin at padding_idx+1. Padding symbols
|
||||
are ignored. This is modified from fairseq's `utils.make_positions`.
|
||||
|
||||
Args:
|
||||
x: torch.Tensor x:
|
||||
|
||||
Returns: torch.Tensor
|
||||
"""
|
||||
# The series of casts and type-conversions here are carefully
|
||||
# balanced to both work with ONNX export and XLA.
|
||||
mask = input_ids.ne(padding_idx).int()
|
||||
|
||||
incremental_indices = (torch.cumsum(mask, dim=0).type_as(mask) +
|
||||
past_key_values_length) * mask
|
||||
|
||||
return incremental_indices.long() + padding_idx
|
||||
|
||||
|
||||
# Adapted from transformers
|
||||
class RobertaClassificationHead(nn.Module):
|
||||
"""Head for sentence-level classification tasks."""
|
||||
@ -174,15 +126,38 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
|
||||
def _build_model(self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "") -> BertModel:
|
||||
return BertModel(vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
embedding_class=RobertaEmbedding)
|
||||
if (vllm_config.model_config.hf_config.position_embedding_type ==
|
||||
"rotary"):
|
||||
config = vllm_config.model_config.hf_config
|
||||
head_dim = config.hidden_size // config.num_attention_heads
|
||||
|
||||
rotary_kwargs = {
|
||||
"head_size": head_dim,
|
||||
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
|
||||
"max_position": config.max_position_embeddings,
|
||||
"base": config.rotary_emb_base,
|
||||
"rope_scaling": getattr(config, "rope_scaling", None)
|
||||
}
|
||||
|
||||
return BertModel(vllm_config=vllm_config,
|
||||
rotary_kwargs=rotary_kwargs,
|
||||
prefix=prefix)
|
||||
else:
|
||||
return BertModel(vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
embedding_class=RobertaEmbedding)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
if getattr(self.config, "lora_rank", 0) > 0:
|
||||
scaling = self.config.lora_alpha / self.config.lora_rank
|
||||
weights = jina_merge_lora_weights(weights, scaling)
|
||||
|
||||
weights = self.hf_to_vllm_mapper.apply(weights)
|
||||
# Separate weights in "roberta"-prefixed and all else (not in memory).
|
||||
# For use with models like FacebookAI/roberta-base.
|
||||
bert_weights, task_weights = roberta_task_weights_filter(weights)
|
||||
bert_weights = jina_to_vllm_mapper.apply(bert_weights)
|
||||
|
||||
loaded = self.model.load_weights(bert_weights)
|
||||
if not len(loaded):
|
||||
# Fix for models like `sentence-transformers/stsb-roberta-base-v2`
|
||||
@ -203,18 +178,6 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
||||
_pooler: An instance of Pooler used for pooling operations.
|
||||
"""
|
||||
|
||||
jina_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_substr={
|
||||
'emb_ln': "embeddings.LayerNorm",
|
||||
'layers': "layer",
|
||||
'mixer.Wqkv': "attention.self.qkv_proj",
|
||||
'mixer.out_proj': "attention.output.dense",
|
||||
'norm1': "attention.output.LayerNorm",
|
||||
'mlp.fc1': "intermediate.dense",
|
||||
'mlp.fc2': "output.dense",
|
||||
'norm2': "output.LayerNorm",
|
||||
})
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
@ -232,7 +195,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
bert_weights, task_weights = roberta_task_weights_filter(weights)
|
||||
bert_weights = self.jina_to_vllm_mapper.apply(bert_weights)
|
||||
bert_weights = jina_to_vllm_mapper.apply(bert_weights)
|
||||
|
||||
self.roberta.load_weights(bert_weights)
|
||||
|
||||
@ -265,3 +228,105 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
token_type_ids=token_type_ids)
|
||||
|
||||
|
||||
# Adapted from transformers
|
||||
def create_position_ids_from_input_ids(input_ids,
|
||||
padding_idx,
|
||||
past_key_values_length=0):
|
||||
"""
|
||||
Replace non-padding symbols with their position numbers.
|
||||
Position numbers begin at padding_idx+1. Padding symbols
|
||||
are ignored. This is modified from fairseq's `utils.make_positions`.
|
||||
|
||||
Args:
|
||||
x: torch.Tensor x:
|
||||
|
||||
Returns: torch.Tensor
|
||||
"""
|
||||
# The series of casts and type-conversions here are carefully
|
||||
# balanced to both work with ONNX export and XLA.
|
||||
mask = input_ids.ne(padding_idx).int()
|
||||
|
||||
incremental_indices = (torch.cumsum(mask, dim=0).type_as(mask) +
|
||||
past_key_values_length) * mask
|
||||
|
||||
return incremental_indices.long() + padding_idx
|
||||
|
||||
|
||||
def roberta_task_weights_filter(
|
||||
all_weights: Iterable[Tuple[str, torch.Tensor]]
|
||||
) -> Tuple[Iterable[Tuple[str, torch.Tensor]], Iterable[Tuple[str,
|
||||
torch.Tensor]]]:
|
||||
"""
|
||||
Separate task-specific weights that are applied on top
|
||||
of the encoder-decoder bert base.
|
||||
To do so, return two generators over the original iterator.
|
||||
Also, remove the "roberta." prefix to make it loadable
|
||||
from vanilla BertModel.
|
||||
"""
|
||||
# Copy of a lazy iterator without in-memory overhead so both
|
||||
# iterators can be iterated upon independently.
|
||||
all_weights1, all_weights2 = itertools.tee(all_weights)
|
||||
|
||||
def encoder_decoder_weights():
|
||||
for name, weight in all_weights1:
|
||||
if name.startswith("roberta."):
|
||||
yield (name[len("roberta."):], weight)
|
||||
|
||||
return encoder_decoder_weights(), ((n, w) for n, w in all_weights2
|
||||
if not n.startswith("roberta."))
|
||||
|
||||
|
||||
jina_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_substr={
|
||||
'emb_ln': "embeddings.LayerNorm",
|
||||
'layers': "layer",
|
||||
'mixer.Wqkv': "attention.self.qkv_proj",
|
||||
'mixer.out_proj': "attention.output.dense",
|
||||
'norm1': "attention.output.LayerNorm",
|
||||
'mlp.fc1': "intermediate.dense",
|
||||
'mlp.fc2': "output.dense",
|
||||
'norm2': "output.LayerNorm",
|
||||
})
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def jina_merge_lora_weights(weights: Iterable[Tuple[str, torch.Tensor]],
|
||||
scaling: float = 1.0):
|
||||
# use for jina-embeddings-v3
|
||||
# Merge Lora weights into a single weight tensor.
|
||||
# This is a temporary solution until we have a better way to handle
|
||||
|
||||
weights = {name: weight for name, weight in weights}
|
||||
|
||||
o = ".original"
|
||||
a = ".0.lora_A"
|
||||
b = ".0.lora_B"
|
||||
|
||||
# text-matching
|
||||
i = -1
|
||||
|
||||
for name in list(weights.keys()):
|
||||
if o in name:
|
||||
dtype = weights[name].dtype
|
||||
shape = weights[name].shape
|
||||
weight_name = name[:-len(o)]
|
||||
|
||||
if "embeddings" in weight_name:
|
||||
B = weights[weight_name + a][i].cuda().float()
|
||||
A = weights[weight_name + b][i].cuda().float()
|
||||
else:
|
||||
B = weights[weight_name + b][i].cuda().float()
|
||||
A = weights[weight_name + a][i].cuda().float()
|
||||
|
||||
weight = (weights[weight_name + o].cuda() +
|
||||
torch.matmul(B, A).view(shape) * scaling)
|
||||
weight = weight.cpu().to(dtype)
|
||||
|
||||
weights[weight_name.replace(".parametrizations", "")] = weight
|
||||
|
||||
del weights[weight_name + o], weights[weight_name +
|
||||
a], weights[weight_name + b]
|
||||
|
||||
return [(name, weight) for name, weight in weights.items()]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user