mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-03 00:17:11 +08:00
Merge branch 'main' into upstream_mori_
This commit is contained in:
commit
03343276fa
@ -490,6 +490,7 @@ These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) A
|
||||
| `GteNewModel`<sup>C</sup> | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | | |
|
||||
| `ModernBertModel`<sup>C</sup> | ModernBERT-based | `Alibaba-NLP/gte-modernbert-base`, etc. | | |
|
||||
| `NomicBertModel`<sup>C</sup> | Nomic BERT | `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc. | | |
|
||||
| `LlamaBidirectionalModel`<sup>C</sup> | Llama-based with bidirectional attention | `nvidia/llama-nemotron-embed-1b-v2`, etc. | ✅︎ | ✅︎ |
|
||||
| `LlamaModel`<sup>C</sup>, `LlamaForCausalLM`<sup>C</sup>, `MistralModel`<sup>C</sup>, etc. | Llama-based | `intfloat/e5-mistral-7b-instruct`, etc. | ✅︎ | ✅︎ |
|
||||
| `Qwen2Model`<sup>C</sup>, `Qwen2ForCausalLM`<sup>C</sup> | Qwen2-based | `ssmits/Qwen2-7B-Instruct-embed-base` (see note), `Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc. | ✅︎ | ✅︎ |
|
||||
| `Qwen3Model`<sup>C</sup>, `Qwen3ForCausalLM`<sup>C</sup> | Qwen3-based | `Qwen/Qwen3-Embedding-0.6B`, etc. | ✅︎ | ✅︎ |
|
||||
@ -543,8 +544,9 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A
|
||||
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | | |
|
||||
| `GemmaForSequenceClassification` | Gemma-based | `BAAI/bge-reranker-v2-gemma` (see note), etc. | ✅︎ | ✅︎ |
|
||||
| `GteNewForSequenceClassification` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-reranker-base`, etc. | | |
|
||||
| `Qwen2ForSequenceClassification` | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2` (see note), etc. | ✅︎ | ✅︎ |
|
||||
| `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ | ✅︎ |
|
||||
| `LlamaBidirectionalForSequenceClassification`<sup>C</sup> | Llama-based with bidirectional attention | `nvidia/llama-nemotron-rerank-1b-v2` (see note), etc. | ✅︎ | ✅︎ |
|
||||
| `Qwen2ForSequenceClassification`<sup>C</sup> | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2` (see note), etc. | ✅︎ | ✅︎ |
|
||||
| `Qwen3ForSequenceClassification`<sup>C</sup> | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ | ✅︎ |
|
||||
| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | |
|
||||
| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | | |
|
||||
| `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | \* | \* |
|
||||
@ -562,6 +564,11 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A
|
||||
!!! note
|
||||
The second-generation GTE model (mGTE-TRM) is named `NewForSequenceClassification`. The name `NewForSequenceClassification` is too generic, you should set `--hf-overrides '{"architectures": ["GteNewForSequenceClassification"]}'` to specify the use of the `GteNewForSequenceClassification` architecture.
|
||||
|
||||
!!! note
|
||||
`nvidia/llama-nemotron-rerank-1b-v2` require a specific prompt format to work correctly.
|
||||
|
||||
Examples : [offline_using_template.py](../../examples/pooling/score/offline_using_template.py) [online_using_template.py](../../examples/pooling/score/online_using_template.py)
|
||||
|
||||
!!! note
|
||||
Load the official original `mxbai-rerank-v2` by using the following command.
|
||||
|
||||
|
||||
@ -669,6 +669,21 @@ You can find the documentation for cross encoder models at [sbert.net](https://w
|
||||
|
||||
Code example: [examples/pooling/score/openai_cross_encoder_score.py](../../examples/pooling/score/openai_cross_encoder_score.py)
|
||||
|
||||
#### Score Template
|
||||
|
||||
Some scoring models require a specific prompt format to work correctly. You can specify a custom score template using the `--chat-template` parameter (see [Chat Template](#chat-template)).
|
||||
|
||||
Score templates are supported for **cross-encoder** models only. If you are using an **embedding** model for scoring, vLLM does not apply a score template.
|
||||
|
||||
Like chat templates, the score template receives a `messages` list. For scoring, each message has a `role` attribute—either `"query"` or `"document"`. For the usual kind of point-wise cross-encoder, you can expect exactly two messages: one query and one document. To access the query and document content, use Jinja's `selectattr` filter:
|
||||
|
||||
- **Query**: `{{ (messages | selectattr("role", "eq", "query") | first).content }}`
|
||||
- **Document**: `{{ (messages | selectattr("role", "eq", "document") | first).content }}`
|
||||
|
||||
This approach is more robust than index-based access (`messages[0]`, `messages[1]`) because it selects messages by their semantic role. It also avoids assumptions about message ordering if additional message types are added to `messages` in the future.
|
||||
|
||||
Example template file: [examples/pooling/score/template/nemotron-rerank.jinja](../../examples/pooling/score/template/nemotron-rerank.jinja)
|
||||
|
||||
#### Single inference
|
||||
|
||||
You can pass a string to both `text_1` and `text_2`, forming a single sentence pair.
|
||||
|
||||
27
examples/pooling/score/offline_using_template.py
Normal file
27
examples/pooling/score/offline_using_template.py
Normal file
@ -0,0 +1,27 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: E501
|
||||
from pathlib import Path
|
||||
|
||||
from vllm import LLM
|
||||
|
||||
model_name = "nvidia/llama-nemotron-rerank-1b-v2"
|
||||
|
||||
# Path to template file
|
||||
template_path = Path(__file__).parent / "template" / "nemotron-rerank.jinja"
|
||||
chat_template = template_path.read_text()
|
||||
|
||||
llm = LLM(model=model_name, runner="pooling", trust_remote_code=True)
|
||||
|
||||
query = "how much protein should a female eat?"
|
||||
documents = [
|
||||
"As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.",
|
||||
"Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments.",
|
||||
"Calorie intake should not fall below 1,200 a day in women or 1,500 a day in men, except under the supervision of a health professional.",
|
||||
]
|
||||
|
||||
outputs = llm.score(query, documents, chat_template=chat_template)
|
||||
|
||||
print("-" * 30)
|
||||
print([output.outputs.score for output in outputs])
|
||||
print("-" * 30)
|
||||
46
examples/pooling/score/online_using_template.py
Normal file
46
examples/pooling/score/online_using_template.py
Normal file
@ -0,0 +1,46 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: E501
|
||||
"""
|
||||
Example of using the rerank API with template.
|
||||
|
||||
run:
|
||||
vllm serve nvidia/llama-nemotron-rerank-1b-v2 --runner pooling --trust-remote-code --chat-template examples/pooling/score/template/nemotron-rerank.jinja
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
url = "http://127.0.0.1:8000/rerank"
|
||||
|
||||
headers = {"accept": "application/json", "Content-Type": "application/json"}
|
||||
|
||||
query = "how much protein should a female eat?"
|
||||
documents = [
|
||||
"As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.",
|
||||
"Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments.",
|
||||
"Calorie intake should not fall below 1,200 a day in women or 1,500 a day in men, except under the supervision of a health professional.",
|
||||
]
|
||||
|
||||
data = {
|
||||
"model": "nvidia/llama-nemotron-rerank-1b-v2",
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
|
||||
# Check the response
|
||||
if response.status_code == 200:
|
||||
print("Request successful!")
|
||||
print(json.dumps(response.json(), indent=2))
|
||||
else:
|
||||
print(f"Request failed with status code: {response.status_code}")
|
||||
print(response.text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
3
examples/pooling/score/template/nemotron-rerank.jinja
Normal file
3
examples/pooling/score/template/nemotron-rerank.jinja
Normal file
@ -0,0 +1,3 @@
|
||||
question:{{ (messages | selectattr("role", "eq", "query") | first).content }}
|
||||
|
||||
passage:{{ (messages | selectattr("role", "eq", "document") | first).content }}
|
||||
352
tests/entrypoints/pooling/score/test_utils.py
Normal file
352
tests/entrypoints/pooling/score/test_utils.py
Normal file
@ -0,0 +1,352 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateResolutionError
|
||||
from vllm.entrypoints.score_utils import get_score_prompt
|
||||
from vllm.inputs import TokensPrompt
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
|
||||
# A cross-encoder model for testing
|
||||
CROSS_ENCODER_MODEL_ID = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||
|
||||
|
||||
def assert_prompt_tokenization_consistent(
|
||||
tokenizer, full_prompt, engine_prompt, add_special_tokens=True
|
||||
):
|
||||
"""Verify that engine_prompt token_ids match tokenizing full_prompt."""
|
||||
expected_ids = tokenizer(full_prompt, add_special_tokens=add_special_tokens)[
|
||||
"input_ids"
|
||||
]
|
||||
actual_ids = engine_prompt["prompt_token_ids"]
|
||||
assert actual_ids == expected_ids, (
|
||||
f"Token IDs don't match.\nExpected: {expected_ids}\nActual: {actual_ids}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def cross_encoder_model_config():
|
||||
return ModelConfig(
|
||||
CROSS_ENCODER_MODEL_ID,
|
||||
runner="pooling",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def cross_encoder_tokenizer(cross_encoder_model_config):
|
||||
return get_tokenizer(
|
||||
CROSS_ENCODER_MODEL_ID,
|
||||
trust_remote_code=cross_encoder_model_config.trust_remote_code,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def llm_reranker_model_config():
|
||||
"""Model config for LLM-as-reranker style (no pad token)."""
|
||||
config = ModelConfig(
|
||||
CROSS_ENCODER_MODEL_ID,
|
||||
runner="pooling",
|
||||
)
|
||||
# use_pad_token is a property that reads from hf_config,
|
||||
# so we set it there to override the default (True)
|
||||
config.hf_config.use_pad_token = False
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tokenization_kwargs():
|
||||
"""Common tokenization kwargs used across tests."""
|
||||
return {"add_special_tokens": True, "return_tensors": None}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model_with_score_template():
|
||||
"""Mock model class that supports score template and tracks post_process calls."""
|
||||
|
||||
class MockModelWithScoreTemplate:
|
||||
supports_score_template = True
|
||||
post_process_called: list[TokensPrompt] = []
|
||||
|
||||
@staticmethod
|
||||
def get_score_template(p1: str, p2: str) -> str:
|
||||
return f"[QUERY]{p1}[SEP][DOC]{p2}"
|
||||
|
||||
@staticmethod
|
||||
def post_process_tokens(prompt: TokensPrompt) -> None:
|
||||
MockModelWithScoreTemplate.post_process_called.append(prompt)
|
||||
|
||||
return MockModelWithScoreTemplate
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model_no_score_template():
|
||||
"""Mock model class that does not support score template."""
|
||||
|
||||
class MockModelNoScoreTemplate:
|
||||
supports_score_template = False
|
||||
|
||||
return MockModelNoScoreTemplate
|
||||
|
||||
|
||||
class TestGetScorePrompt:
|
||||
"""Tests for the get_score_prompt function."""
|
||||
|
||||
def test_tokenization_kwargs_passed_through(
|
||||
self,
|
||||
llm_reranker_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
):
|
||||
"""Test that tokenization kwargs are properly passed through."""
|
||||
data_1 = "Query text"
|
||||
data_2 = "Document text"
|
||||
|
||||
# Test with truncation - custom kwargs for this test
|
||||
custom_tokenization_kwargs = {
|
||||
"add_special_tokens": True,
|
||||
"return_tensors": None,
|
||||
"truncation": True,
|
||||
"max_length": 20,
|
||||
}
|
||||
|
||||
full_prompt, engine_prompt = get_score_prompt(
|
||||
llm_reranker_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
custom_tokenization_kwargs,
|
||||
data_1,
|
||||
data_2,
|
||||
)
|
||||
|
||||
assert isinstance(full_prompt, str)
|
||||
assert "prompt_token_ids" in engine_prompt
|
||||
# With max_length=20 and truncation, should not exceed this
|
||||
assert len(engine_prompt["prompt_token_ids"]) <= 20
|
||||
# Since truncation was applied, token_ids should be a prefix of full encoding
|
||||
full_ids = cross_encoder_tokenizer(full_prompt, add_special_tokens=True)[
|
||||
"input_ids"
|
||||
]
|
||||
actual_ids = engine_prompt["prompt_token_ids"]
|
||||
assert full_ids[: len(actual_ids)] == actual_ids, (
|
||||
f"Token IDs are not a prefix of full encoding.\n"
|
||||
f"Full IDs: {full_ids}\n"
|
||||
f"Actual IDs: {actual_ids}"
|
||||
)
|
||||
|
||||
def test_model_supports_score_template(
|
||||
self,
|
||||
cross_encoder_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
mock_model_with_score_template,
|
||||
):
|
||||
"""Test when model supports score template (no score_template arg)."""
|
||||
with patch(
|
||||
"vllm.model_executor.model_loader.get_model_cls",
|
||||
return_value=mock_model_with_score_template,
|
||||
):
|
||||
full_prompt, engine_prompt = get_score_prompt(
|
||||
cross_encoder_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
"query text",
|
||||
"document text",
|
||||
)
|
||||
|
||||
assert full_prompt == "[QUERY]query text[SEP][DOC]document text"
|
||||
assert "prompt_token_ids" in engine_prompt
|
||||
assert len(engine_prompt["prompt_token_ids"]) > 0
|
||||
assert_prompt_tokenization_consistent(
|
||||
cross_encoder_tokenizer, full_prompt, engine_prompt
|
||||
)
|
||||
|
||||
def test_model_supports_score_template_but_custom_template_provided(
|
||||
self,
|
||||
cross_encoder_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
mock_model_with_score_template,
|
||||
):
|
||||
"""Test when model supports score template but custom template is provided."""
|
||||
template = (
|
||||
'TEMPLATE_USED {{ messages[0]["content"] }} {{ messages[1]["content"] }}'
|
||||
)
|
||||
with (
|
||||
patch(
|
||||
"vllm.model_executor.model_loader.get_model_cls",
|
||||
return_value=mock_model_with_score_template,
|
||||
),
|
||||
):
|
||||
full_prompt, engine_prompt = get_score_prompt(
|
||||
cross_encoder_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
"query",
|
||||
"doc",
|
||||
score_template=template, # Providing a template
|
||||
)
|
||||
|
||||
assert "prompt_token_ids" in engine_prompt
|
||||
assert full_prompt == "TEMPLATE_USED query doc"
|
||||
|
||||
assert_prompt_tokenization_consistent(
|
||||
cross_encoder_tokenizer, full_prompt, engine_prompt
|
||||
)
|
||||
|
||||
def test_not_using_default_template(
|
||||
self,
|
||||
llm_reranker_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
mock_model_no_score_template,
|
||||
):
|
||||
# FIXME: Models implementing SupportsScoreTemplate must use their custom
|
||||
# template implementation by default to preserve existing functionality.
|
||||
# Attempting to use tokenizer_config.json templates would most likely break
|
||||
# these models, as often they just inherit the template from the original LLM.
|
||||
# CLI --chat-template overrides are still supported.
|
||||
with (
|
||||
patch(
|
||||
"vllm.model_executor.model_loader.get_model_cls",
|
||||
return_value=mock_model_no_score_template,
|
||||
),
|
||||
patch(
|
||||
"vllm.entrypoints.score_utils.apply_hf_chat_template",
|
||||
return_value="test querytest doc",
|
||||
),
|
||||
):
|
||||
full_prompt, engine_prompt = get_score_prompt(
|
||||
llm_reranker_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
"test query",
|
||||
"test doc",
|
||||
)
|
||||
|
||||
assert full_prompt == "test querytest doc"
|
||||
assert "prompt_token_ids" in engine_prompt
|
||||
assert_prompt_tokenization_consistent(
|
||||
cross_encoder_tokenizer, full_prompt, engine_prompt
|
||||
)
|
||||
|
||||
def test_fallback_with_pad_token(
|
||||
self,
|
||||
cross_encoder_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
mock_model_no_score_template,
|
||||
):
|
||||
"""Test fallback path when ChatTemplateResolutionError
|
||||
and use_pad_token=True."""
|
||||
with (
|
||||
patch(
|
||||
"vllm.model_executor.model_loader.get_model_cls",
|
||||
return_value=mock_model_no_score_template,
|
||||
),
|
||||
patch(
|
||||
"vllm.entrypoints.score_utils.apply_hf_chat_template",
|
||||
side_effect=ChatTemplateResolutionError("No template"),
|
||||
),
|
||||
):
|
||||
full_prompt, engine_prompt = get_score_prompt(
|
||||
cross_encoder_model_config, # use_pad_token=True
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
"query",
|
||||
"document",
|
||||
)
|
||||
|
||||
assert "prompt_token_ids" in engine_prompt
|
||||
# Should have token_type_ids from text_pair encoding
|
||||
assert "token_type_ids" in engine_prompt
|
||||
assert "query" in full_prompt
|
||||
assert "document" in full_prompt
|
||||
assert full_prompt != "querydocument"
|
||||
assert (
|
||||
engine_prompt["prompt_token_ids"]
|
||||
== cross_encoder_tokenizer(
|
||||
"query", text_pair="document", add_special_tokens=True
|
||||
)["input_ids"]
|
||||
)
|
||||
|
||||
# FIXME(?): add_special_tokens=False is needed because in this case
|
||||
# full_prompt is obtained by decoding the tokenized prompt, which includes
|
||||
# special tokens and we would get duplicated special tokens otherwise.
|
||||
# This is inconsistent with other cases.
|
||||
assert_prompt_tokenization_consistent(
|
||||
cross_encoder_tokenizer,
|
||||
full_prompt,
|
||||
engine_prompt,
|
||||
add_special_tokens=False,
|
||||
)
|
||||
|
||||
def test_fallback_without_pad_token(
|
||||
self,
|
||||
llm_reranker_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
mock_model_no_score_template,
|
||||
):
|
||||
"""Test fallback path when ChatTemplateResolutionError
|
||||
and use_pad_token=False."""
|
||||
with (
|
||||
patch(
|
||||
"vllm.model_executor.model_loader.get_model_cls",
|
||||
return_value=mock_model_no_score_template,
|
||||
),
|
||||
patch(
|
||||
"vllm.entrypoints.score_utils.apply_hf_chat_template",
|
||||
side_effect=ChatTemplateResolutionError("No template"),
|
||||
),
|
||||
):
|
||||
full_prompt, engine_prompt = get_score_prompt(
|
||||
llm_reranker_model_config, # use_pad_token=False
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
"query",
|
||||
"document",
|
||||
)
|
||||
|
||||
assert full_prompt == "querydocument"
|
||||
assert "prompt_token_ids" in engine_prompt
|
||||
assert_prompt_tokenization_consistent(
|
||||
cross_encoder_tokenizer, full_prompt, engine_prompt
|
||||
)
|
||||
|
||||
def test_post_process_tokens_called(
|
||||
self,
|
||||
cross_encoder_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
mock_model_with_score_template,
|
||||
):
|
||||
"""Test that post_process_tokens is called on the engine prompt."""
|
||||
# Reset the call tracker
|
||||
mock_model_with_score_template.post_process_called.clear()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"vllm.model_executor.model_loader.get_model_cls",
|
||||
return_value=mock_model_with_score_template,
|
||||
),
|
||||
patch(
|
||||
"vllm.entrypoints.score_utils.apply_hf_chat_template",
|
||||
side_effect=ChatTemplateResolutionError("No template"),
|
||||
),
|
||||
):
|
||||
full_prompt, engine_prompt = get_score_prompt(
|
||||
cross_encoder_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
"query",
|
||||
"doc",
|
||||
)
|
||||
|
||||
# post_process_tokens should have been called once
|
||||
assert len(mock_model_with_score_template.post_process_called) == 1
|
||||
assert mock_model_with_score_template.post_process_called[0] is engine_prompt
|
||||
assert_prompt_tokenization_consistent(
|
||||
cross_encoder_tokenizer, full_prompt, engine_prompt
|
||||
)
|
||||
@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import mteb
|
||||
import numpy as np
|
||||
@ -19,6 +20,11 @@ from tests.models.utils import (
|
||||
get_vllm_extra_kwargs,
|
||||
)
|
||||
|
||||
template_home = (
|
||||
Path(__file__).parent.parent.parent.parent.parent
|
||||
/ "examples/pooling/score/template"
|
||||
)
|
||||
|
||||
# Most embedding models on the STS12 task (See #17175):
|
||||
# - Model implementation and minor changes in tensor dtype
|
||||
# results in differences less than 1e-4
|
||||
@ -102,30 +108,6 @@ class VllmMtebEncoder(mteb.EncoderProtocol):
|
||||
return sim
|
||||
|
||||
|
||||
class VllmMtebCrossEncoder(mteb.CrossEncoderProtocol):
|
||||
mteb_model_meta = _empty_model_meta
|
||||
|
||||
def __init__(self, vllm_model):
|
||||
self.llm = vllm_model
|
||||
self.rng = np.random.default_rng(seed=42)
|
||||
|
||||
def predict(
|
||||
self,
|
||||
inputs1: DataLoader[mteb.types.BatchedInput],
|
||||
inputs2: DataLoader[mteb.types.BatchedInput],
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
queries = [text for batch in inputs1 for text in batch["text"]]
|
||||
corpus = [text for batch in inputs2 for text in batch["text"]]
|
||||
|
||||
outputs = self.llm.score(
|
||||
queries, corpus, truncate_prompt_tokens=-1, use_tqdm=False
|
||||
)
|
||||
scores = np.array(outputs)
|
||||
return scores
|
||||
|
||||
|
||||
class OpenAIClientMtebEncoder(VllmMtebEncoder):
|
||||
def __init__(self, model_name: str, client):
|
||||
self.model_name = model_name
|
||||
@ -153,6 +135,35 @@ class OpenAIClientMtebEncoder(VllmMtebEncoder):
|
||||
return embeds
|
||||
|
||||
|
||||
class VllmMtebCrossEncoder(mteb.CrossEncoderProtocol):
|
||||
mteb_model_meta = _empty_model_meta
|
||||
|
||||
def __init__(self, vllm_model):
|
||||
self.llm = vllm_model
|
||||
self.rng = np.random.default_rng(seed=42)
|
||||
self.chat_template: str | None = getattr(vllm_model, "chat_template", None)
|
||||
|
||||
def predict(
|
||||
self,
|
||||
inputs1: DataLoader[mteb.types.BatchedInput],
|
||||
inputs2: DataLoader[mteb.types.BatchedInput],
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
queries = [text for batch in inputs1 for text in batch["text"]]
|
||||
corpus = [text for batch in inputs2 for text in batch["text"]]
|
||||
|
||||
outputs = self.llm.score(
|
||||
queries,
|
||||
corpus,
|
||||
truncate_prompt_tokens=-1,
|
||||
use_tqdm=False,
|
||||
chat_template=self.chat_template,
|
||||
)
|
||||
scores = np.array(outputs)
|
||||
return scores
|
||||
|
||||
|
||||
class ScoreClientMtebEncoder(mteb.CrossEncoderProtocol):
|
||||
mteb_model_meta = _empty_model_meta
|
||||
|
||||
@ -387,6 +398,11 @@ def mteb_test_rerank_models(
|
||||
== model_info.default_pooling_type
|
||||
)
|
||||
|
||||
chat_template: str | None = None
|
||||
if model_info.chat_template_name is not None:
|
||||
chat_template = (template_home / model_info.chat_template_name).read_text()
|
||||
vllm_model.chat_template = chat_template
|
||||
|
||||
vllm_main_score = run_mteb_rerank(
|
||||
vllm_mteb_encoder(vllm_model),
|
||||
tasks=MTEB_RERANK_TASKS,
|
||||
|
||||
42
tests/models/language/pooling_mteb_test/test_nemotron.py
Normal file
42
tests/models/language/pooling_mteb_test/test_nemotron.py
Normal file
@ -0,0 +1,42 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.models.utils import (
|
||||
EmbedModelInfo,
|
||||
LASTPoolingEmbedModelInfo,
|
||||
LASTPoolingRerankModelInfo,
|
||||
RerankModelInfo,
|
||||
)
|
||||
|
||||
from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models
|
||||
|
||||
EMBEDDING_MODELS = [
|
||||
LASTPoolingEmbedModelInfo(
|
||||
"nvidia/llama-nemotron-embed-1b-v2",
|
||||
architecture="LlamaBidirectionalModel",
|
||||
mteb_score=0.689164662128673,
|
||||
)
|
||||
]
|
||||
|
||||
RERANK_MODELS = [
|
||||
LASTPoolingRerankModelInfo(
|
||||
"nvidia/llama-nemotron-rerank-1b-v2",
|
||||
architecture="LlamaBidirectionalForSequenceClassification",
|
||||
chat_template_name="nemotron-rerank.jinja",
|
||||
mteb_score=0.33994,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_info", EMBEDDING_MODELS)
|
||||
def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None:
|
||||
mteb_test_embed_models(hf_runner, vllm_runner, model_info)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_info", RERANK_MODELS)
|
||||
def test_rerank_models_mteb(
|
||||
hf_runner, vllm_runner, model_info: RerankModelInfo
|
||||
) -> None:
|
||||
mteb_test_rerank_models(hf_runner, vllm_runner, model_info)
|
||||
@ -488,6 +488,9 @@ _EMBEDDING_EXAMPLE_MODELS = {
|
||||
),
|
||||
"JambaForSequenceClassification": _HfExamplesInfo("ai21labs/Jamba-tiny-reward-dev"),
|
||||
"LlamaModel": _HfExamplesInfo("llama", is_available_online=False),
|
||||
"LlamaBidirectionalModel": _HfExamplesInfo(
|
||||
"nvidia/llama-nemotron-embed-1b-v2", trust_remote_code=True
|
||||
),
|
||||
"MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"),
|
||||
"ModernBertModel": _HfExamplesInfo(
|
||||
"Alibaba-NLP/gte-modernbert-base", trust_remote_code=True
|
||||
@ -554,6 +557,9 @@ _SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = {
|
||||
trust_remote_code=True,
|
||||
hf_overrides={"architectures": ["GteNewForSequenceClassification"]},
|
||||
),
|
||||
"LlamaBidirectionalForSequenceClassification": _HfExamplesInfo(
|
||||
"nvidia/llama-nemotron-rerank-1b-v2", trust_remote_code=True
|
||||
),
|
||||
"ModernBertForSequenceClassification": _HfExamplesInfo(
|
||||
"Alibaba-NLP/gte-reranker-modernbert-base"
|
||||
),
|
||||
|
||||
@ -399,6 +399,7 @@ class LASTPoolingEmbedModelInfo(EmbedModelInfo):
|
||||
@dataclass
|
||||
class RerankModelInfo(ModelInfo):
|
||||
mteb_score: float | None = None
|
||||
chat_template_name: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -11,7 +11,6 @@ import torch
|
||||
from pydantic import ConfigDict, Field, field_validator, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||
from transformers.configuration_utils import ALLOWED_LAYER_TYPES
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
@ -29,6 +28,7 @@ from vllm.transformers_utils.config import (
|
||||
get_pooling_config,
|
||||
get_sentence_transformer_tokenizer_config,
|
||||
is_encoder_decoder,
|
||||
is_rope_parameters_nested,
|
||||
try_get_dense_modules,
|
||||
try_get_generation_config,
|
||||
try_get_safetensors_metadata,
|
||||
@ -2125,9 +2125,7 @@ def _get_and_verify_max_len(
|
||||
# In Transformers v5 rope_parameters could be TypedDict or dict[str, TypedDict].
|
||||
# To simplify the verification, we convert it to dict[str, TypedDict].
|
||||
rope_parameters = getattr(hf_config, "rope_parameters", None)
|
||||
if rope_parameters and not set(rope_parameters.keys()).issubset(
|
||||
ALLOWED_LAYER_TYPES
|
||||
):
|
||||
if rope_parameters and not is_rope_parameters_nested(rope_parameters):
|
||||
rope_parameters = {"": rope_parameters}
|
||||
|
||||
# NOTE(woosuk): Gemma3's max_model_len (128K) is already scaled by RoPE
|
||||
|
||||
@ -67,6 +67,15 @@ else:
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ChatTemplateResolutionError(ValueError):
|
||||
"""Raised when chat template resolution fails.
|
||||
|
||||
This is a subclass of ValueError for backward compatibility with
|
||||
existing exception handlers.
|
||||
"""
|
||||
|
||||
|
||||
MODALITY_PLACEHOLDERS_MAP = {
|
||||
"image": "<##IMAGE##>",
|
||||
"audio": "<##AUDIO##>",
|
||||
@ -1814,7 +1823,7 @@ def apply_hf_chat_template(
|
||||
)
|
||||
|
||||
if hf_chat_template is None:
|
||||
raise ValueError(
|
||||
raise ChatTemplateResolutionError(
|
||||
"As of transformers v4.44, default chat template is no longer "
|
||||
"allowed, so you must provide a chat template if the tokenizer "
|
||||
"does not define one."
|
||||
|
||||
@ -1280,6 +1280,7 @@ class LLM:
|
||||
pooling_params: PoolingParams | None = None,
|
||||
lora_request: list[LoRARequest] | LoRARequest | None = None,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
score_template: str | None = None,
|
||||
) -> list[ScoringRequestOutput]:
|
||||
model_config = self.model_config
|
||||
|
||||
@ -1313,6 +1314,7 @@ class LLM:
|
||||
data_2=d,
|
||||
tokenizer=tokenizer,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
score_template=score_template,
|
||||
)
|
||||
|
||||
if token_type_ids := engine_prompt.pop("token_type_ids", None):
|
||||
@ -1347,6 +1349,7 @@ class LLM:
|
||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||
pooling_params: PoolingParams | None = None,
|
||||
lora_request: list[LoRARequest] | LoRARequest | None = None,
|
||||
chat_template: str | None = None,
|
||||
) -> list[ScoringRequestOutput]:
|
||||
"""Generate similarity scores for all pairs `<text,text_pair>` or
|
||||
`<multi-modal data, multi-modal data pair>`.
|
||||
@ -1379,6 +1382,8 @@ class LLM:
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
pooling_params: The pooling parameters for pooling. If None, we
|
||||
use the default pooling parameters.
|
||||
chat_template: The chat template to use for the scoring. If None, we
|
||||
use the model's default chat template.
|
||||
Returns:
|
||||
A list of `ScoringRequestOutput` objects containing the
|
||||
generated scores in the same order as the input prompts.
|
||||
@ -1406,6 +1411,11 @@ class LLM:
|
||||
):
|
||||
raise ValueError("Score API is only enabled for num_labels == 1.")
|
||||
|
||||
if not model_config.is_cross_encoder and chat_template is not None:
|
||||
raise ValueError(
|
||||
"chat_template is only supported for cross-encoder models."
|
||||
)
|
||||
|
||||
# the tokenizer for models such as
|
||||
# "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
|
||||
# lists of tokens to the `text` and `text_pair` kwargs
|
||||
@ -1475,6 +1485,7 @@ class LLM:
|
||||
use_tqdm,
|
||||
pooling_params,
|
||||
lora_request,
|
||||
score_template=chat_template,
|
||||
)
|
||||
else:
|
||||
return self._embedding_score(
|
||||
|
||||
@ -909,6 +909,16 @@ def build_app(args: Namespace) -> FastAPI:
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(_: Request, exc: RequestValidationError):
|
||||
from vllm.entrypoints.openai.protocol import VLLMValidationError
|
||||
|
||||
param = None
|
||||
for error in exc.errors():
|
||||
if "ctx" in error and "error" in error["ctx"]:
|
||||
ctx_error = error["ctx"]["error"]
|
||||
if isinstance(ctx_error, VLLMValidationError):
|
||||
param = ctx_error.parameter
|
||||
break
|
||||
|
||||
exc_str = str(exc)
|
||||
errors_str = str(exc.errors())
|
||||
|
||||
@ -922,6 +932,7 @@ def build_app(args: Namespace) -> FastAPI:
|
||||
message=message,
|
||||
type=HTTPStatus.BAD_REQUEST.phrase,
|
||||
code=HTTPStatus.BAD_REQUEST,
|
||||
param=param,
|
||||
)
|
||||
)
|
||||
return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
|
||||
@ -1145,6 +1156,7 @@ async def init_app_state(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
score_template=resolved_chat_template,
|
||||
log_error_stack=args.log_error_stack,
|
||||
)
|
||||
if ("embed" in supported_tasks or "score" in supported_tasks)
|
||||
|
||||
@ -131,6 +131,36 @@ class ErrorResponse(OpenAIBaseModel):
|
||||
error: ErrorInfo
|
||||
|
||||
|
||||
class VLLMValidationError(ValueError):
|
||||
"""vLLM-specific validation error for request validation failures.
|
||||
|
||||
Args:
|
||||
message: The error message describing the validation failure.
|
||||
parameter: Optional parameter name that failed validation.
|
||||
value: Optional value that was rejected during validation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
parameter: str | None = None,
|
||||
value: Any = None,
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
self.parameter = parameter
|
||||
self.value = value
|
||||
|
||||
def __str__(self):
|
||||
base = super().__str__()
|
||||
extras = []
|
||||
if self.parameter is not None:
|
||||
extras.append(f"parameter={self.parameter}")
|
||||
if self.value is not None:
|
||||
extras.append(f"value={self.value}")
|
||||
return f"{base} ({', '.join(extras)})" if extras else base
|
||||
|
||||
|
||||
class ModelPermission(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
|
||||
object: str = "model_permission"
|
||||
@ -466,7 +496,9 @@ class ResponsesRequest(OpenAIBaseModel):
|
||||
@model_validator(mode="before")
|
||||
def validate_prompt(cls, data):
|
||||
if data.get("prompt") is not None:
|
||||
raise ValueError("prompt template is not supported")
|
||||
raise VLLMValidationError(
|
||||
"prompt template is not supported", parameter="prompt"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@ -850,7 +882,10 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
@classmethod
|
||||
def validate_stream_options(cls, data):
|
||||
if data.get("stream_options") and not data.get("stream"):
|
||||
raise ValueError("Stream options can only be defined when `stream=True`.")
|
||||
raise VLLMValidationError(
|
||||
"Stream options can only be defined when `stream=True`.",
|
||||
parameter="stream_options",
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@ -859,19 +894,29 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
def check_logprobs(cls, data):
|
||||
if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
|
||||
if data.get("stream") and (prompt_logprobs > 0 or prompt_logprobs == -1):
|
||||
raise ValueError(
|
||||
"`prompt_logprobs` are not available when `stream=True`."
|
||||
raise VLLMValidationError(
|
||||
"`prompt_logprobs` are not available when `stream=True`.",
|
||||
parameter="prompt_logprobs",
|
||||
)
|
||||
|
||||
if prompt_logprobs < 0 and prompt_logprobs != -1:
|
||||
raise ValueError("`prompt_logprobs` must be a positive value or -1.")
|
||||
raise VLLMValidationError(
|
||||
"`prompt_logprobs` must be a positive value or -1.",
|
||||
parameter="prompt_logprobs",
|
||||
value=prompt_logprobs,
|
||||
)
|
||||
if (top_logprobs := data.get("top_logprobs")) is not None:
|
||||
if top_logprobs < 0 and top_logprobs != -1:
|
||||
raise ValueError("`top_logprobs` must be a positive value or -1.")
|
||||
raise VLLMValidationError(
|
||||
"`top_logprobs` must be a positive value or -1.",
|
||||
parameter="top_logprobs",
|
||||
value=top_logprobs,
|
||||
)
|
||||
|
||||
if (top_logprobs == -1 or top_logprobs > 0) and not data.get("logprobs"):
|
||||
raise ValueError(
|
||||
"when using `top_logprobs`, `logprobs` must be set to true."
|
||||
raise VLLMValidationError(
|
||||
"when using `top_logprobs`, `logprobs` must be set to true.",
|
||||
parameter="top_logprobs",
|
||||
)
|
||||
|
||||
return data
|
||||
@ -1285,9 +1330,10 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
for k in ("json", "regex", "choice")
|
||||
)
|
||||
if count > 1:
|
||||
raise ValueError(
|
||||
raise VLLMValidationError(
|
||||
"You can only use one kind of constraints for structured "
|
||||
"outputs ('json', 'regex' or 'choice')."
|
||||
"outputs ('json', 'regex' or 'choice').",
|
||||
parameter="structured_outputs",
|
||||
)
|
||||
return data
|
||||
|
||||
@ -1296,14 +1342,23 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
def check_logprobs(cls, data):
|
||||
if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
|
||||
if data.get("stream") and (prompt_logprobs > 0 or prompt_logprobs == -1):
|
||||
raise ValueError(
|
||||
"`prompt_logprobs` are not available when `stream=True`."
|
||||
raise VLLMValidationError(
|
||||
"`prompt_logprobs` are not available when `stream=True`.",
|
||||
parameter="prompt_logprobs",
|
||||
)
|
||||
|
||||
if prompt_logprobs < 0 and prompt_logprobs != -1:
|
||||
raise ValueError("`prompt_logprobs` must be a positive value or -1.")
|
||||
raise VLLMValidationError(
|
||||
"`prompt_logprobs` must be a positive value or -1.",
|
||||
parameter="prompt_logprobs",
|
||||
value=prompt_logprobs,
|
||||
)
|
||||
if (logprobs := data.get("logprobs")) is not None and logprobs < 0:
|
||||
raise ValueError("`logprobs` must be a positive value.")
|
||||
raise VLLMValidationError(
|
||||
"`logprobs` must be a positive value.",
|
||||
parameter="logprobs",
|
||||
value=logprobs,
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@ -1311,7 +1366,10 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
@classmethod
|
||||
def validate_stream_options(cls, data):
|
||||
if data.get("stream_options") and not data.get("stream"):
|
||||
raise ValueError("Stream options can only be defined when `stream=True`.")
|
||||
raise VLLMValidationError(
|
||||
"Stream options can only be defined when `stream=True`.",
|
||||
parameter="stream_options",
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@ -2138,7 +2196,15 @@ class TranscriptionRequest(OpenAIBaseModel):
|
||||
stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
|
||||
stream = data.get("stream", False)
|
||||
if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
|
||||
raise ValueError("Stream options can only be defined when `stream=True`.")
|
||||
# Find which specific stream option was set
|
||||
invalid_param = next(
|
||||
(so for so in stream_opts if data.get(so, False)),
|
||||
"stream_include_usage",
|
||||
)
|
||||
raise VLLMValidationError(
|
||||
"Stream options can only be defined when `stream=True`.",
|
||||
parameter=invalid_param,
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@ -2351,7 +2417,15 @@ class TranslationRequest(OpenAIBaseModel):
|
||||
stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
|
||||
stream = data.get("stream", False)
|
||||
if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
|
||||
raise ValueError("Stream options can only be defined when `stream=True`.")
|
||||
# Find which specific stream option was set
|
||||
invalid_param = next(
|
||||
(so for so in stream_opts if data.get(so, False)),
|
||||
"stream_include_usage",
|
||||
)
|
||||
raise VLLMValidationError(
|
||||
"Stream options can only be defined when `stream=True`.",
|
||||
parameter=invalid_param,
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
@ -495,6 +495,7 @@ async def run_batch(
|
||||
engine_client,
|
||||
openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
score_template=None,
|
||||
)
|
||||
if ("embed" in supported_tasks or enable_serving_reranking)
|
||||
else None
|
||||
|
||||
@ -417,8 +417,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
generators.append(generator)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
return self.create_error_response(e)
|
||||
|
||||
assert len(generators) == 1
|
||||
(result_generator,) = generators
|
||||
@ -448,8 +447,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
except GenerationError as e:
|
||||
return self._convert_generation_error_to_response(e)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
return self.create_error_response(e)
|
||||
|
||||
def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
|
||||
if request.add_generation_prompt:
|
||||
@ -682,7 +680,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
tool_parsers = [None] * num_choices
|
||||
except Exception as e:
|
||||
logger.exception("Error in tool parser creation.")
|
||||
data = self.create_streaming_error_response(str(e))
|
||||
data = self.create_streaming_error_response(e)
|
||||
yield f"data: {data}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
@ -1328,9 +1326,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
except GenerationError as e:
|
||||
yield f"data: {self._convert_generation_error_to_streaming_response(e)}\n\n"
|
||||
except Exception as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
logger.exception("Error in chat completion stream generator.")
|
||||
data = self.create_streaming_error_response(str(e))
|
||||
data = self.create_streaming_error_response(e)
|
||||
yield f"data: {data}\n\n"
|
||||
# Send the final done message after all response.n are finished
|
||||
yield "data: [DONE]\n\n"
|
||||
@ -1354,8 +1351,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
return self.create_error_response(e)
|
||||
|
||||
assert final_res is not None
|
||||
|
||||
|
||||
@ -23,6 +23,7 @@ from vllm.entrypoints.openai.protocol import (
|
||||
PromptTokenUsageInfo,
|
||||
RequestResponseMetadata,
|
||||
UsageInfo,
|
||||
VLLMValidationError,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_engine import (
|
||||
GenerationError,
|
||||
@ -247,8 +248,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
generators.append(generator)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
return self.create_error_response(e)
|
||||
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
@ -308,8 +308,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
except GenerationError as e:
|
||||
return self._convert_generation_error_to_response(e)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
return self.create_error_response(e)
|
||||
|
||||
# When user requests streaming but we don't stream, we still need to
|
||||
# return a streaming response with a single event.
|
||||
@ -510,9 +509,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
except GenerationError as e:
|
||||
yield f"data: {self._convert_generation_error_to_streaming_response(e)}\n\n"
|
||||
except Exception as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
logger.exception("Error in completion stream generator.")
|
||||
data = self.create_streaming_error_response(str(e))
|
||||
data = self.create_streaming_error_response(e)
|
||||
yield f"data: {data}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
@ -660,8 +658,11 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
token = f"token_id:{token_id}"
|
||||
else:
|
||||
if tokenizer is None:
|
||||
raise ValueError(
|
||||
"Unable to get tokenizer because `skip_tokenizer_init=True`"
|
||||
raise VLLMValidationError(
|
||||
"Unable to get tokenizer because "
|
||||
"`skip_tokenizer_init=True`",
|
||||
parameter="skip_tokenizer_init",
|
||||
value=True,
|
||||
)
|
||||
|
||||
token = tokenizer.decode(token_id)
|
||||
@ -720,6 +721,15 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
request: CompletionRequest,
|
||||
max_input_length: int | None = None,
|
||||
) -> RenderConfig:
|
||||
# Validate max_tokens before using it
|
||||
if request.max_tokens is not None and request.max_tokens > self.max_model_len:
|
||||
raise VLLMValidationError(
|
||||
f"'max_tokens' ({request.max_tokens}) cannot be greater than "
|
||||
f"the model's maximum context length ({self.max_model_len}).",
|
||||
parameter="max_tokens",
|
||||
value=request.max_tokens,
|
||||
)
|
||||
|
||||
max_input_tokens_len = self.max_model_len - (request.max_tokens or 0)
|
||||
return RenderConfig(
|
||||
max_length=max_input_tokens_len,
|
||||
|
||||
@ -57,6 +57,7 @@ from vllm.entrypoints.openai.protocol import (
|
||||
TranscriptionRequest,
|
||||
TranscriptionResponse,
|
||||
TranslationRequest,
|
||||
VLLMValidationError,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.pooling.classify.protocol import (
|
||||
@ -322,8 +323,10 @@ class OpenAIServing:
|
||||
input_processor = self.input_processor
|
||||
tokenizer = input_processor.tokenizer
|
||||
if tokenizer is None:
|
||||
raise ValueError(
|
||||
"You cannot use beam search when `skip_tokenizer_init=True`"
|
||||
raise VLLMValidationError(
|
||||
"You cannot use beam search when `skip_tokenizer_init=True`",
|
||||
parameter="skip_tokenizer_init",
|
||||
value=True,
|
||||
)
|
||||
|
||||
eos_token_id: int = tokenizer.eos_token_id # type: ignore
|
||||
@ -706,8 +709,7 @@ class OpenAIServing:
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
return self.create_error_response(e)
|
||||
|
||||
async def _collect_batch(
|
||||
self,
|
||||
@ -738,14 +740,43 @@ class OpenAIServing:
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
return self.create_error_response(str(e))
|
||||
return self.create_error_response(e)
|
||||
|
||||
def create_error_response(
|
||||
self,
|
||||
message: str,
|
||||
message: str | Exception,
|
||||
err_type: str = "BadRequestError",
|
||||
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
|
||||
param: str | None = None,
|
||||
) -> ErrorResponse:
|
||||
exc: Exception | None = None
|
||||
|
||||
if isinstance(message, Exception):
|
||||
exc = message
|
||||
|
||||
from vllm.entrypoints.openai.protocol import VLLMValidationError
|
||||
|
||||
if isinstance(exc, VLLMValidationError):
|
||||
err_type = "BadRequestError"
|
||||
status_code = HTTPStatus.BAD_REQUEST
|
||||
param = exc.parameter
|
||||
elif isinstance(exc, (ValueError, TypeError, RuntimeError)):
|
||||
# Common validation errors from user input
|
||||
err_type = "BadRequestError"
|
||||
status_code = HTTPStatus.BAD_REQUEST
|
||||
param = None
|
||||
elif exc.__class__.__name__ == "TemplateError":
|
||||
# jinja2.TemplateError (avoid importing jinja2)
|
||||
err_type = "BadRequestError"
|
||||
status_code = HTTPStatus.BAD_REQUEST
|
||||
param = None
|
||||
else:
|
||||
err_type = "InternalServerError"
|
||||
status_code = HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
param = None
|
||||
|
||||
message = str(exc)
|
||||
|
||||
if self.log_error_stack:
|
||||
exc_type, _, _ = sys.exc_info()
|
||||
if exc_type is not None:
|
||||
@ -753,18 +784,27 @@ class OpenAIServing:
|
||||
else:
|
||||
traceback.print_stack()
|
||||
return ErrorResponse(
|
||||
error=ErrorInfo(message=message, type=err_type, code=status_code.value)
|
||||
error=ErrorInfo(
|
||||
message=message,
|
||||
type=err_type,
|
||||
code=status_code.value,
|
||||
param=param,
|
||||
)
|
||||
)
|
||||
|
||||
def create_streaming_error_response(
|
||||
self,
|
||||
message: str,
|
||||
message: str | Exception,
|
||||
err_type: str = "BadRequestError",
|
||||
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
|
||||
param: str | None = None,
|
||||
) -> str:
|
||||
json_str = json.dumps(
|
||||
self.create_error_response(
|
||||
message=message, err_type=err_type, status_code=status_code
|
||||
message=message,
|
||||
err_type=err_type,
|
||||
status_code=status_code,
|
||||
param=param,
|
||||
).model_dump()
|
||||
)
|
||||
return json_str
|
||||
@ -825,6 +865,7 @@ class OpenAIServing:
|
||||
message=f"The model `{request.model}` does not exist.",
|
||||
err_type="NotFoundError",
|
||||
status_code=HTTPStatus.NOT_FOUND,
|
||||
param="model",
|
||||
)
|
||||
|
||||
def _get_active_default_mm_loras(self, request: AnyRequest) -> LoRARequest | None:
|
||||
@ -991,11 +1032,13 @@ class OpenAIServing:
|
||||
ClassificationChatRequest: "classification",
|
||||
}
|
||||
operation = operations.get(type(request), "embedding generation")
|
||||
raise ValueError(
|
||||
raise VLLMValidationError(
|
||||
f"This model's maximum context length is "
|
||||
f"{self.max_model_len} tokens. However, you requested "
|
||||
f"{token_num} tokens in the input for {operation}. "
|
||||
f"Please reduce the length of the input."
|
||||
f"Please reduce the length of the input.",
|
||||
parameter="input_tokens",
|
||||
value=token_num,
|
||||
)
|
||||
return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
|
||||
|
||||
@ -1017,20 +1060,24 @@ class OpenAIServing:
|
||||
# Note: input length can be up to model context length - 1 for
|
||||
# completion-like requests.
|
||||
if token_num >= self.max_model_len:
|
||||
raise ValueError(
|
||||
raise VLLMValidationError(
|
||||
f"This model's maximum context length is "
|
||||
f"{self.max_model_len} tokens. However, your request has "
|
||||
f"{token_num} input tokens. Please reduce the length of "
|
||||
"the input messages."
|
||||
"the input messages.",
|
||||
parameter="input_tokens",
|
||||
value=token_num,
|
||||
)
|
||||
|
||||
if max_tokens is not None and token_num + max_tokens > self.max_model_len:
|
||||
raise ValueError(
|
||||
raise VLLMValidationError(
|
||||
"'max_tokens' or 'max_completion_tokens' is too large: "
|
||||
f"{max_tokens}. This model's maximum context length is "
|
||||
f"{self.max_model_len} tokens and your request has "
|
||||
f"{token_num} input tokens ({max_tokens} > {self.max_model_len}"
|
||||
f" - {token_num})."
|
||||
f" - {token_num}).",
|
||||
parameter="max_tokens",
|
||||
value=max_tokens,
|
||||
)
|
||||
|
||||
return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
|
||||
|
||||
@ -94,6 +94,7 @@ from vllm.entrypoints.openai.protocol import (
|
||||
ResponsesResponse,
|
||||
ResponseUsage,
|
||||
StreamingResponsesResponse,
|
||||
VLLMValidationError,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_engine import (
|
||||
GenerationError,
|
||||
@ -271,6 +272,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
err_type="invalid_request_error",
|
||||
message=error_message,
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
param="input",
|
||||
)
|
||||
return None
|
||||
|
||||
@ -282,6 +284,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
err_type="invalid_request_error",
|
||||
message="logprobs are not supported with gpt-oss models",
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
param="logprobs",
|
||||
)
|
||||
if request.store and not self.enable_store and request.background:
|
||||
return self.create_error_response(
|
||||
@ -294,6 +297,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
"the vLLM server."
|
||||
),
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
param="background",
|
||||
)
|
||||
if request.previous_input_messages and request.previous_response_id:
|
||||
return self.create_error_response(
|
||||
@ -301,6 +305,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
message="Only one of `previous_input_messages` and "
|
||||
"`previous_response_id` can be set.",
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
param="previous_response_id",
|
||||
)
|
||||
return None
|
||||
|
||||
@ -457,8 +462,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
)
|
||||
generators.append(generator)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
return self.create_error_response(e)
|
||||
|
||||
assert len(generators) == 1
|
||||
(result_generator,) = generators
|
||||
@ -546,7 +550,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
except GenerationError as e:
|
||||
return self._convert_generation_error_to_response(e)
|
||||
except Exception as e:
|
||||
return self.create_error_response(str(e))
|
||||
return self.create_error_response(e)
|
||||
|
||||
async def _make_request(
|
||||
self,
|
||||
@ -630,8 +634,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
return self.create_error_response(e)
|
||||
|
||||
# NOTE: Implementation of stauts is still WIP, but for now
|
||||
# we guarantee that if the status is not "completed", it is accurate.
|
||||
@ -1074,7 +1077,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
response = self._convert_generation_error_to_response(e)
|
||||
except Exception as e:
|
||||
logger.exception("Background request failed for %s", request.request_id)
|
||||
response = self.create_error_response(str(e))
|
||||
response = self.create_error_response(e)
|
||||
finally:
|
||||
new_event_signal.set()
|
||||
|
||||
@ -1099,7 +1102,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
response = self._convert_generation_error_to_response(e)
|
||||
except Exception as e:
|
||||
logger.exception("Background request failed for %s", request.request_id)
|
||||
response = self.create_error_response(str(e))
|
||||
response = self.create_error_response(e)
|
||||
|
||||
if isinstance(response, ErrorResponse):
|
||||
# If the request has failed, update the status to "failed".
|
||||
@ -1116,7 +1119,11 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
starting_after: int | None = None,
|
||||
) -> AsyncGenerator[StreamingResponsesResponse, None]:
|
||||
if response_id not in self.event_store:
|
||||
raise ValueError(f"Unknown response_id: {response_id}")
|
||||
raise VLLMValidationError(
|
||||
f"Unknown response_id: {response_id}",
|
||||
parameter="response_id",
|
||||
value=response_id,
|
||||
)
|
||||
|
||||
event_deque, new_event_signal = self.event_store[response_id]
|
||||
start_index = 0 if starting_after is None else starting_after + 1
|
||||
@ -1172,6 +1179,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
return self.create_error_response(
|
||||
err_type="invalid_request_error",
|
||||
message="Cannot cancel a synchronous response.",
|
||||
param="response_id",
|
||||
)
|
||||
|
||||
# Update the status to "cancelled".
|
||||
@ -1191,6 +1199,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
err_type="invalid_request_error",
|
||||
message=f"Response with id '{response_id}' not found.",
|
||||
status_code=HTTPStatus.NOT_FOUND,
|
||||
param="response_id",
|
||||
)
|
||||
|
||||
def _make_store_not_supported_error(self) -> ErrorResponse:
|
||||
@ -1203,6 +1212,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
"starting the vLLM server."
|
||||
),
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
param="store",
|
||||
)
|
||||
|
||||
async def _process_simple_streaming_events(
|
||||
|
||||
@ -30,6 +30,7 @@ from vllm.entrypoints.openai.protocol import (
|
||||
TranslationSegment,
|
||||
TranslationStreamResponse,
|
||||
UsageInfo,
|
||||
VLLMValidationError,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing, SpeechToTextRequest
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
@ -259,7 +260,11 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
)
|
||||
|
||||
if len(audio_data) / 1024**2 > self.max_audio_filesize_mb:
|
||||
raise ValueError("Maximum file size exceeded.")
|
||||
raise VLLMValidationError(
|
||||
"Maximum file size exceeded",
|
||||
parameter="audio_filesize_mb",
|
||||
value=len(audio_data) / 1024**2,
|
||||
)
|
||||
|
||||
with io.BytesIO(audio_data) as bytes_:
|
||||
# NOTE resample to model SR here for efficiency. This is also a
|
||||
@ -287,12 +292,18 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
)
|
||||
if request.response_format == "verbose_json":
|
||||
if not isinstance(prompt, dict):
|
||||
raise ValueError(f"Expected prompt to be a dict,got {type(prompt)}")
|
||||
raise VLLMValidationError(
|
||||
"Expected prompt to be a dict",
|
||||
parameter="prompt",
|
||||
value=type(prompt).__name__,
|
||||
)
|
||||
prompt_dict = cast(dict, prompt)
|
||||
decoder_prompt = prompt.get("decoder_prompt")
|
||||
if not isinstance(decoder_prompt, str):
|
||||
raise ValueError(
|
||||
f"Expected decoder_prompt to bestr, got {type(decoder_prompt)}"
|
||||
raise VLLMValidationError(
|
||||
"Expected decoder_prompt to be str",
|
||||
parameter="decoder_prompt",
|
||||
value=type(decoder_prompt).__name__,
|
||||
)
|
||||
prompt_dict["decoder_prompt"] = decoder_prompt.replace(
|
||||
"<|notimestamps|>", "<|0.00|>"
|
||||
@ -412,7 +423,7 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
return self.create_error_response(e)
|
||||
|
||||
list_result_generator: list[AsyncGenerator[RequestOutput, None]] | None = None
|
||||
try:
|
||||
@ -448,8 +459,7 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
for i, prompt in enumerate(prompts)
|
||||
]
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
return self.create_error_response(e)
|
||||
|
||||
if request.stream:
|
||||
return stream_generator_method(
|
||||
@ -523,8 +533,7 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
return self.create_error_response(e)
|
||||
|
||||
async def _speech_to_text_stream_generator(
|
||||
self,
|
||||
@ -634,9 +643,8 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
logger.exception("Error in %s stream generator.", self.task_type)
|
||||
data = self.create_streaming_error_response(str(e))
|
||||
data = self.create_streaming_error_response(e)
|
||||
yield f"data: {data}\n\n"
|
||||
# Send the final done message after all response.n are finished
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
@ -52,6 +52,7 @@ class ServingScores(OpenAIServing):
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
score_template: str | None = None,
|
||||
log_error_stack: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
@ -60,6 +61,7 @@ class ServingScores(OpenAIServing):
|
||||
request_logger=request_logger,
|
||||
log_error_stack=log_error_stack,
|
||||
)
|
||||
self.score_template = score_template
|
||||
|
||||
async def _embedding_score(
|
||||
self,
|
||||
@ -169,6 +171,7 @@ class ServingScores(OpenAIServing):
|
||||
data_2=data_2,
|
||||
tokenizer=tokenizer,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
score_template=self.score_template,
|
||||
)
|
||||
self._validate_input(request, engine_prompt["prompt_token_ids"], full_prompt)
|
||||
if request.mm_processor_kwargs is not None:
|
||||
|
||||
@ -12,6 +12,7 @@ import torch
|
||||
from pydantic import Field
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.openai.protocol import VLLMValidationError
|
||||
from vllm.inputs.data import EmbedsPrompt, TextPrompt, TokensPrompt
|
||||
from vllm.inputs.parse import get_prompt_components, parse_raw_prompts
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
@ -162,8 +163,9 @@ class BaseRenderer(ABC):
|
||||
) -> list[EmbedsPrompt]:
|
||||
"""Load and validate base64-encoded embeddings into prompt objects."""
|
||||
if not self.model_config.enable_prompt_embeds:
|
||||
raise ValueError(
|
||||
"You must set `--enable-prompt-embeds` to input `prompt_embeds`."
|
||||
raise VLLMValidationError(
|
||||
"You must set `--enable-prompt-embeds` to input `prompt_embeds`.",
|
||||
parameter="prompt_embeds",
|
||||
)
|
||||
|
||||
def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt:
|
||||
@ -396,10 +398,12 @@ class CompletionRenderer(BaseRenderer):
|
||||
) -> TokensPrompt:
|
||||
"""Create validated TokensPrompt."""
|
||||
if max_length is not None and len(token_ids) > max_length:
|
||||
raise ValueError(
|
||||
raise VLLMValidationError(
|
||||
f"This model's maximum context length is {max_length} tokens. "
|
||||
f"However, your request has {len(token_ids)} input tokens. "
|
||||
"Please reduce the length of the input messages."
|
||||
"Please reduce the length of the input messages.",
|
||||
parameter="input_tokens",
|
||||
value=len(token_ids),
|
||||
)
|
||||
|
||||
tokens_prompt = TokensPrompt(prompt_token_ids=token_ids)
|
||||
|
||||
@ -11,9 +11,11 @@ from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionContentPartImageEmbedsParam,
|
||||
ChatCompletionContentPartImageParam,
|
||||
ChatCompletionContentPartTextParam,
|
||||
ChatTemplateResolutionError,
|
||||
MultiModalItemTracker,
|
||||
_ContentPart,
|
||||
_parse_chat_message_content_part,
|
||||
apply_hf_chat_template,
|
||||
)
|
||||
from vllm.inputs import TokensPrompt
|
||||
from vllm.model_executor.models.interfaces import supports_score_template
|
||||
@ -139,10 +141,8 @@ def _parse_score_content(
|
||||
return next(iter(mm_placeholder_storage.values()))[0]
|
||||
|
||||
|
||||
def apply_score_template(
|
||||
model_config: ModelConfig,
|
||||
prompt_1: str,
|
||||
prompt_2: str,
|
||||
def _apply_model_score_template(
|
||||
model_config: ModelConfig, prompt_1: str, prompt_2: str
|
||||
) -> str:
|
||||
# NOTE(Simon): lazy import to avoid bring in all dependencies (e.g. gguf)
|
||||
from vllm.model_executor.model_loader import get_model_cls
|
||||
@ -181,6 +181,7 @@ def get_score_prompt(
|
||||
tokenization_kwargs: dict[str, Any],
|
||||
data_1: str | ScoreContentPartParam,
|
||||
data_2: str | ScoreContentPartParam,
|
||||
score_template: str | None = None,
|
||||
) -> tuple[str, TokensPrompt]:
|
||||
prompt_1, prompt_2, mm_data = parse_score_data(
|
||||
data_1,
|
||||
@ -190,19 +191,48 @@ def get_score_prompt(
|
||||
from vllm.model_executor.model_loader import get_model_cls
|
||||
|
||||
model = get_model_cls(model_config)
|
||||
if supports_score_template(model):
|
||||
full_prompt = apply_score_template(model_config, prompt_1, prompt_2)
|
||||
prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs)
|
||||
elif model_config.use_pad_token:
|
||||
# cross_encoder models defaults to using pad_token.
|
||||
prompt_inputs = tokenizer(
|
||||
text=prompt_1, text_pair=prompt_2, **tokenization_kwargs
|
||||
)
|
||||
full_prompt = tokenizer.decode(prompt_inputs["input_ids"])
|
||||
|
||||
def default_tokenizer_encode():
|
||||
if supports_score_template(model):
|
||||
full_prompt = _apply_model_score_template(model_config, prompt_1, prompt_2)
|
||||
prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs)
|
||||
else:
|
||||
if model_config.use_pad_token:
|
||||
# cross_encoder models defaults to using pad_token.
|
||||
prompt_inputs = tokenizer(
|
||||
text=prompt_1, text_pair=prompt_2, **tokenization_kwargs
|
||||
)
|
||||
full_prompt = tokenizer.decode(prompt_inputs["input_ids"])
|
||||
else:
|
||||
# `llm as reranker` models defaults to not using pad_token.
|
||||
full_prompt = prompt_1 + prompt_2
|
||||
prompt_inputs = tokenizer(text=full_prompt, **tokenization_kwargs)
|
||||
return full_prompt, prompt_inputs
|
||||
|
||||
# FIXME: For now, we only apply a template when one is explicitly provided.
|
||||
# We cannot rely on the tokenizer's chat template because many models
|
||||
# inherit junk templates from their base LLM, which breaks both the models
|
||||
# and the tests that use them.
|
||||
if score_template is None:
|
||||
full_prompt, prompt_inputs = default_tokenizer_encode()
|
||||
else:
|
||||
# `llm as reranker` models defaults to not using pad_token.
|
||||
full_prompt = prompt_1 + prompt_2
|
||||
prompt_inputs = tokenizer(text=full_prompt, **tokenization_kwargs)
|
||||
# FIXME: Try applying a score template from the CLI arg or tokenizer_config.json
|
||||
# If that fails because there is no such template,
|
||||
# fall back to the default implementation.
|
||||
try:
|
||||
full_prompt = apply_hf_chat_template(
|
||||
tokenizer,
|
||||
[
|
||||
{"role": "query", "content": prompt_1},
|
||||
{"role": "document", "content": prompt_2},
|
||||
],
|
||||
score_template,
|
||||
tools=None,
|
||||
model_config=model_config,
|
||||
)
|
||||
prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs)
|
||||
except ChatTemplateResolutionError:
|
||||
full_prompt, prompt_inputs = default_tokenizer_encode()
|
||||
|
||||
engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["input_ids"])
|
||||
|
||||
|
||||
@ -12,7 +12,6 @@ from vllm.lora.peft_helper import PEFTHelper
|
||||
from vllm.lora.utils import (
|
||||
get_lora_id,
|
||||
is_base_embeddding_weights,
|
||||
is_regex_target_modules,
|
||||
parse_fine_tuned_lora_name,
|
||||
)
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
@ -201,37 +200,13 @@ class LoRAModel:
|
||||
for module in f.keys(): # noqa
|
||||
tensors[module] = f.get_tensor(module)
|
||||
elif os.path.isfile(lora_bin_file_path) or os.path.isfile(lora_pt_file_path):
|
||||
# When a bin/pt file is provided, we rely on config to find
|
||||
# unexpected modules.
|
||||
unexpected_modules = []
|
||||
target_modules = peft_helper.target_modules
|
||||
if not isinstance(target_modules, list):
|
||||
target_modules = [target_modules]
|
||||
for module in target_modules:
|
||||
# Compatible with more modules,
|
||||
# such as:layers.11.self_attn.k_proj
|
||||
part_name = module.split(".")[-1]
|
||||
if part_name not in expected_lora_modules:
|
||||
unexpected_modules.append(module)
|
||||
# loaded lora's target modules must be a subset of
|
||||
# expected_lora_modules. It is not reliable. See
|
||||
# https://github.com/vllm-project/vllm/pull/5909. But there's no
|
||||
# other better mechanism.
|
||||
if unexpected_modules and not is_regex_target_modules(
|
||||
peft_helper.target_modules, expected_lora_modules
|
||||
):
|
||||
raise ValueError(
|
||||
f"While loading {lora_dir}, expected"
|
||||
f" target modules in {expected_lora_modules}"
|
||||
f" but received {unexpected_modules}."
|
||||
f" Please verify that the loaded LoRA module is correct"
|
||||
)
|
||||
lora_file_path = (
|
||||
lora_bin_file_path
|
||||
if os.path.isfile(lora_bin_file_path)
|
||||
else lora_pt_file_path
|
||||
)
|
||||
tensors = torch.load(lora_file_path, map_location=device, weights_only=True)
|
||||
check_unexpected_modules(tensors)
|
||||
else:
|
||||
raise ValueError(f"{lora_dir} doesn't contain tensors")
|
||||
|
||||
|
||||
@ -5,7 +5,6 @@ import os
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import huggingface_hub
|
||||
import regex as re
|
||||
from huggingface_hub.utils import (
|
||||
EntryNotFoundError,
|
||||
HfHubHTTPError,
|
||||
@ -186,39 +185,6 @@ def is_base_embeddding_weights(name: str) -> bool:
|
||||
return name.endswith(embedding_suffixes)
|
||||
|
||||
|
||||
def is_regex_target_modules(
|
||||
load_modules: str | list[str], expected_lora_modules: set[str]
|
||||
) -> bool:
|
||||
"""
|
||||
PEFT supports passing `target_modules` in the form of regular expressions,
|
||||
such as `model.*(q_proj|k_proj|v_proj)$`. This function is mainly used to
|
||||
determine whether the suffix in the regular expression is present in the
|
||||
`expected_lora_modules`.
|
||||
"""
|
||||
|
||||
def is_valid_regex(pattern):
|
||||
try:
|
||||
re.compile(pattern)
|
||||
return True
|
||||
except re.error:
|
||||
return False
|
||||
|
||||
def is_subset(sub_list, full_set):
|
||||
return set(sub_list).issubset(full_set)
|
||||
|
||||
# Similar to PEFT's processing logic, regex-related operations are only
|
||||
# executed when the load_modules is a `str`.
|
||||
if not isinstance(load_modules, str):
|
||||
return False
|
||||
|
||||
if is_valid_regex(load_modules):
|
||||
match = re.search(r"\((.*?)\)\$?$", load_modules)
|
||||
if match:
|
||||
suffix = match.group(1).split("|")
|
||||
return is_subset(suffix, expected_lora_modules)
|
||||
return False
|
||||
|
||||
|
||||
def get_supported_lora_modules(model: nn.Module) -> list[str]:
|
||||
"""
|
||||
In vLLM, all linear layers support LoRA.
|
||||
|
||||
@ -88,6 +88,26 @@ class JinaRobertaModelConfig(VerifyAndUpdateConfig):
|
||||
}
|
||||
|
||||
|
||||
class LlamaBidirectionalConfig(VerifyAndUpdateConfig):
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
from vllm.config.pooler import PoolingTypeStr
|
||||
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
hf_config.is_causal = False
|
||||
|
||||
pooling_type_map: dict[str, PoolingTypeStr] = {
|
||||
"avg": "MEAN",
|
||||
"cls": "CLS",
|
||||
"last": "LAST",
|
||||
}
|
||||
|
||||
pooling_type = pooling_type_map.get(hf_config.pooling, None)
|
||||
if pooling_type is None:
|
||||
raise ValueError(f"pool_type {hf_config.pooling} not supported")
|
||||
vllm_config.model_config.pooler_config.pooling_type = pooling_type
|
||||
|
||||
|
||||
class NomicBertModelConfig(VerifyAndUpdateConfig):
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
@ -509,6 +529,8 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
||||
"GteNewModel": GteNewModelConfig,
|
||||
"GteNewForSequenceClassification": GteNewModelConfig,
|
||||
"Gemma3TextModel": Gemma3TextModelConfig,
|
||||
"LlamaBidirectionalForSequenceClassification": LlamaBidirectionalConfig,
|
||||
"LlamaBidirectionalModel": LlamaBidirectionalConfig,
|
||||
"NomicBertModel": NomicBertModelConfig,
|
||||
"Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig,
|
||||
"Qwen2ForRewardModel": Qwen2ForRewardModelConfig,
|
||||
|
||||
@ -57,7 +57,14 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsEagle, SupportsEagle3, SupportsLoRA, SupportsPP
|
||||
from .adapters import as_embedding_model, as_seq_cls_model
|
||||
from .interfaces import (
|
||||
SupportsEagle,
|
||||
SupportsEagle3,
|
||||
SupportsLoRA,
|
||||
SupportsPP,
|
||||
)
|
||||
from .interfaces_base import attn_type
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
PPMissingLayer,
|
||||
@ -698,3 +705,17 @@ class LlamaForCausalLM(
|
||||
name = name.replace(item, mapping[item])
|
||||
|
||||
return name, loaded_weight
|
||||
|
||||
|
||||
@attn_type("encoder_only")
|
||||
class LlamaBidirectionalForSequenceClassification(as_seq_cls_model(LlamaForCausalLM)):
|
||||
# This class sets the correct attention type and pooling type
|
||||
# through LlamaBidirectionalConfig.
|
||||
pass
|
||||
|
||||
|
||||
@attn_type("encoder_only")
|
||||
class LlamaBidirectionalModel(as_embedding_model(LlamaForCausalLM)):
|
||||
# This class sets the correct attention type and pooling type
|
||||
# through LlamaBidirectionalConfig.
|
||||
pass
|
||||
|
||||
@ -203,6 +203,7 @@ _EMBEDDING_MODELS = {
|
||||
"GteNewModel": ("bert_with_rope", "GteNewModel"),
|
||||
"InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
|
||||
"JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"), # noqa: E501
|
||||
"LlamaBidirectionalModel": ("llama", "LlamaBidirectionalModel"),
|
||||
"LlamaModel": ("llama", "LlamaForCausalLM"),
|
||||
**{
|
||||
# Multiple models share the same architecture, so we include them all
|
||||
@ -246,6 +247,11 @@ _CROSS_ENCODER_MODELS = {
|
||||
"bert_with_rope",
|
||||
"GteNewForSequenceClassification",
|
||||
),
|
||||
"JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"),
|
||||
"LlamaBidirectionalForSequenceClassification": (
|
||||
"llama",
|
||||
"LlamaBidirectionalForSequenceClassification",
|
||||
),
|
||||
"ModernBertForSequenceClassification": (
|
||||
"modernbert",
|
||||
"ModernBertForSequenceClassification",
|
||||
@ -259,8 +265,6 @@ _CROSS_ENCODER_MODELS = {
|
||||
"roberta",
|
||||
"RobertaForSequenceClassification",
|
||||
),
|
||||
# [Auto-converted (see adapters.py)]
|
||||
"JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501,
|
||||
}
|
||||
|
||||
_MULTIMODAL_MODELS = {
|
||||
|
||||
@ -22,7 +22,6 @@ from typing import TYPE_CHECKING, Literal
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers.configuration_utils import ALLOWED_LAYER_TYPES
|
||||
|
||||
from vllm.config.utils import getattr_iter
|
||||
from vllm.logger import init_logger
|
||||
@ -32,6 +31,7 @@ from vllm.model_executor.layers.linear import (
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.transformers_utils.config import is_rope_parameters_nested
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
@ -207,7 +207,7 @@ def can_enable_torch_compile(vllm_config: "VllmConfig") -> bool:
|
||||
rope_parameters: dict | None = getattr(text_config, "rope_parameters", None) or {}
|
||||
if rope_parameters:
|
||||
# Nest rope_parameters if not nested already to simplify logic
|
||||
if not set(rope_parameters.keys()).issubset(ALLOWED_LAYER_TYPES):
|
||||
if not is_rope_parameters_nested(rope_parameters):
|
||||
rope_parameters = {"": rope_parameters}
|
||||
return all(rp["rope_type"] != "dynamic" for rp in rope_parameters.values())
|
||||
return True
|
||||
|
||||
@ -15,7 +15,6 @@ from huggingface_hub import (
|
||||
)
|
||||
from packaging.version import Version
|
||||
from transformers import GenerationConfig, PretrainedConfig
|
||||
from transformers.configuration_utils import ALLOWED_LAYER_TYPES
|
||||
from transformers.models.auto.image_processing_auto import get_image_processor_config
|
||||
from transformers.models.auto.modeling_auto import (
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
||||
@ -44,6 +43,16 @@ from .repo_utils import (
|
||||
with_retry,
|
||||
)
|
||||
|
||||
try:
|
||||
# Transformers v5
|
||||
from transformers.configuration_utils import ALLOWED_ATTENTION_LAYER_TYPES
|
||||
except ImportError:
|
||||
# Transformers v4
|
||||
from transformers.configuration_utils import (
|
||||
ALLOWED_LAYER_TYPES as ALLOWED_ATTENTION_LAYER_TYPES,
|
||||
)
|
||||
|
||||
|
||||
if envs.VLLM_USE_MODELSCOPE:
|
||||
from modelscope import AutoConfig
|
||||
else:
|
||||
@ -104,6 +113,14 @@ _AUTO_CONFIG_KWARGS_OVERRIDES: dict[str, dict[str, Any]] = {
|
||||
}
|
||||
|
||||
|
||||
def is_rope_parameters_nested(rope_parameters: dict[str, Any]) -> bool:
|
||||
"""Check if rope_parameters is nested by layer types."""
|
||||
# Cannot be nested if rope_parameters is empty
|
||||
if not rope_parameters:
|
||||
return False
|
||||
return set(rope_parameters.keys()).issubset(ALLOWED_ATTENTION_LAYER_TYPES)
|
||||
|
||||
|
||||
class HFConfigParser(ConfigParserBase):
|
||||
def parse(
|
||||
self,
|
||||
@ -346,7 +363,7 @@ def patch_rope_parameters(config: PretrainedConfig) -> None:
|
||||
config.rope_parameters["original_max_position_embeddings"] = ompe
|
||||
|
||||
# Handle nested rope_parameters in interleaved sliding attention models
|
||||
if set(config.rope_parameters.keys()).issubset(ALLOWED_LAYER_TYPES):
|
||||
if is_rope_parameters_nested(config.rope_parameters):
|
||||
for rope_parameters_layer_type in config.rope_parameters.values():
|
||||
patch_rope_parameters_dict(rope_parameters_layer_type)
|
||||
else:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user