From 1f5d13ab9f1232751a09cdb6b1cdbb4687393833 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Tue, 8 Apr 2025 23:39:12 +0800 Subject: [PATCH] [New Model]: jinaai/jina-embeddings-v3 (#16120) --- .../embed_jina_embeddings_v3.py | 50 +++++ tests/conftest.py | 5 +- ...{test_jina_reranker_v2.py => test_jina.py} | 64 +++++- vllm/config.py | 5 + vllm/model_executor/models/bert.py | 66 ++++-- vllm/model_executor/models/roberta.py | 193 ++++++++++++------ 6 files changed, 297 insertions(+), 86 deletions(-) create mode 100644 examples/offline_inference/embed_jina_embeddings_v3.py rename tests/models/embedding/language/{test_jina_reranker_v2.py => test_jina.py} (59%) diff --git a/examples/offline_inference/embed_jina_embeddings_v3.py b/examples/offline_inference/embed_jina_embeddings_v3.py new file mode 100644 index 0000000000000..f7d9e47e7953e --- /dev/null +++ b/examples/offline_inference/embed_jina_embeddings_v3.py @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: Apache-2.0 + +from argparse import Namespace + +from vllm import LLM, EngineArgs +from vllm.utils import FlexibleArgumentParser + + +def main(args: Namespace): + # Sample prompts. + prompts = [ + "Follow the white rabbit.", # English + "Sigue al conejo blanco.", # Spanish + "Suis le lapin blanc.", # French + "跟着白兔走。", # Chinese + "اتبع الأرنب الأبيض.", # Arabic + "Folge dem weißen Kaninchen.", # German + ] + + # Create an LLM. + # You should pass task="embed" for embedding models + model = LLM(**vars(args)) + + # Generate embedding. The output is a list of EmbeddingRequestOutputs. + # Only text matching task is supported for now. See #16120 + outputs = model.embed(prompts) + + # Print the outputs. + print("\nGenerated Outputs:") + print("Only text matching task is supported for now. See #16120") + print("-" * 60) + for prompt, output in zip(prompts, outputs): + embeds = output.outputs.embedding + embeds_trimmed = ((str(embeds[:16])[:-1] + + ", ...]") if len(embeds) > 16 else embeds) + print(f"Prompt: {prompt!r} \n" + f"Embeddings for text matching: {embeds_trimmed} " + f"(size={len(embeds)})") + print("-" * 60) + + +if __name__ == "__main__": + parser = FlexibleArgumentParser() + parser = EngineArgs.add_cli_args(parser) + # Set example specific arguments + parser.set_defaults(model="jinaai/jina-embeddings-v3", + task="embed", + trust_remote_code=True) + args = parser.parse_args() + main(args) diff --git a/tests/conftest.py b/tests/conftest.py index b833cff4db7c0..c5d393907ec8c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -671,8 +671,9 @@ class HfRunner: return [(output_ids, output_str, output_logprobs) for output_ids, output_str, output_logprobs in outputs] - def encode(self, prompts: list[str]) -> list[list[torch.Tensor]]: - return self.model.encode(prompts) + def encode(self, prompts: list[str], *args, + **kwargs) -> list[list[torch.Tensor]]: + return self.model.encode(prompts, *args, **kwargs) def predict(self, prompts: list[list[str]]) -> torch.Tensor: return self.model.predict(prompts, convert_to_tensor=True) diff --git a/tests/models/embedding/language/test_jina_reranker_v2.py b/tests/models/embedding/language/test_jina.py similarity index 59% rename from tests/models/embedding/language/test_jina_reranker_v2.py rename to tests/models/embedding/language/test_jina.py index ab88fa9ba636c..2a3eab02ddd9e 100644 --- a/tests/models/embedding/language/test_jina_reranker_v2.py +++ b/tests/models/embedding/language/test_jina.py @@ -2,13 +2,15 @@ # ruff: noqa: E501 """Compare the scoring outputs of HF and vLLM models. -Run `pytest tests/models/embedding/language/test_jina_reranker_v2.py`. +Run `pytest tests/models/embedding/language/test_jina.py`. """ import math import pytest -MODELS = [ +from tests.models.embedding.utils import check_embeddings_close + +SCORING_MODELS = [ "jinaai/jina-reranker-v2-base-multilingual", # Roberta ] @@ -27,8 +29,21 @@ TEXTS_2 = [ "新しいメイクのトレンドは鮮やかな色と革新的な技術に焦点を当てています", ] +EMBEDDING_MODELS = [ + "jinaai/jina-embeddings-v3", +] -@pytest.fixture(scope="module", params=MODELS) +EMBEDDING_PROMPTS = [ + "Follow the white rabbit.", # English + "Sigue al conejo blanco.", # Spanish + "Suis le lapin blanc.", # French + "跟着白兔走。", # Chinese + "اتبع الأرنب الأبيض.", # Arabic + "Folge dem weißen Kaninchen.", # German +] + + +@pytest.fixture(scope="module", params=SCORING_MODELS) def model_name(request): yield request.param @@ -68,3 +83,46 @@ def test_llm_1_to_N(vllm_runner, hf_runner, model_name, dtype: str): assert math.isclose(hf_outputs[0], vllm_outputs[0], rel_tol=0.01) assert math.isclose(hf_outputs[1], vllm_outputs[1], rel_tol=0.01) + + +@pytest.fixture(scope="module", params=EMBEDDING_MODELS) +def emb_model_name(request): + yield request.param + + +def test_is_matryoshka(vllm_runner, emb_model_name): + with vllm_runner(emb_model_name, task="embed", + max_model_len=None) as vllm_model: + assert vllm_model.model.llm_engine.model_config.is_matryoshka + + +@pytest.mark.parametrize("model", EMBEDDING_MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +def test_embeddings( + hf_runner, + vllm_runner, + model, + dtype: str, + monkeypatch, +) -> None: + + example_prompts = EMBEDDING_PROMPTS + + with hf_runner( + model, + dtype=dtype, + is_sentence_transformer=True, + ) as hf_model: + hf_outputs = hf_model.encode(example_prompts, task="text-matching") + + with vllm_runner(model, task="embed", dtype=dtype, + max_model_len=None) as vllm_model: + vllm_outputs = vllm_model.encode(example_prompts) + + check_embeddings_close( + embeddings_0_lst=hf_outputs, + embeddings_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + tol=1e-2, + ) diff --git a/vllm/config.py b/vllm/config.py index 439e27b154ab3..2662c6a84990c 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1130,6 +1130,11 @@ class ModelConfig: architectures = getattr(self.hf_config, "architectures", []) return ModelRegistry.is_v1_compatible(architectures) + @property + def is_matryoshka(self) -> bool: + return (hasattr(self.hf_config, "matryoshka_dimensions") + or getattr(self.hf_config, "is_matryoshka", False)) + class CacheConfig: """Configuration for the KV cache. diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 111b49ab8dd2a..e1d77646f47e8 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -18,6 +18,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.pooler import (CrossEncodingPooler, Pooler, PoolingType) from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -38,19 +39,24 @@ class BertEmbedding(nn.Module): self.size = config.hidden_size self.word_embeddings = VocabParallelEmbedding(config.vocab_size, config.hidden_size) - self.position_embeddings = VocabParallelEmbedding( - config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = VocabParallelEmbedding( config.type_vocab_size, config.hidden_size) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.position_ids = nn.Parameter( - torch.empty((1, config.max_position_embeddings)), ) self.position_embedding_type = config.position_embedding_type - if self.position_embedding_type != "absolute": - raise ValueError("Only 'absolute' position_embedding_type" + - " is supported") + if self.position_embedding_type == "absolute": + self.position_embeddings = VocabParallelEmbedding( + config.max_position_embeddings, config.hidden_size) + self.position_ids = nn.Parameter( + torch.empty((1, config.max_position_embeddings)), ) + elif self.position_embedding_type == "rotary": + self.position_embeddings = None + self.position_ids = None + else: + raise ValueError("Only 'absolute' and 'rotary' " + + "position_embedding_type is supported") def forward( self, @@ -64,9 +70,6 @@ class BertEmbedding(nn.Module): # Input embeddings. inputs_embeds = self.word_embeddings(input_ids) - # Position embeddings. - position_embeddings = self.position_embeddings(position_ids) - if token_type_ids is None: token_type_ids = torch.zeros(input_shape, dtype=torch.long, @@ -74,7 +77,12 @@ class BertEmbedding(nn.Module): token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = inputs_embeds + token_type_embeddings + position_embeddings + embeddings = inputs_embeds + token_type_embeddings + + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) return embeddings @@ -98,7 +106,10 @@ class BertPooler(nn.Module): @support_torch_compile class BertEncoder(nn.Module): - def __init__(self, vllm_config: VllmConfig, prefix: str = ""): + def __init__(self, + vllm_config: VllmConfig, + rotary_kwargs: Optional[dict] = None, + prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config @@ -107,16 +118,18 @@ class BertEncoder(nn.Module): BertLayer(config=config, cache_config=cache_config, quant_config=quant_config, + rotary_kwargs=rotary_kwargs, prefix=f"{prefix}.layer.{layer_idx}") for layer_idx in range(config.num_hidden_layers) ]) def forward( self, + positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: for layer in self.layer: - hidden_states = layer(hidden_states) + hidden_states = layer(positions, hidden_states) return hidden_states @@ -126,6 +139,7 @@ class BertLayer(nn.Module): config: BertConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + rotary_kwargs: Optional[dict] = None, prefix: str = ""): super().__init__() @@ -135,6 +149,7 @@ class BertLayer(nn.Module): layer_norm_eps=config.layer_norm_eps, cache_config=cache_config, quant_config=quant_config, + rotary_kwargs=rotary_kwargs, prefix=f"{prefix}.attention") self.intermediate = BertIntermediate( @@ -150,8 +165,8 @@ class BertLayer(nn.Module): quant_config=quant_config, prefix=f"{prefix}.output") - def forward(self, hidden_states: torch.Tensor): - attn_output = self.attention(hidden_states) + def forward(self, positions: torch.Tensor, hidden_states: torch.Tensor): + attn_output = self.attention(positions, hidden_states) intermediate_output = self.intermediate(attn_output) output = self.output(intermediate_output, attn_output) return output @@ -166,6 +181,7 @@ class BertAttention(nn.Module): layer_norm_eps: float, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + rotary_kwargs: Optional[dict] = None, prefix: str = "", ): super().__init__() @@ -174,6 +190,7 @@ class BertAttention(nn.Module): num_attention_heads=num_attention_heads, cache_config=cache_config, quant_config=quant_config, + rotary_kwargs=rotary_kwargs, prefix=f"{prefix}.output") self.output = BertSelfOutput(hidden_size=hidden_size, @@ -183,9 +200,10 @@ class BertAttention(nn.Module): def forward( self, + positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: - self_output = self.self(hidden_states) + self_output = self.self(positions, hidden_states) return self.output(self_output, hidden_states) @@ -197,6 +215,7 @@ class BertSelfAttention(nn.Module): num_attention_heads: int, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + rotary_kwargs: Optional[dict] = None, prefix: str = "", ): super().__init__() @@ -225,6 +244,11 @@ class BertSelfAttention(nn.Module): quant_config=quant_config, prefix=f"{prefix}.qkv_proj") + if rotary_kwargs: + self.rotary_emb = get_rope(**rotary_kwargs) + else: + self.rotary_emb = None + self.attn = Attention(num_heads=self.num_heads, head_size=self.head_dim, scale=self.scaling, @@ -236,10 +260,15 @@ class BertSelfAttention(nn.Module): def forward( self, + positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + if self.rotary_emb: + q, k = self.rotary_emb(positions, q, k) + output = self.attn(q, k, v) return output @@ -321,11 +350,13 @@ class BertModel(nn.Module, SupportsQuant): vllm_config: VllmConfig, prefix: str = "", embedding_class: type = BertEmbedding, + rotary_kwargs: Optional[dict] = None, add_pooling_layer: bool = False): super().__init__() config = vllm_config.model_config.hf_config self.embeddings = embedding_class(config) self.encoder = BertEncoder(vllm_config=vllm_config, + rotary_kwargs=rotary_kwargs, prefix=f"{prefix}.encoder") self.pooler = BertPooler(config) if add_pooling_layer else None @@ -347,7 +378,7 @@ class BertModel(nn.Module, SupportsQuant): seq_lens=attn_metadata.seq_lens_tensor, position_ids=position_ids, token_type_ids=token_type_ids) - return self.encoder(hidden_states) + return self.encoder(position_ids, hidden_states) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: @@ -401,6 +432,7 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() pooler_config = vllm_config.model_config.pooler_config + self.config = vllm_config.model_config.hf_config self.model = self._build_model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) self._pooler = self._build_pooler(pooler_config) diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index a09741a559755..4c23d72a41952 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -22,30 +22,6 @@ from vllm.transformers_utils.config import ( from .interfaces import SupportsCrossEncoding, SupportsV0Only -def roberta_task_weights_filter( - all_weights: Iterable[Tuple[str, torch.Tensor]] -) -> Tuple[Iterable[Tuple[str, torch.Tensor]], Iterable[Tuple[str, - torch.Tensor]]]: - """ - Separate task-specific weights that are applied on top - of the encoder-decoder bert base. - To do so, return two generators over the original iterator. - Also, remove the "roberta." prefix to make it loadable - from vanilla BertModel. - """ - # Copy of a lazy iterator without in-memory overhead so both - # iterators can be iterated upon independently. - all_weights1, all_weights2 = itertools.tee(all_weights) - - def encoder_decoder_weights(): - for name, weight in all_weights1: - if name.startswith("roberta."): - yield (name[len("roberta."):], weight) - - return encoder_decoder_weights(), ((n, w) for n, w in all_weights2 - if not n.startswith("roberta.")) - - class RobertaEmbedding(nn.Module): def __init__(self, config: RobertaConfig): @@ -119,30 +95,6 @@ class RobertaEmbedding(nn.Module): return embeddings -# Adapted from transformers -def create_position_ids_from_input_ids(input_ids, - padding_idx, - past_key_values_length=0): - """ - Replace non-padding symbols with their position numbers. - Position numbers begin at padding_idx+1. Padding symbols - are ignored. This is modified from fairseq's `utils.make_positions`. - - Args: - x: torch.Tensor x: - - Returns: torch.Tensor - """ - # The series of casts and type-conversions here are carefully - # balanced to both work with ONNX export and XLA. - mask = input_ids.ne(padding_idx).int() - - incremental_indices = (torch.cumsum(mask, dim=0).type_as(mask) + - past_key_values_length) * mask - - return incremental_indices.long() + padding_idx - - # Adapted from transformers class RobertaClassificationHead(nn.Module): """Head for sentence-level classification tasks.""" @@ -174,15 +126,38 @@ class RobertaEmbeddingModel(BertEmbeddingModel): def _build_model(self, vllm_config: VllmConfig, prefix: str = "") -> BertModel: - return BertModel(vllm_config=vllm_config, - prefix=prefix, - embedding_class=RobertaEmbedding) + if (vllm_config.model_config.hf_config.position_embedding_type == + "rotary"): + config = vllm_config.model_config.hf_config + head_dim = config.hidden_size // config.num_attention_heads + + rotary_kwargs = { + "head_size": head_dim, + "rotary_dim": getattr(config, "rotary_emb_dim", head_dim), + "max_position": config.max_position_embeddings, + "base": config.rotary_emb_base, + "rope_scaling": getattr(config, "rope_scaling", None) + } + + return BertModel(vllm_config=vllm_config, + rotary_kwargs=rotary_kwargs, + prefix=prefix) + else: + return BertModel(vllm_config=vllm_config, + prefix=prefix, + embedding_class=RobertaEmbedding) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + if getattr(self.config, "lora_rank", 0) > 0: + scaling = self.config.lora_alpha / self.config.lora_rank + weights = jina_merge_lora_weights(weights, scaling) + weights = self.hf_to_vllm_mapper.apply(weights) # Separate weights in "roberta"-prefixed and all else (not in memory). # For use with models like FacebookAI/roberta-base. bert_weights, task_weights = roberta_task_weights_filter(weights) + bert_weights = jina_to_vllm_mapper.apply(bert_weights) + loaded = self.model.load_weights(bert_weights) if not len(loaded): # Fix for models like `sentence-transformers/stsb-roberta-base-v2` @@ -203,18 +178,6 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding, _pooler: An instance of Pooler used for pooling operations. """ - jina_to_vllm_mapper = WeightsMapper( - orig_to_new_substr={ - 'emb_ln': "embeddings.LayerNorm", - 'layers': "layer", - 'mixer.Wqkv': "attention.self.qkv_proj", - 'mixer.out_proj': "attention.output.dense", - 'norm1': "attention.output.LayerNorm", - 'mlp.fc1': "intermediate.dense", - 'mlp.fc2': "output.dense", - 'norm2': "output.LayerNorm", - }) - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -232,7 +195,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding, def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): bert_weights, task_weights = roberta_task_weights_filter(weights) - bert_weights = self.jina_to_vllm_mapper.apply(bert_weights) + bert_weights = jina_to_vllm_mapper.apply(bert_weights) self.roberta.load_weights(bert_weights) @@ -265,3 +228,105 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding, inputs_embeds=inputs_embeds, intermediate_tensors=intermediate_tensors, token_type_ids=token_type_ids) + + +# Adapted from transformers +def create_position_ids_from_input_ids(input_ids, + padding_idx, + past_key_values_length=0): + """ + Replace non-padding symbols with their position numbers. + Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: torch.Tensor x: + + Returns: torch.Tensor + """ + # The series of casts and type-conversions here are carefully + # balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + + incremental_indices = (torch.cumsum(mask, dim=0).type_as(mask) + + past_key_values_length) * mask + + return incremental_indices.long() + padding_idx + + +def roberta_task_weights_filter( + all_weights: Iterable[Tuple[str, torch.Tensor]] +) -> Tuple[Iterable[Tuple[str, torch.Tensor]], Iterable[Tuple[str, + torch.Tensor]]]: + """ + Separate task-specific weights that are applied on top + of the encoder-decoder bert base. + To do so, return two generators over the original iterator. + Also, remove the "roberta." prefix to make it loadable + from vanilla BertModel. + """ + # Copy of a lazy iterator without in-memory overhead so both + # iterators can be iterated upon independently. + all_weights1, all_weights2 = itertools.tee(all_weights) + + def encoder_decoder_weights(): + for name, weight in all_weights1: + if name.startswith("roberta."): + yield (name[len("roberta."):], weight) + + return encoder_decoder_weights(), ((n, w) for n, w in all_weights2 + if not n.startswith("roberta.")) + + +jina_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + 'emb_ln': "embeddings.LayerNorm", + 'layers': "layer", + 'mixer.Wqkv': "attention.self.qkv_proj", + 'mixer.out_proj': "attention.output.dense", + 'norm1': "attention.output.LayerNorm", + 'mlp.fc1': "intermediate.dense", + 'mlp.fc2': "output.dense", + 'norm2': "output.LayerNorm", + }) + + +@torch.inference_mode() +def jina_merge_lora_weights(weights: Iterable[Tuple[str, torch.Tensor]], + scaling: float = 1.0): + # use for jina-embeddings-v3 + # Merge Lora weights into a single weight tensor. + # This is a temporary solution until we have a better way to handle + + weights = {name: weight for name, weight in weights} + + o = ".original" + a = ".0.lora_A" + b = ".0.lora_B" + + # text-matching + i = -1 + + for name in list(weights.keys()): + if o in name: + dtype = weights[name].dtype + shape = weights[name].shape + weight_name = name[:-len(o)] + + if "embeddings" in weight_name: + B = weights[weight_name + a][i].cuda().float() + A = weights[weight_name + b][i].cuda().float() + else: + B = weights[weight_name + b][i].cuda().float() + A = weights[weight_name + a][i].cuda().float() + + weight = (weights[weight_name + o].cuda() + + torch.matmul(B, A).view(shape) * scaling) + weight = weight.cpu().to(dtype) + + weights[weight_name.replace(".parametrizations", "")] = weight + + del weights[weight_name + o], weights[weight_name + + a], weights[weight_name + b] + + return [(name, weight) for name, weight in weights.items()]