mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-27 12:54:28 +08:00
Add filtering for chat template kwargs (#25794)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
3f5d902d2a
commit
7977e5027c
@ -19,6 +19,7 @@ from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template,
|
|||||||
parse_chat_messages,
|
parse_chat_messages,
|
||||||
parse_chat_messages_futures,
|
parse_chat_messages_futures,
|
||||||
resolve_chat_template_content_format,
|
resolve_chat_template_content_format,
|
||||||
|
resolve_chat_template_kwargs,
|
||||||
resolve_hf_chat_template)
|
resolve_hf_chat_template)
|
||||||
from vllm.multimodal import MultiModalDataDict, MultiModalUUIDDict
|
from vllm.multimodal import MultiModalDataDict, MultiModalUUIDDict
|
||||||
from vllm.multimodal.utils import (encode_audio_base64, encode_image_base64,
|
from vllm.multimodal.utils import (encode_audio_base64, encode_image_base64,
|
||||||
@ -37,6 +38,7 @@ QWEN2AUDIO_MODEL_ID = "Qwen/Qwen2-Audio-7B-Instruct"
|
|||||||
QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
|
QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
|
||||||
QWEN25VL_MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"
|
QWEN25VL_MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||||
QWEN25OMNI_MODEL_ID = "Qwen/Qwen2.5-Omni-7B"
|
QWEN25OMNI_MODEL_ID = "Qwen/Qwen2.5-Omni-7B"
|
||||||
|
QWEN3_MODEL_ID = "Qwen/Qwen3-8B"
|
||||||
LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B"
|
LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B"
|
||||||
HERMES_MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B"
|
HERMES_MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B"
|
||||||
MISTRAL_MODEL_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
|
MISTRAL_MODEL_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
|
||||||
@ -2255,6 +2257,89 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
|
|||||||
assert isinstance(chat_template, str)
|
assert isinstance(chat_template, str)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model, expected_kwargs",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
QWEN2VL_MODEL_ID,
|
||||||
|
{
|
||||||
|
"add_vision_id", "add_generation_prompt",
|
||||||
|
"continue_final_message", "tools"
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
QWEN3_MODEL_ID,
|
||||||
|
{
|
||||||
|
"enable_thinking", "add_generation_prompt",
|
||||||
|
"continue_final_message", "tools"
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_resolve_hf_chat_template_kwargs(sample_json_schema, model,
|
||||||
|
expected_kwargs):
|
||||||
|
"""checks that chat_template is a dict type for HF models."""
|
||||||
|
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||||
|
model_info.check_available_online(on_fail="skip")
|
||||||
|
|
||||||
|
tools = ([{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "dummy_function_name",
|
||||||
|
"description": "This is a dummy function",
|
||||||
|
"parameters": sample_json_schema,
|
||||||
|
},
|
||||||
|
}])
|
||||||
|
|
||||||
|
chat_template_kwargs = {
|
||||||
|
# both unused
|
||||||
|
"unsed_kwargs_1": 123,
|
||||||
|
"unsed_kwargs_2": "abc",
|
||||||
|
# should not appear
|
||||||
|
"chat_template": "{% Hello world! %}",
|
||||||
|
# used by tokenizer
|
||||||
|
"continue_final_message": True,
|
||||||
|
"tools": tools,
|
||||||
|
# both used by Qwen2-VL and Qwen3
|
||||||
|
"add_generation_prompt": True,
|
||||||
|
# only used by Qwen2-VL
|
||||||
|
"add_vision_id": True,
|
||||||
|
# only used by Qwen3
|
||||||
|
"enable_thinking": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
model_config = ModelConfig(
|
||||||
|
model,
|
||||||
|
tokenizer=model_info.tokenizer or model,
|
||||||
|
tokenizer_mode=model_info.tokenizer_mode,
|
||||||
|
revision=model_info.revision,
|
||||||
|
trust_remote_code=model_info.trust_remote_code,
|
||||||
|
hf_overrides=model_info.hf_overrides,
|
||||||
|
skip_tokenizer_init=model_info.skip_tokenizer_init,
|
||||||
|
enforce_eager=model_info.enforce_eager,
|
||||||
|
dtype=model_info.dtype)
|
||||||
|
|
||||||
|
# Build the tokenizer
|
||||||
|
tokenizer = get_tokenizer(
|
||||||
|
model,
|
||||||
|
trust_remote_code=model_config.trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test detecting the tokenizer's chat_template
|
||||||
|
chat_template = resolve_hf_chat_template(
|
||||||
|
tokenizer,
|
||||||
|
chat_template=None,
|
||||||
|
tools=tools,
|
||||||
|
model_config=model_config,
|
||||||
|
)
|
||||||
|
resolved_chat_template_kwargs = resolve_chat_template_kwargs(
|
||||||
|
tokenizer,
|
||||||
|
chat_template=chat_template,
|
||||||
|
chat_template_kwargs=chat_template_kwargs,
|
||||||
|
)
|
||||||
|
assert set(resolved_chat_template_kwargs.keys()) == expected_kwargs
|
||||||
|
|
||||||
|
|
||||||
# NOTE: Qwen2-Audio default chat template is specially defined inside
|
# NOTE: Qwen2-Audio default chat template is specially defined inside
|
||||||
# processor class instead of using `tokenizer_config.json`
|
# processor class instead of using `tokenizer_config.json`
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
|
|||||||
@ -11,7 +11,12 @@ from pathlib import Path
|
|||||||
from typing import (Any, Callable, Generic, Literal, Optional, TypeVar, Union,
|
from typing import (Any, Callable, Generic, Literal, Optional, TypeVar, Union,
|
||||||
cast)
|
cast)
|
||||||
|
|
||||||
|
import jinja2
|
||||||
|
import jinja2.ext
|
||||||
|
import jinja2.meta
|
||||||
import jinja2.nodes
|
import jinja2.nodes
|
||||||
|
import jinja2.parser
|
||||||
|
import jinja2.sandbox
|
||||||
import transformers.utils.chat_template_utils as hf_chat_utils
|
import transformers.utils.chat_template_utils as hf_chat_utils
|
||||||
# yapf conflicts with isort for this block
|
# yapf conflicts with isort for this block
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
@ -50,7 +55,7 @@ from vllm.transformers_utils.chat_templates import (
|
|||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.transformers_utils.processor import cached_get_processor
|
from vllm.transformers_utils.processor import cached_get_processor
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||||
from vllm.utils import random_uuid
|
from vllm.utils import random_uuid, supports_kw
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -1554,6 +1559,46 @@ def parse_chat_messages_futures(
|
|||||||
return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
|
return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
|
||||||
|
|
||||||
|
|
||||||
|
# adapted from https://github.com/huggingface/transformers/blob/v4.56.2/src/transformers/utils/chat_template_utils.py#L398-L412
|
||||||
|
# only preserve the parse function used to resolve chat template kwargs
|
||||||
|
class AssistantTracker(jinja2.ext.Extension):
|
||||||
|
tags = {"generation"}
|
||||||
|
|
||||||
|
def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock:
|
||||||
|
lineno = next(parser.stream).lineno
|
||||||
|
body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
|
||||||
|
call = self.call_method("_generation_support")
|
||||||
|
call_block = jinja2.nodes.CallBlock(call, [], [], body)
|
||||||
|
return call_block.set_lineno(lineno)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_chat_template_kwargs(
|
||||||
|
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||||
|
chat_template: str,
|
||||||
|
chat_template_kwargs: dict[str, Any],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
fn_kw = {
|
||||||
|
k for k in chat_template_kwargs
|
||||||
|
if supports_kw(tokenizer.apply_chat_template, k, allow_var_kwargs=False)
|
||||||
|
}
|
||||||
|
|
||||||
|
env = jinja2.sandbox.ImmutableSandboxedEnvironment(
|
||||||
|
trim_blocks=True,
|
||||||
|
lstrip_blocks=True,
|
||||||
|
extensions=[AssistantTracker, jinja2.ext.loopcontrols],
|
||||||
|
)
|
||||||
|
parsed_content = env.parse(chat_template)
|
||||||
|
template_vars = jinja2.meta.find_undeclared_variables(parsed_content)
|
||||||
|
|
||||||
|
# We exclude chat_template from kwargs here, because
|
||||||
|
# chat template has been already resolved at this stage
|
||||||
|
unexpected_vars = {"chat_template"}
|
||||||
|
accept_vars = (fn_kw | template_vars) - unexpected_vars
|
||||||
|
return {
|
||||||
|
k: v for k, v in chat_template_kwargs.items() if k in accept_vars
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def apply_hf_chat_template(
|
def apply_hf_chat_template(
|
||||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||||
conversation: list[ConversationMessage],
|
conversation: list[ConversationMessage],
|
||||||
@ -1579,12 +1624,17 @@ def apply_hf_chat_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
resolved_kwargs = resolve_chat_template_kwargs(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
chat_template=hf_chat_template,
|
||||||
|
chat_template_kwargs=kwargs,
|
||||||
|
)
|
||||||
return tokenizer.apply_chat_template(
|
return tokenizer.apply_chat_template(
|
||||||
conversation=conversation, # type: ignore[arg-type]
|
conversation=conversation, # type: ignore[arg-type]
|
||||||
tools=tools, # type: ignore[arg-type]
|
tools=tools, # type: ignore[arg-type]
|
||||||
chat_template=hf_chat_template,
|
chat_template=hf_chat_template,
|
||||||
tokenize=tokenize,
|
tokenize=tokenize,
|
||||||
**kwargs,
|
**resolved_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# External library exceptions can sometimes occur despite the framework's
|
# External library exceptions can sometimes occur despite the framework's
|
||||||
|
|||||||
@ -1716,6 +1716,7 @@ async def init_app_state(
|
|||||||
request_logger=request_logger,
|
request_logger=request_logger,
|
||||||
chat_template=resolved_chat_template,
|
chat_template=resolved_chat_template,
|
||||||
chat_template_content_format=args.chat_template_content_format,
|
chat_template_content_format=args.chat_template_content_format,
|
||||||
|
trust_request_chat_template=args.trust_request_chat_template,
|
||||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||||
enable_auto_tools=args.enable_auto_tool_choice,
|
enable_auto_tools=args.enable_auto_tool_choice,
|
||||||
exclude_tools_when_tool_choice_none=args.
|
exclude_tools_when_tool_choice_none=args.
|
||||||
|
|||||||
@ -103,9 +103,13 @@ class FrontendArgs:
|
|||||||
chat_template_content_format: ChatTemplateContentFormatOption = "auto"
|
chat_template_content_format: ChatTemplateContentFormatOption = "auto"
|
||||||
"""The format to render message content within a chat template.
|
"""The format to render message content within a chat template.
|
||||||
|
|
||||||
* "string" will render the content as a string. Example: `"Hello World"`
|
* "string" will render the content as a string. Example: `"Hello World"`
|
||||||
* "openai" will render the content as a list of dictionaries, similar to OpenAI
|
* "openai" will render the content as a list of dictionaries, similar to
|
||||||
schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
|
OpenAI schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
|
||||||
|
trust_request_chat_template: bool = False
|
||||||
|
"""Whether to trust the chat template provided in the request. If False,
|
||||||
|
the server will always use the chat template specified by `--chat-template`
|
||||||
|
or the ones from tokenizer."""
|
||||||
response_role: str = "assistant"
|
response_role: str = "assistant"
|
||||||
"""The role name to return if `request.add_generation_prompt=true`."""
|
"""The role name to return if `request.add_generation_prompt=true`."""
|
||||||
ssl_keyfile: Optional[str] = None
|
ssl_keyfile: Optional[str] = None
|
||||||
|
|||||||
@ -68,6 +68,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
request_logger: Optional[RequestLogger],
|
request_logger: Optional[RequestLogger],
|
||||||
chat_template: Optional[str],
|
chat_template: Optional[str],
|
||||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||||
|
trust_request_chat_template: bool = False,
|
||||||
return_tokens_as_token_ids: bool = False,
|
return_tokens_as_token_ids: bool = False,
|
||||||
reasoning_parser: str = "",
|
reasoning_parser: str = "",
|
||||||
enable_auto_tools: bool = False,
|
enable_auto_tools: bool = False,
|
||||||
@ -89,6 +90,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
self.response_role = response_role
|
self.response_role = response_role
|
||||||
self.chat_template = chat_template
|
self.chat_template = chat_template
|
||||||
self.chat_template_content_format: Final = chat_template_content_format
|
self.chat_template_content_format: Final = chat_template_content_format
|
||||||
|
self.trust_request_chat_template = trust_request_chat_template
|
||||||
self.enable_log_outputs = enable_log_outputs
|
self.enable_log_outputs = enable_log_outputs
|
||||||
|
|
||||||
# set up tool use
|
# set up tool use
|
||||||
@ -220,6 +222,16 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
|
|
||||||
if not self.use_harmony:
|
if not self.use_harmony:
|
||||||
# Common case.
|
# Common case.
|
||||||
|
request_chat_template = request.chat_template
|
||||||
|
chat_template_kwargs = request.chat_template_kwargs
|
||||||
|
if not self.trust_request_chat_template and (
|
||||||
|
request_chat_template is not None or
|
||||||
|
(chat_template_kwargs and
|
||||||
|
chat_template_kwargs.get("chat_template") is not None)):
|
||||||
|
return self.create_error_response(
|
||||||
|
"Chat template is passed with request, but "
|
||||||
|
"--trust-request-chat-template is not set. "
|
||||||
|
"Refused request with untrusted chat template.")
|
||||||
(
|
(
|
||||||
conversation,
|
conversation,
|
||||||
request_prompts,
|
request_prompts,
|
||||||
@ -228,7 +240,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
request,
|
request,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
request.messages,
|
request.messages,
|
||||||
chat_template=request.chat_template or self.chat_template,
|
chat_template=request_chat_template or self.chat_template,
|
||||||
chat_template_content_format=self.
|
chat_template_content_format=self.
|
||||||
chat_template_content_format,
|
chat_template_content_format,
|
||||||
add_generation_prompt=request.add_generation_prompt,
|
add_generation_prompt=request.add_generation_prompt,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user