diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index 6efed990b189..8cc51a5d73b3 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -9,11 +9,11 @@ from transformers import __version__ as TRANSFORMERS_VERSION from vllm.assets.image import ImageAsset from vllm.config import ModelConfig -from vllm.entrypoints.chat_utils import (_resolve_hf_chat_template, - _try_extract_ast, load_chat_template, +from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template, parse_chat_messages, parse_chat_messages_futures, - resolve_chat_template_content_format) + resolve_chat_template_content_format, + resolve_hf_chat_template) from vllm.entrypoints.llm import apply_hf_chat_template from vllm.multimodal import MultiModalDataDict from vllm.multimodal.utils import encode_image_base64 @@ -747,7 +747,7 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools): }] if use_tools else None # Test detecting the tokenizer's chat_template - chat_template = _resolve_hf_chat_template( + chat_template = resolve_hf_chat_template( tokenizer, chat_template=None, tools=tools, @@ -781,7 +781,7 @@ def test_resolve_content_format_hf_defined(model, expected_format): tokenizer = tokenizer_group.tokenizer # Test detecting the tokenizer's chat_template - chat_template = _resolve_hf_chat_template( + chat_template = resolve_hf_chat_template( tokenizer, chat_template=None, tools=None, diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index d3613384590d..73a69d3037f7 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -306,7 +306,24 @@ def _detect_content_format( return "openai" -def _resolve_hf_chat_template( +def resolve_mistral_chat_template( + chat_template: Optional[str], + **kwargs: Any, +) -> Optional[str]: + if chat_template is not None: + logger.warning_once( + "'chat_template' cannot be overridden for mistral tokenizer.") + if "add_generation_prompt" in kwargs: + logger.warning_once( + "'add_generation_prompt' is not supported for mistral tokenizer, " + "so it will be ignored.") + if "continue_final_message" in kwargs: + logger.warning_once( + "'continue_final_message' is not supported for mistral tokenizer, " + "so it will be ignored.") + return None + +def resolve_hf_chat_template( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], chat_template: Optional[str], tools: Optional[list[dict[str, Any]]], @@ -352,7 +369,7 @@ def _resolve_chat_template_content_format( trust_remote_code: bool, ) -> _ChatTemplateContentFormat: if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): - hf_chat_template = _resolve_hf_chat_template( + hf_chat_template = resolve_hf_chat_template( tokenizer, chat_template=chat_template, trust_remote_code=trust_remote_code, @@ -1140,7 +1157,7 @@ def apply_hf_chat_template( tokenize: bool = False, # Different from HF's default **kwargs: Any, ) -> str: - hf_chat_template = _resolve_hf_chat_template( + hf_chat_template = resolve_hf_chat_template( tokenizer, chat_template=chat_template, tools=tools, @@ -1169,17 +1186,12 @@ def apply_mistral_chat_template( tools: Optional[list[dict[str, Any]]], **kwargs: Any, ) -> list[int]: - if chat_template is not None: - logger.warning_once( - "'chat_template' cannot be overridden for mistral tokenizer.") - if "add_generation_prompt" in kwargs: - logger.warning_once( - "'add_generation_prompt' is not supported for mistral tokenizer, " - "so it will be ignored.") - if "continue_final_message" in kwargs: - logger.warning_once( - "'continue_final_message' is not supported for mistral tokenizer, " - "so it will be ignored.") + # The return value of resolve_mistral_chat_template is always None, + # and we won't use it. + resolve_mistral_chat_template( + chat_template=chat_template, + **kwargs, + ) return tokenizer.apply_chat_template( messages=messages, diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index f9b1d69a31d8..374e43fb1534 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -35,7 +35,9 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine # type: ignore from vllm.engine.multiprocessing.client import MQLLMEngineClient from vllm.engine.multiprocessing.engine import run_mp_engine from vllm.engine.protocol import EngineClient -from vllm.entrypoints.chat_utils import load_chat_template +from vllm.entrypoints.chat_utils import (load_chat_template, + resolve_hf_chat_template, + resolve_mistral_chat_template) from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.cli_args import (make_arg_parser, @@ -84,6 +86,7 @@ from vllm.entrypoints.utils import load_aware_call, with_cancellation from vllm.logger import init_logger from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) +from vllm.transformers_utils.tokenizer import MistralTokenizer from vllm.usage.usage_lib import UsageContext from vllm.utils import (Device, FlexibleArgumentParser, get_open_zmq_ipc_path, is_valid_ipv6_address, set_ulimit) @@ -883,8 +886,26 @@ async def init_app_state( resolved_chat_template = load_chat_template(args.chat_template) if resolved_chat_template is not None: - logger.info("Using supplied chat template:\n%s", - resolved_chat_template) + # Get the tokenizer to check official template + tokenizer = await engine_client.get_tokenizer() + + if isinstance(tokenizer, MistralTokenizer): + # The warning is logged in resolve_mistral_chat_template. + resolved_chat_template = resolve_mistral_chat_template( + chat_template=resolved_chat_template) + else: + hf_chat_template = resolve_hf_chat_template( + tokenizer, + chat_template=None, + tools=None, + trust_remote_code=model_config.trust_remote_code) + + if hf_chat_template != resolved_chat_template: + logger.warning( + "Using supplied chat template: %s\n" + "It is different from official chat template '%s'. " + "This discrepancy may lead to performance degradation.", + resolved_chat_template, args.model) state.openai_serving_models = OpenAIServingModels( engine_client=engine_client,