[New Model]: jinaai/jina-reranker-v2-base-multilingual (#15876)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
wang.yuqi 2025-04-01 23:32:26 +08:00 committed by GitHub
parent 2b93162fb0
commit 085cbc4f9f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 86 additions and 3 deletions

View File

@ -566,7 +566,7 @@ you should explicitly specify the task type to ensure that the model is used in
*
- * `XLMRobertaModel`
* XLM-RoBERTa-based
* `intfloat/multilingual-e5-large`, etc.
* `intfloat/multilingual-e5-large`, `jinaai/jina-reranker-v2-base-multilingual`, etc.
*
*
:::

View File

@ -0,0 +1,70 @@
# SPDX-License-Identifier: Apache-2.0
# ruff: noqa: E501
"""Compare the scoring outputs of HF and vLLM models.
Run `pytest tests/models/embedding/language/test_jina_reranker_v2.py`.
"""
import math
import pytest
MODELS = [
"jinaai/jina-reranker-v2-base-multilingual", # Roberta
]
TEXTS_1 = ["Organic skincare products for sensitive skin"]
TEXTS_2 = [
"Organic skincare for sensitive skin with aloe vera and chamomile.",
"New makeup trends focus on bold colors and innovative techniques",
"Bio-Hautpflege für empfindliche Haut mit Aloe Vera und Kamille",
"Neue Make-up-Trends setzen auf kräftige Farben und innovative Techniken",
"Cuidado de la piel orgánico para piel sensible con aloe vera y manzanilla",
"Las nuevas tendencias de maquillaje se centran en colores vivos y técnicas innovadoras",
"针对敏感肌专门设计的天然有机护肤产品",
"新的化妆趋势注重鲜艳的颜色和创新的技巧",
"敏感肌のために特別に設計された天然有機スキンケア製品",
"新しいメイクのトレンドは鮮やかな色と革新的な技術に焦点を当てています",
]
@pytest.fixture(scope="module", params=MODELS)
def model_name(request):
yield request.param
@pytest.mark.parametrize("dtype", ["half"])
def test_llm_1_to_1(vllm_runner, hf_runner, model_name, dtype: str):
text_pair = [TEXTS_1[0], TEXTS_2[0]]
with hf_runner(model_name, dtype=dtype, is_cross_encoder=True) as hf_model:
hf_outputs = hf_model.predict([text_pair]).tolist()
with vllm_runner(model_name, task="score", dtype=dtype,
max_model_len=None) as vllm_model:
vllm_outputs = vllm_model.score(text_pair[0], text_pair[1])
assert len(vllm_outputs) == 1
assert len(hf_outputs) == 1
assert math.isclose(hf_outputs[0], vllm_outputs[0], rel_tol=0.01)
@pytest.mark.parametrize("dtype", ["half"])
def test_llm_1_to_N(vllm_runner, hf_runner, model_name, dtype: str):
text_pairs = [[TEXTS_1[0], text] for text in TEXTS_2]
with hf_runner(model_name, dtype=dtype, is_cross_encoder=True) as hf_model:
hf_outputs = hf_model.predict(text_pairs).tolist()
with vllm_runner(model_name, task="score", dtype=dtype,
max_model_len=None) as vllm_model:
vllm_outputs = vllm_model.score(TEXTS_1[0], TEXTS_2)
assert len(vllm_outputs) == 10
assert len(hf_outputs) == 10
assert math.isclose(hf_outputs[0], vllm_outputs[0], rel_tol=0.01)
assert math.isclose(hf_outputs[1], vllm_outputs[1], rel_tol=0.01)

View File

@ -13,7 +13,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel
from vllm.model_executor.models.utils import maybe_prefix
from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.transformers_utils.config import (
@ -203,6 +203,18 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
_pooler: An instance of Pooler used for pooling operations.
"""
jina_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
'emb_ln': "embeddings.LayerNorm",
'layers': "layer",
'mixer.Wqkv': "attention.self.qkv_proj",
'mixer.out_proj': "attention.output.dense",
'norm1': "attention.output.LayerNorm",
'mlp.fc1': "intermediate.dense",
'mlp.fc2': "output.dense",
'norm2': "output.LayerNorm",
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
@ -219,8 +231,9 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
self._pooler = CrossEncodingPooler(config, self.classifier)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
bert_weights, task_weights = roberta_task_weights_filter(weights)
bert_weights = self.jina_to_vllm_mapper.apply(bert_weights)
self.roberta.load_weights(bert_weights)
params_dict = dict(self.named_parameters())