mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:25:00 +08:00
[V1] Add structural_tag support using xgrammar (#17085)
This commit is contained in:
parent
c48334d405
commit
f8acd01ff7
@ -0,0 +1,85 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from openai import OpenAI
|
||||
|
||||
# This example demonstrates the `structural_tag` response format.
|
||||
# It can be used to specify a structured output format that occurs between
|
||||
# specific tags in the response. This example shows how it could be used
|
||||
# to enforce the format of a tool call response, but it could be used for
|
||||
# any structured output within a subset of the response.
|
||||
|
||||
|
||||
def main():
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:8000/v1",
|
||||
api_key="-",
|
||||
)
|
||||
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"""
|
||||
You have access to the following function to retrieve the weather in a city:
|
||||
|
||||
{
|
||||
"name": "get_weather",
|
||||
"parameters": {
|
||||
"city": {
|
||||
"param_type": "string",
|
||||
"description": "The city to get the weather for",
|
||||
"required": True
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
If a you choose to call a function ONLY reply in the following format:
|
||||
<{start_tag}={function_name}>{parameters}{end_tag}
|
||||
where
|
||||
|
||||
start_tag => `<function`
|
||||
parameters => a JSON dict with the function argument name as key and function
|
||||
argument value as value.
|
||||
end_tag => `</function>`
|
||||
|
||||
Here is an example,
|
||||
<function=example_function_name>{"example_name": "example_value"}</function>
|
||||
|
||||
Reminder:
|
||||
- Function calls MUST follow the specified format
|
||||
- Required parameters MUST be specified
|
||||
- Only call one function at a time
|
||||
- Put the entire function call reply on one line
|
||||
- Always add your sources when using search results to answer the user query
|
||||
|
||||
You are a helpful assistant.
|
||||
|
||||
Given the previous instructions, what is the weather in New York City, Boston,
|
||||
and San Francisco?
|
||||
"""
|
||||
}]
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
messages=messages,
|
||||
response_format={
|
||||
"type":
|
||||
"structural_tag",
|
||||
"structures": [{
|
||||
"begin": "<function=get_weather>",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"end": "</function>"
|
||||
}],
|
||||
"triggers": ["<function="]
|
||||
})
|
||||
print(response)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -350,6 +350,7 @@ def test_structured_output(
|
||||
temperature=1.0,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(json=json_schema))
|
||||
|
||||
outputs = llm.generate(
|
||||
prompts="Generate a description of a frog using 50 characters.",
|
||||
sampling_params=sampling_params,
|
||||
@ -368,6 +369,106 @@ def test_structured_output(
|
||||
output_json = json.loads(generated_text)
|
||||
jsonschema.validate(instance=output_json, schema=json_schema)
|
||||
|
||||
#
|
||||
# Test 11: Generate structured output using structural_tag format
|
||||
#
|
||||
structural_tag_config = {
|
||||
"type":
|
||||
"structural_tag",
|
||||
"structures": [{
|
||||
"begin": "<function=get_weather>",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"end": "</function>"
|
||||
}],
|
||||
"triggers": ["<function="]
|
||||
}
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=100,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
structural_tag=json.dumps(structural_tag_config)))
|
||||
|
||||
prompt = """
|
||||
You have access to the following function to retrieve the weather in a city:
|
||||
|
||||
{
|
||||
"name": "get_weather",
|
||||
"parameters": {
|
||||
"city": {
|
||||
"param_type": "string",
|
||||
"description": "The city to get the weather for",
|
||||
"required": True
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
If a you choose to call a function ONLY reply in the following format:
|
||||
<{start_tag}={function_name}>{parameters}{end_tag}
|
||||
where
|
||||
|
||||
start_tag => `<function`
|
||||
parameters => a JSON dict with the function argument name
|
||||
as key and function argument value as value.
|
||||
end_tag => `</function>`
|
||||
|
||||
Here is an example,
|
||||
<function=example_function_name>{"example_name": "example_value"}</function>
|
||||
|
||||
Reminder:
|
||||
- Function calls MUST follow the specified format
|
||||
- Required parameters MUST be specified
|
||||
- Only call one function at a time
|
||||
- Put the entire function call reply on one line
|
||||
- Always add your sources when using search results to answer the user query
|
||||
|
||||
You are a helpful assistant.
|
||||
|
||||
Given the previous instructions, what is the weather in New York City?
|
||||
"""
|
||||
|
||||
# Change this once other backends support structural_tag
|
||||
if guided_decoding_backend.startswith("xgrammar"):
|
||||
outputs = llm.generate(prompts=prompt,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
assert outputs is not None
|
||||
else:
|
||||
outputs = []
|
||||
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
|
||||
# Search for function call pattern in the response
|
||||
function_call_pattern = r'<function=get_weather>(.*?)</function>'
|
||||
matches = re.findall(function_call_pattern, generated_text)
|
||||
|
||||
if not matches:
|
||||
print(f"Warning: No function calls found in response: "
|
||||
f"{generated_text!r}")
|
||||
continue
|
||||
|
||||
# Take the first function call if multiple are found
|
||||
json_str = matches[0]
|
||||
try:
|
||||
json_content = json.loads(json_str)
|
||||
assert "city" in json_content
|
||||
assert isinstance(json_content["city"], str)
|
||||
print(f"Found valid function call: {generated_text!r}")
|
||||
except (json.JSONDecodeError, AssertionError) as e:
|
||||
pytest.fail("Invalid function call format: "
|
||||
f"{generated_text!r}\nError: {str(e)}")
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("model_name, tokenizer_mode",
|
||||
|
||||
@ -1396,7 +1396,9 @@ class LLM:
|
||||
grammar=guided_options.guided_grammar,
|
||||
json_object=guided_options.guided_json_object,
|
||||
backend=guided_options.guided_decoding_backend,
|
||||
whitespace_pattern=guided_options.guided_whitespace_pattern)
|
||||
whitespace_pattern=guided_options.guided_whitespace_pattern,
|
||||
structural_tag=guided_options.structural_tag,
|
||||
)
|
||||
return params
|
||||
|
||||
def _run_engine(
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from argparse import Namespace
|
||||
@ -139,12 +140,30 @@ class JsonSchemaResponseFormat(OpenAIBaseModel):
|
||||
strict: Optional[bool] = None
|
||||
|
||||
|
||||
class StructuralTag(OpenAIBaseModel):
|
||||
begin: str
|
||||
# schema is the field, but that causes conflicts with pydantic so
|
||||
# instead use structural_tag_schema with an alias
|
||||
structural_tag_schema: Optional[dict[str, Any]] = Field(default=None,
|
||||
alias="schema")
|
||||
end: str
|
||||
|
||||
|
||||
class StructuralTagResponseFormat(OpenAIBaseModel):
|
||||
type: Literal["structural_tag"]
|
||||
structures: list[StructuralTag]
|
||||
triggers: list[str]
|
||||
|
||||
|
||||
class ResponseFormat(OpenAIBaseModel):
|
||||
# type must be "json_schema", "json_object" or "text"
|
||||
# type must be "json_schema", "json_object", or "text"
|
||||
type: Literal["text", "json_object", "json_schema"]
|
||||
json_schema: Optional[JsonSchemaResponseFormat] = None
|
||||
|
||||
|
||||
AnyResponseFormat = Union[ResponseFormat, StructuralTagResponseFormat]
|
||||
|
||||
|
||||
class StreamOptions(OpenAIBaseModel):
|
||||
include_usage: Optional[bool] = True
|
||||
continuous_usage_stats: Optional[bool] = False
|
||||
@ -227,7 +246,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
max_completion_tokens: Optional[int] = None
|
||||
n: Optional[int] = 1
|
||||
presence_penalty: Optional[float] = 0.0
|
||||
response_format: Optional[ResponseFormat] = None
|
||||
response_format: Optional[AnyResponseFormat] = None
|
||||
seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
|
||||
stop: Optional[Union[str, list[str]]] = Field(default_factory=list)
|
||||
stream: Optional[bool] = False
|
||||
@ -340,6 +359,11 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
description=(
|
||||
"If specified, the output will follow the context free grammar."),
|
||||
)
|
||||
structural_tag: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, the output will follow the structural tag schema."),
|
||||
)
|
||||
guided_decoding_backend: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
@ -476,6 +500,12 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
json_schema = self.response_format.json_schema
|
||||
assert json_schema is not None
|
||||
self.guided_json = json_schema.json_schema
|
||||
elif self.response_format.type == "structural_tag":
|
||||
structural_tag = self.response_format
|
||||
assert structural_tag is not None and isinstance(
|
||||
structural_tag, StructuralTagResponseFormat)
|
||||
s_tag_obj = structural_tag.model_dump(by_alias=True)
|
||||
self.structural_tag = json.dumps(s_tag_obj)
|
||||
|
||||
guided_decoding = GuidedDecodingParams.from_optional(
|
||||
json=self._get_guided_json_from_tool() or self.guided_json,
|
||||
@ -485,6 +515,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
json_object=guided_json_object,
|
||||
backend=self.guided_decoding_backend,
|
||||
whitespace_pattern=self.guided_whitespace_pattern,
|
||||
structural_tag=self.structural_tag,
|
||||
)
|
||||
|
||||
return SamplingParams.from_optional(
|
||||
@ -742,12 +773,13 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
"If true (the default), special tokens (e.g. BOS) will be added to "
|
||||
"the prompt."),
|
||||
)
|
||||
response_format: Optional[ResponseFormat] = Field(
|
||||
response_format: Optional[AnyResponseFormat] = Field(
|
||||
default=None,
|
||||
description=
|
||||
("Similar to chat completion, this parameter specifies the format of "
|
||||
"output. Only {'type': 'json_object'}, {'type': 'json_schema'} or "
|
||||
"{'type': 'text' } is supported."),
|
||||
description=(
|
||||
"Similar to chat completion, this parameter specifies the format "
|
||||
"of output. Only {'type': 'json_object'}, {'type': 'json_schema'}"
|
||||
", {'type': 'structural_tag'}, or {'type': 'text' } is supported."
|
||||
),
|
||||
)
|
||||
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
|
||||
default=None,
|
||||
|
||||
@ -27,14 +27,15 @@ class GuidedDecodingRequest:
|
||||
guided_decoding_backend: Optional[str] = None
|
||||
guided_whitespace_pattern: Optional[str] = None
|
||||
guided_json_object: Optional[bool] = None
|
||||
structural_tag: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate that some fields are mutually exclusive."""
|
||||
guide_count = sum([
|
||||
self.guided_json is not None, self.guided_regex is not None,
|
||||
self.guided_choice is not None, self.guided_grammar is not None,
|
||||
self.guided_json_object is not None
|
||||
])
|
||||
guide_count = sum(x is not None
|
||||
for x in (self.guided_json, self.guided_regex,
|
||||
self.guided_choice, self.guided_grammar,
|
||||
self.guided_json_object,
|
||||
self.structural_tag))
|
||||
if guide_count > 1:
|
||||
raise ValueError(
|
||||
"You can only use one kind of guided decoding but multiple are "
|
||||
|
||||
@ -38,6 +38,7 @@ class GuidedDecodingParams:
|
||||
"""These are other options that can be set"""
|
||||
backend: Optional[str] = None
|
||||
whitespace_pattern: Optional[str] = None
|
||||
structural_tag: Optional[str] = None
|
||||
|
||||
@staticmethod
|
||||
def from_optional(
|
||||
@ -48,9 +49,10 @@ class GuidedDecodingParams:
|
||||
json_object: Optional[bool] = None,
|
||||
backend: Optional[str] = None,
|
||||
whitespace_pattern: Optional[str] = None,
|
||||
structural_tag: Optional[str] = None,
|
||||
) -> Optional["GuidedDecodingParams"]:
|
||||
if all(arg is None
|
||||
for arg in (json, regex, choice, grammar, json_object)):
|
||||
if all(arg is None for arg in (json, regex, choice, grammar,
|
||||
json_object, structural_tag)):
|
||||
return None
|
||||
# Extract json schemas from pydantic models
|
||||
if isinstance(json, (BaseModel, type(BaseModel))):
|
||||
@ -63,6 +65,7 @@ class GuidedDecodingParams:
|
||||
json_object=json_object,
|
||||
backend=backend,
|
||||
whitespace_pattern=whitespace_pattern,
|
||||
structural_tag=structural_tag,
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@ -194,6 +194,9 @@ def serialize_guidance_grammar(
|
||||
tp = "grammar"
|
||||
elif request_type == StructuredOutputOptions.CHOICE:
|
||||
tp = "choice"
|
||||
elif request_type == StructuredOutputOptions.STRUCTURAL_TAG:
|
||||
raise ValueError("Structural tag is not supported "
|
||||
"for guidance backend yet")
|
||||
else:
|
||||
logger.error("Validation should have already occurred. "
|
||||
"Please file an issue.")
|
||||
|
||||
@ -12,6 +12,7 @@ class StructuredOutputOptions(enum.Enum):
|
||||
REGEX = enum.auto()
|
||||
GRAMMAR = enum.auto()
|
||||
CHOICE = enum.auto()
|
||||
STRUCTURAL_TAG = enum.auto()
|
||||
|
||||
|
||||
StructuredOutputKey = tuple[StructuredOutputOptions, str]
|
||||
|
||||
@ -108,6 +108,16 @@ class XgrammarBackend(StructuredOutputBackend):
|
||||
ctx = self.compiler.compile_grammar(grammar_spec)
|
||||
elif request_type == StructuredOutputOptions.REGEX:
|
||||
ctx = self.compiler.compile_regex(grammar_spec)
|
||||
elif request_type == StructuredOutputOptions.STRUCTURAL_TAG:
|
||||
s_tag = json.loads(grammar_spec)
|
||||
tags = [
|
||||
xgr.StructuralTagItem(
|
||||
begin=s["begin"],
|
||||
schema=json.dumps(s["schema"]),
|
||||
end=s["end"],
|
||||
) for s in s_tag["structures"]
|
||||
]
|
||||
ctx = self.compiler.compile_structural_tag(tags, s_tag["triggers"])
|
||||
else:
|
||||
logger.error(
|
||||
"Validation should have already occurred. Please file an issue."
|
||||
@ -272,3 +282,18 @@ def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None:
|
||||
xgr.Grammar.from_ebnf(gd_params.grammar)
|
||||
except Exception as e:
|
||||
raise ValueError("Invalid grammar specification.") from e
|
||||
return
|
||||
|
||||
if gd_params.structural_tag:
|
||||
try:
|
||||
s_tag = json.loads(gd_params.structural_tag)
|
||||
tags = [
|
||||
xgr.StructuralTagItem(
|
||||
begin=s["begin"],
|
||||
schema=json.dumps(s["schema"]),
|
||||
end=s["end"],
|
||||
) for s in s_tag["structures"]
|
||||
]
|
||||
xgr.Grammar.from_structural_tag(tags, s_tag["triggers"])
|
||||
except Exception as e:
|
||||
raise ValueError("Invalid structural tag specification.") from e
|
||||
|
||||
@ -78,5 +78,7 @@ def get_structured_output_key(
|
||||
return (StructuredOutputOptions.CHOICE, json_str)
|
||||
elif params.grammar is not None:
|
||||
return (StructuredOutputOptions.GRAMMAR, params.grammar)
|
||||
elif params.structural_tag is not None:
|
||||
return (StructuredOutputOptions.STRUCTURAL_TAG, params.structural_tag)
|
||||
else:
|
||||
raise ValueError("No valid structured output parameter found")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user