mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-26 01:03:05 +08:00
[Frontend] Support using chat template as custom score template for reranking models (#30550)
Signed-off-by: Jakub Zakrzewski <jzakrzewski@nvidia.com> Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io> Signed-off-by: wang.yuqi <noooop@126.com> Co-authored-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
parent
27c6c2f98c
commit
23daef548d
@ -490,6 +490,7 @@ These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) A
|
||||
| `GteNewModel`<sup>C</sup> | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | | |
|
||||
| `ModernBertModel`<sup>C</sup> | ModernBERT-based | `Alibaba-NLP/gte-modernbert-base`, etc. | | |
|
||||
| `NomicBertModel`<sup>C</sup> | Nomic BERT | `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc. | | |
|
||||
| `LlamaBidirectionalModel`<sup>C</sup> | Llama-based with bidirectional attention | `nvidia/llama-nemotron-embed-1b-v2`, etc. | ✅︎ | ✅︎ |
|
||||
| `LlamaModel`<sup>C</sup>, `LlamaForCausalLM`<sup>C</sup>, `MistralModel`<sup>C</sup>, etc. | Llama-based | `intfloat/e5-mistral-7b-instruct`, etc. | ✅︎ | ✅︎ |
|
||||
| `Qwen2Model`<sup>C</sup>, `Qwen2ForCausalLM`<sup>C</sup> | Qwen2-based | `ssmits/Qwen2-7B-Instruct-embed-base` (see note), `Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc. | ✅︎ | ✅︎ |
|
||||
| `Qwen3Model`<sup>C</sup>, `Qwen3ForCausalLM`<sup>C</sup> | Qwen3-based | `Qwen/Qwen3-Embedding-0.6B`, etc. | ✅︎ | ✅︎ |
|
||||
@ -543,8 +544,9 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A
|
||||
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | | |
|
||||
| `GemmaForSequenceClassification` | Gemma-based | `BAAI/bge-reranker-v2-gemma` (see note), etc. | ✅︎ | ✅︎ |
|
||||
| `GteNewForSequenceClassification` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-reranker-base`, etc. | | |
|
||||
| `Qwen2ForSequenceClassification` | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2` (see note), etc. | ✅︎ | ✅︎ |
|
||||
| `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ | ✅︎ |
|
||||
| `LlamaBidirectionalForSequenceClassification`<sup>C</sup> | Llama-based with bidirectional attention | `nvidia/llama-nemotron-rerank-1b-v2` (see note), etc. | ✅︎ | ✅︎ |
|
||||
| `Qwen2ForSequenceClassification`<sup>C</sup> | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2` (see note), etc. | ✅︎ | ✅︎ |
|
||||
| `Qwen3ForSequenceClassification`<sup>C</sup> | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ | ✅︎ |
|
||||
| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | |
|
||||
| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | | |
|
||||
| `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | \* | \* |
|
||||
@ -562,6 +564,11 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A
|
||||
!!! note
|
||||
The second-generation GTE model (mGTE-TRM) is named `NewForSequenceClassification`. The name `NewForSequenceClassification` is too generic, you should set `--hf-overrides '{"architectures": ["GteNewForSequenceClassification"]}'` to specify the use of the `GteNewForSequenceClassification` architecture.
|
||||
|
||||
!!! note
|
||||
`nvidia/llama-nemotron-rerank-1b-v2` require a specific prompt format to work correctly.
|
||||
|
||||
Examples : [offline_using_template.py](../../examples/pooling/score/offline_using_template.py) [online_using_template.py](../../examples/pooling/score/online_using_template.py)
|
||||
|
||||
!!! note
|
||||
Load the official original `mxbai-rerank-v2` by using the following command.
|
||||
|
||||
|
||||
@ -669,6 +669,21 @@ You can find the documentation for cross encoder models at [sbert.net](https://w
|
||||
|
||||
Code example: [examples/pooling/score/openai_cross_encoder_score.py](../../examples/pooling/score/openai_cross_encoder_score.py)
|
||||
|
||||
#### Score Template
|
||||
|
||||
Some scoring models require a specific prompt format to work correctly. You can specify a custom score template using the `--chat-template` parameter (see [Chat Template](#chat-template)).
|
||||
|
||||
Score templates are supported for **cross-encoder** models only. If you are using an **embedding** model for scoring, vLLM does not apply a score template.
|
||||
|
||||
Like chat templates, the score template receives a `messages` list. For scoring, each message has a `role` attribute—either `"query"` or `"document"`. For the usual kind of point-wise cross-encoder, you can expect exactly two messages: one query and one document. To access the query and document content, use Jinja's `selectattr` filter:
|
||||
|
||||
- **Query**: `{{ (messages | selectattr("role", "eq", "query") | first).content }}`
|
||||
- **Document**: `{{ (messages | selectattr("role", "eq", "document") | first).content }}`
|
||||
|
||||
This approach is more robust than index-based access (`messages[0]`, `messages[1]`) because it selects messages by their semantic role. It also avoids assumptions about message ordering if additional message types are added to `messages` in the future.
|
||||
|
||||
Example template file: [examples/pooling/score/template/nemotron-rerank.jinja](../../examples/pooling/score/template/nemotron-rerank.jinja)
|
||||
|
||||
#### Single inference
|
||||
|
||||
You can pass a string to both `text_1` and `text_2`, forming a single sentence pair.
|
||||
|
||||
27
examples/pooling/score/offline_using_template.py
Normal file
27
examples/pooling/score/offline_using_template.py
Normal file
@ -0,0 +1,27 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: E501
|
||||
from pathlib import Path
|
||||
|
||||
from vllm import LLM
|
||||
|
||||
model_name = "nvidia/llama-nemotron-rerank-1b-v2"
|
||||
|
||||
# Path to template file
|
||||
template_path = Path(__file__).parent / "template" / "nemotron-rerank.jinja"
|
||||
chat_template = template_path.read_text()
|
||||
|
||||
llm = LLM(model=model_name, runner="pooling", trust_remote_code=True)
|
||||
|
||||
query = "how much protein should a female eat?"
|
||||
documents = [
|
||||
"As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.",
|
||||
"Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments.",
|
||||
"Calorie intake should not fall below 1,200 a day in women or 1,500 a day in men, except under the supervision of a health professional.",
|
||||
]
|
||||
|
||||
outputs = llm.score(query, documents, chat_template=chat_template)
|
||||
|
||||
print("-" * 30)
|
||||
print([output.outputs.score for output in outputs])
|
||||
print("-" * 30)
|
||||
46
examples/pooling/score/online_using_template.py
Normal file
46
examples/pooling/score/online_using_template.py
Normal file
@ -0,0 +1,46 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: E501
|
||||
"""
|
||||
Example of using the rerank API with template.
|
||||
|
||||
run:
|
||||
vllm serve nvidia/llama-nemotron-rerank-1b-v2 --runner pooling --trust-remote-code --chat-template examples/pooling/score/template/nemotron-rerank.jinja
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
url = "http://127.0.0.1:8000/rerank"
|
||||
|
||||
headers = {"accept": "application/json", "Content-Type": "application/json"}
|
||||
|
||||
query = "how much protein should a female eat?"
|
||||
documents = [
|
||||
"As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.",
|
||||
"Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments.",
|
||||
"Calorie intake should not fall below 1,200 a day in women or 1,500 a day in men, except under the supervision of a health professional.",
|
||||
]
|
||||
|
||||
data = {
|
||||
"model": "nvidia/llama-nemotron-rerank-1b-v2",
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
|
||||
# Check the response
|
||||
if response.status_code == 200:
|
||||
print("Request successful!")
|
||||
print(json.dumps(response.json(), indent=2))
|
||||
else:
|
||||
print(f"Request failed with status code: {response.status_code}")
|
||||
print(response.text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
3
examples/pooling/score/template/nemotron-rerank.jinja
Normal file
3
examples/pooling/score/template/nemotron-rerank.jinja
Normal file
@ -0,0 +1,3 @@
|
||||
question:{{ (messages | selectattr("role", "eq", "query") | first).content }}
|
||||
|
||||
passage:{{ (messages | selectattr("role", "eq", "document") | first).content }}
|
||||
352
tests/entrypoints/pooling/score/test_utils.py
Normal file
352
tests/entrypoints/pooling/score/test_utils.py
Normal file
@ -0,0 +1,352 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateResolutionError
|
||||
from vllm.entrypoints.score_utils import get_score_prompt
|
||||
from vllm.inputs import TokensPrompt
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
|
||||
# A cross-encoder model for testing
|
||||
CROSS_ENCODER_MODEL_ID = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||
|
||||
|
||||
def assert_prompt_tokenization_consistent(
|
||||
tokenizer, full_prompt, engine_prompt, add_special_tokens=True
|
||||
):
|
||||
"""Verify that engine_prompt token_ids match tokenizing full_prompt."""
|
||||
expected_ids = tokenizer(full_prompt, add_special_tokens=add_special_tokens)[
|
||||
"input_ids"
|
||||
]
|
||||
actual_ids = engine_prompt["prompt_token_ids"]
|
||||
assert actual_ids == expected_ids, (
|
||||
f"Token IDs don't match.\nExpected: {expected_ids}\nActual: {actual_ids}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def cross_encoder_model_config():
|
||||
return ModelConfig(
|
||||
CROSS_ENCODER_MODEL_ID,
|
||||
runner="pooling",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def cross_encoder_tokenizer(cross_encoder_model_config):
|
||||
return get_tokenizer(
|
||||
CROSS_ENCODER_MODEL_ID,
|
||||
trust_remote_code=cross_encoder_model_config.trust_remote_code,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def llm_reranker_model_config():
|
||||
"""Model config for LLM-as-reranker style (no pad token)."""
|
||||
config = ModelConfig(
|
||||
CROSS_ENCODER_MODEL_ID,
|
||||
runner="pooling",
|
||||
)
|
||||
# use_pad_token is a property that reads from hf_config,
|
||||
# so we set it there to override the default (True)
|
||||
config.hf_config.use_pad_token = False
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tokenization_kwargs():
|
||||
"""Common tokenization kwargs used across tests."""
|
||||
return {"add_special_tokens": True, "return_tensors": None}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model_with_score_template():
|
||||
"""Mock model class that supports score template and tracks post_process calls."""
|
||||
|
||||
class MockModelWithScoreTemplate:
|
||||
supports_score_template = True
|
||||
post_process_called: list[TokensPrompt] = []
|
||||
|
||||
@staticmethod
|
||||
def get_score_template(p1: str, p2: str) -> str:
|
||||
return f"[QUERY]{p1}[SEP][DOC]{p2}"
|
||||
|
||||
@staticmethod
|
||||
def post_process_tokens(prompt: TokensPrompt) -> None:
|
||||
MockModelWithScoreTemplate.post_process_called.append(prompt)
|
||||
|
||||
return MockModelWithScoreTemplate
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model_no_score_template():
|
||||
"""Mock model class that does not support score template."""
|
||||
|
||||
class MockModelNoScoreTemplate:
|
||||
supports_score_template = False
|
||||
|
||||
return MockModelNoScoreTemplate
|
||||
|
||||
|
||||
class TestGetScorePrompt:
|
||||
"""Tests for the get_score_prompt function."""
|
||||
|
||||
def test_tokenization_kwargs_passed_through(
|
||||
self,
|
||||
llm_reranker_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
):
|
||||
"""Test that tokenization kwargs are properly passed through."""
|
||||
data_1 = "Query text"
|
||||
data_2 = "Document text"
|
||||
|
||||
# Test with truncation - custom kwargs for this test
|
||||
custom_tokenization_kwargs = {
|
||||
"add_special_tokens": True,
|
||||
"return_tensors": None,
|
||||
"truncation": True,
|
||||
"max_length": 20,
|
||||
}
|
||||
|
||||
full_prompt, engine_prompt = get_score_prompt(
|
||||
llm_reranker_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
custom_tokenization_kwargs,
|
||||
data_1,
|
||||
data_2,
|
||||
)
|
||||
|
||||
assert isinstance(full_prompt, str)
|
||||
assert "prompt_token_ids" in engine_prompt
|
||||
# With max_length=20 and truncation, should not exceed this
|
||||
assert len(engine_prompt["prompt_token_ids"]) <= 20
|
||||
# Since truncation was applied, token_ids should be a prefix of full encoding
|
||||
full_ids = cross_encoder_tokenizer(full_prompt, add_special_tokens=True)[
|
||||
"input_ids"
|
||||
]
|
||||
actual_ids = engine_prompt["prompt_token_ids"]
|
||||
assert full_ids[: len(actual_ids)] == actual_ids, (
|
||||
f"Token IDs are not a prefix of full encoding.\n"
|
||||
f"Full IDs: {full_ids}\n"
|
||||
f"Actual IDs: {actual_ids}"
|
||||
)
|
||||
|
||||
def test_model_supports_score_template(
|
||||
self,
|
||||
cross_encoder_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
mock_model_with_score_template,
|
||||
):
|
||||
"""Test when model supports score template (no score_template arg)."""
|
||||
with patch(
|
||||
"vllm.model_executor.model_loader.get_model_cls",
|
||||
return_value=mock_model_with_score_template,
|
||||
):
|
||||
full_prompt, engine_prompt = get_score_prompt(
|
||||
cross_encoder_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
"query text",
|
||||
"document text",
|
||||
)
|
||||
|
||||
assert full_prompt == "[QUERY]query text[SEP][DOC]document text"
|
||||
assert "prompt_token_ids" in engine_prompt
|
||||
assert len(engine_prompt["prompt_token_ids"]) > 0
|
||||
assert_prompt_tokenization_consistent(
|
||||
cross_encoder_tokenizer, full_prompt, engine_prompt
|
||||
)
|
||||
|
||||
def test_model_supports_score_template_but_custom_template_provided(
|
||||
self,
|
||||
cross_encoder_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
mock_model_with_score_template,
|
||||
):
|
||||
"""Test when model supports score template but custom template is provided."""
|
||||
template = (
|
||||
'TEMPLATE_USED {{ messages[0]["content"] }} {{ messages[1]["content"] }}'
|
||||
)
|
||||
with (
|
||||
patch(
|
||||
"vllm.model_executor.model_loader.get_model_cls",
|
||||
return_value=mock_model_with_score_template,
|
||||
),
|
||||
):
|
||||
full_prompt, engine_prompt = get_score_prompt(
|
||||
cross_encoder_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
"query",
|
||||
"doc",
|
||||
score_template=template, # Providing a template
|
||||
)
|
||||
|
||||
assert "prompt_token_ids" in engine_prompt
|
||||
assert full_prompt == "TEMPLATE_USED query doc"
|
||||
|
||||
assert_prompt_tokenization_consistent(
|
||||
cross_encoder_tokenizer, full_prompt, engine_prompt
|
||||
)
|
||||
|
||||
def test_not_using_default_template(
|
||||
self,
|
||||
llm_reranker_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
mock_model_no_score_template,
|
||||
):
|
||||
# FIXME: Models implementing SupportsScoreTemplate must use their custom
|
||||
# template implementation by default to preserve existing functionality.
|
||||
# Attempting to use tokenizer_config.json templates would most likely break
|
||||
# these models, as often they just inherit the template from the original LLM.
|
||||
# CLI --chat-template overrides are still supported.
|
||||
with (
|
||||
patch(
|
||||
"vllm.model_executor.model_loader.get_model_cls",
|
||||
return_value=mock_model_no_score_template,
|
||||
),
|
||||
patch(
|
||||
"vllm.entrypoints.score_utils.apply_hf_chat_template",
|
||||
return_value="test querytest doc",
|
||||
),
|
||||
):
|
||||
full_prompt, engine_prompt = get_score_prompt(
|
||||
llm_reranker_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
"test query",
|
||||
"test doc",
|
||||
)
|
||||
|
||||
assert full_prompt == "test querytest doc"
|
||||
assert "prompt_token_ids" in engine_prompt
|
||||
assert_prompt_tokenization_consistent(
|
||||
cross_encoder_tokenizer, full_prompt, engine_prompt
|
||||
)
|
||||
|
||||
def test_fallback_with_pad_token(
|
||||
self,
|
||||
cross_encoder_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
mock_model_no_score_template,
|
||||
):
|
||||
"""Test fallback path when ChatTemplateResolutionError
|
||||
and use_pad_token=True."""
|
||||
with (
|
||||
patch(
|
||||
"vllm.model_executor.model_loader.get_model_cls",
|
||||
return_value=mock_model_no_score_template,
|
||||
),
|
||||
patch(
|
||||
"vllm.entrypoints.score_utils.apply_hf_chat_template",
|
||||
side_effect=ChatTemplateResolutionError("No template"),
|
||||
),
|
||||
):
|
||||
full_prompt, engine_prompt = get_score_prompt(
|
||||
cross_encoder_model_config, # use_pad_token=True
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
"query",
|
||||
"document",
|
||||
)
|
||||
|
||||
assert "prompt_token_ids" in engine_prompt
|
||||
# Should have token_type_ids from text_pair encoding
|
||||
assert "token_type_ids" in engine_prompt
|
||||
assert "query" in full_prompt
|
||||
assert "document" in full_prompt
|
||||
assert full_prompt != "querydocument"
|
||||
assert (
|
||||
engine_prompt["prompt_token_ids"]
|
||||
== cross_encoder_tokenizer(
|
||||
"query", text_pair="document", add_special_tokens=True
|
||||
)["input_ids"]
|
||||
)
|
||||
|
||||
# FIXME(?): add_special_tokens=False is needed because in this case
|
||||
# full_prompt is obtained by decoding the tokenized prompt, which includes
|
||||
# special tokens and we would get duplicated special tokens otherwise.
|
||||
# This is inconsistent with other cases.
|
||||
assert_prompt_tokenization_consistent(
|
||||
cross_encoder_tokenizer,
|
||||
full_prompt,
|
||||
engine_prompt,
|
||||
add_special_tokens=False,
|
||||
)
|
||||
|
||||
def test_fallback_without_pad_token(
|
||||
self,
|
||||
llm_reranker_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
mock_model_no_score_template,
|
||||
):
|
||||
"""Test fallback path when ChatTemplateResolutionError
|
||||
and use_pad_token=False."""
|
||||
with (
|
||||
patch(
|
||||
"vllm.model_executor.model_loader.get_model_cls",
|
||||
return_value=mock_model_no_score_template,
|
||||
),
|
||||
patch(
|
||||
"vllm.entrypoints.score_utils.apply_hf_chat_template",
|
||||
side_effect=ChatTemplateResolutionError("No template"),
|
||||
),
|
||||
):
|
||||
full_prompt, engine_prompt = get_score_prompt(
|
||||
llm_reranker_model_config, # use_pad_token=False
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
"query",
|
||||
"document",
|
||||
)
|
||||
|
||||
assert full_prompt == "querydocument"
|
||||
assert "prompt_token_ids" in engine_prompt
|
||||
assert_prompt_tokenization_consistent(
|
||||
cross_encoder_tokenizer, full_prompt, engine_prompt
|
||||
)
|
||||
|
||||
def test_post_process_tokens_called(
|
||||
self,
|
||||
cross_encoder_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
mock_model_with_score_template,
|
||||
):
|
||||
"""Test that post_process_tokens is called on the engine prompt."""
|
||||
# Reset the call tracker
|
||||
mock_model_with_score_template.post_process_called.clear()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"vllm.model_executor.model_loader.get_model_cls",
|
||||
return_value=mock_model_with_score_template,
|
||||
),
|
||||
patch(
|
||||
"vllm.entrypoints.score_utils.apply_hf_chat_template",
|
||||
side_effect=ChatTemplateResolutionError("No template"),
|
||||
),
|
||||
):
|
||||
full_prompt, engine_prompt = get_score_prompt(
|
||||
cross_encoder_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
"query",
|
||||
"doc",
|
||||
)
|
||||
|
||||
# post_process_tokens should have been called once
|
||||
assert len(mock_model_with_score_template.post_process_called) == 1
|
||||
assert mock_model_with_score_template.post_process_called[0] is engine_prompt
|
||||
assert_prompt_tokenization_consistent(
|
||||
cross_encoder_tokenizer, full_prompt, engine_prompt
|
||||
)
|
||||
@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import mteb
|
||||
import numpy as np
|
||||
@ -19,6 +20,11 @@ from tests.models.utils import (
|
||||
get_vllm_extra_kwargs,
|
||||
)
|
||||
|
||||
template_home = (
|
||||
Path(__file__).parent.parent.parent.parent.parent
|
||||
/ "examples/pooling/score/template"
|
||||
)
|
||||
|
||||
# Most embedding models on the STS12 task (See #17175):
|
||||
# - Model implementation and minor changes in tensor dtype
|
||||
# results in differences less than 1e-4
|
||||
@ -102,30 +108,6 @@ class VllmMtebEncoder(mteb.EncoderProtocol):
|
||||
return sim
|
||||
|
||||
|
||||
class VllmMtebCrossEncoder(mteb.CrossEncoderProtocol):
|
||||
mteb_model_meta = _empty_model_meta
|
||||
|
||||
def __init__(self, vllm_model):
|
||||
self.llm = vllm_model
|
||||
self.rng = np.random.default_rng(seed=42)
|
||||
|
||||
def predict(
|
||||
self,
|
||||
inputs1: DataLoader[mteb.types.BatchedInput],
|
||||
inputs2: DataLoader[mteb.types.BatchedInput],
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
queries = [text for batch in inputs1 for text in batch["text"]]
|
||||
corpus = [text for batch in inputs2 for text in batch["text"]]
|
||||
|
||||
outputs = self.llm.score(
|
||||
queries, corpus, truncate_prompt_tokens=-1, use_tqdm=False
|
||||
)
|
||||
scores = np.array(outputs)
|
||||
return scores
|
||||
|
||||
|
||||
class OpenAIClientMtebEncoder(VllmMtebEncoder):
|
||||
def __init__(self, model_name: str, client):
|
||||
self.model_name = model_name
|
||||
@ -153,6 +135,35 @@ class OpenAIClientMtebEncoder(VllmMtebEncoder):
|
||||
return embeds
|
||||
|
||||
|
||||
class VllmMtebCrossEncoder(mteb.CrossEncoderProtocol):
|
||||
mteb_model_meta = _empty_model_meta
|
||||
|
||||
def __init__(self, vllm_model):
|
||||
self.llm = vllm_model
|
||||
self.rng = np.random.default_rng(seed=42)
|
||||
self.chat_template: str | None = getattr(vllm_model, "chat_template", None)
|
||||
|
||||
def predict(
|
||||
self,
|
||||
inputs1: DataLoader[mteb.types.BatchedInput],
|
||||
inputs2: DataLoader[mteb.types.BatchedInput],
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
queries = [text for batch in inputs1 for text in batch["text"]]
|
||||
corpus = [text for batch in inputs2 for text in batch["text"]]
|
||||
|
||||
outputs = self.llm.score(
|
||||
queries,
|
||||
corpus,
|
||||
truncate_prompt_tokens=-1,
|
||||
use_tqdm=False,
|
||||
chat_template=self.chat_template,
|
||||
)
|
||||
scores = np.array(outputs)
|
||||
return scores
|
||||
|
||||
|
||||
class ScoreClientMtebEncoder(mteb.CrossEncoderProtocol):
|
||||
mteb_model_meta = _empty_model_meta
|
||||
|
||||
@ -387,6 +398,11 @@ def mteb_test_rerank_models(
|
||||
== model_info.default_pooling_type
|
||||
)
|
||||
|
||||
chat_template: str | None = None
|
||||
if model_info.chat_template_name is not None:
|
||||
chat_template = (template_home / model_info.chat_template_name).read_text()
|
||||
vllm_model.chat_template = chat_template
|
||||
|
||||
vllm_main_score = run_mteb_rerank(
|
||||
vllm_mteb_encoder(vllm_model),
|
||||
tasks=MTEB_RERANK_TASKS,
|
||||
|
||||
42
tests/models/language/pooling_mteb_test/test_nemotron.py
Normal file
42
tests/models/language/pooling_mteb_test/test_nemotron.py
Normal file
@ -0,0 +1,42 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.models.utils import (
|
||||
EmbedModelInfo,
|
||||
LASTPoolingEmbedModelInfo,
|
||||
LASTPoolingRerankModelInfo,
|
||||
RerankModelInfo,
|
||||
)
|
||||
|
||||
from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models
|
||||
|
||||
EMBEDDING_MODELS = [
|
||||
LASTPoolingEmbedModelInfo(
|
||||
"nvidia/llama-nemotron-embed-1b-v2",
|
||||
architecture="LlamaBidirectionalModel",
|
||||
mteb_score=0.689164662128673,
|
||||
)
|
||||
]
|
||||
|
||||
RERANK_MODELS = [
|
||||
LASTPoolingRerankModelInfo(
|
||||
"nvidia/llama-nemotron-rerank-1b-v2",
|
||||
architecture="LlamaBidirectionalForSequenceClassification",
|
||||
chat_template_name="nemotron-rerank.jinja",
|
||||
mteb_score=0.33994,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_info", EMBEDDING_MODELS)
|
||||
def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None:
|
||||
mteb_test_embed_models(hf_runner, vllm_runner, model_info)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_info", RERANK_MODELS)
|
||||
def test_rerank_models_mteb(
|
||||
hf_runner, vllm_runner, model_info: RerankModelInfo
|
||||
) -> None:
|
||||
mteb_test_rerank_models(hf_runner, vllm_runner, model_info)
|
||||
@ -488,6 +488,9 @@ _EMBEDDING_EXAMPLE_MODELS = {
|
||||
),
|
||||
"JambaForSequenceClassification": _HfExamplesInfo("ai21labs/Jamba-tiny-reward-dev"),
|
||||
"LlamaModel": _HfExamplesInfo("llama", is_available_online=False),
|
||||
"LlamaBidirectionalModel": _HfExamplesInfo(
|
||||
"nvidia/llama-nemotron-embed-1b-v2", trust_remote_code=True
|
||||
),
|
||||
"MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"),
|
||||
"ModernBertModel": _HfExamplesInfo(
|
||||
"Alibaba-NLP/gte-modernbert-base", trust_remote_code=True
|
||||
@ -554,6 +557,9 @@ _SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = {
|
||||
trust_remote_code=True,
|
||||
hf_overrides={"architectures": ["GteNewForSequenceClassification"]},
|
||||
),
|
||||
"LlamaBidirectionalForSequenceClassification": _HfExamplesInfo(
|
||||
"nvidia/llama-nemotron-rerank-1b-v2", trust_remote_code=True
|
||||
),
|
||||
"ModernBertForSequenceClassification": _HfExamplesInfo(
|
||||
"Alibaba-NLP/gte-reranker-modernbert-base"
|
||||
),
|
||||
|
||||
@ -399,6 +399,7 @@ class LASTPoolingEmbedModelInfo(EmbedModelInfo):
|
||||
@dataclass
|
||||
class RerankModelInfo(ModelInfo):
|
||||
mteb_score: float | None = None
|
||||
chat_template_name: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -67,6 +67,15 @@ else:
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ChatTemplateResolutionError(ValueError):
|
||||
"""Raised when chat template resolution fails.
|
||||
|
||||
This is a subclass of ValueError for backward compatibility with
|
||||
existing exception handlers.
|
||||
"""
|
||||
|
||||
|
||||
MODALITY_PLACEHOLDERS_MAP = {
|
||||
"image": "<##IMAGE##>",
|
||||
"audio": "<##AUDIO##>",
|
||||
@ -1814,7 +1823,7 @@ def apply_hf_chat_template(
|
||||
)
|
||||
|
||||
if hf_chat_template is None:
|
||||
raise ValueError(
|
||||
raise ChatTemplateResolutionError(
|
||||
"As of transformers v4.44, default chat template is no longer "
|
||||
"allowed, so you must provide a chat template if the tokenizer "
|
||||
"does not define one."
|
||||
|
||||
@ -1280,6 +1280,7 @@ class LLM:
|
||||
pooling_params: PoolingParams | None = None,
|
||||
lora_request: list[LoRARequest] | LoRARequest | None = None,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
score_template: str | None = None,
|
||||
) -> list[ScoringRequestOutput]:
|
||||
model_config = self.model_config
|
||||
|
||||
@ -1313,6 +1314,7 @@ class LLM:
|
||||
data_2=d,
|
||||
tokenizer=tokenizer,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
score_template=score_template,
|
||||
)
|
||||
|
||||
if token_type_ids := engine_prompt.pop("token_type_ids", None):
|
||||
@ -1347,6 +1349,7 @@ class LLM:
|
||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||
pooling_params: PoolingParams | None = None,
|
||||
lora_request: list[LoRARequest] | LoRARequest | None = None,
|
||||
chat_template: str | None = None,
|
||||
) -> list[ScoringRequestOutput]:
|
||||
"""Generate similarity scores for all pairs `<text,text_pair>` or
|
||||
`<multi-modal data, multi-modal data pair>`.
|
||||
@ -1379,6 +1382,8 @@ class LLM:
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
pooling_params: The pooling parameters for pooling. If None, we
|
||||
use the default pooling parameters.
|
||||
chat_template: The chat template to use for the scoring. If None, we
|
||||
use the model's default chat template.
|
||||
Returns:
|
||||
A list of `ScoringRequestOutput` objects containing the
|
||||
generated scores in the same order as the input prompts.
|
||||
@ -1406,6 +1411,11 @@ class LLM:
|
||||
):
|
||||
raise ValueError("Score API is only enabled for num_labels == 1.")
|
||||
|
||||
if not model_config.is_cross_encoder and chat_template is not None:
|
||||
raise ValueError(
|
||||
"chat_template is only supported for cross-encoder models."
|
||||
)
|
||||
|
||||
# the tokenizer for models such as
|
||||
# "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
|
||||
# lists of tokens to the `text` and `text_pair` kwargs
|
||||
@ -1475,6 +1485,7 @@ class LLM:
|
||||
use_tqdm,
|
||||
pooling_params,
|
||||
lora_request,
|
||||
score_template=chat_template,
|
||||
)
|
||||
else:
|
||||
return self._embedding_score(
|
||||
|
||||
@ -1145,6 +1145,7 @@ async def init_app_state(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
score_template=resolved_chat_template,
|
||||
log_error_stack=args.log_error_stack,
|
||||
)
|
||||
if ("embed" in supported_tasks or "score" in supported_tasks)
|
||||
|
||||
@ -495,6 +495,7 @@ async def run_batch(
|
||||
engine_client,
|
||||
openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
score_template=None,
|
||||
)
|
||||
if ("embed" in supported_tasks or enable_serving_reranking)
|
||||
else None
|
||||
|
||||
@ -52,6 +52,7 @@ class ServingScores(OpenAIServing):
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
score_template: str | None = None,
|
||||
log_error_stack: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
@ -60,6 +61,7 @@ class ServingScores(OpenAIServing):
|
||||
request_logger=request_logger,
|
||||
log_error_stack=log_error_stack,
|
||||
)
|
||||
self.score_template = score_template
|
||||
|
||||
async def _embedding_score(
|
||||
self,
|
||||
@ -169,6 +171,7 @@ class ServingScores(OpenAIServing):
|
||||
data_2=data_2,
|
||||
tokenizer=tokenizer,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
score_template=self.score_template,
|
||||
)
|
||||
self._validate_input(request, engine_prompt["prompt_token_ids"], full_prompt)
|
||||
if request.mm_processor_kwargs is not None:
|
||||
|
||||
@ -11,9 +11,11 @@ from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionContentPartImageEmbedsParam,
|
||||
ChatCompletionContentPartImageParam,
|
||||
ChatCompletionContentPartTextParam,
|
||||
ChatTemplateResolutionError,
|
||||
MultiModalItemTracker,
|
||||
_ContentPart,
|
||||
_parse_chat_message_content_part,
|
||||
apply_hf_chat_template,
|
||||
)
|
||||
from vllm.inputs import TokensPrompt
|
||||
from vllm.model_executor.models.interfaces import supports_score_template
|
||||
@ -139,10 +141,8 @@ def _parse_score_content(
|
||||
return next(iter(mm_placeholder_storage.values()))[0]
|
||||
|
||||
|
||||
def apply_score_template(
|
||||
model_config: ModelConfig,
|
||||
prompt_1: str,
|
||||
prompt_2: str,
|
||||
def _apply_model_score_template(
|
||||
model_config: ModelConfig, prompt_1: str, prompt_2: str
|
||||
) -> str:
|
||||
# NOTE(Simon): lazy import to avoid bring in all dependencies (e.g. gguf)
|
||||
from vllm.model_executor.model_loader import get_model_cls
|
||||
@ -181,6 +181,7 @@ def get_score_prompt(
|
||||
tokenization_kwargs: dict[str, Any],
|
||||
data_1: str | ScoreContentPartParam,
|
||||
data_2: str | ScoreContentPartParam,
|
||||
score_template: str | None = None,
|
||||
) -> tuple[str, TokensPrompt]:
|
||||
prompt_1, prompt_2, mm_data = parse_score_data(
|
||||
data_1,
|
||||
@ -190,19 +191,48 @@ def get_score_prompt(
|
||||
from vllm.model_executor.model_loader import get_model_cls
|
||||
|
||||
model = get_model_cls(model_config)
|
||||
if supports_score_template(model):
|
||||
full_prompt = apply_score_template(model_config, prompt_1, prompt_2)
|
||||
prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs)
|
||||
elif model_config.use_pad_token:
|
||||
# cross_encoder models defaults to using pad_token.
|
||||
prompt_inputs = tokenizer(
|
||||
text=prompt_1, text_pair=prompt_2, **tokenization_kwargs
|
||||
)
|
||||
full_prompt = tokenizer.decode(prompt_inputs["input_ids"])
|
||||
|
||||
def default_tokenizer_encode():
|
||||
if supports_score_template(model):
|
||||
full_prompt = _apply_model_score_template(model_config, prompt_1, prompt_2)
|
||||
prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs)
|
||||
else:
|
||||
if model_config.use_pad_token:
|
||||
# cross_encoder models defaults to using pad_token.
|
||||
prompt_inputs = tokenizer(
|
||||
text=prompt_1, text_pair=prompt_2, **tokenization_kwargs
|
||||
)
|
||||
full_prompt = tokenizer.decode(prompt_inputs["input_ids"])
|
||||
else:
|
||||
# `llm as reranker` models defaults to not using pad_token.
|
||||
full_prompt = prompt_1 + prompt_2
|
||||
prompt_inputs = tokenizer(text=full_prompt, **tokenization_kwargs)
|
||||
return full_prompt, prompt_inputs
|
||||
|
||||
# FIXME: For now, we only apply a template when one is explicitly provided.
|
||||
# We cannot rely on the tokenizer's chat template because many models
|
||||
# inherit junk templates from their base LLM, which breaks both the models
|
||||
# and the tests that use them.
|
||||
if score_template is None:
|
||||
full_prompt, prompt_inputs = default_tokenizer_encode()
|
||||
else:
|
||||
# `llm as reranker` models defaults to not using pad_token.
|
||||
full_prompt = prompt_1 + prompt_2
|
||||
prompt_inputs = tokenizer(text=full_prompt, **tokenization_kwargs)
|
||||
# FIXME: Try applying a score template from the CLI arg or tokenizer_config.json
|
||||
# If that fails because there is no such template,
|
||||
# fall back to the default implementation.
|
||||
try:
|
||||
full_prompt = apply_hf_chat_template(
|
||||
tokenizer,
|
||||
[
|
||||
{"role": "query", "content": prompt_1},
|
||||
{"role": "document", "content": prompt_2},
|
||||
],
|
||||
score_template,
|
||||
tools=None,
|
||||
model_config=model_config,
|
||||
)
|
||||
prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs)
|
||||
except ChatTemplateResolutionError:
|
||||
full_prompt, prompt_inputs = default_tokenizer_encode()
|
||||
|
||||
engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["input_ids"])
|
||||
|
||||
|
||||
@ -88,6 +88,26 @@ class JinaRobertaModelConfig(VerifyAndUpdateConfig):
|
||||
}
|
||||
|
||||
|
||||
class LlamaBidirectionalConfig(VerifyAndUpdateConfig):
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
from vllm.config.pooler import PoolingTypeStr
|
||||
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
hf_config.is_causal = False
|
||||
|
||||
pooling_type_map: dict[str, PoolingTypeStr] = {
|
||||
"avg": "MEAN",
|
||||
"cls": "CLS",
|
||||
"last": "LAST",
|
||||
}
|
||||
|
||||
pooling_type = pooling_type_map.get(hf_config.pooling, None)
|
||||
if pooling_type is None:
|
||||
raise ValueError(f"pool_type {hf_config.pooling} not supported")
|
||||
vllm_config.model_config.pooler_config.pooling_type = pooling_type
|
||||
|
||||
|
||||
class NomicBertModelConfig(VerifyAndUpdateConfig):
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
@ -509,6 +529,8 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
||||
"GteNewModel": GteNewModelConfig,
|
||||
"GteNewForSequenceClassification": GteNewModelConfig,
|
||||
"Gemma3TextModel": Gemma3TextModelConfig,
|
||||
"LlamaBidirectionalForSequenceClassification": LlamaBidirectionalConfig,
|
||||
"LlamaBidirectionalModel": LlamaBidirectionalConfig,
|
||||
"NomicBertModel": NomicBertModelConfig,
|
||||
"Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig,
|
||||
"Qwen2ForRewardModel": Qwen2ForRewardModelConfig,
|
||||
|
||||
@ -57,7 +57,14 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsEagle, SupportsEagle3, SupportsLoRA, SupportsPP
|
||||
from .adapters import as_embedding_model, as_seq_cls_model
|
||||
from .interfaces import (
|
||||
SupportsEagle,
|
||||
SupportsEagle3,
|
||||
SupportsLoRA,
|
||||
SupportsPP,
|
||||
)
|
||||
from .interfaces_base import attn_type
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
PPMissingLayer,
|
||||
@ -698,3 +705,17 @@ class LlamaForCausalLM(
|
||||
name = name.replace(item, mapping[item])
|
||||
|
||||
return name, loaded_weight
|
||||
|
||||
|
||||
@attn_type("encoder_only")
|
||||
class LlamaBidirectionalForSequenceClassification(as_seq_cls_model(LlamaForCausalLM)):
|
||||
# This class sets the correct attention type and pooling type
|
||||
# through LlamaBidirectionalConfig.
|
||||
pass
|
||||
|
||||
|
||||
@attn_type("encoder_only")
|
||||
class LlamaBidirectionalModel(as_embedding_model(LlamaForCausalLM)):
|
||||
# This class sets the correct attention type and pooling type
|
||||
# through LlamaBidirectionalConfig.
|
||||
pass
|
||||
|
||||
@ -203,6 +203,7 @@ _EMBEDDING_MODELS = {
|
||||
"GteNewModel": ("bert_with_rope", "GteNewModel"),
|
||||
"InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
|
||||
"JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"), # noqa: E501
|
||||
"LlamaBidirectionalModel": ("llama", "LlamaBidirectionalModel"),
|
||||
"LlamaModel": ("llama", "LlamaForCausalLM"),
|
||||
**{
|
||||
# Multiple models share the same architecture, so we include them all
|
||||
@ -246,6 +247,11 @@ _CROSS_ENCODER_MODELS = {
|
||||
"bert_with_rope",
|
||||
"GteNewForSequenceClassification",
|
||||
),
|
||||
"JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"),
|
||||
"LlamaBidirectionalForSequenceClassification": (
|
||||
"llama",
|
||||
"LlamaBidirectionalForSequenceClassification",
|
||||
),
|
||||
"ModernBertForSequenceClassification": (
|
||||
"modernbert",
|
||||
"ModernBertForSequenceClassification",
|
||||
@ -259,8 +265,6 @@ _CROSS_ENCODER_MODELS = {
|
||||
"roberta",
|
||||
"RobertaForSequenceClassification",
|
||||
),
|
||||
# [Auto-converted (see adapters.py)]
|
||||
"JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501,
|
||||
}
|
||||
|
||||
_MULTIMODAL_MODELS = {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user