[V0][V1][Core] Add outlines integration for V1, and update V0 integration. (#15975)

Signed-off-by: Nathan Hoos <thwackyy.y@gmail.com>
This commit is contained in:
Nathan Hoos 2025-07-10 14:30:26 -05:00 committed by GitHub
parent 5e53c89a74
commit d6902ce79f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 804 additions and 461 deletions

View File

@ -21,7 +21,9 @@ prometheus-fastapi-instrumentator >= 7.0.0
tiktoken >= 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer >= 0.10.11, < 0.11
llguidance >= 0.7.11, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64"
outlines == 0.1.11
outlines_core == 0.2.10
# required for outlines backend disk cache
diskcache == 5.6.3
lark == 1.2.2
xgrammar == 0.1.19; platform_machine == "x86_64" or platform_machine == "aarch64" or platform_machine == "arm64"
typing_extensions >= 4.10

View File

@ -16,14 +16,18 @@ from vllm.outputs import RequestOutput
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
GUIDED_DECODING_BACKENDS = [
# Separate backends which support grammars vs ones
# which only support regex based constraints in tests.
GRAMMAR_DECODING_BACKENDS = [
# (backend, disable_any_whitespace),
("outlines", False),
("lm-format-enforcer", False),
("xgrammar", True),
("guidance", True),
]
ALL_DECODING_BACKENDS = ([("outlines", False)] + GRAMMAR_DECODING_BACKENDS)
@pytest.fixture(scope="module")
def llm():
@ -39,7 +43,7 @@ def llm():
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
ALL_DECODING_BACKENDS)
def test_guided_regex(sample_regex, llm, guided_decoding_backend: str,
disable_any_whitespace: bool):
sampling_params = SamplingParams(
@ -49,6 +53,7 @@ def test_guided_regex(sample_regex, llm, guided_decoding_backend: str,
regex=sample_regex,
backend=guided_decoding_backend,
disable_any_whitespace=disable_any_whitespace))
outputs = llm.generate(prompts=[
f"Give an example IPv4 address with this regex: {sample_regex}"
] * 2,
@ -69,7 +74,7 @@ def test_guided_regex(sample_regex, llm, guided_decoding_backend: str,
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
ALL_DECODING_BACKENDS)
def test_guided_json_completion(sample_json_schema, llm,
guided_decoding_backend: str,
disable_any_whitespace: bool):
@ -103,7 +108,7 @@ def test_guided_json_completion(sample_json_schema, llm,
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
ALL_DECODING_BACKENDS)
def test_guided_complex_json_completion(sample_complex_json_schema, llm,
guided_decoding_backend: str,
disable_any_whitespace: bool):
@ -138,7 +143,7 @@ def test_guided_complex_json_completion(sample_complex_json_schema, llm,
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
ALL_DECODING_BACKENDS)
def test_guided_definition_json_completion(sample_definition_json_schema, llm,
guided_decoding_backend: str,
disable_any_whitespace: bool):
@ -173,7 +178,7 @@ def test_guided_definition_json_completion(sample_definition_json_schema, llm,
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
ALL_DECODING_BACKENDS)
def test_guided_enum_json_completion(sample_enum_json_schema, llm,
guided_decoding_backend: str,
disable_any_whitespace: bool):
@ -218,7 +223,7 @@ def test_guided_enum_json_completion(sample_enum_json_schema, llm,
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
ALL_DECODING_BACKENDS)
def test_guided_choice_completion(sample_guided_choice, llm,
guided_decoding_backend: str,
disable_any_whitespace: bool):
@ -248,7 +253,7 @@ def test_guided_choice_completion(sample_guided_choice, llm,
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
GRAMMAR_DECODING_BACKENDS)
def test_guided_grammar(sample_sql_statements, llm,
guided_decoding_backend: str,
disable_any_whitespace: bool):
@ -344,7 +349,7 @@ def test_disable_guided_decoding_fallback(sample_regex, llm):
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
GRAMMAR_DECODING_BACKENDS)
def test_guided_json_object(llm, guided_decoding_backend: str,
disable_any_whitespace: bool):
sampling_params = SamplingParams(
@ -377,7 +382,9 @@ def test_guided_json_object(llm, guided_decoding_backend: str,
# Parse to verify it is valid JSON
parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict)
# A list is not what was intended, but is still valid
# json.
assert isinstance(parsed_json, (dict, list))
class CarType(str, Enum):
@ -395,7 +402,7 @@ class CarDescription(BaseModel):
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
ALL_DECODING_BACKENDS)
def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str,
disable_any_whitespace: bool):
json_schema = CarDescription.model_json_schema()
@ -427,7 +434,7 @@ def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str,
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
ALL_DECODING_BACKENDS)
def test_guided_number_range_json_completion(llm, guided_decoding_backend: str,
disable_any_whitespace: bool):
sample_output_schema = {

View File

@ -46,20 +46,15 @@ def test_guided_logits_processors(zephyr_7B_tokenzer, sample_regex,
whitespace_pattern=None,
reasoner=None)
token_ids = zephyr_7B_tokenzer.encode(
f"Give an example IPv4 address with this regex: {sample_regex}")
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
regex_LP(token_ids, tensor)
tensor = regex_LP([], tensor)
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)
token_ids = zephyr_7B_tokenzer.encode(
f"Give an employee profile that fits this schema: {sample_json_schema}"
)
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
json_LP(token_ids, tensor)
tensor = json_LP([], tensor)
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)
@ -81,8 +76,6 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
seed=0,
dtype="bfloat16",
)
token_ids = zephyr_7B_tokenzer.encode(
f"Give an example IPv4 address with this regex: {sample_regex}")
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
regex_lp = get_local_guided_decoding_logits_processor(
@ -92,13 +85,11 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
assert regex_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
tensor = regex_lp(token_ids, tensor)
# allowed tokens at state 0
tensor = regex_lp([], tensor)
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)
token_ids = zephyr_7B_tokenzer.encode(
f"Give an employee profile that fits this schema: {sample_json_schema}"
)
json_request = GuidedDecodingParams(json=sample_json_schema,
backend=backend)
json_lp = await get_guided_decoding_logits_processor(
@ -106,7 +97,7 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
assert json_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
tensor = json_lp(token_ids, tensor)
tensor = json_lp([], tensor)
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)
@ -130,7 +121,6 @@ async def test_guided_logits_processor_with_reasoning(
dtype="bfloat16",
)
token_ids = deepseek_r1_qwen_tokenizer.encode(
f"Give an example IPv4 address with this regex: {sample_regex}."
"<think>here is the thinking process")
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
@ -141,14 +131,13 @@ async def test_guided_logits_processor_with_reasoning(
regex_request, deepseek_r1_qwen_tokenizer, config,
reasoning_backend)
assert regex_lp is not None
tensor = torch.rand(32000)
tensor = torch.rand(151664)
original_tensor = torch.clone(tensor)
tensor = regex_lp(token_ids, tensor)
assert tensor.shape == original_tensor.shape
assert torch.allclose(tensor, original_tensor)
token_ids = deepseek_r1_qwen_tokenizer.encode(
f"Give an employee profile that fits this schema: {sample_json_schema}."
"<think>here is the thinking process")
json_request = GuidedDecodingParams(json=sample_json_schema,
backend=backend)
@ -158,7 +147,7 @@ async def test_guided_logits_processor_with_reasoning(
await get_guided_decoding_logits_processor(
json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend)
assert json_lp is not None
tensor = torch.rand(32000)
tensor = torch.rand(151664)
original_tensor = torch.clone(tensor)
tensor = json_lp(token_ids, tensor)
assert tensor.shape == original_tensor.shape
@ -166,8 +155,7 @@ async def test_guided_logits_processor_with_reasoning(
# Thinking is over, so the tensor should change.
token_ids = deepseek_r1_qwen_tokenizer.encode(
f"Give an employee profile that fits this schema: {sample_json_schema}."
"<think>here is the thinking process</think> Then")
"<think>here is the thinking process</think>")
json_request = GuidedDecodingParams(json=sample_json_schema,
backend=backend)
json_lp = get_local_guided_decoding_logits_processor(
@ -176,7 +164,7 @@ async def test_guided_logits_processor_with_reasoning(
await get_guided_decoding_logits_processor(
json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend)
assert json_lp is not None
tensor = torch.rand(32000)
tensor = torch.rand(151664)
original_tensor = torch.clone(tensor)
tensor = json_lp(token_ids, tensor)
assert tensor.shape == original_tensor.shape

View File

@ -72,7 +72,7 @@ def _compile_and_check(tools: list[ChatCompletionToolsParam], sample_output,
assert isinstance(schema, dict)
# use build_regex_from_schema used in JSONLogitsProcessor to create Guide
from outlines_core.fsm.json_schema import build_regex_from_schema
from outlines_core.json_schema import build_regex_from_schema
regex = build_regex_from_schema(json.dumps(schema))
compiled = re.compile(regex)
matches = compiled.fullmatch(json.dumps(sample_output)) is not None

View File

@ -41,6 +41,10 @@ PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None),
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None),
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None),
("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", None),
("mistralai/Ministral-8B-Instruct-2410", "outlines", "mistral", None),
("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto",
NGRAM_SPEC_CONFIG),
#FIXME: This test is flaky on CI thus disabled
#("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"),
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto",
@ -106,13 +110,15 @@ def test_structured_output(
enforce_eager = bool(not current_platform.is_tpu())
# Use a single LLM instance for several scenarios to
# speed up the test suite.
llm = LLM(model=model_name,
enforce_eager=enforce_eager,
max_model_len=1024,
guided_decoding_backend=guided_decoding_backend,
guided_decoding_disable_any_whitespace=True,
tokenizer_mode=tokenizer_mode,
speculative_config=speculative_config)
llm = LLM(
model=model_name,
enforce_eager=enforce_eager,
max_model_len=1024,
guided_decoding_backend=guided_decoding_backend,
guided_decoding_disable_any_whitespace=(guided_decoding_backend
in {"xgrammar", "guidance"}),
tokenizer_mode=tokenizer_mode,
speculative_config=speculative_config)
#
# Test 1: Generate JSON output based on a provided schema
@ -146,32 +152,33 @@ def test_structured_output(
#
# Test 2: Generate JSON object without a schema
#
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=4096,
n=2,
guided_decoding=GuidedDecodingParams(json_object=True))
if guided_decoding_backend != "outlines":
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=4096,
n=2,
guided_decoding=GuidedDecodingParams(json_object=True))
outputs = llm.generate(
prompts=("Generate a JSON object with curly braces for a person with "
"name and age fields for John Smith who is 31 years old. "
"Make the response as short as possible."),
sampling_params=sampling_params,
use_tqdm=True)
outputs = llm.generate(prompts=(
"Generate a JSON object with curly braces for a person with "
"name and age fields for John Smith who is 31 years old. "
"Make the response as short as possible."),
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
for i in range(2):
generated_text = output.outputs[i].text
print(generated_text)
assert generated_text is not None
for i in range(2):
generated_text = output.outputs[i].text
print(generated_text)
assert generated_text is not None
# Parse to verify it is a valid JSON object
parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict)
# Parse to verify it is a valid JSON object
parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict)
#
# Test 3: test a jsonschema incompatible with xgrammar
@ -210,97 +217,98 @@ def test_structured_output(
parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict)
#
# Test 4: Generate SQL statement using EBNF grammar
#
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(grammar=sample_sql_ebnf))
outputs = llm.generate(
prompts=(
"Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1. Make the response as short as "
"possible."),
sampling_params=sampling_params,
use_tqdm=True,
)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
# remove spaces for comparison b/c we removed them in the grammar
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
" ", "")
assert generated_text.strip() == ground_truth
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
#
# Test 5: Generate SQL statement using Lark grammar
#
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(grammar=sample_sql_lark))
outputs = llm.generate(
prompts=(
"Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1. Make the response as short as "
"possible."),
sampling_params=sampling_params,
use_tqdm=True,
)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
# use Lark to parse the output, and make sure it's a valid parse tree
from lark import Lark
parser = Lark(sample_sql_lark)
parser.parse(generated_text)
# remove spaces for comparison b/c we removed them in the grammar
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
" ", "")
assert generated_text.strip() == ground_truth
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
#
# Test 6: Test invalid grammar input
#
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(grammar="not a grammar"))
with pytest.raises(ValueError, match="Failed to convert the grammar "):
llm.generate(
if guided_decoding_backend != "outlines":
#
# Test 4: Generate SQL statement using EBNF grammar
#
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(grammar=sample_sql_ebnf))
outputs = llm.generate(
prompts=(
"Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1. Make the response as short "
"as possible."),
"table_1 where it is equal to 1. Make the response as short as "
"possible."),
sampling_params=sampling_params,
use_tqdm=True,
)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
# remove spaces for comparison b/c we removed them in the grammar
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
" ", "")
assert generated_text.strip() == ground_truth
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
#
# Test 5: Generate SQL statement using Lark grammar
#
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(grammar=sample_sql_lark))
outputs = llm.generate(
prompts=(
"Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1. Make the response as short as "
"possible."),
sampling_params=sampling_params,
use_tqdm=True,
)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
# use Lark to parse the output, and make sure it's a valid parse tree
from lark import Lark
parser = Lark(sample_sql_lark)
parser.parse(generated_text)
# remove spaces for comparison b/c we removed them in the grammar
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
" ", "")
assert generated_text.strip() == ground_truth
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
#
# Test 6: Test invalid grammar input
#
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(grammar="not a grammar"))
with pytest.raises(ValueError, match="Failed to convert the grammar "):
llm.generate(
prompts=
("Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1. Make the response as short "
"as possible."),
sampling_params=sampling_params,
use_tqdm=True,
)
#
# Test 7: Generate text based on a regex pattern
#
@ -421,35 +429,36 @@ def test_structured_output(
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=json_schema)
#
# Test 11: Generate structured output using structural_tag format
#
structural_tag_config = {
"type":
"structural_tag",
"structures": [{
"begin": "<function=get_weather>",
"schema": {
"type": "object",
"properties": {
"city": {
"type": "string"
}
if guided_decoding_backend != "outlines":
#
# Test 11: Generate structured output using structural_tag format
#
structural_tag_config = {
"type":
"structural_tag",
"structures": [{
"begin": "<function=get_weather>",
"schema": {
"type": "object",
"properties": {
"city": {
"type": "string"
}
},
"additionalProperties": False
},
"additionalProperties": False
},
"end": "</function>"
}],
"triggers": ["<function="]
}
"end": "</function>"
}],
"triggers": ["<function="]
}
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=4096,
guided_decoding=GuidedDecodingParams(
structural_tag=json.dumps(structural_tag_config)))
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=4096,
guided_decoding=GuidedDecodingParams(
structural_tag=json.dumps(structural_tag_config)))
prompt = """
prompt = """
You have access to the following function to retrieve the weather in a city:
{
@ -469,7 +478,7 @@ where
start_tag => `<function`
parameters => a JSON dict with the function argument name
as key and function argument value as value.
as key and function argument value as value.
end_tag => `</function>`
Here is an example,
@ -488,37 +497,37 @@ Given the previous instructions, what is the weather in New York City? \
Make the response as short as possible.
"""
# Change this once other backends support structural_tag
outputs = llm.generate(prompts=prompt,
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
# Change this once other backends support structural_tag
outputs = llm.generate(prompts=prompt,
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
generated_text = output.outputs[0].text
assert generated_text is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
generated_text = output.outputs[0].text
assert generated_text is not None
# Search for function call pattern in the response
function_call_pattern = r'<function=get_weather>(.*?)</function>'
matches = re.findall(function_call_pattern, generated_text)
# Search for function call pattern in the response
function_call_pattern = r'<function=get_weather>(.*?)</function>'
matches = re.findall(function_call_pattern, generated_text)
if not matches:
print(f"Warning: No function calls found in response: "
f"{generated_text!r}")
continue
if not matches:
print(f"Warning: No function calls found in response: "
f"{generated_text!r}")
continue
# Take the first function call if multiple are found
json_str = matches[0]
try:
json_content = json.loads(json_str)
assert "city" in json_content
assert isinstance(json_content["city"], str)
print(f"Found valid function call: {generated_text!r}")
except (json.JSONDecodeError, AssertionError) as e:
pytest.fail("Invalid function call format: "
f"{generated_text!r}\nError: {str(e)}")
# Take the first function call if multiple are found
json_str = matches[0]
try:
json_content = json.loads(json_str)
assert "city" in json_content
assert isinstance(json_content["city"], str)
print(f"Found valid function call: {generated_text!r}")
except (json.JSONDecodeError, AssertionError) as e:
pytest.fail("Invalid function call format: "
f"{generated_text!r}\nError: {str(e)}")
@pytest.mark.skip_global_cleanup

View File

@ -3580,7 +3580,8 @@ def get_served_model_name(model: str,
GuidedDecodingBackendV0 = Literal["auto", "outlines", "lm-format-enforcer",
"xgrammar", "guidance"]
GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance"]
GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance", "outlines"]
GuidedDecodingBackend = Literal[GuidedDecodingBackendV0,
GuidedDecodingBackendV1]

View File

@ -117,6 +117,7 @@ if TYPE_CHECKING:
VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
VLLM_V0_USE_OUTLINES_CACHE: bool = False
VLLM_V1_USE_OUTLINES_CACHE: bool = False
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None
VLLM_USE_DEEP_GEMM: bool = False
@ -847,6 +848,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_V0_USE_OUTLINES_CACHE":
lambda: os.environ.get("VLLM_V0_USE_OUTLINES_CACHE", "0") == "1",
# Whether to turn on the outlines cache for V1
# This cache is unbounded and on disk, so it's not safe to use in
# an environment with potentially malicious users.
"VLLM_V1_USE_OUTLINES_CACHE":
lambda: os.environ.get("VLLM_V1_USE_OUTLINES_CACHE", "0") == "1",
# Gap between padding buckets for the forward pass. So we have
# 8, we will run forward pass with [16, 24, 32, ...].
"VLLM_TPU_BUCKET_PADDING_GAP":

View File

@ -79,20 +79,33 @@ def maybe_backend_fallback(
fallback_or_error(
guided_params,
"xgrammar does not support Lark grammars and the "
"grammar failed to convert to GBNF.", "outlines")
"grammar failed to convert to GBNF.", "guidance")
# If the xgrammar module cannot be imported successfully,
# we should still allow users to use guided decoding with a fallback.
elif not xgr_installed:
fallback_or_error(
guided_params,
"xgrammar module cannot be imported successfully.", "outlines")
"xgrammar module cannot be imported successfully.", "guidance")
if (guided_params.backend == "outlines"
and guided_params.json_object is not None):
# outlines doesn't support json_object, fallback to guidance
fallback_or_error(guided_params,
"outlines does not support json_object.", "guidance")
if guided_params.backend == "outlines":
if guided_params.json_object is not None:
# outlines doesn't support json_object, fallback to guidance
fallback_or_error(guided_params,
"outlines does not support json_object.",
"guidance")
elif guided_params.grammar is not None:
# outlines grammar support has been removed, fallback to guidance
# if it is a lark-based grammar and xgrammar otherwise
if grammar_is_likely_lark(guided_params.grammar):
fallback_or_error(guided_params,
"outlines no longer supports grammars.",
"guidance")
else:
# The grammar is likely already GBNF format.
fallback_or_error(guided_params,
"outlines no longer supports grammars.",
"xgrammar")
return guided_params
@ -111,7 +124,6 @@ async def get_guided_decoding_logits_processor(
guided_params = maybe_backend_fallback(guided_params)
# CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend == 'outlines':
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
@ -152,7 +164,6 @@ def get_local_guided_decoding_logits_processor(
reasoning_backend)
reasoner = reasoner_class(tokenizer)
# CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend == 'outlines':
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa

View File

@ -12,7 +12,7 @@ from regex import escape as regex_escape
from transformers import PreTrainedTokenizerBase
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
JSONLogitsProcessor, RegexLogitsProcessor)
from vllm.reasoning import ReasoningParser
from vllm.sampling_params import GuidedDecodingParams
@ -21,36 +21,8 @@ class GuidedDecodingMode(Enum):
JSON = "json"
REGEX = "regex"
CHOICE = "choice"
GRAMMAR = "grammar"
# https://github.com/outlines-dev/outlines/blob/main/outlines/grammars/json.lark
# the main difference is that we changed the start: value to
# start: object | array, so we are denying scalar values as the root of the
# JSON. Starting with scalars as the root seems to cause llama to generate
# without stop.
JSON_GRAMMAR = r"""
?start: object | array
?value: object
| array
| UNESCAPED_STRING
| SIGNED_NUMBER -> number
| "true" -> true
| "false" -> false
| "null" -> null
array : "[" [value ("," value)*] "]"
object : "{" [pair ("," pair)*] "}"
pair : UNESCAPED_STRING ":" value
%import common.UNESCAPED_STRING
%import common.SIGNED_NUMBER
%import common.WS
%ignore WS
"""
global_thread_pool = None # used for generating logits processor fsm
# It's not yet clear that using more provides a benefit, and it could
@ -60,16 +32,12 @@ _MAX_THREADPOOL_WORKERS = 16
async def get_outlines_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[ReasoningParser],
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
None]:
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[ReasoningParser]
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, None]:
"""
Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide.
We cache logit processors by (guide, tokenizer), and on cache hit
we make a shallow copy to reuse the same underlying FSM.
"""
global global_thread_pool
guide, mode = _get_guide_and_mode(guided_params)
@ -83,7 +51,6 @@ async def get_outlines_guided_decoding_logits_processor(
global_thread_pool = concurrent.futures.ThreadPoolExecutor(
max_workers=max_workers)
loop = asyncio.get_running_loop()
return await loop.run_in_executor(global_thread_pool,
_get_logits_processor, guide, tokenizer,
mode, guided_params.whitespace_pattern,
@ -91,16 +58,12 @@ async def get_outlines_guided_decoding_logits_processor(
def get_local_outlines_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[ReasoningParser],
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
None]:
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[ReasoningParser]
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, None]:
"""
Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide.
We cache logit processors by (guide, tokenizer), and on cache hit
we make a shallow copy to reuse the same underlying FSM.
"""
guide, mode = _get_guide_and_mode(guided_params)
if not guide or not mode:
@ -130,9 +93,10 @@ def _get_guide_and_mode(
choices_regex = "(" + "|".join(choices) + ")"
return choices_regex, GuidedDecodingMode.CHOICE
elif guided_params.grammar:
return guided_params.grammar, GuidedDecodingMode.GRAMMAR
elif guided_params.json_object:
return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR
raise ValueError(
"The `outlines` guided decoding backend no longer supports grammar "
"guided generation. Please use either the `xgrammar` or `guidance` "
"backend")
else:
return None, None
@ -143,13 +107,11 @@ def _get_logits_processor(
mode: GuidedDecodingMode,
whitespace_pattern: Union[str, None],
reasoner: Optional[ReasoningParser],
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]:
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]:
if mode == GuidedDecodingMode.JSON:
return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern,
reasoner)
elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
return RegexLogitsProcessor(guide, tokenizer, reasoner)
elif mode == GuidedDecodingMode.GRAMMAR:
return CFGLogitsProcessor(guide, tokenizer, reasoner)
else:
raise ValueError(f"Unknown guided decoding mode {mode}")

View File

@ -1,168 +1,124 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright 2024-present the Outlines developers
from __future__ import annotations
# Copyright 2024- the Outlines developers
# This file is adapted from
# https://github.com/outlines-dev/outlines/blob/main/outlines/serve/vllm.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import hashlib
import importlib.metadata
import json
from collections import defaultdict
from functools import lru_cache
from typing import Callable, Optional, Union
import os
from typing import Optional, Union
import numpy as np
import regex as re
import torch
from outlines import grammars
from outlines.caching import cache, disable_cache
from outlines.fsm.guide import (CFGGuide, CFGState, Generate, Guide,
RegexGuide, Write)
from outlines.fsm.parsing import PartialLark
from outlines_core.fsm.json_schema import build_regex_from_schema
from cachetools import LRUCache
from diskcache import Cache
from outlines_core import Guide, Index, Vocabulary
from outlines_core.json_schema import build_regex_from_schema
from outlines_core.kernels.torch import (_apply_token_bitmask_inplace_kernel,
allocate_token_bitmask)
from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase
from transformers.file_utils import SPIECE_UNDERLINE
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.reasoning import ReasoningParser
from vllm.transformers_utils.tokenizer import AnyTokenizer
logger = init_logger(__name__)
if envs.VLLM_V0_USE_OUTLINES_CACHE:
logger.warning("Enabling outlines cache. This is an unbounded on-disk "
"cache. It may consume a lot of disk space and should "
"not be used with untrusted clients.")
else:
disable_cache()
CACHE = None
class BaseLogitsProcessor:
def __init__(self, guide: Guide, reasoner: Optional[ReasoningParser]):
def __init__(self, guide: Guide, eos_token_id: int,
reasoner: Optional[ReasoningParser]) -> None:
self._guide: Guide = guide
self._eos_token_id: int = eos_token_id
self._reasoner: Optional[ReasoningParser] = reasoner
# CFGState is used for the FSM state for CFGGuide
self._fsm_state: defaultdict[int, Union[int,
CFGState]] = defaultdict(int)
def clone(self) -> "BaseLogitsProcessor":
cloned = copy.copy(self)
cloned._guide = self._guide.copy()
cloned._fsm_state = copy.deepcopy(self._fsm_state)
return cloned
self._mask: Optional[torch.Tensor] = None
def __call__(self, input_ids: list[int],
scores: torch.Tensor) -> torch.Tensor:
"""Use the FSM to bias the logits before sampling the next token."""
if self._mask is None:
self._mask = allocate_token_bitmask(scores.size(-1))
# Skip the structured logits processing if reasoning is not finished.
# reasoner is not None only when `--reasoning-parser` is set.
if self._reasoner is not None:
if not self._reasoner.is_reasoning_end(input_ids):
return scores
else:
# Remove the reasoning tokens from the input_ids
# We need this because our implementation relies on the
# hash of the input_ids to store the FSM state.
input_ids = self._reasoner.extract_content_ids(input_ids)
if self._reasoner is not None and not self._reasoner.is_reasoning_end(
input_ids):
return scores
seq_id = hash(tuple(input_ids))
# Remove the reasoning tokens from the input_ids
# We need this because our implementation relies on the
# input_ids sequence to store the FSM state.
input_ids = (self._reasoner.extract_content_ids(input_ids)
if self._reasoner is not None else input_ids)
if len(input_ids) > 0:
last_token = input_ids[-1]
last_seq_id = hash(tuple(input_ids[:-1]))
self._fsm_state[seq_id] = self._guide.get_next_state(
state=self._fsm_state[last_seq_id], token_id=last_token)
else:
# Note: this is a hack.
# Lark pickling does not work properly (silent failure),
# which breaks the RPC (which uses python pickleing).
# We need to find a better solution.
# On the first time this is called, we simply re-create
# the Lark object.
if isinstance(self._guide, CFGGuide):
self._guide.parser = PartialLark(
self._guide.cfg_string,
parser="lalr",
import_paths=[grammars.GRAMMAR_PATH],
)
self._fsm_state[seq_id] = CFGState(
parser_state=self._guide.parser.parse(""), prev_token=None)
# Vllm V0 engine has a weird bug where we have to repeat
# the eos token id twice for generation to stop, or at least
# that is what we have to do from here in any case.
# This is a patch until a better solution can be pushed
# to outlines_core
if input_ids and input_ids[-1] != self._eos_token_id:
self._guide.advance(token_id=input_ids[-1], return_tokens=False)
instruction = self._guide.get_next_instruction(
state=self._fsm_state[seq_id])
self._guide.write_mask_into(
data_ptr=self._mask.data_ptr(),
numel=self._mask.numel(),
element_size=self._mask.element_size(),
)
if type(instruction) == Generate: # noqa: E721
allowed_tokens = instruction.tokens
elif type(instruction) == Write: # noqa: E721
# TODO: support fast forward tokens
allowed_tokens = [instruction.tokens[0]]
else:
raise TypeError(
f"Unsupported instruction type {type(instruction)}")
# Any allowed tokens beyond the length of the scores will
# be ignored by the kernel, taking care of the issue with
# models such as Llama 3.2 Vision with an `<|image|>` token
# with id 128256, but scores.shape == torch.Size([128256])
_apply_token_bitmask_inplace_kernel(
logits=scores.unsqueeze(dim=0),
# mask must be on same device
mask=self._mask.to(scores.device, non_blocking=True))
self._mask.to("cpu", non_blocking=True)
mask = torch.full((scores.shape[-1], ),
-torch.inf,
device=scores.device)
# The tokenizer may support more token ids than the model can generate,
# eg. Llama 3.2 Vision models have an `<|image|>` token with id 128256
# but scores.shape == torch.Size([128256])
# Using NumPy is faster for filtering token ids
allowed_tokens = np.array(allowed_tokens, dtype=np.int64)
allowed_tokens = torch.tensor(allowed_tokens, device=scores.device)
allowed_tokens = allowed_tokens.masked_select(
allowed_tokens < scores.shape[-1])
mask.index_fill_(0, allowed_tokens, 0)
if current_platform.is_hpu():
# Workaround for HPU bug where add_() raise RuntimeError:
# synNodeCreateWithId failed for node: strided_insert
# with synStatus 1 [Invalid argument], hopefully it will
# be fixed in the future releases of the HPU runtime.
scores = scores.add(mask)
else:
scores.add_(mask)
return scores
def clone(self) -> BaseLogitsProcessor:
guide = copy.deepcopy(self._guide)
guide.reset()
return BaseLogitsProcessor(guide=guide,
eos_token_id=self._eos_token_id,
reasoner=self._reasoner)
class RegexLogitsProcessor(BaseLogitsProcessor):
@classmethod
@cache()
def _get_guide(cls, regex_string: str,
tokenizer: PreTrainedTokenizerBase) -> Guide:
tokenizer = _adapt_tokenizer(tokenizer)
return RegexGuide.from_regex(regex_string, tokenizer)
global CACHE
if CACHE is None:
CACHE = get_cache()
vocabulary = get_vocabulary(tokenizer) # type: ignore[arg-type]
cache_key = f"{vocabulary._hash}_{regex_string}"
if CACHE is not None and cache_key in CACHE:
return Guide(CACHE[cache_key])
def __init__(
self,
regex_string: str,
tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[ReasoningParser],
):
"""Compile the FSM that drives the regex-structured generation.
index = Index(regex_string, vocabulary.inner)
Parameters
----------
regex_string
A string that represents a regular expression
tokenizer
The model's tokenizer
if CACHE is not None:
CACHE[cache_key] = index
"""
return Guide(index)
def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[ReasoningParser]) -> None:
super().__init__(
RegexLogitsProcessor._get_guide(regex_string, tokenizer), reasoner)
guide=RegexLogitsProcessor._get_guide(regex_string, tokenizer),
eos_token_id=tokenizer.eos_token_id, # type: ignore
reasoner=reasoner)
class JSONLogitsProcessor(RegexLogitsProcessor):
@ -170,22 +126,8 @@ class JSONLogitsProcessor(RegexLogitsProcessor):
def __init__(self, schema: Union[str, dict, BaseModel],
tokenizer: PreTrainedTokenizerBase,
whitespace_pattern: Union[str, None],
reasoner: Optional[ReasoningParser]):
"""Compile the FSM that drives the JSON-guided generation.
reasoner: Optional[ReasoningParser]) -> None:
Parameters
----------
schema
A JSON schema that encodes the structure we want the model to
generate
tokenizer
The model's tokenizer
whitespace_pattern
Pattern to use for JSON syntactic whitespace (doesn't impact
string literals)
Example: allow only a single space or newline with
`whitespace_pattern=r"[\n ]?"`
"""
if isinstance(schema, type(BaseModel)):
schema_str = json.dumps(schema.model_json_schema())
elif isinstance(schema, dict):
@ -197,63 +139,42 @@ class JSONLogitsProcessor(RegexLogitsProcessor):
f"Cannot parse schema {schema}. The schema must be either "
f"a Pydantic object, a dictionary or a string that contains "
f"the JSON Schema specification")
regex_string = build_regex_from_schema(schema_str, whitespace_pattern)
super().__init__(regex_string, tokenizer, reasoner)
class CFGLogitsProcessor(BaseLogitsProcessor):
@classmethod
@cache()
def _get_guide(cls, cfg: str, tokenizer: PreTrainedTokenizerBase) -> Guide:
tokenizer = _adapt_tokenizer(tokenizer)
return CFGGuide(cfg, tokenizer)
def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[ReasoningParser]):
"""Compile the FSM that drives the context free grammar generation.
Parameters
----------
cfg
A string that represents a context-free grammar
tokenizer
The model's tokenizer
"""
super().__init__(CFGLogitsProcessor._get_guide(cfg, tokenizer),
reasoner)
self._guide = self._guide.copy()
def clone(self) -> "CFGLogitsProcessor":
cloned = copy.copy(self)
cloned._fsm_state = copy.deepcopy(self._fsm_state)
cloned._guide = self._guide.copy()
return cloned
@lru_cache(maxsize=32)
def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
"""Adapt vLLM's tokenizer to use to compile the FSM.
The API of Outlines tokenizers is slightly different to that of
`transformers`. The decoder of outlines, returns a list whereas
the decode of vLLM returns an str. To sync the vLLM decoder with
outlines internal api, the decoder should be adapted. In addition
we need to handle the missing spaces to Llama's tokenizer to be
able to compile FSMs for this model.
class OutlinesVocabulary:
"""
Wrapper class for `outlines_core.Vocabulary`,
which allows us to store a hash with the vocabulary
"""
if getattr(tokenizer, "_outlines_adapted", False):
return tokenizer
tokenizer = copy.deepcopy(tokenizer)
def __init__(self, vocabulary: Vocabulary) -> None:
# Actual vocabulary object
self.inner = vocabulary
# Have to do abs(hash()) because python hashes can
# be negative, and we are using hash as a cache key.
hex_str = hashlib.sha256(
vocabulary.__repr__().encode('utf-8')).hexdigest()
hash_int = int(hex_str, 16)
self._hash = hash_int
tokenizer.vocabulary = tokenizer.get_vocab()
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
re_llama_byte_token = re.compile(r"^<0x[0-9A-F]{2}>$")
re_replacement_seq = re.compile(r"^.{0,6}<7D>+.{0,6}$")
def _reduced_vocabulary(tokenizer: AnyTokenizer,
eos_token_id: int) -> dict[bytes, list[int]]:
"""Create a map from vocabulary tokens to lists of equivalent token ids.
Returns:
A Dict of token string -> equivalent token ids
"""
unicode_to_bytes = {v: k for k, v in bytes_to_unicode().items()}
def convert_token_to_string(token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE
string = tokenizer.convert_tokens_to_string([token])
@ -264,21 +185,123 @@ def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
return string
def change_decoder(
decoder: Callable[[list[int]],
str]) -> Callable[[list[int]], list[str]]:
"""Sync vLLM's decoder with the outlines by returning list."""
vocabulary: dict[bytes, list[int]] = {}
empty_token_ids: list[int] = []
for token, token_idx in tokenizer.get_vocab().items():
if token in tokenizer.all_special_tokens: # type: ignore
continue
def new_decoder(inp_tokens: list[int]) -> list[str]:
if (isinstance(inp_tokens, list) and len(inp_tokens) == 1
and isinstance(inp_tokens[0], list)):
inp_tokens = inp_tokens[0]
return [decoder(inp_tokens)]
token_str = convert_token_to_string(token)
if token_str:
if isinstance(token, (bytes, bytearray)):
# For BPE tokenizers where tokens are stored as bytes.
return new_decoder
# safe to ignore since token_str is of type (bytearray, bytes)
# by this point.
token_bytes = bytes(token_str) # type: ignore[arg-type]
tokenizer.convert_token_to_string = convert_token_to_string
tokenizer.decode = change_decoder(tokenizer.decode)
setattr(tokenizer, "_outlines_adapted", True) # noqa: B010
elif "\ufffd" in token_str and not re_replacement_seq.match(
token_str):
# Handle tokens with invalid UTF-8 sequences.
if re_llama_byte_token.match(token):
# Llama-like tokenizers use <0xXX> for incomplete sequences.
token_bytes = bytes([int(token[3:5], 16)])
else:
# GPT2 tokenizers: map each byte back using unicode_to_bytes
byte_vals = [unicode_to_bytes.get(c) for c in token]
if None in byte_vals:
raise RuntimeError(
f"Cannot convert token `{token}`"
f" ({token_idx}) to bytes: {token_str}")
# safe to ignore, since if None in byte_vals,
# an error is thrown.
token_bytes = bytes(byte_vals) # type: ignore[arg-type]
else:
token_bytes = token_str.encode('utf-8')
return tokenizer
if token_idx != eos_token_id:
vocabulary.setdefault(token_bytes, []).append(token_idx)
else:
empty_token_ids.append(token_idx)
return vocabulary
def get_vocabulary(tokenizer: AnyTokenizer) -> Vocabulary:
"""Get the `Vocabulary` object for a given tokenizer.
"""
if hasattr(tokenizer, "_outlines_vocabulary"):
return tokenizer._outlines_vocabulary # type: ignore
try:
if hasattr(
tokenizer,
"eos_token_id",
) and tokenizer.eos_token_id is not None:
eos_token_id = tokenizer.eos_token_id
else:
raise ValueError(
f"Error during guided decoding setup: Tokenizer"
f" ({type(tokenizer)}) has no `eos_token_id` property, "
"but `eos_token_id` is required for guided decoding"
" to work properly.")
reduced_vocab = _reduced_vocabulary(
tokenizer,
eos_token_id #type: ignore
)
vocabulary = OutlinesVocabulary(Vocabulary(eos_token_id,
reduced_vocab))
tokenizer._outlines_vocabulary = vocabulary # type: ignore
return vocabulary
except AttributeError as e:
raise ValueError(f"Cannot get the vocabulary of the tokenizer "
f"({type(tokenizer)}). The tokenizer should have a "
"get_vocab method.") from e
def get_cache_path() -> str:
"""Get the context object that contains previously-computed return values"""
outlines_cache_dir = os.getenv("OUTLINES_CACHE_DIR")
xdg_cache_home = os.getenv("XDG_CACHE_HOME")
home_dir = os.path.expanduser("~")
if outlines_cache_dir:
# OUTLINES_CACHE_DIR takes precedence
return outlines_cache_dir
elif xdg_cache_home:
return os.path.join(xdg_cache_home, ".cache", "outlines")
# If homedir is "/", we may be inside a container, and thus writing to
# root would be problematic, so we fallback to using a tempfile.
# Also validate the path exists, since os.path.expanduser does
# not garuntee existence.
elif os.path.isdir(home_dir) and home_dir != "/":
# Default Unix fallback: ~/.cache/outlines
return os.path.join(home_dir, ".cache", "outlines")
else:
import tempfile
# home_dir may be / inside a docker container without existing user
tempdir = tempfile.gettempdir()
return os.path.join(tempdir, ".cache", "outlines")
def get_cache():
"""Get the Cache instance to be used for index caching"""
cache_dir = get_cache_path()
if envs.VLLM_V0_USE_OUTLINES_CACHE:
logger.warning("Enabling outlines cache. This is an unbounded on-disk "
"cache. It may consume a lot of disk space and should "
"not be used with untrusted clients.")
cache = Cache(cache_dir, eviction_policy="none", cull_limit=0)
outlines_version = importlib.metadata.version("outlines_core")
cached_version = cache.get('__version__', None)
if cached_version != outlines_version:
cache.clear()
cache.set('__version__', outlines_version)
return cache
else:
return LRUCache(maxsize=128)

View File

@ -23,6 +23,8 @@ from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
from vllm.v1.structured_output.backend_guidance import (
validate_guidance_grammar)
from vllm.v1.structured_output.backend_outlines import (
validate_structured_output_request_outlines)
from vllm.v1.structured_output.backend_xgrammar import (
validate_xgrammar_grammar)
@ -193,6 +195,9 @@ class Processor:
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
# Without tokenizer these are disallowed in grammars.
validate_guidance_grammar(params, tokenizer=None)
elif engine_level_backend == "outlines":
# outlines backend
validate_structured_output_request_outlines(params)
else:
# NOTE: engine_level_backend must be "auto" here, because we have
# checked supported_backends above.

View File

@ -88,6 +88,15 @@ class StructuredOutputManager:
tokenizer=self.tokenizer,
vocab_size=vocab_size,
)
elif backend == "outlines":
from vllm.v1.structured_output.backend_outlines import (
OutlinesBackend)
self.backend = OutlinesBackend(
self.vllm_config,
tokenizer=self.tokenizer,
vocab_size=vocab_size,
)
else:
raise ValueError(
f"Unsupported structured output backend: {backend}")

View File

@ -0,0 +1,319 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright 2025-present the Outlines developers
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import ast
import importlib
import json
import sys
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
import torch
from regex import escape as regex_escape
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
OutlinesVocabulary, get_cache, get_vocabulary)
from vllm.sampling_params import SamplingParams
from vllm.utils import LazyLoader
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
StructuredOutputGrammar,
StructuredOutputOptions)
if TYPE_CHECKING:
import outlines_core as oc
import outlines_core.json_schema as json_schema
else:
oc = LazyLoader("oc", globals(), "outlines_core")
json_schema = LazyLoader("json_schema", globals(),
"outlines_core.json_schema")
# Python 3.11+ sre_parse and sre_constants
# are deprecated, so we must import them from re
if sys.version_info >= (3, 11):
# Hack to get around pre-commit regex module rule
# because going through re is the only way to get sre_parse
# and sre_constants in Python 3.11+
_re = importlib.import_module("re")
sre_parse = _re._parser
sre_constants = _re._constants
else:
import sre_constants
import sre_parse
@dataclass
class OutlinesBackend(StructuredOutputBackend):
def __post_init__(self):
self.vocabulary = get_vocabulary(self.tokenizer)
self.cache = get_cache()
def _compile_index(self, regex_string: str,
vocabulary: OutlinesVocabulary) -> oc.Index:
cache_key = f"{vocabulary._hash}_{regex_string}"
if cache_key in self.cache:
return self.cache[cache_key]
index = oc.Index(regex_string, vocabulary.inner)
self.cache[cache_key] = index
return index
def compile_grammar(self, request_type: StructuredOutputOptions,
grammar_spec: str) -> StructuredOutputGrammar:
if request_type == StructuredOutputOptions.JSON:
regex = json_schema.build_regex_from_schema(grammar_spec)
elif request_type == StructuredOutputOptions.REGEX:
regex = grammar_spec
elif request_type == StructuredOutputOptions.CHOICE:
choices = ast.literal_eval(grammar_spec)
choices = [regex_escape(c) for c in choices]
regex = "(" + "|".join(choices) + ")"
else:
raise ValueError(
f"Invalid request type for Outlines backend ({request_type!s})"
)
index = self._compile_index(regex, self.vocabulary)
max_rollback_tokens = (
self.vllm_config.speculative_config.num_speculative_tokens
if self.vllm_config.speculative_config is not None else 0)
return OutlinesGrammar(vocab_size=self.vocab_size,
guide=oc.Guide(
index, max_rollback=max_rollback_tokens))
def allocate_token_bitmask(self, max_num_seqs: int) -> torch.Tensor:
return torch.full(
(max_num_seqs, (self.vocab_size + 31) // 32),
-1,
dtype=torch.int32,
pin_memory=torch.cuda.is_available(),
)
def destroy(self):
pass
@dataclass
class OutlinesGrammar(StructuredOutputGrammar):
vocab_size: int
guide: oc.Guide = field(hash=False)
num_processed_tokens: int = field(default_factory=lambda: 0,
repr=False,
hash=False,
init=False)
# outlines_core signals done on DFA accept; vLLM expects done after EOS.
# We delay the finished flag by one step so EOS can still be emitted.
_prev_finished: bool = field(default=False,
init=False,
repr=False,
hash=False)
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
"""Accepts a list of tokens and advances the FSM.
Returns True if the FSM was advanced successfully.
Returns False if the FSM failed to advance.
"""
if self.guide.accepts_tokens(tokens):
# Advance cannot fail because we checked Guide.accepts_tokens()
for t in tokens:
self.guide.advance(t)
self.num_processed_tokens += 1
return True
return False
def rollback(self, num_tokens: int) -> None:
self.guide.rollback_state(num_tokens)
self.num_processed_tokens -= num_tokens
def validate_tokens(self, tokens: list[int]) -> list[int]:
accepted: list[int] = []
for tok in tokens:
accepted.append(tok)
if not self.guide.accepts_tokens(accepted):
accepted.pop()
break
return accepted
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
mask = bitmask[idx]
self.guide.write_mask_into(mask.data_ptr(), mask.numel(),
mask.element_size())
def is_terminated(self) -> bool:
curr = self.guide.is_finished()
prev = self._prev_finished
self._prev_finished = curr
return prev
def reset(self):
self.num_processed_tokens = 0
self._prev_finished = False
self.guide.reset()
def validate_structured_output_request_outlines(params: SamplingParams):
if params.guided_decoding is None:
return
gd_params = params.guided_decoding
if gd_params.regex:
validate_regex_is_buildable(gd_params.regex)
elif gd_params.json:
if isinstance(gd_params.json, str):
try:
# make sure schema is valid json
json.loads(gd_params.json)
schema = gd_params.json
except json.JSONDecodeError as e:
raise ValueError("Invalid JSON grammar specification.") from e
else:
try:
schema = json.dumps(gd_params.json)
except Exception as e:
raise ValueError(
f"Error serializing guided decoding jsonschema: {e}"
) from e
pattern = json_schema.build_regex_from_schema(schema)
validate_regex_is_buildable(pattern)
elif gd_params.choice:
choices = [regex_escape(str(choice)) for choice in gd_params.choice]
regex = "(" + "|".join(choices) + ")"
validate_regex_is_buildable(regex)
elif gd_params.grammar:
raise ValueError("Outlines guided decoding backend "
"does not support grammar specifications")
def _prefix_needs_context(parsed) -> bool:
"""Return True if there's a look-around/anchor before any consumer."""
def subpattern_consumes(parsed) -> bool:
"""Return True if subpattern can consume at least one character."""
tokens = parsed.data if hasattr(parsed, 'data') else parsed
for ttype, tval in tokens:
# literal, character class, or dot always consumes
if ttype in (sre_parse.LITERAL, sre_parse.IN, sre_parse.ANY):
return True
# quantified subpattern: check inner pattern
elif ttype == sre_parse.MAX_REPEAT:
_, mx, sub = tval
if mx != 0 and subpattern_consumes(sub):
return True
# alternation: if any branch consumes, the whole does
elif ttype == sre_parse.BRANCH:
_, branches = tval
if any(subpattern_consumes(br) for br in branches):
return True
# grouped subpattern: recurse into its contents
elif ttype == sre_parse.SUBPATTERN and subpattern_consumes(
tval[3]):
return True
# No consumers, return False
return False
tokens = parsed.data if hasattr(parsed, 'data') else parsed
for ttype, tval in tokens:
# Direct anchors or look-around
if ttype == sre_parse.AT or ttype in (sre_constants.ASSERT,
sre_constants.ASSERT_NOT):
return True
# Nested subpattern: check
if ttype == sre_parse.SUBPATTERN:
# tval: (group, add_flags, del_flags, subpattern)
if _prefix_needs_context(tval[3]):
return True
if subpattern_consumes(tval[3]):
return False
# if any branch has a prefix anchor => True,
# else if at least one branch consumes => prefix ends => False
elif ttype == sre_parse.BRANCH:
saw_consumer = False
for br in tval[1]:
if _prefix_needs_context(br):
return True
if subpattern_consumes(br):
saw_consumer = True
if saw_consumer:
return False
# Immediate consumer tokens
elif ttype in (sre_parse.LITERAL, sre_parse.IN, sre_parse.ANY):
return False
# if subpattern has anchor => True, if it can consume => stop
elif ttype == sre_parse.MAX_REPEAT:
if _prefix_needs_context(tval[2]):
return True
if subpattern_consumes(tval[2]):
return False
return False
def _check_unsupported(parsed) -> None:
"""Check for regex features unsupported by regex-automata"""
tokens = parsed.data if hasattr(parsed, 'data') else parsed
for ttype, tval in tokens:
# backreference
if ttype in (sre_parse.GROUPREF, sre_parse.GROUPREF_EXISTS):
raise ValueError("Backreferences are unsupported.")
# look-around assertion
elif ttype in (sre_constants.ASSERT, sre_constants.ASSERT_NOT):
raise ValueError("Look-Around assertion are unsupported.")
# unicode word boundaries
elif ttype == sre_parse.AT:
if tval in (sre_constants.AT_BOUNDARY,
sre_constants.AT_NON_BOUNDARY):
raise ValueError("Unicode word boundaries are unsupported.")
elif ttype == sre_parse.BRANCH:
# tval is (None, branches)
for branch in tval[1]:
_check_unsupported(branch)
# tval is (min, max, subpattern)
elif ttype == sre_parse.MAX_REPEAT:
_check_unsupported(tval[2])
def validate_regex_is_buildable(pattern: str) -> None:
"""
Validates that the input regex is not using unsupported features
of the `regex-automata` crate (outlines_core regex engine) and has a
universal start state.
definition of universal start state used can be found at:
https://docs.rs/regex-automata/latest/regex_automata/dfa/trait.Automaton.html#method.universal_start_state
"""
try:
parsed = sre_parse.parse(pattern)
except sre_constants.error as e:
raise ValueError(f"Error parsing regex: {e}") from e
try:
_check_unsupported(parsed)
except ValueError as e:
raise ValueError(
f"Regex uses unsupported feature for guided decoding: {e}. "
"Only basic matching constructs are supported—lookarounds, "
"backreferences, and unicode boundaries are not.") from e
if _prefix_needs_context(parsed):
raise ValueError(
"Regex does not have a anchored universal start state"
"This means that the Regex uses anchors (^) or look-arounds "
"in a way which requires context before any token is matched."
"Guided decoding needs regexes that can match without needing "
"that context. Try rewriting the pattern without using these "
f"constructs. Pattern:\n{pattern}")