diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 6b0ed23277d3..f8119d89ac49 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -1572,6 +1572,22 @@ class AssistantTracker(jinja2.ext.Extension): return call_block.set_lineno(lineno) +def _resolve_chat_template_kwargs( + chat_template: str, +): + env = jinja2.sandbox.ImmutableSandboxedEnvironment( + trim_blocks=True, + lstrip_blocks=True, + extensions=[AssistantTracker, jinja2.ext.loopcontrols], + ) + parsed_content = env.parse(chat_template) + template_vars = jinja2.meta.find_undeclared_variables(parsed_content) + return template_vars + + +_cached_resolve_chat_template_kwargs = lru_cache(_resolve_chat_template_kwargs) + + def resolve_chat_template_kwargs( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], chat_template: str, @@ -1582,13 +1598,7 @@ def resolve_chat_template_kwargs( if supports_kw(tokenizer.apply_chat_template, k, allow_var_kwargs=False) } - env = jinja2.sandbox.ImmutableSandboxedEnvironment( - trim_blocks=True, - lstrip_blocks=True, - extensions=[AssistantTracker, jinja2.ext.loopcontrols], - ) - parsed_content = env.parse(chat_template) - template_vars = jinja2.meta.find_undeclared_variables(parsed_content) + template_vars = _cached_resolve_chat_template_kwargs(chat_template) # We exclude chat_template from kwargs here, because # chat template has been already resolved at this stage diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 15844d3162fe..2f05e10639f5 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1745,6 +1745,7 @@ async def init_app_state( request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, + trust_request_chat_template=args.trust_request_chat_template, log_error_stack=args.log_error_stack, ) if "encode" in supported_tasks else None state.openai_serving_embedding = OpenAIServingEmbedding( @@ -1754,6 +1755,7 @@ async def init_app_state( request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, + trust_request_chat_template=args.trust_request_chat_template, log_error_stack=args.log_error_stack, ) if "embed" in supported_tasks else None state.openai_serving_classification = ServingClassification( @@ -1777,6 +1779,7 @@ async def init_app_state( request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, + trust_request_chat_template=args.trust_request_chat_template, log_error_stack=args.log_error_stack, ) state.openai_serving_transcription = OpenAIServingTranscription( diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index a646b16da82c..0a04d76b2d69 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -222,16 +222,14 @@ class OpenAIServingChat(OpenAIServing): if not self.use_harmony: # Common case. - request_chat_template = request.chat_template - chat_template_kwargs = request.chat_template_kwargs - if not self.trust_request_chat_template and ( - request_chat_template is not None or - (chat_template_kwargs and - chat_template_kwargs.get("chat_template") is not None)): - return self.create_error_response( - "Chat template is passed with request, but " - "--trust-request-chat-template is not set. " - "Refused request with untrusted chat template.") + error_check_ret = self._validate_chat_template( + request_chat_template=request.chat_template, + chat_template_kwargs=request.chat_template_kwargs, + trust_request_chat_template=self. + trust_request_chat_template, + ) + if error_check_ret is not None: + return error_check_ret ( conversation, request_prompts, @@ -240,7 +238,7 @@ class OpenAIServingChat(OpenAIServing): request, tokenizer, request.messages, - chat_template=request_chat_template or self.chat_template, + chat_template=request.chat_template or self.chat_template, chat_template_content_format=self. chat_template_content_format, add_generation_prompt=request.add_generation_prompt, diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 647e7daed659..85493f121993 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -576,6 +576,7 @@ class OpenAIServingEmbedding(EmbeddingMixin): request_logger: Optional[RequestLogger], chat_template: Optional[str], chat_template_content_format: ChatTemplateContentFormatOption, + trust_request_chat_template: bool = False, log_error_stack: bool = False, ) -> None: super().__init__(engine_client=engine_client, @@ -586,6 +587,7 @@ class OpenAIServingEmbedding(EmbeddingMixin): self.chat_template = chat_template self.chat_template_content_format: Final = chat_template_content_format + self.trust_request_chat_template = trust_request_chat_template async def create_embedding( self, @@ -629,3 +631,17 @@ class OpenAIServingEmbedding(EmbeddingMixin): return self.create_error_response(str(e)) return pooling_params + + async def _preprocess( + self, + ctx: ServeContext, + ) -> Optional[ErrorResponse]: + if isinstance(ctx.request, EmbeddingChatRequest): + error_check_ret = self._validate_chat_template( + request_chat_template=ctx.request.chat_template, + chat_template_kwargs=ctx.request.chat_template_kwargs, + trust_request_chat_template=self.trust_request_chat_template, + ) + if error_check_ret is not None: + return error_check_ret + return await super()._preprocess(ctx) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index e58d943d3f7f..151888afd8da 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -751,6 +751,22 @@ class OpenAIServing: tokenizer=tokenizer, ) + def _validate_chat_template( + self, + request_chat_template: Optional[str], + chat_template_kwargs: Optional[dict[str, Any]], + trust_request_chat_template: bool, + ) -> Optional[ErrorResponse]: + if not trust_request_chat_template and ( + request_chat_template is not None or + (chat_template_kwargs + and chat_template_kwargs.get("chat_template") is not None)): + return self.create_error_response( + "Chat template is passed with request, but " + "--trust-request-chat-template is not set. " + "Refused request with untrusted chat template.") + return None + async def _preprocess_chat( self, request: Union[ChatLikeRequest, ResponsesRequest], diff --git a/vllm/entrypoints/openai/serving_pooling.py b/vllm/entrypoints/openai/serving_pooling.py index 0750c7ec3e9f..3a41c2613624 100644 --- a/vllm/entrypoints/openai/serving_pooling.py +++ b/vllm/entrypoints/openai/serving_pooling.py @@ -65,6 +65,7 @@ class OpenAIServingPooling(OpenAIServing): request_logger: Optional[RequestLogger], chat_template: Optional[str], chat_template_content_format: ChatTemplateContentFormatOption, + trust_request_chat_template: bool = False, log_error_stack: bool = False, ) -> None: super().__init__(engine_client=engine_client, @@ -75,6 +76,7 @@ class OpenAIServingPooling(OpenAIServing): self.chat_template = chat_template self.chat_template_content_format: Final = chat_template_content_format + self.trust_request_chat_template = trust_request_chat_template io_processor_plugin = self.model_config.io_processor_plugin self.io_processor = get_io_processor(vllm_config, io_processor_plugin) @@ -129,6 +131,14 @@ class OpenAIServingPooling(OpenAIServing): prompt=validated_prompt, request_id=request_id) elif isinstance(request, PoolingChatRequest): + error_check_ret = self._validate_chat_template( + request_chat_template=request.chat_template, + chat_template_kwargs=request.chat_template_kwargs, + trust_request_chat_template=self. + trust_request_chat_template, + ) + if error_check_ret is not None: + return error_check_ret ( _, _, diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 3918d08ebf81..1a39fb123210 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -40,6 +40,7 @@ class OpenAIServingTokenization(OpenAIServing): request_logger: Optional[RequestLogger], chat_template: Optional[str], chat_template_content_format: ChatTemplateContentFormatOption, + trust_request_chat_template: bool = False, log_error_stack: bool = False, ) -> None: super().__init__(engine_client=engine_client, @@ -50,6 +51,7 @@ class OpenAIServingTokenization(OpenAIServing): self.chat_template = chat_template self.chat_template_content_format: Final = chat_template_content_format + self.trust_request_chat_template = trust_request_chat_template async def create_tokenize( self, @@ -71,6 +73,14 @@ class OpenAIServingTokenization(OpenAIServing): if isinstance(request, TokenizeChatRequest): tool_dicts = (None if request.tools is None else [tool.model_dump() for tool in request.tools]) + error_check_ret = self._validate_chat_template( + request_chat_template=request.chat_template, + chat_template_kwargs=request.chat_template_kwargs, + trust_request_chat_template=self. + trust_request_chat_template, + ) + if error_check_ret is not None: + return error_check_ret ( _, _,