mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 13:15:01 +08:00
141 lines
5.1 KiB
Python
141 lines
5.1 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from typing import Any, Optional
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
|
|
from tests.conftest import HfRunner
|
|
|
|
from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo
|
|
from .mteb_utils import VllmMtebEncoder, mteb_test_rerank_models
|
|
|
|
RERANK_MODELS = [
|
|
LASTPoolingRerankModelInfo("BAAI/bge-reranker-v2-gemma",
|
|
architecture="GemmaForSequenceClassification"),
|
|
]
|
|
|
|
PROMPT = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'." # noqa: E501
|
|
|
|
|
|
class GemmaRerankerHfRunner(HfRunner):
|
|
|
|
def __init__(self,
|
|
model_name: str,
|
|
dtype: str = "auto",
|
|
*args: Any,
|
|
**kwargs: Any) -> None:
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
super().__init__(model_name, dtype, auto_cls=AutoModelForCausalLM)
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name,
|
|
padding_side='left')
|
|
self.yes_loc = self.tokenizer.convert_tokens_to_ids("Yes")
|
|
|
|
@torch.no_grad()
|
|
def predict(self, prompts: list[list[str]], *args,
|
|
**kwargs) -> torch.Tensor:
|
|
|
|
def get_inputs(pairs, tokenizer, prompt=None):
|
|
if prompt is None:
|
|
prompt = PROMPT
|
|
|
|
sep = "\n"
|
|
prompt_inputs = tokenizer(prompt,
|
|
return_tensors=None,
|
|
add_special_tokens=False)["input_ids"]
|
|
sep_inputs = tokenizer(sep,
|
|
return_tensors=None,
|
|
add_special_tokens=False)["input_ids"]
|
|
inputs = []
|
|
for query, passage in pairs:
|
|
query_inputs = tokenizer(
|
|
f"A: {query}",
|
|
return_tensors=None,
|
|
add_special_tokens=False,
|
|
truncation=True,
|
|
)
|
|
passage_inputs = tokenizer(
|
|
f"B: {passage}",
|
|
return_tensors=None,
|
|
add_special_tokens=False,
|
|
truncation=True,
|
|
)
|
|
item = tokenizer.prepare_for_model(
|
|
[tokenizer.bos_token_id] + query_inputs["input_ids"],
|
|
sep_inputs + passage_inputs["input_ids"],
|
|
truncation="only_second",
|
|
padding=False,
|
|
return_attention_mask=False,
|
|
return_token_type_ids=False,
|
|
add_special_tokens=False,
|
|
)
|
|
item["input_ids"] = item[
|
|
"input_ids"] + sep_inputs + prompt_inputs
|
|
item["attention_mask"] = [1] * len(item["input_ids"])
|
|
inputs.append(item)
|
|
return tokenizer.pad(
|
|
inputs,
|
|
padding=True,
|
|
return_tensors="pt",
|
|
)
|
|
|
|
scores = []
|
|
for query, doc, *_ in prompts:
|
|
pairs = [(query, doc)]
|
|
inputs = get_inputs(pairs, self.tokenizer)
|
|
inputs = inputs.to(self.model.device)
|
|
_n_tokens = inputs["input_ids"].shape[1]
|
|
logits = self.model(**inputs, return_dict=True).logits
|
|
_scores = (logits[:, -1,
|
|
self.yes_loc].view(-1, ).float().sigmoid())
|
|
scores.append(_scores[0].item())
|
|
return torch.Tensor(scores)
|
|
|
|
|
|
class GemmaMtebEncoder(VllmMtebEncoder):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.prompt = PROMPT
|
|
self.query_template = "A: {query}\n"
|
|
self.document_template = "B: {doc}\n{prompt}"
|
|
|
|
def predict(
|
|
self,
|
|
sentences: list[tuple[str, str,
|
|
Optional[str]]], # query, corpus, prompt
|
|
*args,
|
|
**kwargs,
|
|
) -> np.ndarray:
|
|
|
|
_sentences = []
|
|
for query, corpus, prompt in sentences:
|
|
query = self.query_template.format(query=query)
|
|
corpus = self.document_template.format(doc=corpus, prompt=prompt)
|
|
_sentences.append((query, corpus, prompt))
|
|
|
|
return super().predict(_sentences, *args, **kwargs)
|
|
|
|
|
|
@pytest.mark.parametrize("model_info", RERANK_MODELS)
|
|
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo,
|
|
monkeypatch) -> None:
|
|
monkeypatch.setenv("VLLM_USE_V1", "0")
|
|
|
|
assert model_info.architecture == "GemmaForSequenceClassification"
|
|
|
|
vllm_extra_kwargs: dict[str, Any] = {
|
|
"hf_overrides": {
|
|
"architectures": ["GemmaForSequenceClassification"],
|
|
"classifier_from_token": ["Yes"],
|
|
"method": "no_post_processing",
|
|
}
|
|
}
|
|
|
|
mteb_test_rerank_models(GemmaRerankerHfRunner,
|
|
vllm_runner,
|
|
model_info,
|
|
vllm_extra_kwargs,
|
|
vllm_mteb_encoder=GemmaMtebEncoder)
|