[Frontend] Cache chat template kwargs resolution (#26227)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py 2025-10-04 23:32:30 +08:00 committed by GitHub
parent 5c057e068f
commit a42d2df75f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 81 additions and 18 deletions

View File

@ -1572,6 +1572,22 @@ class AssistantTracker(jinja2.ext.Extension):
return call_block.set_lineno(lineno) return call_block.set_lineno(lineno)
def _resolve_chat_template_kwargs(
chat_template: str,
):
env = jinja2.sandbox.ImmutableSandboxedEnvironment(
trim_blocks=True,
lstrip_blocks=True,
extensions=[AssistantTracker, jinja2.ext.loopcontrols],
)
parsed_content = env.parse(chat_template)
template_vars = jinja2.meta.find_undeclared_variables(parsed_content)
return template_vars
_cached_resolve_chat_template_kwargs = lru_cache(_resolve_chat_template_kwargs)
def resolve_chat_template_kwargs( def resolve_chat_template_kwargs(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
chat_template: str, chat_template: str,
@ -1582,13 +1598,7 @@ def resolve_chat_template_kwargs(
if supports_kw(tokenizer.apply_chat_template, k, allow_var_kwargs=False) if supports_kw(tokenizer.apply_chat_template, k, allow_var_kwargs=False)
} }
env = jinja2.sandbox.ImmutableSandboxedEnvironment( template_vars = _cached_resolve_chat_template_kwargs(chat_template)
trim_blocks=True,
lstrip_blocks=True,
extensions=[AssistantTracker, jinja2.ext.loopcontrols],
)
parsed_content = env.parse(chat_template)
template_vars = jinja2.meta.find_undeclared_variables(parsed_content)
# We exclude chat_template from kwargs here, because # We exclude chat_template from kwargs here, because
# chat template has been already resolved at this stage # chat template has been already resolved at this stage

View File

@ -1745,6 +1745,7 @@ async def init_app_state(
request_logger=request_logger, request_logger=request_logger,
chat_template=resolved_chat_template, chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format, chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
log_error_stack=args.log_error_stack, log_error_stack=args.log_error_stack,
) if "encode" in supported_tasks else None ) if "encode" in supported_tasks else None
state.openai_serving_embedding = OpenAIServingEmbedding( state.openai_serving_embedding = OpenAIServingEmbedding(
@ -1754,6 +1755,7 @@ async def init_app_state(
request_logger=request_logger, request_logger=request_logger,
chat_template=resolved_chat_template, chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format, chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
log_error_stack=args.log_error_stack, log_error_stack=args.log_error_stack,
) if "embed" in supported_tasks else None ) if "embed" in supported_tasks else None
state.openai_serving_classification = ServingClassification( state.openai_serving_classification = ServingClassification(
@ -1777,6 +1779,7 @@ async def init_app_state(
request_logger=request_logger, request_logger=request_logger,
chat_template=resolved_chat_template, chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format, chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
log_error_stack=args.log_error_stack, log_error_stack=args.log_error_stack,
) )
state.openai_serving_transcription = OpenAIServingTranscription( state.openai_serving_transcription = OpenAIServingTranscription(

View File

@ -222,16 +222,14 @@ class OpenAIServingChat(OpenAIServing):
if not self.use_harmony: if not self.use_harmony:
# Common case. # Common case.
request_chat_template = request.chat_template error_check_ret = self._validate_chat_template(
chat_template_kwargs = request.chat_template_kwargs request_chat_template=request.chat_template,
if not self.trust_request_chat_template and ( chat_template_kwargs=request.chat_template_kwargs,
request_chat_template is not None or trust_request_chat_template=self.
(chat_template_kwargs and trust_request_chat_template,
chat_template_kwargs.get("chat_template") is not None)): )
return self.create_error_response( if error_check_ret is not None:
"Chat template is passed with request, but " return error_check_ret
"--trust-request-chat-template is not set. "
"Refused request with untrusted chat template.")
( (
conversation, conversation,
request_prompts, request_prompts,
@ -240,7 +238,7 @@ class OpenAIServingChat(OpenAIServing):
request, request,
tokenizer, tokenizer,
request.messages, request.messages,
chat_template=request_chat_template or self.chat_template, chat_template=request.chat_template or self.chat_template,
chat_template_content_format=self. chat_template_content_format=self.
chat_template_content_format, chat_template_content_format,
add_generation_prompt=request.add_generation_prompt, add_generation_prompt=request.add_generation_prompt,

View File

@ -576,6 +576,7 @@ class OpenAIServingEmbedding(EmbeddingMixin):
request_logger: Optional[RequestLogger], request_logger: Optional[RequestLogger],
chat_template: Optional[str], chat_template: Optional[str],
chat_template_content_format: ChatTemplateContentFormatOption, chat_template_content_format: ChatTemplateContentFormatOption,
trust_request_chat_template: bool = False,
log_error_stack: bool = False, log_error_stack: bool = False,
) -> None: ) -> None:
super().__init__(engine_client=engine_client, super().__init__(engine_client=engine_client,
@ -586,6 +587,7 @@ class OpenAIServingEmbedding(EmbeddingMixin):
self.chat_template = chat_template self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
async def create_embedding( async def create_embedding(
self, self,
@ -629,3 +631,17 @@ class OpenAIServingEmbedding(EmbeddingMixin):
return self.create_error_response(str(e)) return self.create_error_response(str(e))
return pooling_params return pooling_params
async def _preprocess(
self,
ctx: ServeContext,
) -> Optional[ErrorResponse]:
if isinstance(ctx.request, EmbeddingChatRequest):
error_check_ret = self._validate_chat_template(
request_chat_template=ctx.request.chat_template,
chat_template_kwargs=ctx.request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template,
)
if error_check_ret is not None:
return error_check_ret
return await super()._preprocess(ctx)

View File

@ -751,6 +751,22 @@ class OpenAIServing:
tokenizer=tokenizer, tokenizer=tokenizer,
) )
def _validate_chat_template(
self,
request_chat_template: Optional[str],
chat_template_kwargs: Optional[dict[str, Any]],
trust_request_chat_template: bool,
) -> Optional[ErrorResponse]:
if not trust_request_chat_template and (
request_chat_template is not None or
(chat_template_kwargs
and chat_template_kwargs.get("chat_template") is not None)):
return self.create_error_response(
"Chat template is passed with request, but "
"--trust-request-chat-template is not set. "
"Refused request with untrusted chat template.")
return None
async def _preprocess_chat( async def _preprocess_chat(
self, self,
request: Union[ChatLikeRequest, ResponsesRequest], request: Union[ChatLikeRequest, ResponsesRequest],

View File

@ -65,6 +65,7 @@ class OpenAIServingPooling(OpenAIServing):
request_logger: Optional[RequestLogger], request_logger: Optional[RequestLogger],
chat_template: Optional[str], chat_template: Optional[str],
chat_template_content_format: ChatTemplateContentFormatOption, chat_template_content_format: ChatTemplateContentFormatOption,
trust_request_chat_template: bool = False,
log_error_stack: bool = False, log_error_stack: bool = False,
) -> None: ) -> None:
super().__init__(engine_client=engine_client, super().__init__(engine_client=engine_client,
@ -75,6 +76,7 @@ class OpenAIServingPooling(OpenAIServing):
self.chat_template = chat_template self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
io_processor_plugin = self.model_config.io_processor_plugin io_processor_plugin = self.model_config.io_processor_plugin
self.io_processor = get_io_processor(vllm_config, io_processor_plugin) self.io_processor = get_io_processor(vllm_config, io_processor_plugin)
@ -129,6 +131,14 @@ class OpenAIServingPooling(OpenAIServing):
prompt=validated_prompt, request_id=request_id) prompt=validated_prompt, request_id=request_id)
elif isinstance(request, PoolingChatRequest): elif isinstance(request, PoolingChatRequest):
error_check_ret = self._validate_chat_template(
request_chat_template=request.chat_template,
chat_template_kwargs=request.chat_template_kwargs,
trust_request_chat_template=self.
trust_request_chat_template,
)
if error_check_ret is not None:
return error_check_ret
( (
_, _,
_, _,

View File

@ -40,6 +40,7 @@ class OpenAIServingTokenization(OpenAIServing):
request_logger: Optional[RequestLogger], request_logger: Optional[RequestLogger],
chat_template: Optional[str], chat_template: Optional[str],
chat_template_content_format: ChatTemplateContentFormatOption, chat_template_content_format: ChatTemplateContentFormatOption,
trust_request_chat_template: bool = False,
log_error_stack: bool = False, log_error_stack: bool = False,
) -> None: ) -> None:
super().__init__(engine_client=engine_client, super().__init__(engine_client=engine_client,
@ -50,6 +51,7 @@ class OpenAIServingTokenization(OpenAIServing):
self.chat_template = chat_template self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
async def create_tokenize( async def create_tokenize(
self, self,
@ -71,6 +73,14 @@ class OpenAIServingTokenization(OpenAIServing):
if isinstance(request, TokenizeChatRequest): if isinstance(request, TokenizeChatRequest):
tool_dicts = (None if request.tools is None else tool_dicts = (None if request.tools is None else
[tool.model_dump() for tool in request.tools]) [tool.model_dump() for tool in request.tools])
error_check_ret = self._validate_chat_template(
request_chat_template=request.chat_template,
chat_template_kwargs=request.chat_template_kwargs,
trust_request_chat_template=self.
trust_request_chat_template,
)
if error_check_ret is not None:
return error_check_ret
( (
_, _,
_, _,