From f8acd01ff758fad3a44302f5be6407243cbec193 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Sat, 26 Apr 2025 10:06:37 -0400 Subject: [PATCH] [V1] Add `structural_tag` support using xgrammar (#17085) --- ...etion_structured_outputs_structural_tag.py | 85 +++++++++++++++ .../llm/test_struct_output_generate.py | 101 ++++++++++++++++++ vllm/entrypoints/llm.py | 4 +- vllm/entrypoints/openai/protocol.py | 46 ++++++-- .../guided_decoding/guided_fields.py | 11 +- vllm/sampling_params.py | 7 +- vllm/v1/structured_output/backend_guidance.py | 3 + vllm/v1/structured_output/backend_types.py | 1 + vllm/v1/structured_output/backend_xgrammar.py | 25 +++++ vllm/v1/structured_output/request.py | 2 + 10 files changed, 270 insertions(+), 15 deletions(-) create mode 100644 examples/online_serving/openai_chat_completion_structured_outputs_structural_tag.py diff --git a/examples/online_serving/openai_chat_completion_structured_outputs_structural_tag.py b/examples/online_serving/openai_chat_completion_structured_outputs_structural_tag.py new file mode 100644 index 000000000000..b807bc540526 --- /dev/null +++ b/examples/online_serving/openai_chat_completion_structured_outputs_structural_tag.py @@ -0,0 +1,85 @@ +# SPDX-License-Identifier: Apache-2.0 +from openai import OpenAI + +# This example demonstrates the `structural_tag` response format. +# It can be used to specify a structured output format that occurs between +# specific tags in the response. This example shows how it could be used +# to enforce the format of a tool call response, but it could be used for +# any structured output within a subset of the response. + + +def main(): + client = OpenAI( + base_url="http://localhost:8000/v1", + api_key="-", + ) + + messages = [{ + "role": + "user", + "content": + """ +You have access to the following function to retrieve the weather in a city: + + { + "name": "get_weather", + "parameters": { + "city": { + "param_type": "string", + "description": "The city to get the weather for", + "required": True + } + } + } + +If a you choose to call a function ONLY reply in the following format: +<{start_tag}={function_name}>{parameters}{end_tag} +where + +start_tag => ` a JSON dict with the function argument name as key and function + argument value as value. +end_tag => `` + +Here is an example, +{"example_name": "example_value"} + +Reminder: +- Function calls MUST follow the specified format +- Required parameters MUST be specified +- Only call one function at a time +- Put the entire function call reply on one line +- Always add your sources when using search results to answer the user query + +You are a helpful assistant. + +Given the previous instructions, what is the weather in New York City, Boston, +and San Francisco? +""" + }] + + response = client.chat.completions.create( + model="meta-llama/Llama-3.1-8B-Instruct", + messages=messages, + response_format={ + "type": + "structural_tag", + "structures": [{ + "begin": "", + "schema": { + "type": "object", + "properties": { + "city": { + "type": "string" + } + } + }, + "end": "" + }], + "triggers": ["", + "schema": { + "type": "object", + "properties": { + "city": { + "type": "string" + } + } + }, + "end": "" + }], + "triggers": ["{parameters}{end_tag} +where + +start_tag => ` a JSON dict with the function argument name + as key and function argument value as value. +end_tag => `` + +Here is an example, +{"example_name": "example_value"} + +Reminder: +- Function calls MUST follow the specified format +- Required parameters MUST be specified +- Only call one function at a time +- Put the entire function call reply on one line +- Always add your sources when using search results to answer the user query + +You are a helpful assistant. + +Given the previous instructions, what is the weather in New York City? +""" + + # Change this once other backends support structural_tag + if guided_decoding_backend.startswith("xgrammar"): + outputs = llm.generate(prompts=prompt, + sampling_params=sampling_params, + use_tqdm=True) + assert outputs is not None + else: + outputs = [] + + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + generated_text = output.outputs[0].text + assert generated_text is not None + + # Search for function call pattern in the response + function_call_pattern = r'(.*?)' + matches = re.findall(function_call_pattern, generated_text) + + if not matches: + print(f"Warning: No function calls found in response: " + f"{generated_text!r}") + continue + + # Take the first function call if multiple are found + json_str = matches[0] + try: + json_content = json.loads(json_str) + assert "city" in json_content + assert isinstance(json_content["city"], str) + print(f"Found valid function call: {generated_text!r}") + except (json.JSONDecodeError, AssertionError) as e: + pytest.fail("Invalid function call format: " + f"{generated_text!r}\nError: {str(e)}") + @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("model_name, tokenizer_mode", diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 90bd5494c183..653e61a11ebd 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1396,7 +1396,9 @@ class LLM: grammar=guided_options.guided_grammar, json_object=guided_options.guided_json_object, backend=guided_options.guided_decoding_backend, - whitespace_pattern=guided_options.guided_whitespace_pattern) + whitespace_pattern=guided_options.guided_whitespace_pattern, + structural_tag=guided_options.structural_tag, + ) return params def _run_engine( diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 8d2ab29d221e..015943762ab1 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -2,6 +2,7 @@ # Adapted from # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py +import json import re import time from argparse import Namespace @@ -139,12 +140,30 @@ class JsonSchemaResponseFormat(OpenAIBaseModel): strict: Optional[bool] = None +class StructuralTag(OpenAIBaseModel): + begin: str + # schema is the field, but that causes conflicts with pydantic so + # instead use structural_tag_schema with an alias + structural_tag_schema: Optional[dict[str, Any]] = Field(default=None, + alias="schema") + end: str + + +class StructuralTagResponseFormat(OpenAIBaseModel): + type: Literal["structural_tag"] + structures: list[StructuralTag] + triggers: list[str] + + class ResponseFormat(OpenAIBaseModel): - # type must be "json_schema", "json_object" or "text" + # type must be "json_schema", "json_object", or "text" type: Literal["text", "json_object", "json_schema"] json_schema: Optional[JsonSchemaResponseFormat] = None +AnyResponseFormat = Union[ResponseFormat, StructuralTagResponseFormat] + + class StreamOptions(OpenAIBaseModel): include_usage: Optional[bool] = True continuous_usage_stats: Optional[bool] = False @@ -227,7 +246,7 @@ class ChatCompletionRequest(OpenAIBaseModel): max_completion_tokens: Optional[int] = None n: Optional[int] = 1 presence_penalty: Optional[float] = 0.0 - response_format: Optional[ResponseFormat] = None + response_format: Optional[AnyResponseFormat] = None seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) stop: Optional[Union[str, list[str]]] = Field(default_factory=list) stream: Optional[bool] = False @@ -340,6 +359,11 @@ class ChatCompletionRequest(OpenAIBaseModel): description=( "If specified, the output will follow the context free grammar."), ) + structural_tag: Optional[str] = Field( + default=None, + description=( + "If specified, the output will follow the structural tag schema."), + ) guided_decoding_backend: Optional[str] = Field( default=None, description=( @@ -476,6 +500,12 @@ class ChatCompletionRequest(OpenAIBaseModel): json_schema = self.response_format.json_schema assert json_schema is not None self.guided_json = json_schema.json_schema + elif self.response_format.type == "structural_tag": + structural_tag = self.response_format + assert structural_tag is not None and isinstance( + structural_tag, StructuralTagResponseFormat) + s_tag_obj = structural_tag.model_dump(by_alias=True) + self.structural_tag = json.dumps(s_tag_obj) guided_decoding = GuidedDecodingParams.from_optional( json=self._get_guided_json_from_tool() or self.guided_json, @@ -485,6 +515,7 @@ class ChatCompletionRequest(OpenAIBaseModel): json_object=guided_json_object, backend=self.guided_decoding_backend, whitespace_pattern=self.guided_whitespace_pattern, + structural_tag=self.structural_tag, ) return SamplingParams.from_optional( @@ -742,12 +773,13 @@ class CompletionRequest(OpenAIBaseModel): "If true (the default), special tokens (e.g. BOS) will be added to " "the prompt."), ) - response_format: Optional[ResponseFormat] = Field( + response_format: Optional[AnyResponseFormat] = Field( default=None, - description= - ("Similar to chat completion, this parameter specifies the format of " - "output. Only {'type': 'json_object'}, {'type': 'json_schema'} or " - "{'type': 'text' } is supported."), + description=( + "Similar to chat completion, this parameter specifies the format " + "of output. Only {'type': 'json_object'}, {'type': 'json_schema'}" + ", {'type': 'structural_tag'}, or {'type': 'text' } is supported." + ), ) guided_json: Optional[Union[str, dict, BaseModel]] = Field( default=None, diff --git a/vllm/model_executor/guided_decoding/guided_fields.py b/vllm/model_executor/guided_decoding/guided_fields.py index db4ce26806c1..1593868a164a 100644 --- a/vllm/model_executor/guided_decoding/guided_fields.py +++ b/vllm/model_executor/guided_decoding/guided_fields.py @@ -27,14 +27,15 @@ class GuidedDecodingRequest: guided_decoding_backend: Optional[str] = None guided_whitespace_pattern: Optional[str] = None guided_json_object: Optional[bool] = None + structural_tag: Optional[str] = None def __post_init__(self): """Validate that some fields are mutually exclusive.""" - guide_count = sum([ - self.guided_json is not None, self.guided_regex is not None, - self.guided_choice is not None, self.guided_grammar is not None, - self.guided_json_object is not None - ]) + guide_count = sum(x is not None + for x in (self.guided_json, self.guided_regex, + self.guided_choice, self.guided_grammar, + self.guided_json_object, + self.structural_tag)) if guide_count > 1: raise ValueError( "You can only use one kind of guided decoding but multiple are " diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 707a757ca83a..79c7178af8d7 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -38,6 +38,7 @@ class GuidedDecodingParams: """These are other options that can be set""" backend: Optional[str] = None whitespace_pattern: Optional[str] = None + structural_tag: Optional[str] = None @staticmethod def from_optional( @@ -48,9 +49,10 @@ class GuidedDecodingParams: json_object: Optional[bool] = None, backend: Optional[str] = None, whitespace_pattern: Optional[str] = None, + structural_tag: Optional[str] = None, ) -> Optional["GuidedDecodingParams"]: - if all(arg is None - for arg in (json, regex, choice, grammar, json_object)): + if all(arg is None for arg in (json, regex, choice, grammar, + json_object, structural_tag)): return None # Extract json schemas from pydantic models if isinstance(json, (BaseModel, type(BaseModel))): @@ -63,6 +65,7 @@ class GuidedDecodingParams: json_object=json_object, backend=backend, whitespace_pattern=whitespace_pattern, + structural_tag=structural_tag, ) @property diff --git a/vllm/v1/structured_output/backend_guidance.py b/vllm/v1/structured_output/backend_guidance.py index 6d2ccd4019d4..1453e284b013 100644 --- a/vllm/v1/structured_output/backend_guidance.py +++ b/vllm/v1/structured_output/backend_guidance.py @@ -194,6 +194,9 @@ def serialize_guidance_grammar( tp = "grammar" elif request_type == StructuredOutputOptions.CHOICE: tp = "choice" + elif request_type == StructuredOutputOptions.STRUCTURAL_TAG: + raise ValueError("Structural tag is not supported " + "for guidance backend yet") else: logger.error("Validation should have already occurred. " "Please file an issue.") diff --git a/vllm/v1/structured_output/backend_types.py b/vllm/v1/structured_output/backend_types.py index 306e4aa0196c..6330bcbf20c3 100644 --- a/vllm/v1/structured_output/backend_types.py +++ b/vllm/v1/structured_output/backend_types.py @@ -12,6 +12,7 @@ class StructuredOutputOptions(enum.Enum): REGEX = enum.auto() GRAMMAR = enum.auto() CHOICE = enum.auto() + STRUCTURAL_TAG = enum.auto() StructuredOutputKey = tuple[StructuredOutputOptions, str] diff --git a/vllm/v1/structured_output/backend_xgrammar.py b/vllm/v1/structured_output/backend_xgrammar.py index bb7c7edc278d..ecaeb6e4ee80 100644 --- a/vllm/v1/structured_output/backend_xgrammar.py +++ b/vllm/v1/structured_output/backend_xgrammar.py @@ -108,6 +108,16 @@ class XgrammarBackend(StructuredOutputBackend): ctx = self.compiler.compile_grammar(grammar_spec) elif request_type == StructuredOutputOptions.REGEX: ctx = self.compiler.compile_regex(grammar_spec) + elif request_type == StructuredOutputOptions.STRUCTURAL_TAG: + s_tag = json.loads(grammar_spec) + tags = [ + xgr.StructuralTagItem( + begin=s["begin"], + schema=json.dumps(s["schema"]), + end=s["end"], + ) for s in s_tag["structures"] + ] + ctx = self.compiler.compile_structural_tag(tags, s_tag["triggers"]) else: logger.error( "Validation should have already occurred. Please file an issue." @@ -272,3 +282,18 @@ def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None: xgr.Grammar.from_ebnf(gd_params.grammar) except Exception as e: raise ValueError("Invalid grammar specification.") from e + return + + if gd_params.structural_tag: + try: + s_tag = json.loads(gd_params.structural_tag) + tags = [ + xgr.StructuralTagItem( + begin=s["begin"], + schema=json.dumps(s["schema"]), + end=s["end"], + ) for s in s_tag["structures"] + ] + xgr.Grammar.from_structural_tag(tags, s_tag["triggers"]) + except Exception as e: + raise ValueError("Invalid structural tag specification.") from e diff --git a/vllm/v1/structured_output/request.py b/vllm/v1/structured_output/request.py index 9e54b8bf028d..6ef472eb896c 100644 --- a/vllm/v1/structured_output/request.py +++ b/vllm/v1/structured_output/request.py @@ -78,5 +78,7 @@ def get_structured_output_key( return (StructuredOutputOptions.CHOICE, json_str) elif params.grammar is not None: return (StructuredOutputOptions.GRAMMAR, params.grammar) + elif params.structural_tag is not None: + return (StructuredOutputOptions.STRUCTURAL_TAG, params.structural_tag) else: raise ValueError("No valid structured output parameter found")