[Fix] Move "model_config" as keyword args in chat_utils.py (#18098)

Signed-off-by: Linkun <github@lkchen.net>
This commit is contained in:
lkchen 2025-05-13 23:27:26 -07:00 committed by GitHub
parent 33011318c2
commit 6685890d11
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 42 additions and 23 deletions

View File

@ -122,10 +122,10 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
# Call the function and get the result # Call the function and get the result
result = apply_hf_chat_template( result = apply_hf_chat_template(
model_config, tokenizer=tokenizer,
tokenizer,
conversation=mock_request.messages, conversation=mock_request.messages,
chat_template=mock_request.chat_template or template_content, chat_template=mock_request.chat_template or template_content,
model_config=model_config,
tools=None, tools=None,
add_generation_prompt=mock_request.add_generation_prompt, add_generation_prompt=mock_request.add_generation_prompt,
continue_final_message=mock_request.continue_final_message, continue_final_message=mock_request.continue_final_message,

View File

@ -793,10 +793,10 @@ def test_multimodal_image_parsing_matches_hf(model, image_url):
) )
vllm_result = apply_hf_chat_template( vllm_result = apply_hf_chat_template(
model_config, tokenizer=tokenizer,
tokenizer,
conversation=conversation, conversation=conversation,
chat_template=None, chat_template=None,
model_config=model_config,
tools=None, tools=None,
add_generation_prompt=True, add_generation_prompt=True,
) )
@ -903,11 +903,11 @@ def test_resolve_content_format_hf_defined(model, expected_format):
print(_try_extract_ast(chat_template)) print(_try_extract_ast(chat_template))
resolved_format = resolve_chat_template_content_format( resolved_format = resolve_chat_template_content_format(
model_config,
None, # Test detecting the tokenizer's chat_template None, # Test detecting the tokenizer's chat_template
None, None,
"auto", "auto",
tokenizer, tokenizer,
model_config=model_config,
) )
assert resolved_format == expected_format assert resolved_format == expected_format
@ -962,11 +962,11 @@ def test_resolve_content_format_fallbacks(model, expected_format):
print(_try_extract_ast(chat_template)) print(_try_extract_ast(chat_template))
resolved_format = resolve_chat_template_content_format( resolved_format = resolve_chat_template_content_format(
model_config,
None, # Test detecting the tokenizer's chat_template None, # Test detecting the tokenizer's chat_template
None, None,
"auto", "auto",
tokenizer, tokenizer,
model_config=model_config,
) )
assert resolved_format == expected_format assert resolved_format == expected_format
@ -1021,11 +1021,11 @@ def test_resolve_content_format_examples(template_path, expected_format):
print(_try_extract_ast(chat_template)) print(_try_extract_ast(chat_template))
resolved_format = resolve_chat_template_content_format( resolved_format = resolve_chat_template_content_format(
model_config,
chat_template, chat_template,
None, None,
"auto", "auto",
dummy_tokenizer, dummy_tokenizer,
model_config=model_config,
) )
assert resolved_format == expected_format assert resolved_format == expected_format

View File

@ -44,7 +44,7 @@ from vllm.transformers_utils.chat_templates import (
# yapf: enable # yapf: enable
from vllm.transformers_utils.processor import cached_get_processor from vllm.transformers_utils.processor import cached_get_processor
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import random_uuid from vllm.utils import deprecate_kwargs, random_uuid
logger = init_logger(__name__) logger = init_logger(__name__)
@ -329,11 +329,17 @@ def resolve_mistral_chat_template(
"so it will be ignored.") "so it will be ignored.")
return None return None
@deprecate_kwargs(
"trust_remote_code",
additional_message="Please use `model_config.trust_remote_code` instead.",
)
def resolve_hf_chat_template( def resolve_hf_chat_template(
model_config: ModelConfig,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
chat_template: Optional[str], chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]], tools: Optional[list[dict[str, Any]]],
*,
model_config: ModelConfig,
trsut_remote_code: Optional[bool] = None,
) -> Optional[str]: ) -> Optional[str]:
# 1st priority: The given chat template # 1st priority: The given chat template
if chat_template is not None: if chat_template is not None:
@ -379,18 +385,19 @@ def resolve_hf_chat_template(
def _resolve_chat_template_content_format( def _resolve_chat_template_content_format(
model_config: ModelConfig,
chat_template: Optional[str], chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]], tools: Optional[list[dict[str, Any]]],
given_format: ChatTemplateContentFormatOption, given_format: ChatTemplateContentFormatOption,
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
*,
model_config: ModelConfig,
) -> _ChatTemplateContentFormat: ) -> _ChatTemplateContentFormat:
if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
hf_chat_template = resolve_hf_chat_template( hf_chat_template = resolve_hf_chat_template(
model_config,
tokenizer, tokenizer,
chat_template=chat_template, chat_template=chat_template,
tools=tools, tools=tools,
model_config=model_config,
) )
else: else:
hf_chat_template = None hf_chat_template = None
@ -428,19 +435,25 @@ def _log_chat_template_content_format(
) )
@deprecate_kwargs(
"trust_remote_code",
additional_message="Please use `model_config.trust_remote_code` instead.",
)
def resolve_chat_template_content_format( def resolve_chat_template_content_format(
model_config: ModelConfig,
chat_template: Optional[str], chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]], tools: Optional[list[dict[str, Any]]],
given_format: ChatTemplateContentFormatOption, given_format: ChatTemplateContentFormatOption,
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
*,
model_config: ModelConfig,
trust_remote_code: Optional[bool] = None,
) -> _ChatTemplateContentFormat: ) -> _ChatTemplateContentFormat:
detected_format = _resolve_chat_template_content_format( detected_format = _resolve_chat_template_content_format(
model_config,
chat_template, chat_template,
tools, tools,
given_format, given_format,
tokenizer, tokenizer,
model_config=model_config,
) )
_log_chat_template_content_format( _log_chat_template_content_format(
@ -1191,21 +1204,27 @@ def parse_chat_messages_futures(
return conversation, mm_tracker.all_mm_data() return conversation, mm_tracker.all_mm_data()
@deprecate_kwargs(
"trust_remote_code",
additional_message="Please use `model_config.trust_remote_code` instead.",
)
def apply_hf_chat_template( def apply_hf_chat_template(
model_config: ModelConfig,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
conversation: list[ConversationMessage], conversation: list[ConversationMessage],
chat_template: Optional[str], chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]], tools: Optional[list[dict[str, Any]]],
*, *,
model_config: ModelConfig,
tokenize: bool = False, # Different from HF's default tokenize: bool = False, # Different from HF's default
# Deprecated, explicitly capture here so it doesn't slit into kwargs.
trust_remote_code: Optional[bool] = None,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
hf_chat_template = resolve_hf_chat_template( hf_chat_template = resolve_hf_chat_template(
model_config,
tokenizer, tokenizer,
chat_template=chat_template, chat_template=chat_template,
tools=tools, tools=tools,
model_config=model_config,
) )
if hf_chat_template is None: if hf_chat_template is None:

View File

@ -731,11 +731,11 @@ class LLM:
tokenizer = self.get_tokenizer(lora_request) tokenizer = self.get_tokenizer(lora_request)
model_config = self.llm_engine.get_model_config() model_config = self.llm_engine.get_model_config()
resolved_content_format = resolve_chat_template_content_format( resolved_content_format = resolve_chat_template_content_format(
model_config,
chat_template, chat_template,
tools, tools,
chat_template_content_format, chat_template_content_format,
tokenizer, tokenizer,
model_config=model_config,
) )
_chat_template_kwargs: dict[str, Any] = dict( _chat_template_kwargs: dict[str, Any] = dict(
@ -767,9 +767,9 @@ class LLM:
) )
else: else:
prompt_str = apply_hf_chat_template( prompt_str = apply_hf_chat_template(
model_config, tokenizer=tokenizer,
tokenizer,
conversation=conversation, conversation=conversation,
model_config=model_config,
**_chat_template_kwargs, **_chat_template_kwargs,
) )
# Special tokens are already included in chat templates so # Special tokens are already included in chat templates so

View File

@ -971,10 +971,10 @@ async def init_app_state(
chat_template=resolved_chat_template) chat_template=resolved_chat_template)
else: else:
hf_chat_template = resolve_hf_chat_template( hf_chat_template = resolve_hf_chat_template(
vllm_config.model_config, tokenizer=tokenizer,
tokenizer,
chat_template=None, chat_template=None,
tools=None, tools=None,
model_config=vllm_config.model_config,
) )
if hf_chat_template != resolved_chat_template: if hf_chat_template != resolved_chat_template:

View File

@ -670,11 +670,11 @@ class OpenAIServing:
model_config = self.model_config model_config = self.model_config
resolved_content_format = resolve_chat_template_content_format( resolved_content_format = resolve_chat_template_content_format(
model_config,
chat_template, chat_template,
tool_dicts, tool_dicts,
chat_template_content_format, chat_template_content_format,
tokenizer, tokenizer,
model_config=model_config,
) )
conversation, mm_data_future = parse_chat_messages_futures( conversation, mm_data_future = parse_chat_messages_futures(
messages, messages,
@ -701,9 +701,9 @@ class OpenAIServing:
) )
else: else:
request_prompt = apply_hf_chat_template( request_prompt = apply_hf_chat_template(
model_config, tokenizer=tokenizer,
tokenizer,
conversation=conversation, conversation=conversation,
model_config=model_config,
**_chat_template_kwargs, **_chat_template_kwargs,
) )