[New Model]: jinaai/jina-embeddings-v3 (#16120)

This commit is contained in:
wang.yuqi 2025-04-08 23:39:12 +08:00 committed by GitHub
parent 90cb44eb02
commit 1f5d13ab9f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 297 additions and 86 deletions

View File

@ -0,0 +1,50 @@
# SPDX-License-Identifier: Apache-2.0
from argparse import Namespace
from vllm import LLM, EngineArgs
from vllm.utils import FlexibleArgumentParser
def main(args: Namespace):
# Sample prompts.
prompts = [
"Follow the white rabbit.", # English
"Sigue al conejo blanco.", # Spanish
"Suis le lapin blanc.", # French
"跟着白兔走。", # Chinese
"اتبع الأرنب الأبيض.", # Arabic
"Folge dem weißen Kaninchen.", # German
]
# Create an LLM.
# You should pass task="embed" for embedding models
model = LLM(**vars(args))
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
# Only text matching task is supported for now. See #16120
outputs = model.embed(prompts)
# Print the outputs.
print("\nGenerated Outputs:")
print("Only text matching task is supported for now. See #16120")
print("-" * 60)
for prompt, output in zip(prompts, outputs):
embeds = output.outputs.embedding
embeds_trimmed = ((str(embeds[:16])[:-1] +
", ...]") if len(embeds) > 16 else embeds)
print(f"Prompt: {prompt!r} \n"
f"Embeddings for text matching: {embeds_trimmed} "
f"(size={len(embeds)})")
print("-" * 60)
if __name__ == "__main__":
parser = FlexibleArgumentParser()
parser = EngineArgs.add_cli_args(parser)
# Set example specific arguments
parser.set_defaults(model="jinaai/jina-embeddings-v3",
task="embed",
trust_remote_code=True)
args = parser.parse_args()
main(args)

View File

@ -671,8 +671,9 @@ class HfRunner:
return [(output_ids, output_str, output_logprobs)
for output_ids, output_str, output_logprobs in outputs]
def encode(self, prompts: list[str]) -> list[list[torch.Tensor]]:
return self.model.encode(prompts)
def encode(self, prompts: list[str], *args,
**kwargs) -> list[list[torch.Tensor]]:
return self.model.encode(prompts, *args, **kwargs)
def predict(self, prompts: list[list[str]]) -> torch.Tensor:
return self.model.predict(prompts, convert_to_tensor=True)

View File

@ -2,13 +2,15 @@
# ruff: noqa: E501
"""Compare the scoring outputs of HF and vLLM models.
Run `pytest tests/models/embedding/language/test_jina_reranker_v2.py`.
Run `pytest tests/models/embedding/language/test_jina.py`.
"""
import math
import pytest
MODELS = [
from tests.models.embedding.utils import check_embeddings_close
SCORING_MODELS = [
"jinaai/jina-reranker-v2-base-multilingual", # Roberta
]
@ -27,8 +29,21 @@ TEXTS_2 = [
"新しいメイクのトレンドは鮮やかな色と革新的な技術に焦点を当てています",
]
EMBEDDING_MODELS = [
"jinaai/jina-embeddings-v3",
]
@pytest.fixture(scope="module", params=MODELS)
EMBEDDING_PROMPTS = [
"Follow the white rabbit.", # English
"Sigue al conejo blanco.", # Spanish
"Suis le lapin blanc.", # French
"跟着白兔走。", # Chinese
"اتبع الأرنب الأبيض.", # Arabic
"Folge dem weißen Kaninchen.", # German
]
@pytest.fixture(scope="module", params=SCORING_MODELS)
def model_name(request):
yield request.param
@ -68,3 +83,46 @@ def test_llm_1_to_N(vllm_runner, hf_runner, model_name, dtype: str):
assert math.isclose(hf_outputs[0], vllm_outputs[0], rel_tol=0.01)
assert math.isclose(hf_outputs[1], vllm_outputs[1], rel_tol=0.01)
@pytest.fixture(scope="module", params=EMBEDDING_MODELS)
def emb_model_name(request):
yield request.param
def test_is_matryoshka(vllm_runner, emb_model_name):
with vllm_runner(emb_model_name, task="embed",
max_model_len=None) as vllm_model:
assert vllm_model.model.llm_engine.model_config.is_matryoshka
@pytest.mark.parametrize("model", EMBEDDING_MODELS)
@pytest.mark.parametrize("dtype", ["half"])
def test_embeddings(
hf_runner,
vllm_runner,
model,
dtype: str,
monkeypatch,
) -> None:
example_prompts = EMBEDDING_PROMPTS
with hf_runner(
model,
dtype=dtype,
is_sentence_transformer=True,
) as hf_model:
hf_outputs = hf_model.encode(example_prompts, task="text-matching")
with vllm_runner(model, task="embed", dtype=dtype,
max_model_len=None) as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts)
check_embeddings_close(
embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
tol=1e-2,
)

View File

@ -1130,6 +1130,11 @@ class ModelConfig:
architectures = getattr(self.hf_config, "architectures", [])
return ModelRegistry.is_v1_compatible(architectures)
@property
def is_matryoshka(self) -> bool:
return (hasattr(self.hf_config, "matryoshka_dimensions")
or getattr(self.hf_config, "is_matryoshka", False))
class CacheConfig:
"""Configuration for the KV cache.

View File

@ -18,6 +18,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.pooler import (CrossEncodingPooler, Pooler,
PoolingType)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@ -38,19 +39,24 @@ class BertEmbedding(nn.Module):
self.size = config.hidden_size
self.word_embeddings = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.position_embeddings = VocabParallelEmbedding(
config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = VocabParallelEmbedding(
config.type_vocab_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.position_ids = nn.Parameter(
torch.empty((1, config.max_position_embeddings)), )
self.position_embedding_type = config.position_embedding_type
if self.position_embedding_type != "absolute":
raise ValueError("Only 'absolute' position_embedding_type" +
" is supported")
if self.position_embedding_type == "absolute":
self.position_embeddings = VocabParallelEmbedding(
config.max_position_embeddings, config.hidden_size)
self.position_ids = nn.Parameter(
torch.empty((1, config.max_position_embeddings)), )
elif self.position_embedding_type == "rotary":
self.position_embeddings = None
self.position_ids = None
else:
raise ValueError("Only 'absolute' and 'rotary' " +
"position_embedding_type is supported")
def forward(
self,
@ -64,9 +70,6 @@ class BertEmbedding(nn.Module):
# Input embeddings.
inputs_embeds = self.word_embeddings(input_ids)
# Position embeddings.
position_embeddings = self.position_embeddings(position_ids)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape,
dtype=torch.long,
@ -74,7 +77,12 @@ class BertEmbedding(nn.Module):
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + token_type_embeddings + position_embeddings
embeddings = inputs_embeds + token_type_embeddings
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings
embeddings = self.LayerNorm(embeddings)
return embeddings
@ -98,7 +106,10 @@ class BertPooler(nn.Module):
@support_torch_compile
class BertEncoder(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
def __init__(self,
vllm_config: VllmConfig,
rotary_kwargs: Optional[dict] = None,
prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
@ -107,16 +118,18 @@ class BertEncoder(nn.Module):
BertLayer(config=config,
cache_config=cache_config,
quant_config=quant_config,
rotary_kwargs=rotary_kwargs,
prefix=f"{prefix}.layer.{layer_idx}")
for layer_idx in range(config.num_hidden_layers)
])
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
for layer in self.layer:
hidden_states = layer(hidden_states)
hidden_states = layer(positions, hidden_states)
return hidden_states
@ -126,6 +139,7 @@ class BertLayer(nn.Module):
config: BertConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
rotary_kwargs: Optional[dict] = None,
prefix: str = ""):
super().__init__()
@ -135,6 +149,7 @@ class BertLayer(nn.Module):
layer_norm_eps=config.layer_norm_eps,
cache_config=cache_config,
quant_config=quant_config,
rotary_kwargs=rotary_kwargs,
prefix=f"{prefix}.attention")
self.intermediate = BertIntermediate(
@ -150,8 +165,8 @@ class BertLayer(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.output")
def forward(self, hidden_states: torch.Tensor):
attn_output = self.attention(hidden_states)
def forward(self, positions: torch.Tensor, hidden_states: torch.Tensor):
attn_output = self.attention(positions, hidden_states)
intermediate_output = self.intermediate(attn_output)
output = self.output(intermediate_output, attn_output)
return output
@ -166,6 +181,7 @@ class BertAttention(nn.Module):
layer_norm_eps: float,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
rotary_kwargs: Optional[dict] = None,
prefix: str = "",
):
super().__init__()
@ -174,6 +190,7 @@ class BertAttention(nn.Module):
num_attention_heads=num_attention_heads,
cache_config=cache_config,
quant_config=quant_config,
rotary_kwargs=rotary_kwargs,
prefix=f"{prefix}.output")
self.output = BertSelfOutput(hidden_size=hidden_size,
@ -183,9 +200,10 @@ class BertAttention(nn.Module):
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
self_output = self.self(hidden_states)
self_output = self.self(positions, hidden_states)
return self.output(self_output, hidden_states)
@ -197,6 +215,7 @@ class BertSelfAttention(nn.Module):
num_attention_heads: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
rotary_kwargs: Optional[dict] = None,
prefix: str = "",
):
super().__init__()
@ -225,6 +244,11 @@ class BertSelfAttention(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj")
if rotary_kwargs:
self.rotary_emb = get_rope(**rotary_kwargs)
else:
self.rotary_emb = None
self.attn = Attention(num_heads=self.num_heads,
head_size=self.head_dim,
scale=self.scaling,
@ -236,10 +260,15 @@ class BertSelfAttention(nn.Module):
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.rotary_emb:
q, k = self.rotary_emb(positions, q, k)
output = self.attn(q, k, v)
return output
@ -321,11 +350,13 @@ class BertModel(nn.Module, SupportsQuant):
vllm_config: VllmConfig,
prefix: str = "",
embedding_class: type = BertEmbedding,
rotary_kwargs: Optional[dict] = None,
add_pooling_layer: bool = False):
super().__init__()
config = vllm_config.model_config.hf_config
self.embeddings = embedding_class(config)
self.encoder = BertEncoder(vllm_config=vllm_config,
rotary_kwargs=rotary_kwargs,
prefix=f"{prefix}.encoder")
self.pooler = BertPooler(config) if add_pooling_layer else None
@ -347,7 +378,7 @@ class BertModel(nn.Module, SupportsQuant):
seq_lens=attn_metadata.seq_lens_tensor,
position_ids=position_ids,
token_type_ids=token_type_ids)
return self.encoder(hidden_states)
return self.encoder(position_ids, hidden_states)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
@ -401,6 +432,7 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
pooler_config = vllm_config.model_config.pooler_config
self.config = vllm_config.model_config.hf_config
self.model = self._build_model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self._pooler = self._build_pooler(pooler_config)

View File

@ -22,30 +22,6 @@ from vllm.transformers_utils.config import (
from .interfaces import SupportsCrossEncoding, SupportsV0Only
def roberta_task_weights_filter(
all_weights: Iterable[Tuple[str, torch.Tensor]]
) -> Tuple[Iterable[Tuple[str, torch.Tensor]], Iterable[Tuple[str,
torch.Tensor]]]:
"""
Separate task-specific weights that are applied on top
of the encoder-decoder bert base.
To do so, return two generators over the original iterator.
Also, remove the "roberta." prefix to make it loadable
from vanilla BertModel.
"""
# Copy of a lazy iterator without in-memory overhead so both
# iterators can be iterated upon independently.
all_weights1, all_weights2 = itertools.tee(all_weights)
def encoder_decoder_weights():
for name, weight in all_weights1:
if name.startswith("roberta."):
yield (name[len("roberta."):], weight)
return encoder_decoder_weights(), ((n, w) for n, w in all_weights2
if not n.startswith("roberta."))
class RobertaEmbedding(nn.Module):
def __init__(self, config: RobertaConfig):
@ -119,30 +95,6 @@ class RobertaEmbedding(nn.Module):
return embeddings
# Adapted from transformers
def create_position_ids_from_input_ids(input_ids,
padding_idx,
past_key_values_length=0):
"""
Replace non-padding symbols with their position numbers.
Position numbers begin at padding_idx+1. Padding symbols
are ignored. This is modified from fairseq's `utils.make_positions`.
Args:
x: torch.Tensor x:
Returns: torch.Tensor
"""
# The series of casts and type-conversions here are carefully
# balanced to both work with ONNX export and XLA.
mask = input_ids.ne(padding_idx).int()
incremental_indices = (torch.cumsum(mask, dim=0).type_as(mask) +
past_key_values_length) * mask
return incremental_indices.long() + padding_idx
# Adapted from transformers
class RobertaClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
@ -174,15 +126,38 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
def _build_model(self,
vllm_config: VllmConfig,
prefix: str = "") -> BertModel:
return BertModel(vllm_config=vllm_config,
prefix=prefix,
embedding_class=RobertaEmbedding)
if (vllm_config.model_config.hf_config.position_embedding_type ==
"rotary"):
config = vllm_config.model_config.hf_config
head_dim = config.hidden_size // config.num_attention_heads
rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_position_embeddings,
"base": config.rotary_emb_base,
"rope_scaling": getattr(config, "rope_scaling", None)
}
return BertModel(vllm_config=vllm_config,
rotary_kwargs=rotary_kwargs,
prefix=prefix)
else:
return BertModel(vllm_config=vllm_config,
prefix=prefix,
embedding_class=RobertaEmbedding)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
if getattr(self.config, "lora_rank", 0) > 0:
scaling = self.config.lora_alpha / self.config.lora_rank
weights = jina_merge_lora_weights(weights, scaling)
weights = self.hf_to_vllm_mapper.apply(weights)
# Separate weights in "roberta"-prefixed and all else (not in memory).
# For use with models like FacebookAI/roberta-base.
bert_weights, task_weights = roberta_task_weights_filter(weights)
bert_weights = jina_to_vllm_mapper.apply(bert_weights)
loaded = self.model.load_weights(bert_weights)
if not len(loaded):
# Fix for models like `sentence-transformers/stsb-roberta-base-v2`
@ -203,18 +178,6 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
_pooler: An instance of Pooler used for pooling operations.
"""
jina_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
'emb_ln': "embeddings.LayerNorm",
'layers': "layer",
'mixer.Wqkv': "attention.self.qkv_proj",
'mixer.out_proj': "attention.output.dense",
'norm1': "attention.output.LayerNorm",
'mlp.fc1': "intermediate.dense",
'mlp.fc2': "output.dense",
'norm2': "output.LayerNorm",
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
@ -232,7 +195,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
bert_weights, task_weights = roberta_task_weights_filter(weights)
bert_weights = self.jina_to_vllm_mapper.apply(bert_weights)
bert_weights = jina_to_vllm_mapper.apply(bert_weights)
self.roberta.load_weights(bert_weights)
@ -265,3 +228,105 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors,
token_type_ids=token_type_ids)
# Adapted from transformers
def create_position_ids_from_input_ids(input_ids,
padding_idx,
past_key_values_length=0):
"""
Replace non-padding symbols with their position numbers.
Position numbers begin at padding_idx+1. Padding symbols
are ignored. This is modified from fairseq's `utils.make_positions`.
Args:
x: torch.Tensor x:
Returns: torch.Tensor
"""
# The series of casts and type-conversions here are carefully
# balanced to both work with ONNX export and XLA.
mask = input_ids.ne(padding_idx).int()
incremental_indices = (torch.cumsum(mask, dim=0).type_as(mask) +
past_key_values_length) * mask
return incremental_indices.long() + padding_idx
def roberta_task_weights_filter(
all_weights: Iterable[Tuple[str, torch.Tensor]]
) -> Tuple[Iterable[Tuple[str, torch.Tensor]], Iterable[Tuple[str,
torch.Tensor]]]:
"""
Separate task-specific weights that are applied on top
of the encoder-decoder bert base.
To do so, return two generators over the original iterator.
Also, remove the "roberta." prefix to make it loadable
from vanilla BertModel.
"""
# Copy of a lazy iterator without in-memory overhead so both
# iterators can be iterated upon independently.
all_weights1, all_weights2 = itertools.tee(all_weights)
def encoder_decoder_weights():
for name, weight in all_weights1:
if name.startswith("roberta."):
yield (name[len("roberta."):], weight)
return encoder_decoder_weights(), ((n, w) for n, w in all_weights2
if not n.startswith("roberta."))
jina_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
'emb_ln': "embeddings.LayerNorm",
'layers': "layer",
'mixer.Wqkv': "attention.self.qkv_proj",
'mixer.out_proj': "attention.output.dense",
'norm1': "attention.output.LayerNorm",
'mlp.fc1': "intermediate.dense",
'mlp.fc2': "output.dense",
'norm2': "output.LayerNorm",
})
@torch.inference_mode()
def jina_merge_lora_weights(weights: Iterable[Tuple[str, torch.Tensor]],
scaling: float = 1.0):
# use for jina-embeddings-v3
# Merge Lora weights into a single weight tensor.
# This is a temporary solution until we have a better way to handle
weights = {name: weight for name, weight in weights}
o = ".original"
a = ".0.lora_A"
b = ".0.lora_B"
# text-matching
i = -1
for name in list(weights.keys()):
if o in name:
dtype = weights[name].dtype
shape = weights[name].shape
weight_name = name[:-len(o)]
if "embeddings" in weight_name:
B = weights[weight_name + a][i].cuda().float()
A = weights[weight_name + b][i].cuda().float()
else:
B = weights[weight_name + b][i].cuda().float()
A = weights[weight_name + a][i].cuda().float()
weight = (weights[weight_name + o].cuda() +
torch.matmul(B, A).view(shape) * scaling)
weight = weight.cpu().to(dtype)
weights[weight_name.replace(".parametrizations", "")] = weight
del weights[weight_name + o], weights[weight_name +
a], weights[weight_name + b]
return [(name, weight) for name, weight in weights.items()]