mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-25 20:41:51 +08:00
Signed-off-by: Jakub Zakrzewski <jzakrzewski@nvidia.com> Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io> Signed-off-by: wang.yuqi <noooop@126.com> Co-authored-by: wang.yuqi <yuqi.wang@daocloud.io>
353 lines
12 KiB
Python
353 lines
12 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
|
|
from vllm.config import ModelConfig
|
|
from vllm.entrypoints.chat_utils import ChatTemplateResolutionError
|
|
from vllm.entrypoints.score_utils import get_score_prompt
|
|
from vllm.inputs import TokensPrompt
|
|
from vllm.tokenizers import get_tokenizer
|
|
|
|
# A cross-encoder model for testing
|
|
CROSS_ENCODER_MODEL_ID = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
|
|
|
|
|
def assert_prompt_tokenization_consistent(
|
|
tokenizer, full_prompt, engine_prompt, add_special_tokens=True
|
|
):
|
|
"""Verify that engine_prompt token_ids match tokenizing full_prompt."""
|
|
expected_ids = tokenizer(full_prompt, add_special_tokens=add_special_tokens)[
|
|
"input_ids"
|
|
]
|
|
actual_ids = engine_prompt["prompt_token_ids"]
|
|
assert actual_ids == expected_ids, (
|
|
f"Token IDs don't match.\nExpected: {expected_ids}\nActual: {actual_ids}"
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def cross_encoder_model_config():
|
|
return ModelConfig(
|
|
CROSS_ENCODER_MODEL_ID,
|
|
runner="pooling",
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def cross_encoder_tokenizer(cross_encoder_model_config):
|
|
return get_tokenizer(
|
|
CROSS_ENCODER_MODEL_ID,
|
|
trust_remote_code=cross_encoder_model_config.trust_remote_code,
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def llm_reranker_model_config():
|
|
"""Model config for LLM-as-reranker style (no pad token)."""
|
|
config = ModelConfig(
|
|
CROSS_ENCODER_MODEL_ID,
|
|
runner="pooling",
|
|
)
|
|
# use_pad_token is a property that reads from hf_config,
|
|
# so we set it there to override the default (True)
|
|
config.hf_config.use_pad_token = False
|
|
return config
|
|
|
|
|
|
@pytest.fixture
|
|
def tokenization_kwargs():
|
|
"""Common tokenization kwargs used across tests."""
|
|
return {"add_special_tokens": True, "return_tensors": None}
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_model_with_score_template():
|
|
"""Mock model class that supports score template and tracks post_process calls."""
|
|
|
|
class MockModelWithScoreTemplate:
|
|
supports_score_template = True
|
|
post_process_called: list[TokensPrompt] = []
|
|
|
|
@staticmethod
|
|
def get_score_template(p1: str, p2: str) -> str:
|
|
return f"[QUERY]{p1}[SEP][DOC]{p2}"
|
|
|
|
@staticmethod
|
|
def post_process_tokens(prompt: TokensPrompt) -> None:
|
|
MockModelWithScoreTemplate.post_process_called.append(prompt)
|
|
|
|
return MockModelWithScoreTemplate
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_model_no_score_template():
|
|
"""Mock model class that does not support score template."""
|
|
|
|
class MockModelNoScoreTemplate:
|
|
supports_score_template = False
|
|
|
|
return MockModelNoScoreTemplate
|
|
|
|
|
|
class TestGetScorePrompt:
|
|
"""Tests for the get_score_prompt function."""
|
|
|
|
def test_tokenization_kwargs_passed_through(
|
|
self,
|
|
llm_reranker_model_config,
|
|
cross_encoder_tokenizer,
|
|
):
|
|
"""Test that tokenization kwargs are properly passed through."""
|
|
data_1 = "Query text"
|
|
data_2 = "Document text"
|
|
|
|
# Test with truncation - custom kwargs for this test
|
|
custom_tokenization_kwargs = {
|
|
"add_special_tokens": True,
|
|
"return_tensors": None,
|
|
"truncation": True,
|
|
"max_length": 20,
|
|
}
|
|
|
|
full_prompt, engine_prompt = get_score_prompt(
|
|
llm_reranker_model_config,
|
|
cross_encoder_tokenizer,
|
|
custom_tokenization_kwargs,
|
|
data_1,
|
|
data_2,
|
|
)
|
|
|
|
assert isinstance(full_prompt, str)
|
|
assert "prompt_token_ids" in engine_prompt
|
|
# With max_length=20 and truncation, should not exceed this
|
|
assert len(engine_prompt["prompt_token_ids"]) <= 20
|
|
# Since truncation was applied, token_ids should be a prefix of full encoding
|
|
full_ids = cross_encoder_tokenizer(full_prompt, add_special_tokens=True)[
|
|
"input_ids"
|
|
]
|
|
actual_ids = engine_prompt["prompt_token_ids"]
|
|
assert full_ids[: len(actual_ids)] == actual_ids, (
|
|
f"Token IDs are not a prefix of full encoding.\n"
|
|
f"Full IDs: {full_ids}\n"
|
|
f"Actual IDs: {actual_ids}"
|
|
)
|
|
|
|
def test_model_supports_score_template(
|
|
self,
|
|
cross_encoder_model_config,
|
|
cross_encoder_tokenizer,
|
|
tokenization_kwargs,
|
|
mock_model_with_score_template,
|
|
):
|
|
"""Test when model supports score template (no score_template arg)."""
|
|
with patch(
|
|
"vllm.model_executor.model_loader.get_model_cls",
|
|
return_value=mock_model_with_score_template,
|
|
):
|
|
full_prompt, engine_prompt = get_score_prompt(
|
|
cross_encoder_model_config,
|
|
cross_encoder_tokenizer,
|
|
tokenization_kwargs,
|
|
"query text",
|
|
"document text",
|
|
)
|
|
|
|
assert full_prompt == "[QUERY]query text[SEP][DOC]document text"
|
|
assert "prompt_token_ids" in engine_prompt
|
|
assert len(engine_prompt["prompt_token_ids"]) > 0
|
|
assert_prompt_tokenization_consistent(
|
|
cross_encoder_tokenizer, full_prompt, engine_prompt
|
|
)
|
|
|
|
def test_model_supports_score_template_but_custom_template_provided(
|
|
self,
|
|
cross_encoder_model_config,
|
|
cross_encoder_tokenizer,
|
|
tokenization_kwargs,
|
|
mock_model_with_score_template,
|
|
):
|
|
"""Test when model supports score template but custom template is provided."""
|
|
template = (
|
|
'TEMPLATE_USED {{ messages[0]["content"] }} {{ messages[1]["content"] }}'
|
|
)
|
|
with (
|
|
patch(
|
|
"vllm.model_executor.model_loader.get_model_cls",
|
|
return_value=mock_model_with_score_template,
|
|
),
|
|
):
|
|
full_prompt, engine_prompt = get_score_prompt(
|
|
cross_encoder_model_config,
|
|
cross_encoder_tokenizer,
|
|
tokenization_kwargs,
|
|
"query",
|
|
"doc",
|
|
score_template=template, # Providing a template
|
|
)
|
|
|
|
assert "prompt_token_ids" in engine_prompt
|
|
assert full_prompt == "TEMPLATE_USED query doc"
|
|
|
|
assert_prompt_tokenization_consistent(
|
|
cross_encoder_tokenizer, full_prompt, engine_prompt
|
|
)
|
|
|
|
def test_not_using_default_template(
|
|
self,
|
|
llm_reranker_model_config,
|
|
cross_encoder_tokenizer,
|
|
tokenization_kwargs,
|
|
mock_model_no_score_template,
|
|
):
|
|
# FIXME: Models implementing SupportsScoreTemplate must use their custom
|
|
# template implementation by default to preserve existing functionality.
|
|
# Attempting to use tokenizer_config.json templates would most likely break
|
|
# these models, as often they just inherit the template from the original LLM.
|
|
# CLI --chat-template overrides are still supported.
|
|
with (
|
|
patch(
|
|
"vllm.model_executor.model_loader.get_model_cls",
|
|
return_value=mock_model_no_score_template,
|
|
),
|
|
patch(
|
|
"vllm.entrypoints.score_utils.apply_hf_chat_template",
|
|
return_value="test querytest doc",
|
|
),
|
|
):
|
|
full_prompt, engine_prompt = get_score_prompt(
|
|
llm_reranker_model_config,
|
|
cross_encoder_tokenizer,
|
|
tokenization_kwargs,
|
|
"test query",
|
|
"test doc",
|
|
)
|
|
|
|
assert full_prompt == "test querytest doc"
|
|
assert "prompt_token_ids" in engine_prompt
|
|
assert_prompt_tokenization_consistent(
|
|
cross_encoder_tokenizer, full_prompt, engine_prompt
|
|
)
|
|
|
|
def test_fallback_with_pad_token(
|
|
self,
|
|
cross_encoder_model_config,
|
|
cross_encoder_tokenizer,
|
|
tokenization_kwargs,
|
|
mock_model_no_score_template,
|
|
):
|
|
"""Test fallback path when ChatTemplateResolutionError
|
|
and use_pad_token=True."""
|
|
with (
|
|
patch(
|
|
"vllm.model_executor.model_loader.get_model_cls",
|
|
return_value=mock_model_no_score_template,
|
|
),
|
|
patch(
|
|
"vllm.entrypoints.score_utils.apply_hf_chat_template",
|
|
side_effect=ChatTemplateResolutionError("No template"),
|
|
),
|
|
):
|
|
full_prompt, engine_prompt = get_score_prompt(
|
|
cross_encoder_model_config, # use_pad_token=True
|
|
cross_encoder_tokenizer,
|
|
tokenization_kwargs,
|
|
"query",
|
|
"document",
|
|
)
|
|
|
|
assert "prompt_token_ids" in engine_prompt
|
|
# Should have token_type_ids from text_pair encoding
|
|
assert "token_type_ids" in engine_prompt
|
|
assert "query" in full_prompt
|
|
assert "document" in full_prompt
|
|
assert full_prompt != "querydocument"
|
|
assert (
|
|
engine_prompt["prompt_token_ids"]
|
|
== cross_encoder_tokenizer(
|
|
"query", text_pair="document", add_special_tokens=True
|
|
)["input_ids"]
|
|
)
|
|
|
|
# FIXME(?): add_special_tokens=False is needed because in this case
|
|
# full_prompt is obtained by decoding the tokenized prompt, which includes
|
|
# special tokens and we would get duplicated special tokens otherwise.
|
|
# This is inconsistent with other cases.
|
|
assert_prompt_tokenization_consistent(
|
|
cross_encoder_tokenizer,
|
|
full_prompt,
|
|
engine_prompt,
|
|
add_special_tokens=False,
|
|
)
|
|
|
|
def test_fallback_without_pad_token(
|
|
self,
|
|
llm_reranker_model_config,
|
|
cross_encoder_tokenizer,
|
|
tokenization_kwargs,
|
|
mock_model_no_score_template,
|
|
):
|
|
"""Test fallback path when ChatTemplateResolutionError
|
|
and use_pad_token=False."""
|
|
with (
|
|
patch(
|
|
"vllm.model_executor.model_loader.get_model_cls",
|
|
return_value=mock_model_no_score_template,
|
|
),
|
|
patch(
|
|
"vllm.entrypoints.score_utils.apply_hf_chat_template",
|
|
side_effect=ChatTemplateResolutionError("No template"),
|
|
),
|
|
):
|
|
full_prompt, engine_prompt = get_score_prompt(
|
|
llm_reranker_model_config, # use_pad_token=False
|
|
cross_encoder_tokenizer,
|
|
tokenization_kwargs,
|
|
"query",
|
|
"document",
|
|
)
|
|
|
|
assert full_prompt == "querydocument"
|
|
assert "prompt_token_ids" in engine_prompt
|
|
assert_prompt_tokenization_consistent(
|
|
cross_encoder_tokenizer, full_prompt, engine_prompt
|
|
)
|
|
|
|
def test_post_process_tokens_called(
|
|
self,
|
|
cross_encoder_model_config,
|
|
cross_encoder_tokenizer,
|
|
tokenization_kwargs,
|
|
mock_model_with_score_template,
|
|
):
|
|
"""Test that post_process_tokens is called on the engine prompt."""
|
|
# Reset the call tracker
|
|
mock_model_with_score_template.post_process_called.clear()
|
|
|
|
with (
|
|
patch(
|
|
"vllm.model_executor.model_loader.get_model_cls",
|
|
return_value=mock_model_with_score_template,
|
|
),
|
|
patch(
|
|
"vllm.entrypoints.score_utils.apply_hf_chat_template",
|
|
side_effect=ChatTemplateResolutionError("No template"),
|
|
),
|
|
):
|
|
full_prompt, engine_prompt = get_score_prompt(
|
|
cross_encoder_model_config,
|
|
cross_encoder_tokenizer,
|
|
tokenization_kwargs,
|
|
"query",
|
|
"doc",
|
|
)
|
|
|
|
# post_process_tokens should have been called once
|
|
assert len(mock_model_with_score_template.post_process_called) == 1
|
|
assert mock_model_with_score_template.post_process_called[0] is engine_prompt
|
|
assert_prompt_tokenization_consistent(
|
|
cross_encoder_tokenizer, full_prompt, engine_prompt
|
|
)
|