Add filtering for chat template kwargs (#25794)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Russell Bryant 2025-09-27 06:46:49 -04:00 committed by GitHub
parent 3f5d902d2a
commit 7977e5027c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 158 additions and 6 deletions

View File

@ -19,6 +19,7 @@ from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template,
parse_chat_messages, parse_chat_messages,
parse_chat_messages_futures, parse_chat_messages_futures,
resolve_chat_template_content_format, resolve_chat_template_content_format,
resolve_chat_template_kwargs,
resolve_hf_chat_template) resolve_hf_chat_template)
from vllm.multimodal import MultiModalDataDict, MultiModalUUIDDict from vllm.multimodal import MultiModalDataDict, MultiModalUUIDDict
from vllm.multimodal.utils import (encode_audio_base64, encode_image_base64, from vllm.multimodal.utils import (encode_audio_base64, encode_image_base64,
@ -37,6 +38,7 @@ QWEN2AUDIO_MODEL_ID = "Qwen/Qwen2-Audio-7B-Instruct"
QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct" QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
QWEN25VL_MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct" QWEN25VL_MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"
QWEN25OMNI_MODEL_ID = "Qwen/Qwen2.5-Omni-7B" QWEN25OMNI_MODEL_ID = "Qwen/Qwen2.5-Omni-7B"
QWEN3_MODEL_ID = "Qwen/Qwen3-8B"
LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B" LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B"
HERMES_MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B" HERMES_MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B"
MISTRAL_MODEL_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" MISTRAL_MODEL_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
@ -2255,6 +2257,89 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
assert isinstance(chat_template, str) assert isinstance(chat_template, str)
@pytest.mark.parametrize(
"model, expected_kwargs",
[
(
QWEN2VL_MODEL_ID,
{
"add_vision_id", "add_generation_prompt",
"continue_final_message", "tools"
},
),
(
QWEN3_MODEL_ID,
{
"enable_thinking", "add_generation_prompt",
"continue_final_message", "tools"
},
),
],
)
def test_resolve_hf_chat_template_kwargs(sample_json_schema, model,
expected_kwargs):
"""checks that chat_template is a dict type for HF models."""
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
tools = ([{
"type": "function",
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": sample_json_schema,
},
}])
chat_template_kwargs = {
# both unused
"unsed_kwargs_1": 123,
"unsed_kwargs_2": "abc",
# should not appear
"chat_template": "{% Hello world! %}",
# used by tokenizer
"continue_final_message": True,
"tools": tools,
# both used by Qwen2-VL and Qwen3
"add_generation_prompt": True,
# only used by Qwen2-VL
"add_vision_id": True,
# only used by Qwen3
"enable_thinking": True,
}
model_config = ModelConfig(
model,
tokenizer=model_info.tokenizer or model,
tokenizer_mode=model_info.tokenizer_mode,
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
skip_tokenizer_init=model_info.skip_tokenizer_init,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype)
# Build the tokenizer
tokenizer = get_tokenizer(
model,
trust_remote_code=model_config.trust_remote_code,
)
# Test detecting the tokenizer's chat_template
chat_template = resolve_hf_chat_template(
tokenizer,
chat_template=None,
tools=tools,
model_config=model_config,
)
resolved_chat_template_kwargs = resolve_chat_template_kwargs(
tokenizer,
chat_template=chat_template,
chat_template_kwargs=chat_template_kwargs,
)
assert set(resolved_chat_template_kwargs.keys()) == expected_kwargs
# NOTE: Qwen2-Audio default chat template is specially defined inside # NOTE: Qwen2-Audio default chat template is specially defined inside
# processor class instead of using `tokenizer_config.json` # processor class instead of using `tokenizer_config.json`
# yapf: disable # yapf: disable

View File

@ -11,7 +11,12 @@ from pathlib import Path
from typing import (Any, Callable, Generic, Literal, Optional, TypeVar, Union, from typing import (Any, Callable, Generic, Literal, Optional, TypeVar, Union,
cast) cast)
import jinja2
import jinja2.ext
import jinja2.meta
import jinja2.nodes import jinja2.nodes
import jinja2.parser
import jinja2.sandbox
import transformers.utils.chat_template_utils as hf_chat_utils import transformers.utils.chat_template_utils as hf_chat_utils
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
@ -50,7 +55,7 @@ from vllm.transformers_utils.chat_templates import (
# yapf: enable # yapf: enable
from vllm.transformers_utils.processor import cached_get_processor from vllm.transformers_utils.processor import cached_get_processor
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import random_uuid from vllm.utils import random_uuid, supports_kw
logger = init_logger(__name__) logger = init_logger(__name__)
@ -1554,6 +1559,46 @@ def parse_chat_messages_futures(
return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids() return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
# adapted from https://github.com/huggingface/transformers/blob/v4.56.2/src/transformers/utils/chat_template_utils.py#L398-L412
# only preserve the parse function used to resolve chat template kwargs
class AssistantTracker(jinja2.ext.Extension):
tags = {"generation"}
def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock:
lineno = next(parser.stream).lineno
body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
call = self.call_method("_generation_support")
call_block = jinja2.nodes.CallBlock(call, [], [], body)
return call_block.set_lineno(lineno)
def resolve_chat_template_kwargs(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
chat_template: str,
chat_template_kwargs: dict[str, Any],
) -> dict[str, Any]:
fn_kw = {
k for k in chat_template_kwargs
if supports_kw(tokenizer.apply_chat_template, k, allow_var_kwargs=False)
}
env = jinja2.sandbox.ImmutableSandboxedEnvironment(
trim_blocks=True,
lstrip_blocks=True,
extensions=[AssistantTracker, jinja2.ext.loopcontrols],
)
parsed_content = env.parse(chat_template)
template_vars = jinja2.meta.find_undeclared_variables(parsed_content)
# We exclude chat_template from kwargs here, because
# chat template has been already resolved at this stage
unexpected_vars = {"chat_template"}
accept_vars = (fn_kw | template_vars) - unexpected_vars
return {
k: v for k, v in chat_template_kwargs.items() if k in accept_vars
}
def apply_hf_chat_template( def apply_hf_chat_template(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
conversation: list[ConversationMessage], conversation: list[ConversationMessage],
@ -1579,12 +1624,17 @@ def apply_hf_chat_template(
) )
try: try:
resolved_kwargs = resolve_chat_template_kwargs(
tokenizer=tokenizer,
chat_template=hf_chat_template,
chat_template_kwargs=kwargs,
)
return tokenizer.apply_chat_template( return tokenizer.apply_chat_template(
conversation=conversation, # type: ignore[arg-type] conversation=conversation, # type: ignore[arg-type]
tools=tools, # type: ignore[arg-type] tools=tools, # type: ignore[arg-type]
chat_template=hf_chat_template, chat_template=hf_chat_template,
tokenize=tokenize, tokenize=tokenize,
**kwargs, **resolved_kwargs,
) )
# External library exceptions can sometimes occur despite the framework's # External library exceptions can sometimes occur despite the framework's

View File

@ -1716,6 +1716,7 @@ async def init_app_state(
request_logger=request_logger, request_logger=request_logger,
chat_template=resolved_chat_template, chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format, chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
return_tokens_as_token_ids=args.return_tokens_as_token_ids, return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice, enable_auto_tools=args.enable_auto_tool_choice,
exclude_tools_when_tool_choice_none=args. exclude_tools_when_tool_choice_none=args.

View File

@ -103,9 +103,13 @@ class FrontendArgs:
chat_template_content_format: ChatTemplateContentFormatOption = "auto" chat_template_content_format: ChatTemplateContentFormatOption = "auto"
"""The format to render message content within a chat template. """The format to render message content within a chat template.
* "string" will render the content as a string. Example: `"Hello World"` * "string" will render the content as a string. Example: `"Hello World"`
* "openai" will render the content as a list of dictionaries, similar to OpenAI * "openai" will render the content as a list of dictionaries, similar to
schema. Example: `[{"type": "text", "text": "Hello world!"}]`""" OpenAI schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
trust_request_chat_template: bool = False
"""Whether to trust the chat template provided in the request. If False,
the server will always use the chat template specified by `--chat-template`
or the ones from tokenizer."""
response_role: str = "assistant" response_role: str = "assistant"
"""The role name to return if `request.add_generation_prompt=true`.""" """The role name to return if `request.add_generation_prompt=true`."""
ssl_keyfile: Optional[str] = None ssl_keyfile: Optional[str] = None

View File

@ -68,6 +68,7 @@ class OpenAIServingChat(OpenAIServing):
request_logger: Optional[RequestLogger], request_logger: Optional[RequestLogger],
chat_template: Optional[str], chat_template: Optional[str],
chat_template_content_format: ChatTemplateContentFormatOption, chat_template_content_format: ChatTemplateContentFormatOption,
trust_request_chat_template: bool = False,
return_tokens_as_token_ids: bool = False, return_tokens_as_token_ids: bool = False,
reasoning_parser: str = "", reasoning_parser: str = "",
enable_auto_tools: bool = False, enable_auto_tools: bool = False,
@ -89,6 +90,7 @@ class OpenAIServingChat(OpenAIServing):
self.response_role = response_role self.response_role = response_role
self.chat_template = chat_template self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
self.enable_log_outputs = enable_log_outputs self.enable_log_outputs = enable_log_outputs
# set up tool use # set up tool use
@ -220,6 +222,16 @@ class OpenAIServingChat(OpenAIServing):
if not self.use_harmony: if not self.use_harmony:
# Common case. # Common case.
request_chat_template = request.chat_template
chat_template_kwargs = request.chat_template_kwargs
if not self.trust_request_chat_template and (
request_chat_template is not None or
(chat_template_kwargs and
chat_template_kwargs.get("chat_template") is not None)):
return self.create_error_response(
"Chat template is passed with request, but "
"--trust-request-chat-template is not set. "
"Refused request with untrusted chat template.")
( (
conversation, conversation,
request_prompts, request_prompts,
@ -228,7 +240,7 @@ class OpenAIServingChat(OpenAIServing):
request, request,
tokenizer, tokenizer,
request.messages, request.messages,
chat_template=request.chat_template or self.chat_template, chat_template=request_chat_template or self.chat_template,
chat_template_content_format=self. chat_template_content_format=self.
chat_template_content_format, chat_template_content_format,
add_generation_prompt=request.add_generation_prompt, add_generation_prompt=request.add_generation_prompt,