mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-24 00:14:34 +08:00
[V0][V1][Core] Add outlines integration for V1, and update V0 integration. (#15975)
Signed-off-by: Nathan Hoos <thwackyy.y@gmail.com>
This commit is contained in:
parent
5e53c89a74
commit
d6902ce79f
@ -21,7 +21,9 @@ prometheus-fastapi-instrumentator >= 7.0.0
|
||||
tiktoken >= 0.6.0 # Required for DBRX tokenizer
|
||||
lm-format-enforcer >= 0.10.11, < 0.11
|
||||
llguidance >= 0.7.11, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64"
|
||||
outlines == 0.1.11
|
||||
outlines_core == 0.2.10
|
||||
# required for outlines backend disk cache
|
||||
diskcache == 5.6.3
|
||||
lark == 1.2.2
|
||||
xgrammar == 0.1.19; platform_machine == "x86_64" or platform_machine == "aarch64" or platform_machine == "arm64"
|
||||
typing_extensions >= 4.10
|
||||
|
||||
@ -16,14 +16,18 @@ from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
GUIDED_DECODING_BACKENDS = [
|
||||
|
||||
# Separate backends which support grammars vs ones
|
||||
# which only support regex based constraints in tests.
|
||||
GRAMMAR_DECODING_BACKENDS = [
|
||||
# (backend, disable_any_whitespace),
|
||||
("outlines", False),
|
||||
("lm-format-enforcer", False),
|
||||
("xgrammar", True),
|
||||
("guidance", True),
|
||||
]
|
||||
|
||||
ALL_DECODING_BACKENDS = ([("outlines", False)] + GRAMMAR_DECODING_BACKENDS)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def llm():
|
||||
@ -39,7 +43,7 @@ def llm():
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
GUIDED_DECODING_BACKENDS)
|
||||
ALL_DECODING_BACKENDS)
|
||||
def test_guided_regex(sample_regex, llm, guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
sampling_params = SamplingParams(
|
||||
@ -49,6 +53,7 @@ def test_guided_regex(sample_regex, llm, guided_decoding_backend: str,
|
||||
regex=sample_regex,
|
||||
backend=guided_decoding_backend,
|
||||
disable_any_whitespace=disable_any_whitespace))
|
||||
|
||||
outputs = llm.generate(prompts=[
|
||||
f"Give an example IPv4 address with this regex: {sample_regex}"
|
||||
] * 2,
|
||||
@ -69,7 +74,7 @@ def test_guided_regex(sample_regex, llm, guided_decoding_backend: str,
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
GUIDED_DECODING_BACKENDS)
|
||||
ALL_DECODING_BACKENDS)
|
||||
def test_guided_json_completion(sample_json_schema, llm,
|
||||
guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
@ -103,7 +108,7 @@ def test_guided_json_completion(sample_json_schema, llm,
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
GUIDED_DECODING_BACKENDS)
|
||||
ALL_DECODING_BACKENDS)
|
||||
def test_guided_complex_json_completion(sample_complex_json_schema, llm,
|
||||
guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
@ -138,7 +143,7 @@ def test_guided_complex_json_completion(sample_complex_json_schema, llm,
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
GUIDED_DECODING_BACKENDS)
|
||||
ALL_DECODING_BACKENDS)
|
||||
def test_guided_definition_json_completion(sample_definition_json_schema, llm,
|
||||
guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
@ -173,7 +178,7 @@ def test_guided_definition_json_completion(sample_definition_json_schema, llm,
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
GUIDED_DECODING_BACKENDS)
|
||||
ALL_DECODING_BACKENDS)
|
||||
def test_guided_enum_json_completion(sample_enum_json_schema, llm,
|
||||
guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
@ -218,7 +223,7 @@ def test_guided_enum_json_completion(sample_enum_json_schema, llm,
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
GUIDED_DECODING_BACKENDS)
|
||||
ALL_DECODING_BACKENDS)
|
||||
def test_guided_choice_completion(sample_guided_choice, llm,
|
||||
guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
@ -248,7 +253,7 @@ def test_guided_choice_completion(sample_guided_choice, llm,
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
GUIDED_DECODING_BACKENDS)
|
||||
GRAMMAR_DECODING_BACKENDS)
|
||||
def test_guided_grammar(sample_sql_statements, llm,
|
||||
guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
@ -344,7 +349,7 @@ def test_disable_guided_decoding_fallback(sample_regex, llm):
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
GUIDED_DECODING_BACKENDS)
|
||||
GRAMMAR_DECODING_BACKENDS)
|
||||
def test_guided_json_object(llm, guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
sampling_params = SamplingParams(
|
||||
@ -377,7 +382,9 @@ def test_guided_json_object(llm, guided_decoding_backend: str,
|
||||
|
||||
# Parse to verify it is valid JSON
|
||||
parsed_json = json.loads(generated_text)
|
||||
assert isinstance(parsed_json, dict)
|
||||
# A list is not what was intended, but is still valid
|
||||
# json.
|
||||
assert isinstance(parsed_json, (dict, list))
|
||||
|
||||
|
||||
class CarType(str, Enum):
|
||||
@ -395,7 +402,7 @@ class CarDescription(BaseModel):
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
GUIDED_DECODING_BACKENDS)
|
||||
ALL_DECODING_BACKENDS)
|
||||
def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
json_schema = CarDescription.model_json_schema()
|
||||
@ -427,7 +434,7 @@ def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str,
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
GUIDED_DECODING_BACKENDS)
|
||||
ALL_DECODING_BACKENDS)
|
||||
def test_guided_number_range_json_completion(llm, guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
sample_output_schema = {
|
||||
|
||||
@ -46,20 +46,15 @@ def test_guided_logits_processors(zephyr_7B_tokenzer, sample_regex,
|
||||
whitespace_pattern=None,
|
||||
reasoner=None)
|
||||
|
||||
token_ids = zephyr_7B_tokenzer.encode(
|
||||
f"Give an example IPv4 address with this regex: {sample_regex}")
|
||||
tensor = torch.rand(32000)
|
||||
original_tensor = torch.clone(tensor)
|
||||
regex_LP(token_ids, tensor)
|
||||
tensor = regex_LP([], tensor)
|
||||
assert tensor.shape == original_tensor.shape
|
||||
assert not torch.allclose(tensor, original_tensor)
|
||||
|
||||
token_ids = zephyr_7B_tokenzer.encode(
|
||||
f"Give an employee profile that fits this schema: {sample_json_schema}"
|
||||
)
|
||||
tensor = torch.rand(32000)
|
||||
original_tensor = torch.clone(tensor)
|
||||
json_LP(token_ids, tensor)
|
||||
tensor = json_LP([], tensor)
|
||||
assert tensor.shape == original_tensor.shape
|
||||
assert not torch.allclose(tensor, original_tensor)
|
||||
|
||||
@ -81,8 +76,6 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
|
||||
seed=0,
|
||||
dtype="bfloat16",
|
||||
)
|
||||
token_ids = zephyr_7B_tokenzer.encode(
|
||||
f"Give an example IPv4 address with this regex: {sample_regex}")
|
||||
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
|
||||
|
||||
regex_lp = get_local_guided_decoding_logits_processor(
|
||||
@ -92,13 +85,11 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
|
||||
assert regex_lp is not None
|
||||
tensor = torch.rand(32000)
|
||||
original_tensor = torch.clone(tensor)
|
||||
tensor = regex_lp(token_ids, tensor)
|
||||
# allowed tokens at state 0
|
||||
tensor = regex_lp([], tensor)
|
||||
assert tensor.shape == original_tensor.shape
|
||||
assert not torch.allclose(tensor, original_tensor)
|
||||
|
||||
token_ids = zephyr_7B_tokenzer.encode(
|
||||
f"Give an employee profile that fits this schema: {sample_json_schema}"
|
||||
)
|
||||
json_request = GuidedDecodingParams(json=sample_json_schema,
|
||||
backend=backend)
|
||||
json_lp = await get_guided_decoding_logits_processor(
|
||||
@ -106,7 +97,7 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
|
||||
assert json_lp is not None
|
||||
tensor = torch.rand(32000)
|
||||
original_tensor = torch.clone(tensor)
|
||||
tensor = json_lp(token_ids, tensor)
|
||||
tensor = json_lp([], tensor)
|
||||
assert tensor.shape == original_tensor.shape
|
||||
assert not torch.allclose(tensor, original_tensor)
|
||||
|
||||
@ -130,7 +121,6 @@ async def test_guided_logits_processor_with_reasoning(
|
||||
dtype="bfloat16",
|
||||
)
|
||||
token_ids = deepseek_r1_qwen_tokenizer.encode(
|
||||
f"Give an example IPv4 address with this regex: {sample_regex}."
|
||||
"<think>here is the thinking process")
|
||||
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
|
||||
|
||||
@ -141,14 +131,13 @@ async def test_guided_logits_processor_with_reasoning(
|
||||
regex_request, deepseek_r1_qwen_tokenizer, config,
|
||||
reasoning_backend)
|
||||
assert regex_lp is not None
|
||||
tensor = torch.rand(32000)
|
||||
tensor = torch.rand(151664)
|
||||
original_tensor = torch.clone(tensor)
|
||||
tensor = regex_lp(token_ids, tensor)
|
||||
assert tensor.shape == original_tensor.shape
|
||||
assert torch.allclose(tensor, original_tensor)
|
||||
|
||||
token_ids = deepseek_r1_qwen_tokenizer.encode(
|
||||
f"Give an employee profile that fits this schema: {sample_json_schema}."
|
||||
"<think>here is the thinking process")
|
||||
json_request = GuidedDecodingParams(json=sample_json_schema,
|
||||
backend=backend)
|
||||
@ -158,7 +147,7 @@ async def test_guided_logits_processor_with_reasoning(
|
||||
await get_guided_decoding_logits_processor(
|
||||
json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend)
|
||||
assert json_lp is not None
|
||||
tensor = torch.rand(32000)
|
||||
tensor = torch.rand(151664)
|
||||
original_tensor = torch.clone(tensor)
|
||||
tensor = json_lp(token_ids, tensor)
|
||||
assert tensor.shape == original_tensor.shape
|
||||
@ -166,8 +155,7 @@ async def test_guided_logits_processor_with_reasoning(
|
||||
|
||||
# Thinking is over, so the tensor should change.
|
||||
token_ids = deepseek_r1_qwen_tokenizer.encode(
|
||||
f"Give an employee profile that fits this schema: {sample_json_schema}."
|
||||
"<think>here is the thinking process</think> Then")
|
||||
"<think>here is the thinking process</think>")
|
||||
json_request = GuidedDecodingParams(json=sample_json_schema,
|
||||
backend=backend)
|
||||
json_lp = get_local_guided_decoding_logits_processor(
|
||||
@ -176,7 +164,7 @@ async def test_guided_logits_processor_with_reasoning(
|
||||
await get_guided_decoding_logits_processor(
|
||||
json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend)
|
||||
assert json_lp is not None
|
||||
tensor = torch.rand(32000)
|
||||
tensor = torch.rand(151664)
|
||||
original_tensor = torch.clone(tensor)
|
||||
tensor = json_lp(token_ids, tensor)
|
||||
assert tensor.shape == original_tensor.shape
|
||||
|
||||
@ -72,7 +72,7 @@ def _compile_and_check(tools: list[ChatCompletionToolsParam], sample_output,
|
||||
assert isinstance(schema, dict)
|
||||
|
||||
# use build_regex_from_schema used in JSONLogitsProcessor to create Guide
|
||||
from outlines_core.fsm.json_schema import build_regex_from_schema
|
||||
from outlines_core.json_schema import build_regex_from_schema
|
||||
regex = build_regex_from_schema(json.dumps(schema))
|
||||
compiled = re.compile(regex)
|
||||
matches = compiled.fullmatch(json.dumps(sample_output)) is not None
|
||||
|
||||
@ -41,6 +41,10 @@ PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [
|
||||
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None),
|
||||
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None),
|
||||
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None),
|
||||
("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", None),
|
||||
("mistralai/Ministral-8B-Instruct-2410", "outlines", "mistral", None),
|
||||
("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto",
|
||||
NGRAM_SPEC_CONFIG),
|
||||
#FIXME: This test is flaky on CI thus disabled
|
||||
#("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"),
|
||||
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto",
|
||||
@ -106,13 +110,15 @@ def test_structured_output(
|
||||
enforce_eager = bool(not current_platform.is_tpu())
|
||||
# Use a single LLM instance for several scenarios to
|
||||
# speed up the test suite.
|
||||
llm = LLM(model=model_name,
|
||||
enforce_eager=enforce_eager,
|
||||
max_model_len=1024,
|
||||
guided_decoding_backend=guided_decoding_backend,
|
||||
guided_decoding_disable_any_whitespace=True,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
speculative_config=speculative_config)
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
enforce_eager=enforce_eager,
|
||||
max_model_len=1024,
|
||||
guided_decoding_backend=guided_decoding_backend,
|
||||
guided_decoding_disable_any_whitespace=(guided_decoding_backend
|
||||
in {"xgrammar", "guidance"}),
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
speculative_config=speculative_config)
|
||||
|
||||
#
|
||||
# Test 1: Generate JSON output based on a provided schema
|
||||
@ -146,32 +152,33 @@ def test_structured_output(
|
||||
#
|
||||
# Test 2: Generate JSON object without a schema
|
||||
#
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
max_tokens=4096,
|
||||
n=2,
|
||||
guided_decoding=GuidedDecodingParams(json_object=True))
|
||||
if guided_decoding_backend != "outlines":
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
max_tokens=4096,
|
||||
n=2,
|
||||
guided_decoding=GuidedDecodingParams(json_object=True))
|
||||
|
||||
outputs = llm.generate(
|
||||
prompts=("Generate a JSON object with curly braces for a person with "
|
||||
"name and age fields for John Smith who is 31 years old. "
|
||||
"Make the response as short as possible."),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
outputs = llm.generate(prompts=(
|
||||
"Generate a JSON object with curly braces for a person with "
|
||||
"name and age fields for John Smith who is 31 years old. "
|
||||
"Make the response as short as possible."),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
|
||||
assert outputs is not None
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
assert outputs is not None
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
|
||||
for i in range(2):
|
||||
generated_text = output.outputs[i].text
|
||||
print(generated_text)
|
||||
assert generated_text is not None
|
||||
for i in range(2):
|
||||
generated_text = output.outputs[i].text
|
||||
print(generated_text)
|
||||
assert generated_text is not None
|
||||
|
||||
# Parse to verify it is a valid JSON object
|
||||
parsed_json = json.loads(generated_text)
|
||||
assert isinstance(parsed_json, dict)
|
||||
# Parse to verify it is a valid JSON object
|
||||
parsed_json = json.loads(generated_text)
|
||||
assert isinstance(parsed_json, dict)
|
||||
|
||||
#
|
||||
# Test 3: test a jsonschema incompatible with xgrammar
|
||||
@ -210,97 +217,98 @@ def test_structured_output(
|
||||
parsed_json = json.loads(generated_text)
|
||||
assert isinstance(parsed_json, dict)
|
||||
|
||||
#
|
||||
# Test 4: Generate SQL statement using EBNF grammar
|
||||
#
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(grammar=sample_sql_ebnf))
|
||||
outputs = llm.generate(
|
||||
prompts=(
|
||||
"Generate a sql statement that selects col_1 from "
|
||||
"table_1 where it is equal to 1. Make the response as short as "
|
||||
"possible."),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
)
|
||||
|
||||
assert outputs is not None
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
prompt = output.prompt
|
||||
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
|
||||
# remove spaces for comparison b/c we removed them in the grammar
|
||||
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
|
||||
" ", "")
|
||||
|
||||
assert generated_text.strip() == ground_truth
|
||||
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
#
|
||||
# Test 5: Generate SQL statement using Lark grammar
|
||||
#
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(grammar=sample_sql_lark))
|
||||
outputs = llm.generate(
|
||||
prompts=(
|
||||
"Generate a sql statement that selects col_1 from "
|
||||
"table_1 where it is equal to 1. Make the response as short as "
|
||||
"possible."),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
)
|
||||
|
||||
assert outputs is not None
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
prompt = output.prompt
|
||||
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
|
||||
# use Lark to parse the output, and make sure it's a valid parse tree
|
||||
from lark import Lark
|
||||
parser = Lark(sample_sql_lark)
|
||||
parser.parse(generated_text)
|
||||
|
||||
# remove spaces for comparison b/c we removed them in the grammar
|
||||
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
|
||||
" ", "")
|
||||
|
||||
assert generated_text.strip() == ground_truth
|
||||
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
#
|
||||
# Test 6: Test invalid grammar input
|
||||
#
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(grammar="not a grammar"))
|
||||
with pytest.raises(ValueError, match="Failed to convert the grammar "):
|
||||
llm.generate(
|
||||
if guided_decoding_backend != "outlines":
|
||||
#
|
||||
# Test 4: Generate SQL statement using EBNF grammar
|
||||
#
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(grammar=sample_sql_ebnf))
|
||||
outputs = llm.generate(
|
||||
prompts=(
|
||||
"Generate a sql statement that selects col_1 from "
|
||||
"table_1 where it is equal to 1. Make the response as short "
|
||||
"as possible."),
|
||||
"table_1 where it is equal to 1. Make the response as short as "
|
||||
"possible."),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
)
|
||||
|
||||
assert outputs is not None
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
prompt = output.prompt
|
||||
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
|
||||
# remove spaces for comparison b/c we removed them in the grammar
|
||||
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
|
||||
" ", "")
|
||||
|
||||
assert generated_text.strip() == ground_truth
|
||||
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
#
|
||||
# Test 5: Generate SQL statement using Lark grammar
|
||||
#
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(grammar=sample_sql_lark))
|
||||
outputs = llm.generate(
|
||||
prompts=(
|
||||
"Generate a sql statement that selects col_1 from "
|
||||
"table_1 where it is equal to 1. Make the response as short as "
|
||||
"possible."),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
)
|
||||
|
||||
assert outputs is not None
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
prompt = output.prompt
|
||||
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
|
||||
# use Lark to parse the output, and make sure it's a valid parse tree
|
||||
from lark import Lark
|
||||
parser = Lark(sample_sql_lark)
|
||||
parser.parse(generated_text)
|
||||
|
||||
# remove spaces for comparison b/c we removed them in the grammar
|
||||
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
|
||||
" ", "")
|
||||
|
||||
assert generated_text.strip() == ground_truth
|
||||
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
#
|
||||
# Test 6: Test invalid grammar input
|
||||
#
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(grammar="not a grammar"))
|
||||
with pytest.raises(ValueError, match="Failed to convert the grammar "):
|
||||
llm.generate(
|
||||
prompts=
|
||||
("Generate a sql statement that selects col_1 from "
|
||||
"table_1 where it is equal to 1. Make the response as short "
|
||||
"as possible."),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
)
|
||||
|
||||
#
|
||||
# Test 7: Generate text based on a regex pattern
|
||||
#
|
||||
@ -421,35 +429,36 @@ def test_structured_output(
|
||||
output_json = json.loads(generated_text)
|
||||
jsonschema.validate(instance=output_json, schema=json_schema)
|
||||
|
||||
#
|
||||
# Test 11: Generate structured output using structural_tag format
|
||||
#
|
||||
structural_tag_config = {
|
||||
"type":
|
||||
"structural_tag",
|
||||
"structures": [{
|
||||
"begin": "<function=get_weather>",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string"
|
||||
}
|
||||
if guided_decoding_backend != "outlines":
|
||||
#
|
||||
# Test 11: Generate structured output using structural_tag format
|
||||
#
|
||||
structural_tag_config = {
|
||||
"type":
|
||||
"structural_tag",
|
||||
"structures": [{
|
||||
"begin": "<function=get_weather>",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"additionalProperties": False
|
||||
},
|
||||
"additionalProperties": False
|
||||
},
|
||||
"end": "</function>"
|
||||
}],
|
||||
"triggers": ["<function="]
|
||||
}
|
||||
"end": "</function>"
|
||||
}],
|
||||
"triggers": ["<function="]
|
||||
}
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=4096,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
structural_tag=json.dumps(structural_tag_config)))
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=4096,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
structural_tag=json.dumps(structural_tag_config)))
|
||||
|
||||
prompt = """
|
||||
prompt = """
|
||||
You have access to the following function to retrieve the weather in a city:
|
||||
|
||||
{
|
||||
@ -469,7 +478,7 @@ where
|
||||
|
||||
start_tag => `<function`
|
||||
parameters => a JSON dict with the function argument name
|
||||
as key and function argument value as value.
|
||||
as key and function argument value as value.
|
||||
end_tag => `</function>`
|
||||
|
||||
Here is an example,
|
||||
@ -488,37 +497,37 @@ Given the previous instructions, what is the weather in New York City? \
|
||||
Make the response as short as possible.
|
||||
"""
|
||||
|
||||
# Change this once other backends support structural_tag
|
||||
outputs = llm.generate(prompts=prompt,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
assert outputs is not None
|
||||
# Change this once other backends support structural_tag
|
||||
outputs = llm.generate(prompts=prompt,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
assert outputs is not None
|
||||
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
|
||||
# Search for function call pattern in the response
|
||||
function_call_pattern = r'<function=get_weather>(.*?)</function>'
|
||||
matches = re.findall(function_call_pattern, generated_text)
|
||||
# Search for function call pattern in the response
|
||||
function_call_pattern = r'<function=get_weather>(.*?)</function>'
|
||||
matches = re.findall(function_call_pattern, generated_text)
|
||||
|
||||
if not matches:
|
||||
print(f"Warning: No function calls found in response: "
|
||||
f"{generated_text!r}")
|
||||
continue
|
||||
if not matches:
|
||||
print(f"Warning: No function calls found in response: "
|
||||
f"{generated_text!r}")
|
||||
continue
|
||||
|
||||
# Take the first function call if multiple are found
|
||||
json_str = matches[0]
|
||||
try:
|
||||
json_content = json.loads(json_str)
|
||||
assert "city" in json_content
|
||||
assert isinstance(json_content["city"], str)
|
||||
print(f"Found valid function call: {generated_text!r}")
|
||||
except (json.JSONDecodeError, AssertionError) as e:
|
||||
pytest.fail("Invalid function call format: "
|
||||
f"{generated_text!r}\nError: {str(e)}")
|
||||
# Take the first function call if multiple are found
|
||||
json_str = matches[0]
|
||||
try:
|
||||
json_content = json.loads(json_str)
|
||||
assert "city" in json_content
|
||||
assert isinstance(json_content["city"], str)
|
||||
print(f"Found valid function call: {generated_text!r}")
|
||||
except (json.JSONDecodeError, AssertionError) as e:
|
||||
pytest.fail("Invalid function call format: "
|
||||
f"{generated_text!r}\nError: {str(e)}")
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
|
||||
@ -3580,7 +3580,8 @@ def get_served_model_name(model: str,
|
||||
|
||||
GuidedDecodingBackendV0 = Literal["auto", "outlines", "lm-format-enforcer",
|
||||
"xgrammar", "guidance"]
|
||||
GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance"]
|
||||
|
||||
GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance", "outlines"]
|
||||
GuidedDecodingBackend = Literal[GuidedDecodingBackendV0,
|
||||
GuidedDecodingBackendV1]
|
||||
|
||||
|
||||
@ -117,6 +117,7 @@ if TYPE_CHECKING:
|
||||
VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False
|
||||
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
|
||||
VLLM_V0_USE_OUTLINES_CACHE: bool = False
|
||||
VLLM_V1_USE_OUTLINES_CACHE: bool = False
|
||||
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
|
||||
VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None
|
||||
VLLM_USE_DEEP_GEMM: bool = False
|
||||
@ -847,6 +848,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_V0_USE_OUTLINES_CACHE":
|
||||
lambda: os.environ.get("VLLM_V0_USE_OUTLINES_CACHE", "0") == "1",
|
||||
|
||||
# Whether to turn on the outlines cache for V1
|
||||
# This cache is unbounded and on disk, so it's not safe to use in
|
||||
# an environment with potentially malicious users.
|
||||
"VLLM_V1_USE_OUTLINES_CACHE":
|
||||
lambda: os.environ.get("VLLM_V1_USE_OUTLINES_CACHE", "0") == "1",
|
||||
|
||||
# Gap between padding buckets for the forward pass. So we have
|
||||
# 8, we will run forward pass with [16, 24, 32, ...].
|
||||
"VLLM_TPU_BUCKET_PADDING_GAP":
|
||||
|
||||
@ -79,20 +79,33 @@ def maybe_backend_fallback(
|
||||
fallback_or_error(
|
||||
guided_params,
|
||||
"xgrammar does not support Lark grammars and the "
|
||||
"grammar failed to convert to GBNF.", "outlines")
|
||||
"grammar failed to convert to GBNF.", "guidance")
|
||||
|
||||
# If the xgrammar module cannot be imported successfully,
|
||||
# we should still allow users to use guided decoding with a fallback.
|
||||
elif not xgr_installed:
|
||||
fallback_or_error(
|
||||
guided_params,
|
||||
"xgrammar module cannot be imported successfully.", "outlines")
|
||||
"xgrammar module cannot be imported successfully.", "guidance")
|
||||
|
||||
if (guided_params.backend == "outlines"
|
||||
and guided_params.json_object is not None):
|
||||
# outlines doesn't support json_object, fallback to guidance
|
||||
fallback_or_error(guided_params,
|
||||
"outlines does not support json_object.", "guidance")
|
||||
if guided_params.backend == "outlines":
|
||||
if guided_params.json_object is not None:
|
||||
# outlines doesn't support json_object, fallback to guidance
|
||||
fallback_or_error(guided_params,
|
||||
"outlines does not support json_object.",
|
||||
"guidance")
|
||||
elif guided_params.grammar is not None:
|
||||
# outlines grammar support has been removed, fallback to guidance
|
||||
# if it is a lark-based grammar and xgrammar otherwise
|
||||
if grammar_is_likely_lark(guided_params.grammar):
|
||||
fallback_or_error(guided_params,
|
||||
"outlines no longer supports grammars.",
|
||||
"guidance")
|
||||
else:
|
||||
# The grammar is likely already GBNF format.
|
||||
fallback_or_error(guided_params,
|
||||
"outlines no longer supports grammars.",
|
||||
"xgrammar")
|
||||
|
||||
return guided_params
|
||||
|
||||
@ -111,7 +124,6 @@ async def get_guided_decoding_logits_processor(
|
||||
|
||||
guided_params = maybe_backend_fallback(guided_params)
|
||||
|
||||
# CFG grammar not supported by LMFE, so we use outlines instead
|
||||
if guided_params.backend == 'outlines':
|
||||
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
|
||||
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
|
||||
@ -152,7 +164,6 @@ def get_local_guided_decoding_logits_processor(
|
||||
reasoning_backend)
|
||||
reasoner = reasoner_class(tokenizer)
|
||||
|
||||
# CFG grammar not supported by LMFE, so we use outlines instead
|
||||
if guided_params.backend == 'outlines':
|
||||
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
|
||||
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
|
||||
|
||||
@ -12,7 +12,7 @@ from regex import escape as regex_escape
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
|
||||
CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
|
||||
JSONLogitsProcessor, RegexLogitsProcessor)
|
||||
from vllm.reasoning import ReasoningParser
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
|
||||
@ -21,36 +21,8 @@ class GuidedDecodingMode(Enum):
|
||||
JSON = "json"
|
||||
REGEX = "regex"
|
||||
CHOICE = "choice"
|
||||
GRAMMAR = "grammar"
|
||||
|
||||
|
||||
# https://github.com/outlines-dev/outlines/blob/main/outlines/grammars/json.lark
|
||||
# the main difference is that we changed the start: value to
|
||||
# start: object | array, so we are denying scalar values as the root of the
|
||||
# JSON. Starting with scalars as the root seems to cause llama to generate
|
||||
# without stop.
|
||||
JSON_GRAMMAR = r"""
|
||||
?start: object | array
|
||||
|
||||
?value: object
|
||||
| array
|
||||
| UNESCAPED_STRING
|
||||
| SIGNED_NUMBER -> number
|
||||
| "true" -> true
|
||||
| "false" -> false
|
||||
| "null" -> null
|
||||
|
||||
array : "[" [value ("," value)*] "]"
|
||||
object : "{" [pair ("," pair)*] "}"
|
||||
pair : UNESCAPED_STRING ":" value
|
||||
|
||||
%import common.UNESCAPED_STRING
|
||||
%import common.SIGNED_NUMBER
|
||||
%import common.WS
|
||||
|
||||
%ignore WS
|
||||
"""
|
||||
|
||||
global_thread_pool = None # used for generating logits processor fsm
|
||||
|
||||
# It's not yet clear that using more provides a benefit, and it could
|
||||
@ -60,16 +32,12 @@ _MAX_THREADPOOL_WORKERS = 16
|
||||
|
||||
|
||||
async def get_outlines_guided_decoding_logits_processor(
|
||||
guided_params: GuidedDecodingParams,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
reasoner: Optional[ReasoningParser],
|
||||
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
|
||||
None]:
|
||||
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase,
|
||||
reasoner: Optional[ReasoningParser]
|
||||
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, None]:
|
||||
"""
|
||||
Given an OpenAI-compatible request, check for guided decoding parameters
|
||||
and get the necessary logits processor for the given guide.
|
||||
We cache logit processors by (guide, tokenizer), and on cache hit
|
||||
we make a shallow copy to reuse the same underlying FSM.
|
||||
"""
|
||||
global global_thread_pool
|
||||
guide, mode = _get_guide_and_mode(guided_params)
|
||||
@ -83,7 +51,6 @@ async def get_outlines_guided_decoding_logits_processor(
|
||||
global_thread_pool = concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=max_workers)
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
return await loop.run_in_executor(global_thread_pool,
|
||||
_get_logits_processor, guide, tokenizer,
|
||||
mode, guided_params.whitespace_pattern,
|
||||
@ -91,16 +58,12 @@ async def get_outlines_guided_decoding_logits_processor(
|
||||
|
||||
|
||||
def get_local_outlines_guided_decoding_logits_processor(
|
||||
guided_params: GuidedDecodingParams,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
reasoner: Optional[ReasoningParser],
|
||||
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
|
||||
None]:
|
||||
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase,
|
||||
reasoner: Optional[ReasoningParser]
|
||||
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, None]:
|
||||
"""
|
||||
Given an OpenAI-compatible request, check for guided decoding parameters
|
||||
and get the necessary logits processor for the given guide.
|
||||
We cache logit processors by (guide, tokenizer), and on cache hit
|
||||
we make a shallow copy to reuse the same underlying FSM.
|
||||
"""
|
||||
guide, mode = _get_guide_and_mode(guided_params)
|
||||
if not guide or not mode:
|
||||
@ -130,9 +93,10 @@ def _get_guide_and_mode(
|
||||
choices_regex = "(" + "|".join(choices) + ")"
|
||||
return choices_regex, GuidedDecodingMode.CHOICE
|
||||
elif guided_params.grammar:
|
||||
return guided_params.grammar, GuidedDecodingMode.GRAMMAR
|
||||
elif guided_params.json_object:
|
||||
return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR
|
||||
raise ValueError(
|
||||
"The `outlines` guided decoding backend no longer supports grammar "
|
||||
"guided generation. Please use either the `xgrammar` or `guidance` "
|
||||
"backend")
|
||||
else:
|
||||
return None, None
|
||||
|
||||
@ -143,13 +107,11 @@ def _get_logits_processor(
|
||||
mode: GuidedDecodingMode,
|
||||
whitespace_pattern: Union[str, None],
|
||||
reasoner: Optional[ReasoningParser],
|
||||
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]:
|
||||
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]:
|
||||
if mode == GuidedDecodingMode.JSON:
|
||||
return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern,
|
||||
reasoner)
|
||||
elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
|
||||
return RegexLogitsProcessor(guide, tokenizer, reasoner)
|
||||
elif mode == GuidedDecodingMode.GRAMMAR:
|
||||
return CFGLogitsProcessor(guide, tokenizer, reasoner)
|
||||
else:
|
||||
raise ValueError(f"Unknown guided decoding mode {mode}")
|
||||
|
||||
@ -1,168 +1,124 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Copyright 2024-present the Outlines developers
|
||||
from __future__ import annotations
|
||||
|
||||
# Copyright 2024- the Outlines developers
|
||||
# This file is adapted from
|
||||
# https://github.com/outlines-dev/outlines/blob/main/outlines/serve/vllm.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import hashlib
|
||||
import importlib.metadata
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from functools import lru_cache
|
||||
from typing import Callable, Optional, Union
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import regex as re
|
||||
import torch
|
||||
from outlines import grammars
|
||||
from outlines.caching import cache, disable_cache
|
||||
from outlines.fsm.guide import (CFGGuide, CFGState, Generate, Guide,
|
||||
RegexGuide, Write)
|
||||
from outlines.fsm.parsing import PartialLark
|
||||
from outlines_core.fsm.json_schema import build_regex_from_schema
|
||||
from cachetools import LRUCache
|
||||
from diskcache import Cache
|
||||
from outlines_core import Guide, Index, Vocabulary
|
||||
from outlines_core.json_schema import build_regex_from_schema
|
||||
from outlines_core.kernels.torch import (_apply_token_bitmask_inplace_kernel,
|
||||
allocate_token_bitmask)
|
||||
from pydantic import BaseModel
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from transformers.file_utils import SPIECE_UNDERLINE
|
||||
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.reasoning import ReasoningParser
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if envs.VLLM_V0_USE_OUTLINES_CACHE:
|
||||
logger.warning("Enabling outlines cache. This is an unbounded on-disk "
|
||||
"cache. It may consume a lot of disk space and should "
|
||||
"not be used with untrusted clients.")
|
||||
else:
|
||||
disable_cache()
|
||||
CACHE = None
|
||||
|
||||
|
||||
class BaseLogitsProcessor:
|
||||
|
||||
def __init__(self, guide: Guide, reasoner: Optional[ReasoningParser]):
|
||||
def __init__(self, guide: Guide, eos_token_id: int,
|
||||
reasoner: Optional[ReasoningParser]) -> None:
|
||||
self._guide: Guide = guide
|
||||
self._eos_token_id: int = eos_token_id
|
||||
self._reasoner: Optional[ReasoningParser] = reasoner
|
||||
# CFGState is used for the FSM state for CFGGuide
|
||||
self._fsm_state: defaultdict[int, Union[int,
|
||||
CFGState]] = defaultdict(int)
|
||||
|
||||
def clone(self) -> "BaseLogitsProcessor":
|
||||
cloned = copy.copy(self)
|
||||
cloned._guide = self._guide.copy()
|
||||
cloned._fsm_state = copy.deepcopy(self._fsm_state)
|
||||
return cloned
|
||||
self._mask: Optional[torch.Tensor] = None
|
||||
|
||||
def __call__(self, input_ids: list[int],
|
||||
scores: torch.Tensor) -> torch.Tensor:
|
||||
"""Use the FSM to bias the logits before sampling the next token."""
|
||||
if self._mask is None:
|
||||
self._mask = allocate_token_bitmask(scores.size(-1))
|
||||
|
||||
# Skip the structured logits processing if reasoning is not finished.
|
||||
# reasoner is not None only when `--reasoning-parser` is set.
|
||||
if self._reasoner is not None:
|
||||
if not self._reasoner.is_reasoning_end(input_ids):
|
||||
return scores
|
||||
else:
|
||||
# Remove the reasoning tokens from the input_ids
|
||||
# We need this because our implementation relies on the
|
||||
# hash of the input_ids to store the FSM state.
|
||||
input_ids = self._reasoner.extract_content_ids(input_ids)
|
||||
if self._reasoner is not None and not self._reasoner.is_reasoning_end(
|
||||
input_ids):
|
||||
return scores
|
||||
|
||||
seq_id = hash(tuple(input_ids))
|
||||
# Remove the reasoning tokens from the input_ids
|
||||
# We need this because our implementation relies on the
|
||||
# input_ids sequence to store the FSM state.
|
||||
input_ids = (self._reasoner.extract_content_ids(input_ids)
|
||||
if self._reasoner is not None else input_ids)
|
||||
|
||||
if len(input_ids) > 0:
|
||||
last_token = input_ids[-1]
|
||||
last_seq_id = hash(tuple(input_ids[:-1]))
|
||||
self._fsm_state[seq_id] = self._guide.get_next_state(
|
||||
state=self._fsm_state[last_seq_id], token_id=last_token)
|
||||
else:
|
||||
# Note: this is a hack.
|
||||
# Lark pickling does not work properly (silent failure),
|
||||
# which breaks the RPC (which uses python pickleing).
|
||||
# We need to find a better solution.
|
||||
# On the first time this is called, we simply re-create
|
||||
# the Lark object.
|
||||
if isinstance(self._guide, CFGGuide):
|
||||
self._guide.parser = PartialLark(
|
||||
self._guide.cfg_string,
|
||||
parser="lalr",
|
||||
import_paths=[grammars.GRAMMAR_PATH],
|
||||
)
|
||||
self._fsm_state[seq_id] = CFGState(
|
||||
parser_state=self._guide.parser.parse(""), prev_token=None)
|
||||
# Vllm V0 engine has a weird bug where we have to repeat
|
||||
# the eos token id twice for generation to stop, or at least
|
||||
# that is what we have to do from here in any case.
|
||||
# This is a patch until a better solution can be pushed
|
||||
# to outlines_core
|
||||
if input_ids and input_ids[-1] != self._eos_token_id:
|
||||
self._guide.advance(token_id=input_ids[-1], return_tokens=False)
|
||||
|
||||
instruction = self._guide.get_next_instruction(
|
||||
state=self._fsm_state[seq_id])
|
||||
self._guide.write_mask_into(
|
||||
data_ptr=self._mask.data_ptr(),
|
||||
numel=self._mask.numel(),
|
||||
element_size=self._mask.element_size(),
|
||||
)
|
||||
|
||||
if type(instruction) == Generate: # noqa: E721
|
||||
allowed_tokens = instruction.tokens
|
||||
elif type(instruction) == Write: # noqa: E721
|
||||
# TODO: support fast forward tokens
|
||||
allowed_tokens = [instruction.tokens[0]]
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Unsupported instruction type {type(instruction)}")
|
||||
# Any allowed tokens beyond the length of the scores will
|
||||
# be ignored by the kernel, taking care of the issue with
|
||||
# models such as Llama 3.2 Vision with an `<|image|>` token
|
||||
# with id 128256, but scores.shape == torch.Size([128256])
|
||||
_apply_token_bitmask_inplace_kernel(
|
||||
logits=scores.unsqueeze(dim=0),
|
||||
# mask must be on same device
|
||||
mask=self._mask.to(scores.device, non_blocking=True))
|
||||
self._mask.to("cpu", non_blocking=True)
|
||||
|
||||
mask = torch.full((scores.shape[-1], ),
|
||||
-torch.inf,
|
||||
device=scores.device)
|
||||
# The tokenizer may support more token ids than the model can generate,
|
||||
# eg. Llama 3.2 Vision models have an `<|image|>` token with id 128256
|
||||
# but scores.shape == torch.Size([128256])
|
||||
# Using NumPy is faster for filtering token ids
|
||||
allowed_tokens = np.array(allowed_tokens, dtype=np.int64)
|
||||
allowed_tokens = torch.tensor(allowed_tokens, device=scores.device)
|
||||
allowed_tokens = allowed_tokens.masked_select(
|
||||
allowed_tokens < scores.shape[-1])
|
||||
mask.index_fill_(0, allowed_tokens, 0)
|
||||
if current_platform.is_hpu():
|
||||
# Workaround for HPU bug where add_() raise RuntimeError:
|
||||
# synNodeCreateWithId failed for node: strided_insert
|
||||
# with synStatus 1 [Invalid argument], hopefully it will
|
||||
# be fixed in the future releases of the HPU runtime.
|
||||
scores = scores.add(mask)
|
||||
else:
|
||||
scores.add_(mask)
|
||||
return scores
|
||||
|
||||
def clone(self) -> BaseLogitsProcessor:
|
||||
guide = copy.deepcopy(self._guide)
|
||||
guide.reset()
|
||||
return BaseLogitsProcessor(guide=guide,
|
||||
eos_token_id=self._eos_token_id,
|
||||
reasoner=self._reasoner)
|
||||
|
||||
|
||||
class RegexLogitsProcessor(BaseLogitsProcessor):
|
||||
|
||||
@classmethod
|
||||
@cache()
|
||||
def _get_guide(cls, regex_string: str,
|
||||
tokenizer: PreTrainedTokenizerBase) -> Guide:
|
||||
tokenizer = _adapt_tokenizer(tokenizer)
|
||||
return RegexGuide.from_regex(regex_string, tokenizer)
|
||||
global CACHE
|
||||
if CACHE is None:
|
||||
CACHE = get_cache()
|
||||
vocabulary = get_vocabulary(tokenizer) # type: ignore[arg-type]
|
||||
cache_key = f"{vocabulary._hash}_{regex_string}"
|
||||
if CACHE is not None and cache_key in CACHE:
|
||||
return Guide(CACHE[cache_key])
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
regex_string: str,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
reasoner: Optional[ReasoningParser],
|
||||
):
|
||||
"""Compile the FSM that drives the regex-structured generation.
|
||||
index = Index(regex_string, vocabulary.inner)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
regex_string
|
||||
A string that represents a regular expression
|
||||
tokenizer
|
||||
The model's tokenizer
|
||||
if CACHE is not None:
|
||||
CACHE[cache_key] = index
|
||||
|
||||
"""
|
||||
return Guide(index)
|
||||
|
||||
def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase,
|
||||
reasoner: Optional[ReasoningParser]) -> None:
|
||||
super().__init__(
|
||||
RegexLogitsProcessor._get_guide(regex_string, tokenizer), reasoner)
|
||||
guide=RegexLogitsProcessor._get_guide(regex_string, tokenizer),
|
||||
eos_token_id=tokenizer.eos_token_id, # type: ignore
|
||||
reasoner=reasoner)
|
||||
|
||||
|
||||
class JSONLogitsProcessor(RegexLogitsProcessor):
|
||||
@ -170,22 +126,8 @@ class JSONLogitsProcessor(RegexLogitsProcessor):
|
||||
def __init__(self, schema: Union[str, dict, BaseModel],
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
whitespace_pattern: Union[str, None],
|
||||
reasoner: Optional[ReasoningParser]):
|
||||
"""Compile the FSM that drives the JSON-guided generation.
|
||||
reasoner: Optional[ReasoningParser]) -> None:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
schema
|
||||
A JSON schema that encodes the structure we want the model to
|
||||
generate
|
||||
tokenizer
|
||||
The model's tokenizer
|
||||
whitespace_pattern
|
||||
Pattern to use for JSON syntactic whitespace (doesn't impact
|
||||
string literals)
|
||||
Example: allow only a single space or newline with
|
||||
`whitespace_pattern=r"[\n ]?"`
|
||||
"""
|
||||
if isinstance(schema, type(BaseModel)):
|
||||
schema_str = json.dumps(schema.model_json_schema())
|
||||
elif isinstance(schema, dict):
|
||||
@ -197,63 +139,42 @@ class JSONLogitsProcessor(RegexLogitsProcessor):
|
||||
f"Cannot parse schema {schema}. The schema must be either "
|
||||
f"a Pydantic object, a dictionary or a string that contains "
|
||||
f"the JSON Schema specification")
|
||||
|
||||
regex_string = build_regex_from_schema(schema_str, whitespace_pattern)
|
||||
super().__init__(regex_string, tokenizer, reasoner)
|
||||
|
||||
|
||||
class CFGLogitsProcessor(BaseLogitsProcessor):
|
||||
|
||||
@classmethod
|
||||
@cache()
|
||||
def _get_guide(cls, cfg: str, tokenizer: PreTrainedTokenizerBase) -> Guide:
|
||||
tokenizer = _adapt_tokenizer(tokenizer)
|
||||
return CFGGuide(cfg, tokenizer)
|
||||
|
||||
def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase,
|
||||
reasoner: Optional[ReasoningParser]):
|
||||
"""Compile the FSM that drives the context free grammar generation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cfg
|
||||
A string that represents a context-free grammar
|
||||
tokenizer
|
||||
The model's tokenizer
|
||||
|
||||
"""
|
||||
super().__init__(CFGLogitsProcessor._get_guide(cfg, tokenizer),
|
||||
reasoner)
|
||||
self._guide = self._guide.copy()
|
||||
|
||||
def clone(self) -> "CFGLogitsProcessor":
|
||||
cloned = copy.copy(self)
|
||||
cloned._fsm_state = copy.deepcopy(self._fsm_state)
|
||||
cloned._guide = self._guide.copy()
|
||||
return cloned
|
||||
|
||||
|
||||
@lru_cache(maxsize=32)
|
||||
def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
|
||||
"""Adapt vLLM's tokenizer to use to compile the FSM.
|
||||
|
||||
The API of Outlines tokenizers is slightly different to that of
|
||||
`transformers`. The decoder of outlines, returns a list whereas
|
||||
the decode of vLLM returns an str. To sync the vLLM decoder with
|
||||
outlines internal api, the decoder should be adapted. In addition
|
||||
we need to handle the missing spaces to Llama's tokenizer to be
|
||||
able to compile FSMs for this model.
|
||||
|
||||
class OutlinesVocabulary:
|
||||
"""
|
||||
Wrapper class for `outlines_core.Vocabulary`,
|
||||
which allows us to store a hash with the vocabulary
|
||||
"""
|
||||
if getattr(tokenizer, "_outlines_adapted", False):
|
||||
return tokenizer
|
||||
|
||||
tokenizer = copy.deepcopy(tokenizer)
|
||||
def __init__(self, vocabulary: Vocabulary) -> None:
|
||||
# Actual vocabulary object
|
||||
self.inner = vocabulary
|
||||
# Have to do abs(hash()) because python hashes can
|
||||
# be negative, and we are using hash as a cache key.
|
||||
hex_str = hashlib.sha256(
|
||||
vocabulary.__repr__().encode('utf-8')).hexdigest()
|
||||
hash_int = int(hex_str, 16)
|
||||
self._hash = hash_int
|
||||
|
||||
tokenizer.vocabulary = tokenizer.get_vocab()
|
||||
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
|
||||
|
||||
re_llama_byte_token = re.compile(r"^<0x[0-9A-F]{2}>$")
|
||||
re_replacement_seq = re.compile(r"^.{0,6}<7D>+.{0,6}$")
|
||||
|
||||
|
||||
def _reduced_vocabulary(tokenizer: AnyTokenizer,
|
||||
eos_token_id: int) -> dict[bytes, list[int]]:
|
||||
"""Create a map from vocabulary tokens to lists of equivalent token ids.
|
||||
|
||||
Returns:
|
||||
A Dict of token string -> equivalent token ids
|
||||
"""
|
||||
unicode_to_bytes = {v: k for k, v in bytes_to_unicode().items()}
|
||||
|
||||
def convert_token_to_string(token: str) -> str:
|
||||
from transformers.file_utils import SPIECE_UNDERLINE
|
||||
|
||||
string = tokenizer.convert_tokens_to_string([token])
|
||||
|
||||
@ -264,21 +185,123 @@ def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
|
||||
|
||||
return string
|
||||
|
||||
def change_decoder(
|
||||
decoder: Callable[[list[int]],
|
||||
str]) -> Callable[[list[int]], list[str]]:
|
||||
"""Sync vLLM's decoder with the outlines by returning list."""
|
||||
vocabulary: dict[bytes, list[int]] = {}
|
||||
empty_token_ids: list[int] = []
|
||||
for token, token_idx in tokenizer.get_vocab().items():
|
||||
if token in tokenizer.all_special_tokens: # type: ignore
|
||||
continue
|
||||
|
||||
def new_decoder(inp_tokens: list[int]) -> list[str]:
|
||||
if (isinstance(inp_tokens, list) and len(inp_tokens) == 1
|
||||
and isinstance(inp_tokens[0], list)):
|
||||
inp_tokens = inp_tokens[0]
|
||||
return [decoder(inp_tokens)]
|
||||
token_str = convert_token_to_string(token)
|
||||
if token_str:
|
||||
if isinstance(token, (bytes, bytearray)):
|
||||
# For BPE tokenizers where tokens are stored as bytes.
|
||||
|
||||
return new_decoder
|
||||
# safe to ignore since token_str is of type (bytearray, bytes)
|
||||
# by this point.
|
||||
token_bytes = bytes(token_str) # type: ignore[arg-type]
|
||||
|
||||
tokenizer.convert_token_to_string = convert_token_to_string
|
||||
tokenizer.decode = change_decoder(tokenizer.decode)
|
||||
setattr(tokenizer, "_outlines_adapted", True) # noqa: B010
|
||||
elif "\ufffd" in token_str and not re_replacement_seq.match(
|
||||
token_str):
|
||||
# Handle tokens with invalid UTF-8 sequences.
|
||||
if re_llama_byte_token.match(token):
|
||||
# Llama-like tokenizers use <0xXX> for incomplete sequences.
|
||||
token_bytes = bytes([int(token[3:5], 16)])
|
||||
else:
|
||||
# GPT2 tokenizers: map each byte back using unicode_to_bytes
|
||||
byte_vals = [unicode_to_bytes.get(c) for c in token]
|
||||
if None in byte_vals:
|
||||
raise RuntimeError(
|
||||
f"Cannot convert token `{token}`"
|
||||
f" ({token_idx}) to bytes: {token_str}")
|
||||
# safe to ignore, since if None in byte_vals,
|
||||
# an error is thrown.
|
||||
token_bytes = bytes(byte_vals) # type: ignore[arg-type]
|
||||
else:
|
||||
token_bytes = token_str.encode('utf-8')
|
||||
|
||||
return tokenizer
|
||||
if token_idx != eos_token_id:
|
||||
vocabulary.setdefault(token_bytes, []).append(token_idx)
|
||||
else:
|
||||
empty_token_ids.append(token_idx)
|
||||
|
||||
return vocabulary
|
||||
|
||||
|
||||
def get_vocabulary(tokenizer: AnyTokenizer) -> Vocabulary:
|
||||
"""Get the `Vocabulary` object for a given tokenizer.
|
||||
"""
|
||||
if hasattr(tokenizer, "_outlines_vocabulary"):
|
||||
return tokenizer._outlines_vocabulary # type: ignore
|
||||
|
||||
try:
|
||||
if hasattr(
|
||||
tokenizer,
|
||||
"eos_token_id",
|
||||
) and tokenizer.eos_token_id is not None:
|
||||
eos_token_id = tokenizer.eos_token_id
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Error during guided decoding setup: Tokenizer"
|
||||
f" ({type(tokenizer)}) has no `eos_token_id` property, "
|
||||
"but `eos_token_id` is required for guided decoding"
|
||||
" to work properly.")
|
||||
|
||||
reduced_vocab = _reduced_vocabulary(
|
||||
tokenizer,
|
||||
eos_token_id #type: ignore
|
||||
)
|
||||
vocabulary = OutlinesVocabulary(Vocabulary(eos_token_id,
|
||||
reduced_vocab))
|
||||
tokenizer._outlines_vocabulary = vocabulary # type: ignore
|
||||
|
||||
return vocabulary
|
||||
except AttributeError as e:
|
||||
raise ValueError(f"Cannot get the vocabulary of the tokenizer "
|
||||
f"({type(tokenizer)}). The tokenizer should have a "
|
||||
"get_vocab method.") from e
|
||||
|
||||
|
||||
def get_cache_path() -> str:
|
||||
"""Get the context object that contains previously-computed return values"""
|
||||
outlines_cache_dir = os.getenv("OUTLINES_CACHE_DIR")
|
||||
xdg_cache_home = os.getenv("XDG_CACHE_HOME")
|
||||
home_dir = os.path.expanduser("~")
|
||||
|
||||
if outlines_cache_dir:
|
||||
# OUTLINES_CACHE_DIR takes precedence
|
||||
return outlines_cache_dir
|
||||
elif xdg_cache_home:
|
||||
return os.path.join(xdg_cache_home, ".cache", "outlines")
|
||||
# If homedir is "/", we may be inside a container, and thus writing to
|
||||
# root would be problematic, so we fallback to using a tempfile.
|
||||
# Also validate the path exists, since os.path.expanduser does
|
||||
# not garuntee existence.
|
||||
elif os.path.isdir(home_dir) and home_dir != "/":
|
||||
# Default Unix fallback: ~/.cache/outlines
|
||||
return os.path.join(home_dir, ".cache", "outlines")
|
||||
else:
|
||||
import tempfile
|
||||
|
||||
# home_dir may be / inside a docker container without existing user
|
||||
tempdir = tempfile.gettempdir()
|
||||
return os.path.join(tempdir, ".cache", "outlines")
|
||||
|
||||
|
||||
def get_cache():
|
||||
"""Get the Cache instance to be used for index caching"""
|
||||
|
||||
cache_dir = get_cache_path()
|
||||
if envs.VLLM_V0_USE_OUTLINES_CACHE:
|
||||
logger.warning("Enabling outlines cache. This is an unbounded on-disk "
|
||||
"cache. It may consume a lot of disk space and should "
|
||||
"not be used with untrusted clients.")
|
||||
cache = Cache(cache_dir, eviction_policy="none", cull_limit=0)
|
||||
outlines_version = importlib.metadata.version("outlines_core")
|
||||
|
||||
cached_version = cache.get('__version__', None)
|
||||
if cached_version != outlines_version:
|
||||
cache.clear()
|
||||
cache.set('__version__', outlines_version)
|
||||
return cache
|
||||
else:
|
||||
return LRUCache(maxsize=128)
|
||||
|
||||
@ -23,6 +23,8 @@ from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
|
||||
from vllm.v1.structured_output.backend_guidance import (
|
||||
validate_guidance_grammar)
|
||||
from vllm.v1.structured_output.backend_outlines import (
|
||||
validate_structured_output_request_outlines)
|
||||
from vllm.v1.structured_output.backend_xgrammar import (
|
||||
validate_xgrammar_grammar)
|
||||
|
||||
@ -193,6 +195,9 @@ class Processor:
|
||||
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
|
||||
# Without tokenizer these are disallowed in grammars.
|
||||
validate_guidance_grammar(params, tokenizer=None)
|
||||
elif engine_level_backend == "outlines":
|
||||
# outlines backend
|
||||
validate_structured_output_request_outlines(params)
|
||||
else:
|
||||
# NOTE: engine_level_backend must be "auto" here, because we have
|
||||
# checked supported_backends above.
|
||||
|
||||
@ -88,6 +88,15 @@ class StructuredOutputManager:
|
||||
tokenizer=self.tokenizer,
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
elif backend == "outlines":
|
||||
from vllm.v1.structured_output.backend_outlines import (
|
||||
OutlinesBackend)
|
||||
|
||||
self.backend = OutlinesBackend(
|
||||
self.vllm_config,
|
||||
tokenizer=self.tokenizer,
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported structured output backend: {backend}")
|
||||
|
||||
319
vllm/v1/structured_output/backend_outlines.py
Normal file
319
vllm/v1/structured_output/backend_outlines.py
Normal file
@ -0,0 +1,319 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright 2025-present the Outlines developers
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import importlib
|
||||
import json
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from regex import escape as regex_escape
|
||||
|
||||
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
|
||||
OutlinesVocabulary, get_cache, get_vocabulary)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import LazyLoader
|
||||
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
|
||||
StructuredOutputGrammar,
|
||||
StructuredOutputOptions)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import outlines_core as oc
|
||||
import outlines_core.json_schema as json_schema
|
||||
else:
|
||||
oc = LazyLoader("oc", globals(), "outlines_core")
|
||||
json_schema = LazyLoader("json_schema", globals(),
|
||||
"outlines_core.json_schema")
|
||||
|
||||
# Python 3.11+ sre_parse and sre_constants
|
||||
# are deprecated, so we must import them from re
|
||||
if sys.version_info >= (3, 11):
|
||||
# Hack to get around pre-commit regex module rule
|
||||
# because going through re is the only way to get sre_parse
|
||||
# and sre_constants in Python 3.11+
|
||||
_re = importlib.import_module("re")
|
||||
sre_parse = _re._parser
|
||||
sre_constants = _re._constants
|
||||
else:
|
||||
import sre_constants
|
||||
import sre_parse
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutlinesBackend(StructuredOutputBackend):
|
||||
|
||||
def __post_init__(self):
|
||||
self.vocabulary = get_vocabulary(self.tokenizer)
|
||||
self.cache = get_cache()
|
||||
|
||||
def _compile_index(self, regex_string: str,
|
||||
vocabulary: OutlinesVocabulary) -> oc.Index:
|
||||
cache_key = f"{vocabulary._hash}_{regex_string}"
|
||||
if cache_key in self.cache:
|
||||
return self.cache[cache_key]
|
||||
|
||||
index = oc.Index(regex_string, vocabulary.inner)
|
||||
self.cache[cache_key] = index
|
||||
|
||||
return index
|
||||
|
||||
def compile_grammar(self, request_type: StructuredOutputOptions,
|
||||
grammar_spec: str) -> StructuredOutputGrammar:
|
||||
if request_type == StructuredOutputOptions.JSON:
|
||||
regex = json_schema.build_regex_from_schema(grammar_spec)
|
||||
elif request_type == StructuredOutputOptions.REGEX:
|
||||
regex = grammar_spec
|
||||
elif request_type == StructuredOutputOptions.CHOICE:
|
||||
choices = ast.literal_eval(grammar_spec)
|
||||
choices = [regex_escape(c) for c in choices]
|
||||
regex = "(" + "|".join(choices) + ")"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid request type for Outlines backend ({request_type!s})"
|
||||
)
|
||||
index = self._compile_index(regex, self.vocabulary)
|
||||
max_rollback_tokens = (
|
||||
self.vllm_config.speculative_config.num_speculative_tokens
|
||||
if self.vllm_config.speculative_config is not None else 0)
|
||||
return OutlinesGrammar(vocab_size=self.vocab_size,
|
||||
guide=oc.Guide(
|
||||
index, max_rollback=max_rollback_tokens))
|
||||
|
||||
def allocate_token_bitmask(self, max_num_seqs: int) -> torch.Tensor:
|
||||
return torch.full(
|
||||
(max_num_seqs, (self.vocab_size + 31) // 32),
|
||||
-1,
|
||||
dtype=torch.int32,
|
||||
pin_memory=torch.cuda.is_available(),
|
||||
)
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutlinesGrammar(StructuredOutputGrammar):
|
||||
|
||||
vocab_size: int
|
||||
guide: oc.Guide = field(hash=False)
|
||||
num_processed_tokens: int = field(default_factory=lambda: 0,
|
||||
repr=False,
|
||||
hash=False,
|
||||
init=False)
|
||||
|
||||
# outlines_core signals done on DFA accept; vLLM expects done after EOS.
|
||||
# We delay the finished flag by one step so EOS can still be emitted.
|
||||
_prev_finished: bool = field(default=False,
|
||||
init=False,
|
||||
repr=False,
|
||||
hash=False)
|
||||
|
||||
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
|
||||
"""Accepts a list of tokens and advances the FSM.
|
||||
|
||||
Returns True if the FSM was advanced successfully.
|
||||
Returns False if the FSM failed to advance.
|
||||
"""
|
||||
if self.guide.accepts_tokens(tokens):
|
||||
# Advance cannot fail because we checked Guide.accepts_tokens()
|
||||
for t in tokens:
|
||||
self.guide.advance(t)
|
||||
self.num_processed_tokens += 1
|
||||
return True
|
||||
return False
|
||||
|
||||
def rollback(self, num_tokens: int) -> None:
|
||||
self.guide.rollback_state(num_tokens)
|
||||
self.num_processed_tokens -= num_tokens
|
||||
|
||||
def validate_tokens(self, tokens: list[int]) -> list[int]:
|
||||
accepted: list[int] = []
|
||||
for tok in tokens:
|
||||
accepted.append(tok)
|
||||
if not self.guide.accepts_tokens(accepted):
|
||||
accepted.pop()
|
||||
break
|
||||
return accepted
|
||||
|
||||
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
|
||||
mask = bitmask[idx]
|
||||
self.guide.write_mask_into(mask.data_ptr(), mask.numel(),
|
||||
mask.element_size())
|
||||
|
||||
def is_terminated(self) -> bool:
|
||||
curr = self.guide.is_finished()
|
||||
prev = self._prev_finished
|
||||
self._prev_finished = curr
|
||||
return prev
|
||||
|
||||
def reset(self):
|
||||
self.num_processed_tokens = 0
|
||||
self._prev_finished = False
|
||||
self.guide.reset()
|
||||
|
||||
|
||||
def validate_structured_output_request_outlines(params: SamplingParams):
|
||||
if params.guided_decoding is None:
|
||||
return
|
||||
|
||||
gd_params = params.guided_decoding
|
||||
|
||||
if gd_params.regex:
|
||||
validate_regex_is_buildable(gd_params.regex)
|
||||
elif gd_params.json:
|
||||
if isinstance(gd_params.json, str):
|
||||
try:
|
||||
# make sure schema is valid json
|
||||
json.loads(gd_params.json)
|
||||
schema = gd_params.json
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError("Invalid JSON grammar specification.") from e
|
||||
else:
|
||||
try:
|
||||
schema = json.dumps(gd_params.json)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Error serializing guided decoding jsonschema: {e}"
|
||||
) from e
|
||||
pattern = json_schema.build_regex_from_schema(schema)
|
||||
validate_regex_is_buildable(pattern)
|
||||
elif gd_params.choice:
|
||||
choices = [regex_escape(str(choice)) for choice in gd_params.choice]
|
||||
regex = "(" + "|".join(choices) + ")"
|
||||
validate_regex_is_buildable(regex)
|
||||
elif gd_params.grammar:
|
||||
raise ValueError("Outlines guided decoding backend "
|
||||
"does not support grammar specifications")
|
||||
|
||||
|
||||
def _prefix_needs_context(parsed) -> bool:
|
||||
"""Return True if there's a look-around/anchor before any consumer."""
|
||||
|
||||
def subpattern_consumes(parsed) -> bool:
|
||||
"""Return True if subpattern can consume at least one character."""
|
||||
tokens = parsed.data if hasattr(parsed, 'data') else parsed
|
||||
for ttype, tval in tokens:
|
||||
# literal, character class, or dot always consumes
|
||||
if ttype in (sre_parse.LITERAL, sre_parse.IN, sre_parse.ANY):
|
||||
return True
|
||||
# quantified subpattern: check inner pattern
|
||||
elif ttype == sre_parse.MAX_REPEAT:
|
||||
_, mx, sub = tval
|
||||
if mx != 0 and subpattern_consumes(sub):
|
||||
return True
|
||||
# alternation: if any branch consumes, the whole does
|
||||
elif ttype == sre_parse.BRANCH:
|
||||
_, branches = tval
|
||||
if any(subpattern_consumes(br) for br in branches):
|
||||
return True
|
||||
# grouped subpattern: recurse into its contents
|
||||
elif ttype == sre_parse.SUBPATTERN and subpattern_consumes(
|
||||
tval[3]):
|
||||
return True
|
||||
# No consumers, return False
|
||||
return False
|
||||
|
||||
tokens = parsed.data if hasattr(parsed, 'data') else parsed
|
||||
for ttype, tval in tokens:
|
||||
# Direct anchors or look-around
|
||||
if ttype == sre_parse.AT or ttype in (sre_constants.ASSERT,
|
||||
sre_constants.ASSERT_NOT):
|
||||
return True
|
||||
|
||||
# Nested subpattern: check
|
||||
if ttype == sre_parse.SUBPATTERN:
|
||||
# tval: (group, add_flags, del_flags, subpattern)
|
||||
if _prefix_needs_context(tval[3]):
|
||||
return True
|
||||
if subpattern_consumes(tval[3]):
|
||||
return False
|
||||
|
||||
# if any branch has a prefix anchor => True,
|
||||
# else if at least one branch consumes => prefix ends => False
|
||||
elif ttype == sre_parse.BRANCH:
|
||||
saw_consumer = False
|
||||
for br in tval[1]:
|
||||
if _prefix_needs_context(br):
|
||||
return True
|
||||
if subpattern_consumes(br):
|
||||
saw_consumer = True
|
||||
if saw_consumer:
|
||||
return False
|
||||
|
||||
# Immediate consumer tokens
|
||||
elif ttype in (sre_parse.LITERAL, sre_parse.IN, sre_parse.ANY):
|
||||
return False
|
||||
|
||||
# if subpattern has anchor => True, if it can consume => stop
|
||||
elif ttype == sre_parse.MAX_REPEAT:
|
||||
if _prefix_needs_context(tval[2]):
|
||||
return True
|
||||
if subpattern_consumes(tval[2]):
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _check_unsupported(parsed) -> None:
|
||||
"""Check for regex features unsupported by regex-automata"""
|
||||
tokens = parsed.data if hasattr(parsed, 'data') else parsed
|
||||
for ttype, tval in tokens:
|
||||
|
||||
# backreference
|
||||
if ttype in (sre_parse.GROUPREF, sre_parse.GROUPREF_EXISTS):
|
||||
raise ValueError("Backreferences are unsupported.")
|
||||
|
||||
# look-around assertion
|
||||
elif ttype in (sre_constants.ASSERT, sre_constants.ASSERT_NOT):
|
||||
raise ValueError("Look-Around assertion are unsupported.")
|
||||
|
||||
# unicode word boundaries
|
||||
elif ttype == sre_parse.AT:
|
||||
if tval in (sre_constants.AT_BOUNDARY,
|
||||
sre_constants.AT_NON_BOUNDARY):
|
||||
raise ValueError("Unicode word boundaries are unsupported.")
|
||||
|
||||
elif ttype == sre_parse.BRANCH:
|
||||
# tval is (None, branches)
|
||||
for branch in tval[1]:
|
||||
_check_unsupported(branch)
|
||||
|
||||
# tval is (min, max, subpattern)
|
||||
elif ttype == sre_parse.MAX_REPEAT:
|
||||
_check_unsupported(tval[2])
|
||||
|
||||
|
||||
def validate_regex_is_buildable(pattern: str) -> None:
|
||||
"""
|
||||
Validates that the input regex is not using unsupported features
|
||||
of the `regex-automata` crate (outlines_core regex engine) and has a
|
||||
universal start state.
|
||||
definition of universal start state used can be found at:
|
||||
https://docs.rs/regex-automata/latest/regex_automata/dfa/trait.Automaton.html#method.universal_start_state
|
||||
"""
|
||||
try:
|
||||
parsed = sre_parse.parse(pattern)
|
||||
|
||||
except sre_constants.error as e:
|
||||
raise ValueError(f"Error parsing regex: {e}") from e
|
||||
|
||||
try:
|
||||
_check_unsupported(parsed)
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"Regex uses unsupported feature for guided decoding: {e}. "
|
||||
"Only basic matching constructs are supported—lookarounds, "
|
||||
"backreferences, and unicode boundaries are not.") from e
|
||||
|
||||
if _prefix_needs_context(parsed):
|
||||
raise ValueError(
|
||||
"Regex does not have a anchored universal start state"
|
||||
"This means that the Regex uses anchors (^) or look-arounds "
|
||||
"in a way which requires context before any token is matched."
|
||||
"Guided decoding needs regexes that can match without needing "
|
||||
"that context. Try rewriting the pattern without using these "
|
||||
f"constructs. Pattern:\n{pattern}")
|
||||
Loading…
x
Reference in New Issue
Block a user