From cbcdf2c6091b162b94cbecf41312270e4b5d6ff2 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Mon, 24 Mar 2025 21:50:09 +0800 Subject: [PATCH] [Bugfix] Fix chat template loading (#15143) Signed-off-by: DarkLight1337 Signed-off-by: Roger Wang Co-authored-by: chaunceyjiang Co-authored-by: Roger Wang --- .../entrypoints/openai/test_chat_template.py | 2 + tests/entrypoints/openai/test_video.py | 4 +- tests/entrypoints/test_chat_utils.py | 84 ++++++++-- tests/tool_use/utils.py | 5 +- vllm/entrypoints/chat_utils.py | 143 +++++++++++++----- vllm/entrypoints/llm.py | 7 +- vllm/entrypoints/openai/serving_engine.py | 7 +- 7 files changed, 196 insertions(+), 56 deletions(-) diff --git a/tests/entrypoints/openai/test_chat_template.py b/tests/entrypoints/openai/test_chat_template.py index 255aba139ad32..78e40eeecde13 100644 --- a/tests/entrypoints/openai/test_chat_template.py +++ b/tests/entrypoints/openai/test_chat_template.py @@ -107,8 +107,10 @@ def test_get_gen_prompt(model, template, add_generation_prompt, # Call the function and get the result result = apply_hf_chat_template( tokenizer, + trust_remote_code=True, conversation=mock_request.messages, chat_template=mock_request.chat_template or template_content, + tools=None, add_generation_prompt=mock_request.add_generation_prompt, continue_final_message=mock_request.continue_final_message, ) diff --git a/tests/entrypoints/openai/test_video.py b/tests/entrypoints/openai/test_video.py index 8c7564ba9dced..f9ccce9c1c332 100644 --- a/tests/entrypoints/openai/test_video.py +++ b/tests/entrypoints/openai/test_video.py @@ -87,7 +87,7 @@ async def test_single_chat_session_video(client: openai.AsyncOpenAI, choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=6299, total_tokens=6309) + completion_tokens=10, prompt_tokens=6287, total_tokens=6297) message = choice.message message = chat_completion.choices[0].message @@ -180,7 +180,7 @@ async def test_single_chat_session_video_base64encoded( choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=6299, total_tokens=6309) + completion_tokens=10, prompt_tokens=6287, total_tokens=6297) message = choice.message message = chat_completion.choices[0].message diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index e3b7b660ee270..6efed990b1893 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -4,10 +4,13 @@ import warnings from typing import Optional import pytest +from packaging.version import Version +from transformers import __version__ as TRANSFORMERS_VERSION from vllm.assets.image import ImageAsset from vllm.config import ModelConfig -from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template, +from vllm.entrypoints.chat_utils import (_resolve_hf_chat_template, + _try_extract_ast, load_chat_template, parse_chat_messages, parse_chat_messages_futures, resolve_chat_template_content_format) @@ -23,8 +26,10 @@ EXAMPLES_DIR = VLLM_PATH / "examples" PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct" ULTRAVOX_MODEL_ID = "fixie-ai/ultravox-v0_5-llama-3_2-1b" QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct" +QWEN25VL_MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct" MLLAMA_MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct" LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B" +HERMES_MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B" @pytest.fixture(scope="function") @@ -703,25 +708,27 @@ def test_multimodal_image_parsing_matches_hf(model, image_url): vllm_result = apply_hf_chat_template( tokenizer, + trust_remote_code=model_config.trust_remote_code, conversation=conversation, chat_template=None, + tools=None, add_generation_prompt=True, ) assert hf_result == vllm_result -# yapf: disable @pytest.mark.parametrize( - ("model", "expected_format"), - [(PHI3V_MODEL_ID, "string"), - (QWEN2VL_MODEL_ID, "openai"), - (ULTRAVOX_MODEL_ID, "string"), - (MLLAMA_MODEL_ID, "openai"), - (LLAMA_GUARD_MODEL_ID, "openai")], -) -# yapf: enable -def test_resolve_content_format_hf_defined(model, expected_format): + "model", + [ + QWEN2VL_MODEL_ID, # tokenizer.chat_template is of type str + HERMES_MODEL_ID, # tokenizer.chat_template is of type dict + ]) +@pytest.mark.parametrize("use_tools", [True, False]) +def test_resolve_hf_chat_template(sample_json_schema, model, use_tools): + """checks that chat_template is a dict type for HF models.""" + + # Build the tokenizer group and grab the underlying tokenizer tokenizer_group = TokenizerGroup( model, enable_lora=False, @@ -730,7 +737,56 @@ def test_resolve_content_format_hf_defined(model, expected_format): ) tokenizer = tokenizer_group.tokenizer - chat_template = tokenizer.chat_template + tools = [{ + "type": "function", + "function": { + "name": "dummy_function_name", + "description": "This is a dummy function", + "parameters": sample_json_schema + } + }] if use_tools else None + + # Test detecting the tokenizer's chat_template + chat_template = _resolve_hf_chat_template( + tokenizer, + chat_template=None, + tools=tools, + trust_remote_code=True, + ) + assert isinstance(chat_template, str) + + +# yapf: disable +@pytest.mark.parametrize( + ("model", "expected_format"), + [(PHI3V_MODEL_ID, "string"), + (QWEN2VL_MODEL_ID, "openai"), + (QWEN25VL_MODEL_ID, "openai"), + (ULTRAVOX_MODEL_ID, "string"), + (MLLAMA_MODEL_ID, "openai"), + (LLAMA_GUARD_MODEL_ID, "openai")], +) +# yapf: enable +def test_resolve_content_format_hf_defined(model, expected_format): + if model == QWEN25VL_MODEL_ID and Version(TRANSFORMERS_VERSION) < Version( + "4.49.0"): + pytest.skip("Qwen2.5-VL requires transformers>=4.49.0") + + tokenizer_group = TokenizerGroup( + model, + enable_lora=False, + max_num_seqs=5, + max_input_length=None, + ) + tokenizer = tokenizer_group.tokenizer + + # Test detecting the tokenizer's chat_template + chat_template = _resolve_hf_chat_template( + tokenizer, + chat_template=None, + tools=None, + trust_remote_code=True, + ) assert isinstance(chat_template, str) print("[TEXT]") @@ -740,8 +796,10 @@ def test_resolve_content_format_hf_defined(model, expected_format): resolved_format = resolve_chat_template_content_format( None, # Test detecting the tokenizer's chat_template + None, "auto", tokenizer, + trust_remote_code=True, ) assert resolved_format == expected_format @@ -791,8 +849,10 @@ def test_resolve_content_format_examples(template_path, expected_format): resolved_format = resolve_chat_template_content_format( chat_template, + None, "auto", dummy_tokenizer, + trust_remote_code=True, ) assert resolved_format == expected_format diff --git a/tests/tool_use/utils.py b/tests/tool_use/utils.py index df117b96cd07b..231e4aad8c336 100644 --- a/tests/tool_use/utils.py +++ b/tests/tool_use/utils.py @@ -39,7 +39,10 @@ def ensure_system_prompt(messages: list[dict[str, Any]], # universal args for all models go here. also good if you need to test locally # and change type or KV cache quantization or something. -ARGS: list[str] = ["--enable-auto-tool-choice", "--max-model-len", "1024"] +ARGS: list[str] = [ + "--enable-auto-tool-choice", "--max-model-len", "1024", "--max-num-seqs", + "256" +] CONFIGS: dict[str, ServerConfig] = { "hermes": { diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 61a91fe03d2e0..988fa01446076 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio -import codecs import json from abc import ABC, abstractmethod from collections import defaultdict, deque @@ -30,7 +29,8 @@ from openai.types.chat.chat_completion_content_part_input_audio_param import ( InputAudio) # yapf: enable # pydantic needs the TypedDict from typing_extensions -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from transformers import (PreTrainedTokenizer, PreTrainedTokenizerFast, + ProcessorMixin) from typing_extensions import Required, TypeAlias, TypedDict from vllm.config import ModelConfig @@ -306,24 +306,63 @@ def _detect_content_format( return "openai" +def _resolve_hf_chat_template( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + chat_template: Optional[str], + tools: Optional[list[dict[str, Any]]], + *, + trust_remote_code: bool, +) -> Optional[str]: + # 1st priority: The given chat template + if chat_template is not None: + return chat_template + + # 2nd priority: AutoProcessor chat template, unless tool calling is enabled + if tools is None: + try: + processor = cached_get_processor( + tokenizer.name_or_path, + processor_cls=(PreTrainedTokenizer, PreTrainedTokenizerFast, + ProcessorMixin), + trust_remote_code=trust_remote_code, + ) + if isinstance(processor, ProcessorMixin) and \ + processor.chat_template is not None: + return processor.chat_template + except Exception: + logger.debug("Failed to load AutoProcessor chat template for %s", + tokenizer.name_or_path, exc_info=True) + + # 3rd priority: AutoTokenizer chat template + try: + return tokenizer.get_chat_template(chat_template, tools=tools) + except Exception: + logger.debug("Failed to load AutoTokenizer chat template for %s", + tokenizer.name_or_path, exc_info=True) + + return None + + def _resolve_chat_template_content_format( chat_template: Optional[str], + tools: Optional[list[dict[str, Any]]], given_format: ChatTemplateContentFormatOption, tokenizer: AnyTokenizer, + *, + trust_remote_code: bool, ) -> _ChatTemplateContentFormat: if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): - tokenizer_chat_template = tokenizer.chat_template + hf_chat_template = _resolve_hf_chat_template( + tokenizer, + chat_template=chat_template, + trust_remote_code=trust_remote_code, + tools=tools, + ) else: - tokenizer_chat_template = None + hf_chat_template = None - jinja_text: Optional[str] - if isinstance(tokenizer_chat_template, str) and chat_template is None: - jinja_text = tokenizer_chat_template - elif (isinstance(tokenizer_chat_template, dict) - and chat_template in tokenizer_chat_template): - jinja_text = tokenizer_chat_template[chat_template] - else: - jinja_text = load_chat_template(chat_template, is_literal=True) + jinja_text = (hf_chat_template if isinstance(hf_chat_template, str) + else load_chat_template(chat_template, is_literal=True)) detected_format = ("string" if jinja_text is None else _detect_content_format(jinja_text, default="string")) @@ -332,17 +371,11 @@ def _resolve_chat_template_content_format( @lru_cache -def resolve_chat_template_content_format( +def _log_chat_template_content_format( chat_template: Optional[str], given_format: ChatTemplateContentFormatOption, - tokenizer: AnyTokenizer, -) -> _ChatTemplateContentFormat: - detected_format = _resolve_chat_template_content_format( - chat_template, - given_format, - tokenizer, - ) - + detected_format: ChatTemplateContentFormatOption, +): logger.info( "Detected the chat template content format to be '%s'. " "You can set `--chat-template-content-format` to override this.", @@ -360,6 +393,29 @@ def resolve_chat_template_content_format( detected_format, ) + +def resolve_chat_template_content_format( + chat_template: Optional[str], + tools: Optional[list[dict[str, Any]]], + given_format: ChatTemplateContentFormatOption, + tokenizer: AnyTokenizer, + *, + trust_remote_code: bool = False, +) -> _ChatTemplateContentFormat: + detected_format = _resolve_chat_template_content_format( + chat_template, + tools, + given_format, + tokenizer, + trust_remote_code=trust_remote_code, + ) + + _log_chat_template_content_format( + chat_template, + given_format=given_format, + detected_format=detected_format, + ) + return detected_format @@ -711,7 +767,7 @@ def validate_chat_template(chat_template: Optional[Union[Path, str]]): f"{type(chat_template)} is not a valid chat template type") -def load_chat_template( +def _load_chat_template( chat_template: Optional[Union[Path, str]], *, is_literal: bool = False, @@ -724,7 +780,7 @@ def load_chat_template( raise TypeError("chat_template is expected to be read directly " "from its value") - return codecs.decode(chat_template, "unicode_escape") + return chat_template try: with open(chat_template) as f: @@ -742,7 +798,18 @@ def load_chat_template( # If opening a file fails, set chat template to be args to # ensure we decode so our escape are interpreted correctly - return load_chat_template(chat_template, is_literal=True) + return _load_chat_template(chat_template, is_literal=True) + + +_cached_load_chat_template = lru_cache(_load_chat_template) + + +def load_chat_template( + chat_template: Optional[Union[Path, str]], + *, + is_literal: bool = False, +) -> Optional[str]: + return _cached_load_chat_template(chat_template, is_literal=is_literal) # TODO: Let user specify how to insert multimodal tokens into prompt @@ -1067,23 +1134,20 @@ def apply_hf_chat_template( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], conversation: list[ConversationMessage], chat_template: Optional[str], + tools: Optional[list[dict[str, Any]]], *, + trust_remote_code: bool = False, tokenize: bool = False, # Different from HF's default **kwargs: Any, ) -> str: - if chat_template is None: - chat_template = tokenizer.chat_template + hf_chat_template = _resolve_hf_chat_template( + tokenizer, + chat_template=chat_template, + tools=tools, + trust_remote_code=trust_remote_code, + ) - # FIXME: Temporary workaround for - # https://huggingface.co/mistral-community/pixtral-12b/discussions/31 - if chat_template is None: - try: - processor = cached_get_processor(tokenizer.name_or_path) - chat_template = processor.chat_template - except Exception: - pass - - if chat_template is None: + if hf_chat_template is None: raise ValueError( "As of transformers v4.44, default chat template is no longer " "allowed, so you must provide a chat template if the tokenizer " @@ -1091,7 +1155,8 @@ def apply_hf_chat_template( return tokenizer.apply_chat_template( conversation=conversation, # type: ignore[arg-type] - chat_template=chat_template, + tools=tools, # type: ignore[arg-type] + chat_template=hf_chat_template, tokenize=tokenize, **kwargs, ) @@ -1100,7 +1165,8 @@ def apply_hf_chat_template( def apply_mistral_chat_template( tokenizer: MistralTokenizer, messages: list[ChatCompletionMessageParam], - chat_template: Optional[str] = None, + chat_template: Optional[str], + tools: Optional[list[dict[str, Any]]], **kwargs: Any, ) -> list[int]: if chat_template is not None: @@ -1117,5 +1183,6 @@ def apply_mistral_chat_template( return tokenizer.apply_chat_template( messages=messages, + tools=tools, **kwargs, ) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 84b0093ce4caf..1887caf25a30f 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -690,8 +690,10 @@ class LLM: model_config = self.llm_engine.get_model_config() resolved_content_format = resolve_chat_template_content_format( chat_template, + tools, chat_template_content_format, tokenizer, + trust_remote_code=model_config.trust_remote_code, ) prompts: list[Union[TokensPrompt, TextPrompt]] = [] @@ -713,18 +715,19 @@ class LLM: tokenizer, messages=msgs, chat_template=chat_template, + tools=tools, add_generation_prompt=add_generation_prompt, continue_final_message=continue_final_message, - tools=tools, ) else: prompt_data = apply_hf_chat_template( tokenizer, + trust_remote_code=model_config.trust_remote_code, conversation=conversation, chat_template=chat_template, + tools=tools, add_generation_prompt=add_generation_prompt, continue_final_message=continue_final_message, - tools=tools, ) prompt: Union[TokensPrompt, TextPrompt] diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 125812d2cc019..7cb4a2dce1dc0 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -379,14 +379,18 @@ class OpenAIServing: add_special_tokens: bool = False, ) -> tuple[list[ConversationMessage], Sequence[RequestPrompt], list[TokensPrompt]]: + model_config = self.model_config + resolved_content_format = resolve_chat_template_content_format( chat_template, + tool_dicts, chat_template_content_format, tokenizer, + trust_remote_code=model_config.trust_remote_code, ) conversation, mm_data_future = parse_chat_messages_futures( messages, - self.model_config, + model_config, tokenizer, content_format=resolved_content_format, ) @@ -410,6 +414,7 @@ class OpenAIServing: else: request_prompt = apply_hf_chat_template( tokenizer, + trust_remote_code=model_config.trust_remote_code, conversation=conversation, **_chat_template_kwargs, )