mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 22:35:32 +08:00
[Misc] Enhance warning information to user-defined chat template (#15408)
Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
This commit is contained in:
parent
5ebf66748b
commit
99f536f830
@ -9,11 +9,11 @@ from transformers import __version__ as TRANSFORMERS_VERSION
|
|||||||
|
|
||||||
from vllm.assets.image import ImageAsset
|
from vllm.assets.image import ImageAsset
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.entrypoints.chat_utils import (_resolve_hf_chat_template,
|
from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template,
|
||||||
_try_extract_ast, load_chat_template,
|
|
||||||
parse_chat_messages,
|
parse_chat_messages,
|
||||||
parse_chat_messages_futures,
|
parse_chat_messages_futures,
|
||||||
resolve_chat_template_content_format)
|
resolve_chat_template_content_format,
|
||||||
|
resolve_hf_chat_template)
|
||||||
from vllm.entrypoints.llm import apply_hf_chat_template
|
from vllm.entrypoints.llm import apply_hf_chat_template
|
||||||
from vllm.multimodal import MultiModalDataDict
|
from vllm.multimodal import MultiModalDataDict
|
||||||
from vllm.multimodal.utils import encode_image_base64
|
from vllm.multimodal.utils import encode_image_base64
|
||||||
@ -747,7 +747,7 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
|
|||||||
}] if use_tools else None
|
}] if use_tools else None
|
||||||
|
|
||||||
# Test detecting the tokenizer's chat_template
|
# Test detecting the tokenizer's chat_template
|
||||||
chat_template = _resolve_hf_chat_template(
|
chat_template = resolve_hf_chat_template(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
chat_template=None,
|
chat_template=None,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
@ -781,7 +781,7 @@ def test_resolve_content_format_hf_defined(model, expected_format):
|
|||||||
tokenizer = tokenizer_group.tokenizer
|
tokenizer = tokenizer_group.tokenizer
|
||||||
|
|
||||||
# Test detecting the tokenizer's chat_template
|
# Test detecting the tokenizer's chat_template
|
||||||
chat_template = _resolve_hf_chat_template(
|
chat_template = resolve_hf_chat_template(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
chat_template=None,
|
chat_template=None,
|
||||||
tools=None,
|
tools=None,
|
||||||
|
|||||||
@ -306,7 +306,24 @@ def _detect_content_format(
|
|||||||
return "openai"
|
return "openai"
|
||||||
|
|
||||||
|
|
||||||
def _resolve_hf_chat_template(
|
def resolve_mistral_chat_template(
|
||||||
|
chat_template: Optional[str],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Optional[str]:
|
||||||
|
if chat_template is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
"'chat_template' cannot be overridden for mistral tokenizer.")
|
||||||
|
if "add_generation_prompt" in kwargs:
|
||||||
|
logger.warning_once(
|
||||||
|
"'add_generation_prompt' is not supported for mistral tokenizer, "
|
||||||
|
"so it will be ignored.")
|
||||||
|
if "continue_final_message" in kwargs:
|
||||||
|
logger.warning_once(
|
||||||
|
"'continue_final_message' is not supported for mistral tokenizer, "
|
||||||
|
"so it will be ignored.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def resolve_hf_chat_template(
|
||||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||||
chat_template: Optional[str],
|
chat_template: Optional[str],
|
||||||
tools: Optional[list[dict[str, Any]]],
|
tools: Optional[list[dict[str, Any]]],
|
||||||
@ -352,7 +369,7 @@ def _resolve_chat_template_content_format(
|
|||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
) -> _ChatTemplateContentFormat:
|
) -> _ChatTemplateContentFormat:
|
||||||
if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
|
if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
|
||||||
hf_chat_template = _resolve_hf_chat_template(
|
hf_chat_template = resolve_hf_chat_template(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
chat_template=chat_template,
|
chat_template=chat_template,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
@ -1140,7 +1157,7 @@ def apply_hf_chat_template(
|
|||||||
tokenize: bool = False, # Different from HF's default
|
tokenize: bool = False, # Different from HF's default
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
hf_chat_template = _resolve_hf_chat_template(
|
hf_chat_template = resolve_hf_chat_template(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
chat_template=chat_template,
|
chat_template=chat_template,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
@ -1169,17 +1186,12 @@ def apply_mistral_chat_template(
|
|||||||
tools: Optional[list[dict[str, Any]]],
|
tools: Optional[list[dict[str, Any]]],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> list[int]:
|
) -> list[int]:
|
||||||
if chat_template is not None:
|
# The return value of resolve_mistral_chat_template is always None,
|
||||||
logger.warning_once(
|
# and we won't use it.
|
||||||
"'chat_template' cannot be overridden for mistral tokenizer.")
|
resolve_mistral_chat_template(
|
||||||
if "add_generation_prompt" in kwargs:
|
chat_template=chat_template,
|
||||||
logger.warning_once(
|
**kwargs,
|
||||||
"'add_generation_prompt' is not supported for mistral tokenizer, "
|
)
|
||||||
"so it will be ignored.")
|
|
||||||
if "continue_final_message" in kwargs:
|
|
||||||
logger.warning_once(
|
|
||||||
"'continue_final_message' is not supported for mistral tokenizer, "
|
|
||||||
"so it will be ignored.")
|
|
||||||
|
|
||||||
return tokenizer.apply_chat_template(
|
return tokenizer.apply_chat_template(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
|||||||
@ -35,7 +35,9 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine # type: ignore
|
|||||||
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||||
from vllm.engine.multiprocessing.engine import run_mp_engine
|
from vllm.engine.multiprocessing.engine import run_mp_engine
|
||||||
from vllm.engine.protocol import EngineClient
|
from vllm.engine.protocol import EngineClient
|
||||||
from vllm.entrypoints.chat_utils import load_chat_template
|
from vllm.entrypoints.chat_utils import (load_chat_template,
|
||||||
|
resolve_hf_chat_template,
|
||||||
|
resolve_mistral_chat_template)
|
||||||
from vllm.entrypoints.launcher import serve_http
|
from vllm.entrypoints.launcher import serve_http
|
||||||
from vllm.entrypoints.logger import RequestLogger
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
|
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
|
||||||
@ -84,6 +86,7 @@ from vllm.entrypoints.utils import load_aware_call, with_cancellation
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.transformers_utils.config import (
|
from vllm.transformers_utils.config import (
|
||||||
maybe_register_config_serialize_by_value)
|
maybe_register_config_serialize_by_value)
|
||||||
|
from vllm.transformers_utils.tokenizer import MistralTokenizer
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.utils import (Device, FlexibleArgumentParser, get_open_zmq_ipc_path,
|
from vllm.utils import (Device, FlexibleArgumentParser, get_open_zmq_ipc_path,
|
||||||
is_valid_ipv6_address, set_ulimit)
|
is_valid_ipv6_address, set_ulimit)
|
||||||
@ -883,8 +886,26 @@ async def init_app_state(
|
|||||||
|
|
||||||
resolved_chat_template = load_chat_template(args.chat_template)
|
resolved_chat_template = load_chat_template(args.chat_template)
|
||||||
if resolved_chat_template is not None:
|
if resolved_chat_template is not None:
|
||||||
logger.info("Using supplied chat template:\n%s",
|
# Get the tokenizer to check official template
|
||||||
resolved_chat_template)
|
tokenizer = await engine_client.get_tokenizer()
|
||||||
|
|
||||||
|
if isinstance(tokenizer, MistralTokenizer):
|
||||||
|
# The warning is logged in resolve_mistral_chat_template.
|
||||||
|
resolved_chat_template = resolve_mistral_chat_template(
|
||||||
|
chat_template=resolved_chat_template)
|
||||||
|
else:
|
||||||
|
hf_chat_template = resolve_hf_chat_template(
|
||||||
|
tokenizer,
|
||||||
|
chat_template=None,
|
||||||
|
tools=None,
|
||||||
|
trust_remote_code=model_config.trust_remote_code)
|
||||||
|
|
||||||
|
if hf_chat_template != resolved_chat_template:
|
||||||
|
logger.warning(
|
||||||
|
"Using supplied chat template: %s\n"
|
||||||
|
"It is different from official chat template '%s'. "
|
||||||
|
"This discrepancy may lead to performance degradation.",
|
||||||
|
resolved_chat_template, args.model)
|
||||||
|
|
||||||
state.openai_serving_models = OpenAIServingModels(
|
state.openai_serving_models = OpenAIServingModels(
|
||||||
engine_client=engine_client,
|
engine_client=engine_client,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user