mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-08 06:09:10 +08:00
[V0][V1][Core] Add outlines integration for V1, and update V0 integration. (#15975)
Signed-off-by: Nathan Hoos <thwackyy.y@gmail.com>
This commit is contained in:
parent
5e53c89a74
commit
d6902ce79f
@ -21,7 +21,9 @@ prometheus-fastapi-instrumentator >= 7.0.0
|
|||||||
tiktoken >= 0.6.0 # Required for DBRX tokenizer
|
tiktoken >= 0.6.0 # Required for DBRX tokenizer
|
||||||
lm-format-enforcer >= 0.10.11, < 0.11
|
lm-format-enforcer >= 0.10.11, < 0.11
|
||||||
llguidance >= 0.7.11, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64"
|
llguidance >= 0.7.11, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64"
|
||||||
outlines == 0.1.11
|
outlines_core == 0.2.10
|
||||||
|
# required for outlines backend disk cache
|
||||||
|
diskcache == 5.6.3
|
||||||
lark == 1.2.2
|
lark == 1.2.2
|
||||||
xgrammar == 0.1.19; platform_machine == "x86_64" or platform_machine == "aarch64" or platform_machine == "arm64"
|
xgrammar == 0.1.19; platform_machine == "x86_64" or platform_machine == "aarch64" or platform_machine == "arm64"
|
||||||
typing_extensions >= 4.10
|
typing_extensions >= 4.10
|
||||||
|
|||||||
@ -16,14 +16,18 @@ from vllm.outputs import RequestOutput
|
|||||||
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||||
|
|
||||||
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
|
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||||
GUIDED_DECODING_BACKENDS = [
|
|
||||||
|
# Separate backends which support grammars vs ones
|
||||||
|
# which only support regex based constraints in tests.
|
||||||
|
GRAMMAR_DECODING_BACKENDS = [
|
||||||
# (backend, disable_any_whitespace),
|
# (backend, disable_any_whitespace),
|
||||||
("outlines", False),
|
|
||||||
("lm-format-enforcer", False),
|
("lm-format-enforcer", False),
|
||||||
("xgrammar", True),
|
("xgrammar", True),
|
||||||
("guidance", True),
|
("guidance", True),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
ALL_DECODING_BACKENDS = ([("outlines", False)] + GRAMMAR_DECODING_BACKENDS)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def llm():
|
def llm():
|
||||||
@ -39,7 +43,7 @@ def llm():
|
|||||||
|
|
||||||
@pytest.mark.skip_global_cleanup
|
@pytest.mark.skip_global_cleanup
|
||||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||||
GUIDED_DECODING_BACKENDS)
|
ALL_DECODING_BACKENDS)
|
||||||
def test_guided_regex(sample_regex, llm, guided_decoding_backend: str,
|
def test_guided_regex(sample_regex, llm, guided_decoding_backend: str,
|
||||||
disable_any_whitespace: bool):
|
disable_any_whitespace: bool):
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
@ -49,6 +53,7 @@ def test_guided_regex(sample_regex, llm, guided_decoding_backend: str,
|
|||||||
regex=sample_regex,
|
regex=sample_regex,
|
||||||
backend=guided_decoding_backend,
|
backend=guided_decoding_backend,
|
||||||
disable_any_whitespace=disable_any_whitespace))
|
disable_any_whitespace=disable_any_whitespace))
|
||||||
|
|
||||||
outputs = llm.generate(prompts=[
|
outputs = llm.generate(prompts=[
|
||||||
f"Give an example IPv4 address with this regex: {sample_regex}"
|
f"Give an example IPv4 address with this regex: {sample_regex}"
|
||||||
] * 2,
|
] * 2,
|
||||||
@ -69,7 +74,7 @@ def test_guided_regex(sample_regex, llm, guided_decoding_backend: str,
|
|||||||
|
|
||||||
@pytest.mark.skip_global_cleanup
|
@pytest.mark.skip_global_cleanup
|
||||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||||
GUIDED_DECODING_BACKENDS)
|
ALL_DECODING_BACKENDS)
|
||||||
def test_guided_json_completion(sample_json_schema, llm,
|
def test_guided_json_completion(sample_json_schema, llm,
|
||||||
guided_decoding_backend: str,
|
guided_decoding_backend: str,
|
||||||
disable_any_whitespace: bool):
|
disable_any_whitespace: bool):
|
||||||
@ -103,7 +108,7 @@ def test_guided_json_completion(sample_json_schema, llm,
|
|||||||
|
|
||||||
@pytest.mark.skip_global_cleanup
|
@pytest.mark.skip_global_cleanup
|
||||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||||
GUIDED_DECODING_BACKENDS)
|
ALL_DECODING_BACKENDS)
|
||||||
def test_guided_complex_json_completion(sample_complex_json_schema, llm,
|
def test_guided_complex_json_completion(sample_complex_json_schema, llm,
|
||||||
guided_decoding_backend: str,
|
guided_decoding_backend: str,
|
||||||
disable_any_whitespace: bool):
|
disable_any_whitespace: bool):
|
||||||
@ -138,7 +143,7 @@ def test_guided_complex_json_completion(sample_complex_json_schema, llm,
|
|||||||
|
|
||||||
@pytest.mark.skip_global_cleanup
|
@pytest.mark.skip_global_cleanup
|
||||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||||
GUIDED_DECODING_BACKENDS)
|
ALL_DECODING_BACKENDS)
|
||||||
def test_guided_definition_json_completion(sample_definition_json_schema, llm,
|
def test_guided_definition_json_completion(sample_definition_json_schema, llm,
|
||||||
guided_decoding_backend: str,
|
guided_decoding_backend: str,
|
||||||
disable_any_whitespace: bool):
|
disable_any_whitespace: bool):
|
||||||
@ -173,7 +178,7 @@ def test_guided_definition_json_completion(sample_definition_json_schema, llm,
|
|||||||
|
|
||||||
@pytest.mark.skip_global_cleanup
|
@pytest.mark.skip_global_cleanup
|
||||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||||
GUIDED_DECODING_BACKENDS)
|
ALL_DECODING_BACKENDS)
|
||||||
def test_guided_enum_json_completion(sample_enum_json_schema, llm,
|
def test_guided_enum_json_completion(sample_enum_json_schema, llm,
|
||||||
guided_decoding_backend: str,
|
guided_decoding_backend: str,
|
||||||
disable_any_whitespace: bool):
|
disable_any_whitespace: bool):
|
||||||
@ -218,7 +223,7 @@ def test_guided_enum_json_completion(sample_enum_json_schema, llm,
|
|||||||
|
|
||||||
@pytest.mark.skip_global_cleanup
|
@pytest.mark.skip_global_cleanup
|
||||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||||
GUIDED_DECODING_BACKENDS)
|
ALL_DECODING_BACKENDS)
|
||||||
def test_guided_choice_completion(sample_guided_choice, llm,
|
def test_guided_choice_completion(sample_guided_choice, llm,
|
||||||
guided_decoding_backend: str,
|
guided_decoding_backend: str,
|
||||||
disable_any_whitespace: bool):
|
disable_any_whitespace: bool):
|
||||||
@ -248,7 +253,7 @@ def test_guided_choice_completion(sample_guided_choice, llm,
|
|||||||
|
|
||||||
@pytest.mark.skip_global_cleanup
|
@pytest.mark.skip_global_cleanup
|
||||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||||
GUIDED_DECODING_BACKENDS)
|
GRAMMAR_DECODING_BACKENDS)
|
||||||
def test_guided_grammar(sample_sql_statements, llm,
|
def test_guided_grammar(sample_sql_statements, llm,
|
||||||
guided_decoding_backend: str,
|
guided_decoding_backend: str,
|
||||||
disable_any_whitespace: bool):
|
disable_any_whitespace: bool):
|
||||||
@ -344,7 +349,7 @@ def test_disable_guided_decoding_fallback(sample_regex, llm):
|
|||||||
|
|
||||||
@pytest.mark.skip_global_cleanup
|
@pytest.mark.skip_global_cleanup
|
||||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||||
GUIDED_DECODING_BACKENDS)
|
GRAMMAR_DECODING_BACKENDS)
|
||||||
def test_guided_json_object(llm, guided_decoding_backend: str,
|
def test_guided_json_object(llm, guided_decoding_backend: str,
|
||||||
disable_any_whitespace: bool):
|
disable_any_whitespace: bool):
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
@ -377,7 +382,9 @@ def test_guided_json_object(llm, guided_decoding_backend: str,
|
|||||||
|
|
||||||
# Parse to verify it is valid JSON
|
# Parse to verify it is valid JSON
|
||||||
parsed_json = json.loads(generated_text)
|
parsed_json = json.loads(generated_text)
|
||||||
assert isinstance(parsed_json, dict)
|
# A list is not what was intended, but is still valid
|
||||||
|
# json.
|
||||||
|
assert isinstance(parsed_json, (dict, list))
|
||||||
|
|
||||||
|
|
||||||
class CarType(str, Enum):
|
class CarType(str, Enum):
|
||||||
@ -395,7 +402,7 @@ class CarDescription(BaseModel):
|
|||||||
|
|
||||||
@pytest.mark.skip_global_cleanup
|
@pytest.mark.skip_global_cleanup
|
||||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||||
GUIDED_DECODING_BACKENDS)
|
ALL_DECODING_BACKENDS)
|
||||||
def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str,
|
def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str,
|
||||||
disable_any_whitespace: bool):
|
disable_any_whitespace: bool):
|
||||||
json_schema = CarDescription.model_json_schema()
|
json_schema = CarDescription.model_json_schema()
|
||||||
@ -427,7 +434,7 @@ def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str,
|
|||||||
|
|
||||||
@pytest.mark.skip_global_cleanup
|
@pytest.mark.skip_global_cleanup
|
||||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||||
GUIDED_DECODING_BACKENDS)
|
ALL_DECODING_BACKENDS)
|
||||||
def test_guided_number_range_json_completion(llm, guided_decoding_backend: str,
|
def test_guided_number_range_json_completion(llm, guided_decoding_backend: str,
|
||||||
disable_any_whitespace: bool):
|
disable_any_whitespace: bool):
|
||||||
sample_output_schema = {
|
sample_output_schema = {
|
||||||
|
|||||||
@ -46,20 +46,15 @@ def test_guided_logits_processors(zephyr_7B_tokenzer, sample_regex,
|
|||||||
whitespace_pattern=None,
|
whitespace_pattern=None,
|
||||||
reasoner=None)
|
reasoner=None)
|
||||||
|
|
||||||
token_ids = zephyr_7B_tokenzer.encode(
|
|
||||||
f"Give an example IPv4 address with this regex: {sample_regex}")
|
|
||||||
tensor = torch.rand(32000)
|
tensor = torch.rand(32000)
|
||||||
original_tensor = torch.clone(tensor)
|
original_tensor = torch.clone(tensor)
|
||||||
regex_LP(token_ids, tensor)
|
tensor = regex_LP([], tensor)
|
||||||
assert tensor.shape == original_tensor.shape
|
assert tensor.shape == original_tensor.shape
|
||||||
assert not torch.allclose(tensor, original_tensor)
|
assert not torch.allclose(tensor, original_tensor)
|
||||||
|
|
||||||
token_ids = zephyr_7B_tokenzer.encode(
|
|
||||||
f"Give an employee profile that fits this schema: {sample_json_schema}"
|
|
||||||
)
|
|
||||||
tensor = torch.rand(32000)
|
tensor = torch.rand(32000)
|
||||||
original_tensor = torch.clone(tensor)
|
original_tensor = torch.clone(tensor)
|
||||||
json_LP(token_ids, tensor)
|
tensor = json_LP([], tensor)
|
||||||
assert tensor.shape == original_tensor.shape
|
assert tensor.shape == original_tensor.shape
|
||||||
assert not torch.allclose(tensor, original_tensor)
|
assert not torch.allclose(tensor, original_tensor)
|
||||||
|
|
||||||
@ -81,8 +76,6 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
|
|||||||
seed=0,
|
seed=0,
|
||||||
dtype="bfloat16",
|
dtype="bfloat16",
|
||||||
)
|
)
|
||||||
token_ids = zephyr_7B_tokenzer.encode(
|
|
||||||
f"Give an example IPv4 address with this regex: {sample_regex}")
|
|
||||||
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
|
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
|
||||||
|
|
||||||
regex_lp = get_local_guided_decoding_logits_processor(
|
regex_lp = get_local_guided_decoding_logits_processor(
|
||||||
@ -92,13 +85,11 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
|
|||||||
assert regex_lp is not None
|
assert regex_lp is not None
|
||||||
tensor = torch.rand(32000)
|
tensor = torch.rand(32000)
|
||||||
original_tensor = torch.clone(tensor)
|
original_tensor = torch.clone(tensor)
|
||||||
tensor = regex_lp(token_ids, tensor)
|
# allowed tokens at state 0
|
||||||
|
tensor = regex_lp([], tensor)
|
||||||
assert tensor.shape == original_tensor.shape
|
assert tensor.shape == original_tensor.shape
|
||||||
assert not torch.allclose(tensor, original_tensor)
|
assert not torch.allclose(tensor, original_tensor)
|
||||||
|
|
||||||
token_ids = zephyr_7B_tokenzer.encode(
|
|
||||||
f"Give an employee profile that fits this schema: {sample_json_schema}"
|
|
||||||
)
|
|
||||||
json_request = GuidedDecodingParams(json=sample_json_schema,
|
json_request = GuidedDecodingParams(json=sample_json_schema,
|
||||||
backend=backend)
|
backend=backend)
|
||||||
json_lp = await get_guided_decoding_logits_processor(
|
json_lp = await get_guided_decoding_logits_processor(
|
||||||
@ -106,7 +97,7 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
|
|||||||
assert json_lp is not None
|
assert json_lp is not None
|
||||||
tensor = torch.rand(32000)
|
tensor = torch.rand(32000)
|
||||||
original_tensor = torch.clone(tensor)
|
original_tensor = torch.clone(tensor)
|
||||||
tensor = json_lp(token_ids, tensor)
|
tensor = json_lp([], tensor)
|
||||||
assert tensor.shape == original_tensor.shape
|
assert tensor.shape == original_tensor.shape
|
||||||
assert not torch.allclose(tensor, original_tensor)
|
assert not torch.allclose(tensor, original_tensor)
|
||||||
|
|
||||||
@ -130,7 +121,6 @@ async def test_guided_logits_processor_with_reasoning(
|
|||||||
dtype="bfloat16",
|
dtype="bfloat16",
|
||||||
)
|
)
|
||||||
token_ids = deepseek_r1_qwen_tokenizer.encode(
|
token_ids = deepseek_r1_qwen_tokenizer.encode(
|
||||||
f"Give an example IPv4 address with this regex: {sample_regex}."
|
|
||||||
"<think>here is the thinking process")
|
"<think>here is the thinking process")
|
||||||
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
|
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
|
||||||
|
|
||||||
@ -141,14 +131,13 @@ async def test_guided_logits_processor_with_reasoning(
|
|||||||
regex_request, deepseek_r1_qwen_tokenizer, config,
|
regex_request, deepseek_r1_qwen_tokenizer, config,
|
||||||
reasoning_backend)
|
reasoning_backend)
|
||||||
assert regex_lp is not None
|
assert regex_lp is not None
|
||||||
tensor = torch.rand(32000)
|
tensor = torch.rand(151664)
|
||||||
original_tensor = torch.clone(tensor)
|
original_tensor = torch.clone(tensor)
|
||||||
tensor = regex_lp(token_ids, tensor)
|
tensor = regex_lp(token_ids, tensor)
|
||||||
assert tensor.shape == original_tensor.shape
|
assert tensor.shape == original_tensor.shape
|
||||||
assert torch.allclose(tensor, original_tensor)
|
assert torch.allclose(tensor, original_tensor)
|
||||||
|
|
||||||
token_ids = deepseek_r1_qwen_tokenizer.encode(
|
token_ids = deepseek_r1_qwen_tokenizer.encode(
|
||||||
f"Give an employee profile that fits this schema: {sample_json_schema}."
|
|
||||||
"<think>here is the thinking process")
|
"<think>here is the thinking process")
|
||||||
json_request = GuidedDecodingParams(json=sample_json_schema,
|
json_request = GuidedDecodingParams(json=sample_json_schema,
|
||||||
backend=backend)
|
backend=backend)
|
||||||
@ -158,7 +147,7 @@ async def test_guided_logits_processor_with_reasoning(
|
|||||||
await get_guided_decoding_logits_processor(
|
await get_guided_decoding_logits_processor(
|
||||||
json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend)
|
json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend)
|
||||||
assert json_lp is not None
|
assert json_lp is not None
|
||||||
tensor = torch.rand(32000)
|
tensor = torch.rand(151664)
|
||||||
original_tensor = torch.clone(tensor)
|
original_tensor = torch.clone(tensor)
|
||||||
tensor = json_lp(token_ids, tensor)
|
tensor = json_lp(token_ids, tensor)
|
||||||
assert tensor.shape == original_tensor.shape
|
assert tensor.shape == original_tensor.shape
|
||||||
@ -166,8 +155,7 @@ async def test_guided_logits_processor_with_reasoning(
|
|||||||
|
|
||||||
# Thinking is over, so the tensor should change.
|
# Thinking is over, so the tensor should change.
|
||||||
token_ids = deepseek_r1_qwen_tokenizer.encode(
|
token_ids = deepseek_r1_qwen_tokenizer.encode(
|
||||||
f"Give an employee profile that fits this schema: {sample_json_schema}."
|
"<think>here is the thinking process</think>")
|
||||||
"<think>here is the thinking process</think> Then")
|
|
||||||
json_request = GuidedDecodingParams(json=sample_json_schema,
|
json_request = GuidedDecodingParams(json=sample_json_schema,
|
||||||
backend=backend)
|
backend=backend)
|
||||||
json_lp = get_local_guided_decoding_logits_processor(
|
json_lp = get_local_guided_decoding_logits_processor(
|
||||||
@ -176,7 +164,7 @@ async def test_guided_logits_processor_with_reasoning(
|
|||||||
await get_guided_decoding_logits_processor(
|
await get_guided_decoding_logits_processor(
|
||||||
json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend)
|
json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend)
|
||||||
assert json_lp is not None
|
assert json_lp is not None
|
||||||
tensor = torch.rand(32000)
|
tensor = torch.rand(151664)
|
||||||
original_tensor = torch.clone(tensor)
|
original_tensor = torch.clone(tensor)
|
||||||
tensor = json_lp(token_ids, tensor)
|
tensor = json_lp(token_ids, tensor)
|
||||||
assert tensor.shape == original_tensor.shape
|
assert tensor.shape == original_tensor.shape
|
||||||
|
|||||||
@ -72,7 +72,7 @@ def _compile_and_check(tools: list[ChatCompletionToolsParam], sample_output,
|
|||||||
assert isinstance(schema, dict)
|
assert isinstance(schema, dict)
|
||||||
|
|
||||||
# use build_regex_from_schema used in JSONLogitsProcessor to create Guide
|
# use build_regex_from_schema used in JSONLogitsProcessor to create Guide
|
||||||
from outlines_core.fsm.json_schema import build_regex_from_schema
|
from outlines_core.json_schema import build_regex_from_schema
|
||||||
regex = build_regex_from_schema(json.dumps(schema))
|
regex = build_regex_from_schema(json.dumps(schema))
|
||||||
compiled = re.compile(regex)
|
compiled = re.compile(regex)
|
||||||
matches = compiled.fullmatch(json.dumps(sample_output)) is not None
|
matches = compiled.fullmatch(json.dumps(sample_output)) is not None
|
||||||
|
|||||||
@ -41,6 +41,10 @@ PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [
|
|||||||
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None),
|
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None),
|
||||||
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None),
|
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None),
|
||||||
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None),
|
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None),
|
||||||
|
("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", None),
|
||||||
|
("mistralai/Ministral-8B-Instruct-2410", "outlines", "mistral", None),
|
||||||
|
("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto",
|
||||||
|
NGRAM_SPEC_CONFIG),
|
||||||
#FIXME: This test is flaky on CI thus disabled
|
#FIXME: This test is flaky on CI thus disabled
|
||||||
#("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"),
|
#("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"),
|
||||||
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto",
|
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto",
|
||||||
@ -106,13 +110,15 @@ def test_structured_output(
|
|||||||
enforce_eager = bool(not current_platform.is_tpu())
|
enforce_eager = bool(not current_platform.is_tpu())
|
||||||
# Use a single LLM instance for several scenarios to
|
# Use a single LLM instance for several scenarios to
|
||||||
# speed up the test suite.
|
# speed up the test suite.
|
||||||
llm = LLM(model=model_name,
|
llm = LLM(
|
||||||
enforce_eager=enforce_eager,
|
model=model_name,
|
||||||
max_model_len=1024,
|
enforce_eager=enforce_eager,
|
||||||
guided_decoding_backend=guided_decoding_backend,
|
max_model_len=1024,
|
||||||
guided_decoding_disable_any_whitespace=True,
|
guided_decoding_backend=guided_decoding_backend,
|
||||||
tokenizer_mode=tokenizer_mode,
|
guided_decoding_disable_any_whitespace=(guided_decoding_backend
|
||||||
speculative_config=speculative_config)
|
in {"xgrammar", "guidance"}),
|
||||||
|
tokenizer_mode=tokenizer_mode,
|
||||||
|
speculative_config=speculative_config)
|
||||||
|
|
||||||
#
|
#
|
||||||
# Test 1: Generate JSON output based on a provided schema
|
# Test 1: Generate JSON output based on a provided schema
|
||||||
@ -146,32 +152,33 @@ def test_structured_output(
|
|||||||
#
|
#
|
||||||
# Test 2: Generate JSON object without a schema
|
# Test 2: Generate JSON object without a schema
|
||||||
#
|
#
|
||||||
sampling_params = SamplingParams(
|
if guided_decoding_backend != "outlines":
|
||||||
temperature=1.0,
|
sampling_params = SamplingParams(
|
||||||
max_tokens=4096,
|
temperature=1.0,
|
||||||
n=2,
|
max_tokens=4096,
|
||||||
guided_decoding=GuidedDecodingParams(json_object=True))
|
n=2,
|
||||||
|
guided_decoding=GuidedDecodingParams(json_object=True))
|
||||||
|
|
||||||
outputs = llm.generate(
|
outputs = llm.generate(prompts=(
|
||||||
prompts=("Generate a JSON object with curly braces for a person with "
|
"Generate a JSON object with curly braces for a person with "
|
||||||
"name and age fields for John Smith who is 31 years old. "
|
"name and age fields for John Smith who is 31 years old. "
|
||||||
"Make the response as short as possible."),
|
"Make the response as short as possible."),
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
use_tqdm=True)
|
use_tqdm=True)
|
||||||
|
|
||||||
assert outputs is not None
|
assert outputs is not None
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
assert output is not None
|
assert output is not None
|
||||||
assert isinstance(output, RequestOutput)
|
assert isinstance(output, RequestOutput)
|
||||||
|
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
generated_text = output.outputs[i].text
|
generated_text = output.outputs[i].text
|
||||||
print(generated_text)
|
print(generated_text)
|
||||||
assert generated_text is not None
|
assert generated_text is not None
|
||||||
|
|
||||||
# Parse to verify it is a valid JSON object
|
# Parse to verify it is a valid JSON object
|
||||||
parsed_json = json.loads(generated_text)
|
parsed_json = json.loads(generated_text)
|
||||||
assert isinstance(parsed_json, dict)
|
assert isinstance(parsed_json, dict)
|
||||||
|
|
||||||
#
|
#
|
||||||
# Test 3: test a jsonschema incompatible with xgrammar
|
# Test 3: test a jsonschema incompatible with xgrammar
|
||||||
@ -210,97 +217,98 @@ def test_structured_output(
|
|||||||
parsed_json = json.loads(generated_text)
|
parsed_json = json.loads(generated_text)
|
||||||
assert isinstance(parsed_json, dict)
|
assert isinstance(parsed_json, dict)
|
||||||
|
|
||||||
#
|
if guided_decoding_backend != "outlines":
|
||||||
# Test 4: Generate SQL statement using EBNF grammar
|
#
|
||||||
#
|
# Test 4: Generate SQL statement using EBNF grammar
|
||||||
sampling_params = SamplingParams(
|
#
|
||||||
temperature=0.8,
|
sampling_params = SamplingParams(
|
||||||
top_p=0.95,
|
temperature=0.8,
|
||||||
max_tokens=1000,
|
top_p=0.95,
|
||||||
guided_decoding=GuidedDecodingParams(grammar=sample_sql_ebnf))
|
max_tokens=1000,
|
||||||
outputs = llm.generate(
|
guided_decoding=GuidedDecodingParams(grammar=sample_sql_ebnf))
|
||||||
prompts=(
|
outputs = llm.generate(
|
||||||
"Generate a sql statement that selects col_1 from "
|
|
||||||
"table_1 where it is equal to 1. Make the response as short as "
|
|
||||||
"possible."),
|
|
||||||
sampling_params=sampling_params,
|
|
||||||
use_tqdm=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert outputs is not None
|
|
||||||
for output in outputs:
|
|
||||||
assert output is not None
|
|
||||||
assert isinstance(output, RequestOutput)
|
|
||||||
prompt = output.prompt
|
|
||||||
|
|
||||||
generated_text = output.outputs[0].text
|
|
||||||
assert generated_text is not None
|
|
||||||
|
|
||||||
# remove spaces for comparison b/c we removed them in the grammar
|
|
||||||
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
|
|
||||||
" ", "")
|
|
||||||
|
|
||||||
assert generated_text.strip() == ground_truth
|
|
||||||
|
|
||||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
|
||||||
|
|
||||||
#
|
|
||||||
# Test 5: Generate SQL statement using Lark grammar
|
|
||||||
#
|
|
||||||
sampling_params = SamplingParams(
|
|
||||||
temperature=0.8,
|
|
||||||
top_p=0.95,
|
|
||||||
max_tokens=1000,
|
|
||||||
guided_decoding=GuidedDecodingParams(grammar=sample_sql_lark))
|
|
||||||
outputs = llm.generate(
|
|
||||||
prompts=(
|
|
||||||
"Generate a sql statement that selects col_1 from "
|
|
||||||
"table_1 where it is equal to 1. Make the response as short as "
|
|
||||||
"possible."),
|
|
||||||
sampling_params=sampling_params,
|
|
||||||
use_tqdm=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert outputs is not None
|
|
||||||
for output in outputs:
|
|
||||||
assert output is not None
|
|
||||||
assert isinstance(output, RequestOutput)
|
|
||||||
prompt = output.prompt
|
|
||||||
|
|
||||||
generated_text = output.outputs[0].text
|
|
||||||
assert generated_text is not None
|
|
||||||
|
|
||||||
# use Lark to parse the output, and make sure it's a valid parse tree
|
|
||||||
from lark import Lark
|
|
||||||
parser = Lark(sample_sql_lark)
|
|
||||||
parser.parse(generated_text)
|
|
||||||
|
|
||||||
# remove spaces for comparison b/c we removed them in the grammar
|
|
||||||
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
|
|
||||||
" ", "")
|
|
||||||
|
|
||||||
assert generated_text.strip() == ground_truth
|
|
||||||
|
|
||||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
|
||||||
|
|
||||||
#
|
|
||||||
# Test 6: Test invalid grammar input
|
|
||||||
#
|
|
||||||
sampling_params = SamplingParams(
|
|
||||||
temperature=0.8,
|
|
||||||
top_p=0.95,
|
|
||||||
max_tokens=1000,
|
|
||||||
guided_decoding=GuidedDecodingParams(grammar="not a grammar"))
|
|
||||||
with pytest.raises(ValueError, match="Failed to convert the grammar "):
|
|
||||||
llm.generate(
|
|
||||||
prompts=(
|
prompts=(
|
||||||
"Generate a sql statement that selects col_1 from "
|
"Generate a sql statement that selects col_1 from "
|
||||||
"table_1 where it is equal to 1. Make the response as short "
|
"table_1 where it is equal to 1. Make the response as short as "
|
||||||
"as possible."),
|
"possible."),
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
use_tqdm=True,
|
use_tqdm=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert outputs is not None
|
||||||
|
for output in outputs:
|
||||||
|
assert output is not None
|
||||||
|
assert isinstance(output, RequestOutput)
|
||||||
|
prompt = output.prompt
|
||||||
|
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
assert generated_text is not None
|
||||||
|
|
||||||
|
# remove spaces for comparison b/c we removed them in the grammar
|
||||||
|
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
|
||||||
|
" ", "")
|
||||||
|
|
||||||
|
assert generated_text.strip() == ground_truth
|
||||||
|
|
||||||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
|
||||||
|
#
|
||||||
|
# Test 5: Generate SQL statement using Lark grammar
|
||||||
|
#
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=0.8,
|
||||||
|
top_p=0.95,
|
||||||
|
max_tokens=1000,
|
||||||
|
guided_decoding=GuidedDecodingParams(grammar=sample_sql_lark))
|
||||||
|
outputs = llm.generate(
|
||||||
|
prompts=(
|
||||||
|
"Generate a sql statement that selects col_1 from "
|
||||||
|
"table_1 where it is equal to 1. Make the response as short as "
|
||||||
|
"possible."),
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
use_tqdm=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert outputs is not None
|
||||||
|
for output in outputs:
|
||||||
|
assert output is not None
|
||||||
|
assert isinstance(output, RequestOutput)
|
||||||
|
prompt = output.prompt
|
||||||
|
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
assert generated_text is not None
|
||||||
|
|
||||||
|
# use Lark to parse the output, and make sure it's a valid parse tree
|
||||||
|
from lark import Lark
|
||||||
|
parser = Lark(sample_sql_lark)
|
||||||
|
parser.parse(generated_text)
|
||||||
|
|
||||||
|
# remove spaces for comparison b/c we removed them in the grammar
|
||||||
|
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
|
||||||
|
" ", "")
|
||||||
|
|
||||||
|
assert generated_text.strip() == ground_truth
|
||||||
|
|
||||||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
|
||||||
|
#
|
||||||
|
# Test 6: Test invalid grammar input
|
||||||
|
#
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=0.8,
|
||||||
|
top_p=0.95,
|
||||||
|
max_tokens=1000,
|
||||||
|
guided_decoding=GuidedDecodingParams(grammar="not a grammar"))
|
||||||
|
with pytest.raises(ValueError, match="Failed to convert the grammar "):
|
||||||
|
llm.generate(
|
||||||
|
prompts=
|
||||||
|
("Generate a sql statement that selects col_1 from "
|
||||||
|
"table_1 where it is equal to 1. Make the response as short "
|
||||||
|
"as possible."),
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
use_tqdm=True,
|
||||||
|
)
|
||||||
|
|
||||||
#
|
#
|
||||||
# Test 7: Generate text based on a regex pattern
|
# Test 7: Generate text based on a regex pattern
|
||||||
#
|
#
|
||||||
@ -421,35 +429,36 @@ def test_structured_output(
|
|||||||
output_json = json.loads(generated_text)
|
output_json = json.loads(generated_text)
|
||||||
jsonschema.validate(instance=output_json, schema=json_schema)
|
jsonschema.validate(instance=output_json, schema=json_schema)
|
||||||
|
|
||||||
#
|
if guided_decoding_backend != "outlines":
|
||||||
# Test 11: Generate structured output using structural_tag format
|
#
|
||||||
#
|
# Test 11: Generate structured output using structural_tag format
|
||||||
structural_tag_config = {
|
#
|
||||||
"type":
|
structural_tag_config = {
|
||||||
"structural_tag",
|
"type":
|
||||||
"structures": [{
|
"structural_tag",
|
||||||
"begin": "<function=get_weather>",
|
"structures": [{
|
||||||
"schema": {
|
"begin": "<function=get_weather>",
|
||||||
"type": "object",
|
"schema": {
|
||||||
"properties": {
|
"type": "object",
|
||||||
"city": {
|
"properties": {
|
||||||
"type": "string"
|
"city": {
|
||||||
}
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": False
|
||||||
},
|
},
|
||||||
"additionalProperties": False
|
"end": "</function>"
|
||||||
},
|
}],
|
||||||
"end": "</function>"
|
"triggers": ["<function="]
|
||||||
}],
|
}
|
||||||
"triggers": ["<function="]
|
|
||||||
}
|
|
||||||
|
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
max_tokens=4096,
|
max_tokens=4096,
|
||||||
guided_decoding=GuidedDecodingParams(
|
guided_decoding=GuidedDecodingParams(
|
||||||
structural_tag=json.dumps(structural_tag_config)))
|
structural_tag=json.dumps(structural_tag_config)))
|
||||||
|
|
||||||
prompt = """
|
prompt = """
|
||||||
You have access to the following function to retrieve the weather in a city:
|
You have access to the following function to retrieve the weather in a city:
|
||||||
|
|
||||||
{
|
{
|
||||||
@ -469,7 +478,7 @@ where
|
|||||||
|
|
||||||
start_tag => `<function`
|
start_tag => `<function`
|
||||||
parameters => a JSON dict with the function argument name
|
parameters => a JSON dict with the function argument name
|
||||||
as key and function argument value as value.
|
as key and function argument value as value.
|
||||||
end_tag => `</function>`
|
end_tag => `</function>`
|
||||||
|
|
||||||
Here is an example,
|
Here is an example,
|
||||||
@ -488,37 +497,37 @@ Given the previous instructions, what is the weather in New York City? \
|
|||||||
Make the response as short as possible.
|
Make the response as short as possible.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Change this once other backends support structural_tag
|
# Change this once other backends support structural_tag
|
||||||
outputs = llm.generate(prompts=prompt,
|
outputs = llm.generate(prompts=prompt,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
use_tqdm=True)
|
use_tqdm=True)
|
||||||
assert outputs is not None
|
assert outputs is not None
|
||||||
|
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
assert output is not None
|
assert output is not None
|
||||||
assert isinstance(output, RequestOutput)
|
assert isinstance(output, RequestOutput)
|
||||||
generated_text = output.outputs[0].text
|
generated_text = output.outputs[0].text
|
||||||
assert generated_text is not None
|
assert generated_text is not None
|
||||||
|
|
||||||
# Search for function call pattern in the response
|
# Search for function call pattern in the response
|
||||||
function_call_pattern = r'<function=get_weather>(.*?)</function>'
|
function_call_pattern = r'<function=get_weather>(.*?)</function>'
|
||||||
matches = re.findall(function_call_pattern, generated_text)
|
matches = re.findall(function_call_pattern, generated_text)
|
||||||
|
|
||||||
if not matches:
|
if not matches:
|
||||||
print(f"Warning: No function calls found in response: "
|
print(f"Warning: No function calls found in response: "
|
||||||
f"{generated_text!r}")
|
f"{generated_text!r}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Take the first function call if multiple are found
|
# Take the first function call if multiple are found
|
||||||
json_str = matches[0]
|
json_str = matches[0]
|
||||||
try:
|
try:
|
||||||
json_content = json.loads(json_str)
|
json_content = json.loads(json_str)
|
||||||
assert "city" in json_content
|
assert "city" in json_content
|
||||||
assert isinstance(json_content["city"], str)
|
assert isinstance(json_content["city"], str)
|
||||||
print(f"Found valid function call: {generated_text!r}")
|
print(f"Found valid function call: {generated_text!r}")
|
||||||
except (json.JSONDecodeError, AssertionError) as e:
|
except (json.JSONDecodeError, AssertionError) as e:
|
||||||
pytest.fail("Invalid function call format: "
|
pytest.fail("Invalid function call format: "
|
||||||
f"{generated_text!r}\nError: {str(e)}")
|
f"{generated_text!r}\nError: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip_global_cleanup
|
@pytest.mark.skip_global_cleanup
|
||||||
|
|||||||
@ -3580,7 +3580,8 @@ def get_served_model_name(model: str,
|
|||||||
|
|
||||||
GuidedDecodingBackendV0 = Literal["auto", "outlines", "lm-format-enforcer",
|
GuidedDecodingBackendV0 = Literal["auto", "outlines", "lm-format-enforcer",
|
||||||
"xgrammar", "guidance"]
|
"xgrammar", "guidance"]
|
||||||
GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance"]
|
|
||||||
|
GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance", "outlines"]
|
||||||
GuidedDecodingBackend = Literal[GuidedDecodingBackendV0,
|
GuidedDecodingBackend = Literal[GuidedDecodingBackendV0,
|
||||||
GuidedDecodingBackendV1]
|
GuidedDecodingBackendV1]
|
||||||
|
|
||||||
|
|||||||
@ -117,6 +117,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False
|
VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False
|
||||||
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
|
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
|
||||||
VLLM_V0_USE_OUTLINES_CACHE: bool = False
|
VLLM_V0_USE_OUTLINES_CACHE: bool = False
|
||||||
|
VLLM_V1_USE_OUTLINES_CACHE: bool = False
|
||||||
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
|
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
|
||||||
VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None
|
VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None
|
||||||
VLLM_USE_DEEP_GEMM: bool = False
|
VLLM_USE_DEEP_GEMM: bool = False
|
||||||
@ -847,6 +848,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_V0_USE_OUTLINES_CACHE":
|
"VLLM_V0_USE_OUTLINES_CACHE":
|
||||||
lambda: os.environ.get("VLLM_V0_USE_OUTLINES_CACHE", "0") == "1",
|
lambda: os.environ.get("VLLM_V0_USE_OUTLINES_CACHE", "0") == "1",
|
||||||
|
|
||||||
|
# Whether to turn on the outlines cache for V1
|
||||||
|
# This cache is unbounded and on disk, so it's not safe to use in
|
||||||
|
# an environment with potentially malicious users.
|
||||||
|
"VLLM_V1_USE_OUTLINES_CACHE":
|
||||||
|
lambda: os.environ.get("VLLM_V1_USE_OUTLINES_CACHE", "0") == "1",
|
||||||
|
|
||||||
# Gap between padding buckets for the forward pass. So we have
|
# Gap between padding buckets for the forward pass. So we have
|
||||||
# 8, we will run forward pass with [16, 24, 32, ...].
|
# 8, we will run forward pass with [16, 24, 32, ...].
|
||||||
"VLLM_TPU_BUCKET_PADDING_GAP":
|
"VLLM_TPU_BUCKET_PADDING_GAP":
|
||||||
|
|||||||
@ -79,20 +79,33 @@ def maybe_backend_fallback(
|
|||||||
fallback_or_error(
|
fallback_or_error(
|
||||||
guided_params,
|
guided_params,
|
||||||
"xgrammar does not support Lark grammars and the "
|
"xgrammar does not support Lark grammars and the "
|
||||||
"grammar failed to convert to GBNF.", "outlines")
|
"grammar failed to convert to GBNF.", "guidance")
|
||||||
|
|
||||||
# If the xgrammar module cannot be imported successfully,
|
# If the xgrammar module cannot be imported successfully,
|
||||||
# we should still allow users to use guided decoding with a fallback.
|
# we should still allow users to use guided decoding with a fallback.
|
||||||
elif not xgr_installed:
|
elif not xgr_installed:
|
||||||
fallback_or_error(
|
fallback_or_error(
|
||||||
guided_params,
|
guided_params,
|
||||||
"xgrammar module cannot be imported successfully.", "outlines")
|
"xgrammar module cannot be imported successfully.", "guidance")
|
||||||
|
|
||||||
if (guided_params.backend == "outlines"
|
if guided_params.backend == "outlines":
|
||||||
and guided_params.json_object is not None):
|
if guided_params.json_object is not None:
|
||||||
# outlines doesn't support json_object, fallback to guidance
|
# outlines doesn't support json_object, fallback to guidance
|
||||||
fallback_or_error(guided_params,
|
fallback_or_error(guided_params,
|
||||||
"outlines does not support json_object.", "guidance")
|
"outlines does not support json_object.",
|
||||||
|
"guidance")
|
||||||
|
elif guided_params.grammar is not None:
|
||||||
|
# outlines grammar support has been removed, fallback to guidance
|
||||||
|
# if it is a lark-based grammar and xgrammar otherwise
|
||||||
|
if grammar_is_likely_lark(guided_params.grammar):
|
||||||
|
fallback_or_error(guided_params,
|
||||||
|
"outlines no longer supports grammars.",
|
||||||
|
"guidance")
|
||||||
|
else:
|
||||||
|
# The grammar is likely already GBNF format.
|
||||||
|
fallback_or_error(guided_params,
|
||||||
|
"outlines no longer supports grammars.",
|
||||||
|
"xgrammar")
|
||||||
|
|
||||||
return guided_params
|
return guided_params
|
||||||
|
|
||||||
@ -111,7 +124,6 @@ async def get_guided_decoding_logits_processor(
|
|||||||
|
|
||||||
guided_params = maybe_backend_fallback(guided_params)
|
guided_params = maybe_backend_fallback(guided_params)
|
||||||
|
|
||||||
# CFG grammar not supported by LMFE, so we use outlines instead
|
|
||||||
if guided_params.backend == 'outlines':
|
if guided_params.backend == 'outlines':
|
||||||
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
|
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
|
||||||
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
|
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
|
||||||
@ -152,7 +164,6 @@ def get_local_guided_decoding_logits_processor(
|
|||||||
reasoning_backend)
|
reasoning_backend)
|
||||||
reasoner = reasoner_class(tokenizer)
|
reasoner = reasoner_class(tokenizer)
|
||||||
|
|
||||||
# CFG grammar not supported by LMFE, so we use outlines instead
|
|
||||||
if guided_params.backend == 'outlines':
|
if guided_params.backend == 'outlines':
|
||||||
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
|
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
|
||||||
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
|
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
|
||||||
|
|||||||
@ -12,7 +12,7 @@ from regex import escape as regex_escape
|
|||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
|
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
|
||||||
CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
|
JSONLogitsProcessor, RegexLogitsProcessor)
|
||||||
from vllm.reasoning import ReasoningParser
|
from vllm.reasoning import ReasoningParser
|
||||||
from vllm.sampling_params import GuidedDecodingParams
|
from vllm.sampling_params import GuidedDecodingParams
|
||||||
|
|
||||||
@ -21,36 +21,8 @@ class GuidedDecodingMode(Enum):
|
|||||||
JSON = "json"
|
JSON = "json"
|
||||||
REGEX = "regex"
|
REGEX = "regex"
|
||||||
CHOICE = "choice"
|
CHOICE = "choice"
|
||||||
GRAMMAR = "grammar"
|
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/outlines-dev/outlines/blob/main/outlines/grammars/json.lark
|
|
||||||
# the main difference is that we changed the start: value to
|
|
||||||
# start: object | array, so we are denying scalar values as the root of the
|
|
||||||
# JSON. Starting with scalars as the root seems to cause llama to generate
|
|
||||||
# without stop.
|
|
||||||
JSON_GRAMMAR = r"""
|
|
||||||
?start: object | array
|
|
||||||
|
|
||||||
?value: object
|
|
||||||
| array
|
|
||||||
| UNESCAPED_STRING
|
|
||||||
| SIGNED_NUMBER -> number
|
|
||||||
| "true" -> true
|
|
||||||
| "false" -> false
|
|
||||||
| "null" -> null
|
|
||||||
|
|
||||||
array : "[" [value ("," value)*] "]"
|
|
||||||
object : "{" [pair ("," pair)*] "}"
|
|
||||||
pair : UNESCAPED_STRING ":" value
|
|
||||||
|
|
||||||
%import common.UNESCAPED_STRING
|
|
||||||
%import common.SIGNED_NUMBER
|
|
||||||
%import common.WS
|
|
||||||
|
|
||||||
%ignore WS
|
|
||||||
"""
|
|
||||||
|
|
||||||
global_thread_pool = None # used for generating logits processor fsm
|
global_thread_pool = None # used for generating logits processor fsm
|
||||||
|
|
||||||
# It's not yet clear that using more provides a benefit, and it could
|
# It's not yet clear that using more provides a benefit, and it could
|
||||||
@ -60,16 +32,12 @@ _MAX_THREADPOOL_WORKERS = 16
|
|||||||
|
|
||||||
|
|
||||||
async def get_outlines_guided_decoding_logits_processor(
|
async def get_outlines_guided_decoding_logits_processor(
|
||||||
guided_params: GuidedDecodingParams,
|
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
reasoner: Optional[ReasoningParser]
|
||||||
reasoner: Optional[ReasoningParser],
|
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, None]:
|
||||||
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
|
|
||||||
None]:
|
|
||||||
"""
|
"""
|
||||||
Given an OpenAI-compatible request, check for guided decoding parameters
|
Given an OpenAI-compatible request, check for guided decoding parameters
|
||||||
and get the necessary logits processor for the given guide.
|
and get the necessary logits processor for the given guide.
|
||||||
We cache logit processors by (guide, tokenizer), and on cache hit
|
|
||||||
we make a shallow copy to reuse the same underlying FSM.
|
|
||||||
"""
|
"""
|
||||||
global global_thread_pool
|
global global_thread_pool
|
||||||
guide, mode = _get_guide_and_mode(guided_params)
|
guide, mode = _get_guide_and_mode(guided_params)
|
||||||
@ -83,7 +51,6 @@ async def get_outlines_guided_decoding_logits_processor(
|
|||||||
global_thread_pool = concurrent.futures.ThreadPoolExecutor(
|
global_thread_pool = concurrent.futures.ThreadPoolExecutor(
|
||||||
max_workers=max_workers)
|
max_workers=max_workers)
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
return await loop.run_in_executor(global_thread_pool,
|
return await loop.run_in_executor(global_thread_pool,
|
||||||
_get_logits_processor, guide, tokenizer,
|
_get_logits_processor, guide, tokenizer,
|
||||||
mode, guided_params.whitespace_pattern,
|
mode, guided_params.whitespace_pattern,
|
||||||
@ -91,16 +58,12 @@ async def get_outlines_guided_decoding_logits_processor(
|
|||||||
|
|
||||||
|
|
||||||
def get_local_outlines_guided_decoding_logits_processor(
|
def get_local_outlines_guided_decoding_logits_processor(
|
||||||
guided_params: GuidedDecodingParams,
|
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
reasoner: Optional[ReasoningParser]
|
||||||
reasoner: Optional[ReasoningParser],
|
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, None]:
|
||||||
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
|
|
||||||
None]:
|
|
||||||
"""
|
"""
|
||||||
Given an OpenAI-compatible request, check for guided decoding parameters
|
Given an OpenAI-compatible request, check for guided decoding parameters
|
||||||
and get the necessary logits processor for the given guide.
|
and get the necessary logits processor for the given guide.
|
||||||
We cache logit processors by (guide, tokenizer), and on cache hit
|
|
||||||
we make a shallow copy to reuse the same underlying FSM.
|
|
||||||
"""
|
"""
|
||||||
guide, mode = _get_guide_and_mode(guided_params)
|
guide, mode = _get_guide_and_mode(guided_params)
|
||||||
if not guide or not mode:
|
if not guide or not mode:
|
||||||
@ -130,9 +93,10 @@ def _get_guide_and_mode(
|
|||||||
choices_regex = "(" + "|".join(choices) + ")"
|
choices_regex = "(" + "|".join(choices) + ")"
|
||||||
return choices_regex, GuidedDecodingMode.CHOICE
|
return choices_regex, GuidedDecodingMode.CHOICE
|
||||||
elif guided_params.grammar:
|
elif guided_params.grammar:
|
||||||
return guided_params.grammar, GuidedDecodingMode.GRAMMAR
|
raise ValueError(
|
||||||
elif guided_params.json_object:
|
"The `outlines` guided decoding backend no longer supports grammar "
|
||||||
return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR
|
"guided generation. Please use either the `xgrammar` or `guidance` "
|
||||||
|
"backend")
|
||||||
else:
|
else:
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
@ -143,13 +107,11 @@ def _get_logits_processor(
|
|||||||
mode: GuidedDecodingMode,
|
mode: GuidedDecodingMode,
|
||||||
whitespace_pattern: Union[str, None],
|
whitespace_pattern: Union[str, None],
|
||||||
reasoner: Optional[ReasoningParser],
|
reasoner: Optional[ReasoningParser],
|
||||||
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]:
|
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]:
|
||||||
if mode == GuidedDecodingMode.JSON:
|
if mode == GuidedDecodingMode.JSON:
|
||||||
return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern,
|
return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern,
|
||||||
reasoner)
|
reasoner)
|
||||||
elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
|
elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
|
||||||
return RegexLogitsProcessor(guide, tokenizer, reasoner)
|
return RegexLogitsProcessor(guide, tokenizer, reasoner)
|
||||||
elif mode == GuidedDecodingMode.GRAMMAR:
|
|
||||||
return CFGLogitsProcessor(guide, tokenizer, reasoner)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown guided decoding mode {mode}")
|
raise ValueError(f"Unknown guided decoding mode {mode}")
|
||||||
|
|||||||
@ -1,168 +1,124 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
# SPDX-FileCopyrightText: Copyright 2024-present the Outlines developers
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
# Copyright 2024- the Outlines developers
|
|
||||||
# This file is adapted from
|
|
||||||
# https://github.com/outlines-dev/outlines/blob/main/outlines/serve/vllm.py
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import copy
|
import copy
|
||||||
|
import hashlib
|
||||||
|
import importlib.metadata
|
||||||
import json
|
import json
|
||||||
from collections import defaultdict
|
import os
|
||||||
from functools import lru_cache
|
from typing import Optional, Union
|
||||||
from typing import Callable, Optional, Union
|
|
||||||
|
|
||||||
import numpy as np
|
import regex as re
|
||||||
import torch
|
import torch
|
||||||
from outlines import grammars
|
from cachetools import LRUCache
|
||||||
from outlines.caching import cache, disable_cache
|
from diskcache import Cache
|
||||||
from outlines.fsm.guide import (CFGGuide, CFGState, Generate, Guide,
|
from outlines_core import Guide, Index, Vocabulary
|
||||||
RegexGuide, Write)
|
from outlines_core.json_schema import build_regex_from_schema
|
||||||
from outlines.fsm.parsing import PartialLark
|
from outlines_core.kernels.torch import (_apply_token_bitmask_inplace_kernel,
|
||||||
from outlines_core.fsm.json_schema import build_regex_from_schema
|
allocate_token_bitmask)
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
from transformers.file_utils import SPIECE_UNDERLINE
|
||||||
|
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
|
||||||
from vllm.reasoning import ReasoningParser
|
from vllm.reasoning import ReasoningParser
|
||||||
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
if envs.VLLM_V0_USE_OUTLINES_CACHE:
|
CACHE = None
|
||||||
logger.warning("Enabling outlines cache. This is an unbounded on-disk "
|
|
||||||
"cache. It may consume a lot of disk space and should "
|
|
||||||
"not be used with untrusted clients.")
|
|
||||||
else:
|
|
||||||
disable_cache()
|
|
||||||
|
|
||||||
|
|
||||||
class BaseLogitsProcessor:
|
class BaseLogitsProcessor:
|
||||||
|
|
||||||
def __init__(self, guide: Guide, reasoner: Optional[ReasoningParser]):
|
def __init__(self, guide: Guide, eos_token_id: int,
|
||||||
|
reasoner: Optional[ReasoningParser]) -> None:
|
||||||
self._guide: Guide = guide
|
self._guide: Guide = guide
|
||||||
|
self._eos_token_id: int = eos_token_id
|
||||||
self._reasoner: Optional[ReasoningParser] = reasoner
|
self._reasoner: Optional[ReasoningParser] = reasoner
|
||||||
# CFGState is used for the FSM state for CFGGuide
|
self._mask: Optional[torch.Tensor] = None
|
||||||
self._fsm_state: defaultdict[int, Union[int,
|
|
||||||
CFGState]] = defaultdict(int)
|
|
||||||
|
|
||||||
def clone(self) -> "BaseLogitsProcessor":
|
|
||||||
cloned = copy.copy(self)
|
|
||||||
cloned._guide = self._guide.copy()
|
|
||||||
cloned._fsm_state = copy.deepcopy(self._fsm_state)
|
|
||||||
return cloned
|
|
||||||
|
|
||||||
def __call__(self, input_ids: list[int],
|
def __call__(self, input_ids: list[int],
|
||||||
scores: torch.Tensor) -> torch.Tensor:
|
scores: torch.Tensor) -> torch.Tensor:
|
||||||
"""Use the FSM to bias the logits before sampling the next token."""
|
if self._mask is None:
|
||||||
|
self._mask = allocate_token_bitmask(scores.size(-1))
|
||||||
|
|
||||||
# Skip the structured logits processing if reasoning is not finished.
|
# Skip the structured logits processing if reasoning is not finished.
|
||||||
# reasoner is not None only when `--reasoning-parser` is set.
|
# reasoner is not None only when `--reasoning-parser` is set.
|
||||||
if self._reasoner is not None:
|
if self._reasoner is not None and not self._reasoner.is_reasoning_end(
|
||||||
if not self._reasoner.is_reasoning_end(input_ids):
|
input_ids):
|
||||||
return scores
|
return scores
|
||||||
else:
|
|
||||||
# Remove the reasoning tokens from the input_ids
|
|
||||||
# We need this because our implementation relies on the
|
|
||||||
# hash of the input_ids to store the FSM state.
|
|
||||||
input_ids = self._reasoner.extract_content_ids(input_ids)
|
|
||||||
|
|
||||||
seq_id = hash(tuple(input_ids))
|
# Remove the reasoning tokens from the input_ids
|
||||||
|
# We need this because our implementation relies on the
|
||||||
|
# input_ids sequence to store the FSM state.
|
||||||
|
input_ids = (self._reasoner.extract_content_ids(input_ids)
|
||||||
|
if self._reasoner is not None else input_ids)
|
||||||
|
|
||||||
if len(input_ids) > 0:
|
# Vllm V0 engine has a weird bug where we have to repeat
|
||||||
last_token = input_ids[-1]
|
# the eos token id twice for generation to stop, or at least
|
||||||
last_seq_id = hash(tuple(input_ids[:-1]))
|
# that is what we have to do from here in any case.
|
||||||
self._fsm_state[seq_id] = self._guide.get_next_state(
|
# This is a patch until a better solution can be pushed
|
||||||
state=self._fsm_state[last_seq_id], token_id=last_token)
|
# to outlines_core
|
||||||
else:
|
if input_ids and input_ids[-1] != self._eos_token_id:
|
||||||
# Note: this is a hack.
|
self._guide.advance(token_id=input_ids[-1], return_tokens=False)
|
||||||
# Lark pickling does not work properly (silent failure),
|
|
||||||
# which breaks the RPC (which uses python pickleing).
|
|
||||||
# We need to find a better solution.
|
|
||||||
# On the first time this is called, we simply re-create
|
|
||||||
# the Lark object.
|
|
||||||
if isinstance(self._guide, CFGGuide):
|
|
||||||
self._guide.parser = PartialLark(
|
|
||||||
self._guide.cfg_string,
|
|
||||||
parser="lalr",
|
|
||||||
import_paths=[grammars.GRAMMAR_PATH],
|
|
||||||
)
|
|
||||||
self._fsm_state[seq_id] = CFGState(
|
|
||||||
parser_state=self._guide.parser.parse(""), prev_token=None)
|
|
||||||
|
|
||||||
instruction = self._guide.get_next_instruction(
|
self._guide.write_mask_into(
|
||||||
state=self._fsm_state[seq_id])
|
data_ptr=self._mask.data_ptr(),
|
||||||
|
numel=self._mask.numel(),
|
||||||
|
element_size=self._mask.element_size(),
|
||||||
|
)
|
||||||
|
|
||||||
if type(instruction) == Generate: # noqa: E721
|
# Any allowed tokens beyond the length of the scores will
|
||||||
allowed_tokens = instruction.tokens
|
# be ignored by the kernel, taking care of the issue with
|
||||||
elif type(instruction) == Write: # noqa: E721
|
# models such as Llama 3.2 Vision with an `<|image|>` token
|
||||||
# TODO: support fast forward tokens
|
# with id 128256, but scores.shape == torch.Size([128256])
|
||||||
allowed_tokens = [instruction.tokens[0]]
|
_apply_token_bitmask_inplace_kernel(
|
||||||
else:
|
logits=scores.unsqueeze(dim=0),
|
||||||
raise TypeError(
|
# mask must be on same device
|
||||||
f"Unsupported instruction type {type(instruction)}")
|
mask=self._mask.to(scores.device, non_blocking=True))
|
||||||
|
self._mask.to("cpu", non_blocking=True)
|
||||||
|
|
||||||
mask = torch.full((scores.shape[-1], ),
|
|
||||||
-torch.inf,
|
|
||||||
device=scores.device)
|
|
||||||
# The tokenizer may support more token ids than the model can generate,
|
|
||||||
# eg. Llama 3.2 Vision models have an `<|image|>` token with id 128256
|
|
||||||
# but scores.shape == torch.Size([128256])
|
|
||||||
# Using NumPy is faster for filtering token ids
|
|
||||||
allowed_tokens = np.array(allowed_tokens, dtype=np.int64)
|
|
||||||
allowed_tokens = torch.tensor(allowed_tokens, device=scores.device)
|
|
||||||
allowed_tokens = allowed_tokens.masked_select(
|
|
||||||
allowed_tokens < scores.shape[-1])
|
|
||||||
mask.index_fill_(0, allowed_tokens, 0)
|
|
||||||
if current_platform.is_hpu():
|
|
||||||
# Workaround for HPU bug where add_() raise RuntimeError:
|
|
||||||
# synNodeCreateWithId failed for node: strided_insert
|
|
||||||
# with synStatus 1 [Invalid argument], hopefully it will
|
|
||||||
# be fixed in the future releases of the HPU runtime.
|
|
||||||
scores = scores.add(mask)
|
|
||||||
else:
|
|
||||||
scores.add_(mask)
|
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
def clone(self) -> BaseLogitsProcessor:
|
||||||
|
guide = copy.deepcopy(self._guide)
|
||||||
|
guide.reset()
|
||||||
|
return BaseLogitsProcessor(guide=guide,
|
||||||
|
eos_token_id=self._eos_token_id,
|
||||||
|
reasoner=self._reasoner)
|
||||||
|
|
||||||
|
|
||||||
class RegexLogitsProcessor(BaseLogitsProcessor):
|
class RegexLogitsProcessor(BaseLogitsProcessor):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@cache()
|
|
||||||
def _get_guide(cls, regex_string: str,
|
def _get_guide(cls, regex_string: str,
|
||||||
tokenizer: PreTrainedTokenizerBase) -> Guide:
|
tokenizer: PreTrainedTokenizerBase) -> Guide:
|
||||||
tokenizer = _adapt_tokenizer(tokenizer)
|
global CACHE
|
||||||
return RegexGuide.from_regex(regex_string, tokenizer)
|
if CACHE is None:
|
||||||
|
CACHE = get_cache()
|
||||||
|
vocabulary = get_vocabulary(tokenizer) # type: ignore[arg-type]
|
||||||
|
cache_key = f"{vocabulary._hash}_{regex_string}"
|
||||||
|
if CACHE is not None and cache_key in CACHE:
|
||||||
|
return Guide(CACHE[cache_key])
|
||||||
|
|
||||||
def __init__(
|
index = Index(regex_string, vocabulary.inner)
|
||||||
self,
|
|
||||||
regex_string: str,
|
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
|
||||||
reasoner: Optional[ReasoningParser],
|
|
||||||
):
|
|
||||||
"""Compile the FSM that drives the regex-structured generation.
|
|
||||||
|
|
||||||
Parameters
|
if CACHE is not None:
|
||||||
----------
|
CACHE[cache_key] = index
|
||||||
regex_string
|
|
||||||
A string that represents a regular expression
|
|
||||||
tokenizer
|
|
||||||
The model's tokenizer
|
|
||||||
|
|
||||||
"""
|
return Guide(index)
|
||||||
|
|
||||||
|
def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase,
|
||||||
|
reasoner: Optional[ReasoningParser]) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
RegexLogitsProcessor._get_guide(regex_string, tokenizer), reasoner)
|
guide=RegexLogitsProcessor._get_guide(regex_string, tokenizer),
|
||||||
|
eos_token_id=tokenizer.eos_token_id, # type: ignore
|
||||||
|
reasoner=reasoner)
|
||||||
|
|
||||||
|
|
||||||
class JSONLogitsProcessor(RegexLogitsProcessor):
|
class JSONLogitsProcessor(RegexLogitsProcessor):
|
||||||
@ -170,22 +126,8 @@ class JSONLogitsProcessor(RegexLogitsProcessor):
|
|||||||
def __init__(self, schema: Union[str, dict, BaseModel],
|
def __init__(self, schema: Union[str, dict, BaseModel],
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
whitespace_pattern: Union[str, None],
|
whitespace_pattern: Union[str, None],
|
||||||
reasoner: Optional[ReasoningParser]):
|
reasoner: Optional[ReasoningParser]) -> None:
|
||||||
"""Compile the FSM that drives the JSON-guided generation.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
schema
|
|
||||||
A JSON schema that encodes the structure we want the model to
|
|
||||||
generate
|
|
||||||
tokenizer
|
|
||||||
The model's tokenizer
|
|
||||||
whitespace_pattern
|
|
||||||
Pattern to use for JSON syntactic whitespace (doesn't impact
|
|
||||||
string literals)
|
|
||||||
Example: allow only a single space or newline with
|
|
||||||
`whitespace_pattern=r"[\n ]?"`
|
|
||||||
"""
|
|
||||||
if isinstance(schema, type(BaseModel)):
|
if isinstance(schema, type(BaseModel)):
|
||||||
schema_str = json.dumps(schema.model_json_schema())
|
schema_str = json.dumps(schema.model_json_schema())
|
||||||
elif isinstance(schema, dict):
|
elif isinstance(schema, dict):
|
||||||
@ -197,63 +139,42 @@ class JSONLogitsProcessor(RegexLogitsProcessor):
|
|||||||
f"Cannot parse schema {schema}. The schema must be either "
|
f"Cannot parse schema {schema}. The schema must be either "
|
||||||
f"a Pydantic object, a dictionary or a string that contains "
|
f"a Pydantic object, a dictionary or a string that contains "
|
||||||
f"the JSON Schema specification")
|
f"the JSON Schema specification")
|
||||||
|
|
||||||
regex_string = build_regex_from_schema(schema_str, whitespace_pattern)
|
regex_string = build_regex_from_schema(schema_str, whitespace_pattern)
|
||||||
super().__init__(regex_string, tokenizer, reasoner)
|
super().__init__(regex_string, tokenizer, reasoner)
|
||||||
|
|
||||||
|
|
||||||
class CFGLogitsProcessor(BaseLogitsProcessor):
|
class OutlinesVocabulary:
|
||||||
|
"""
|
||||||
@classmethod
|
Wrapper class for `outlines_core.Vocabulary`,
|
||||||
@cache()
|
which allows us to store a hash with the vocabulary
|
||||||
def _get_guide(cls, cfg: str, tokenizer: PreTrainedTokenizerBase) -> Guide:
|
|
||||||
tokenizer = _adapt_tokenizer(tokenizer)
|
|
||||||
return CFGGuide(cfg, tokenizer)
|
|
||||||
|
|
||||||
def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase,
|
|
||||||
reasoner: Optional[ReasoningParser]):
|
|
||||||
"""Compile the FSM that drives the context free grammar generation.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
cfg
|
|
||||||
A string that represents a context-free grammar
|
|
||||||
tokenizer
|
|
||||||
The model's tokenizer
|
|
||||||
|
|
||||||
"""
|
|
||||||
super().__init__(CFGLogitsProcessor._get_guide(cfg, tokenizer),
|
|
||||||
reasoner)
|
|
||||||
self._guide = self._guide.copy()
|
|
||||||
|
|
||||||
def clone(self) -> "CFGLogitsProcessor":
|
|
||||||
cloned = copy.copy(self)
|
|
||||||
cloned._fsm_state = copy.deepcopy(self._fsm_state)
|
|
||||||
cloned._guide = self._guide.copy()
|
|
||||||
return cloned
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=32)
|
|
||||||
def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
|
|
||||||
"""Adapt vLLM's tokenizer to use to compile the FSM.
|
|
||||||
|
|
||||||
The API of Outlines tokenizers is slightly different to that of
|
|
||||||
`transformers`. The decoder of outlines, returns a list whereas
|
|
||||||
the decode of vLLM returns an str. To sync the vLLM decoder with
|
|
||||||
outlines internal api, the decoder should be adapted. In addition
|
|
||||||
we need to handle the missing spaces to Llama's tokenizer to be
|
|
||||||
able to compile FSMs for this model.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if getattr(tokenizer, "_outlines_adapted", False):
|
|
||||||
return tokenizer
|
|
||||||
|
|
||||||
tokenizer = copy.deepcopy(tokenizer)
|
def __init__(self, vocabulary: Vocabulary) -> None:
|
||||||
|
# Actual vocabulary object
|
||||||
|
self.inner = vocabulary
|
||||||
|
# Have to do abs(hash()) because python hashes can
|
||||||
|
# be negative, and we are using hash as a cache key.
|
||||||
|
hex_str = hashlib.sha256(
|
||||||
|
vocabulary.__repr__().encode('utf-8')).hexdigest()
|
||||||
|
hash_int = int(hex_str, 16)
|
||||||
|
self._hash = hash_int
|
||||||
|
|
||||||
tokenizer.vocabulary = tokenizer.get_vocab()
|
|
||||||
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
|
re_llama_byte_token = re.compile(r"^<0x[0-9A-F]{2}>$")
|
||||||
|
re_replacement_seq = re.compile(r"^.{0,6}<7D>+.{0,6}$")
|
||||||
|
|
||||||
|
|
||||||
|
def _reduced_vocabulary(tokenizer: AnyTokenizer,
|
||||||
|
eos_token_id: int) -> dict[bytes, list[int]]:
|
||||||
|
"""Create a map from vocabulary tokens to lists of equivalent token ids.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Dict of token string -> equivalent token ids
|
||||||
|
"""
|
||||||
|
unicode_to_bytes = {v: k for k, v in bytes_to_unicode().items()}
|
||||||
|
|
||||||
def convert_token_to_string(token: str) -> str:
|
def convert_token_to_string(token: str) -> str:
|
||||||
from transformers.file_utils import SPIECE_UNDERLINE
|
|
||||||
|
|
||||||
string = tokenizer.convert_tokens_to_string([token])
|
string = tokenizer.convert_tokens_to_string([token])
|
||||||
|
|
||||||
@ -264,21 +185,123 @@ def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
|
|||||||
|
|
||||||
return string
|
return string
|
||||||
|
|
||||||
def change_decoder(
|
vocabulary: dict[bytes, list[int]] = {}
|
||||||
decoder: Callable[[list[int]],
|
empty_token_ids: list[int] = []
|
||||||
str]) -> Callable[[list[int]], list[str]]:
|
for token, token_idx in tokenizer.get_vocab().items():
|
||||||
"""Sync vLLM's decoder with the outlines by returning list."""
|
if token in tokenizer.all_special_tokens: # type: ignore
|
||||||
|
continue
|
||||||
|
|
||||||
def new_decoder(inp_tokens: list[int]) -> list[str]:
|
token_str = convert_token_to_string(token)
|
||||||
if (isinstance(inp_tokens, list) and len(inp_tokens) == 1
|
if token_str:
|
||||||
and isinstance(inp_tokens[0], list)):
|
if isinstance(token, (bytes, bytearray)):
|
||||||
inp_tokens = inp_tokens[0]
|
# For BPE tokenizers where tokens are stored as bytes.
|
||||||
return [decoder(inp_tokens)]
|
|
||||||
|
|
||||||
return new_decoder
|
# safe to ignore since token_str is of type (bytearray, bytes)
|
||||||
|
# by this point.
|
||||||
|
token_bytes = bytes(token_str) # type: ignore[arg-type]
|
||||||
|
|
||||||
tokenizer.convert_token_to_string = convert_token_to_string
|
elif "\ufffd" in token_str and not re_replacement_seq.match(
|
||||||
tokenizer.decode = change_decoder(tokenizer.decode)
|
token_str):
|
||||||
setattr(tokenizer, "_outlines_adapted", True) # noqa: B010
|
# Handle tokens with invalid UTF-8 sequences.
|
||||||
|
if re_llama_byte_token.match(token):
|
||||||
|
# Llama-like tokenizers use <0xXX> for incomplete sequences.
|
||||||
|
token_bytes = bytes([int(token[3:5], 16)])
|
||||||
|
else:
|
||||||
|
# GPT2 tokenizers: map each byte back using unicode_to_bytes
|
||||||
|
byte_vals = [unicode_to_bytes.get(c) for c in token]
|
||||||
|
if None in byte_vals:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot convert token `{token}`"
|
||||||
|
f" ({token_idx}) to bytes: {token_str}")
|
||||||
|
# safe to ignore, since if None in byte_vals,
|
||||||
|
# an error is thrown.
|
||||||
|
token_bytes = bytes(byte_vals) # type: ignore[arg-type]
|
||||||
|
else:
|
||||||
|
token_bytes = token_str.encode('utf-8')
|
||||||
|
|
||||||
return tokenizer
|
if token_idx != eos_token_id:
|
||||||
|
vocabulary.setdefault(token_bytes, []).append(token_idx)
|
||||||
|
else:
|
||||||
|
empty_token_ids.append(token_idx)
|
||||||
|
|
||||||
|
return vocabulary
|
||||||
|
|
||||||
|
|
||||||
|
def get_vocabulary(tokenizer: AnyTokenizer) -> Vocabulary:
|
||||||
|
"""Get the `Vocabulary` object for a given tokenizer.
|
||||||
|
"""
|
||||||
|
if hasattr(tokenizer, "_outlines_vocabulary"):
|
||||||
|
return tokenizer._outlines_vocabulary # type: ignore
|
||||||
|
|
||||||
|
try:
|
||||||
|
if hasattr(
|
||||||
|
tokenizer,
|
||||||
|
"eos_token_id",
|
||||||
|
) and tokenizer.eos_token_id is not None:
|
||||||
|
eos_token_id = tokenizer.eos_token_id
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Error during guided decoding setup: Tokenizer"
|
||||||
|
f" ({type(tokenizer)}) has no `eos_token_id` property, "
|
||||||
|
"but `eos_token_id` is required for guided decoding"
|
||||||
|
" to work properly.")
|
||||||
|
|
||||||
|
reduced_vocab = _reduced_vocabulary(
|
||||||
|
tokenizer,
|
||||||
|
eos_token_id #type: ignore
|
||||||
|
)
|
||||||
|
vocabulary = OutlinesVocabulary(Vocabulary(eos_token_id,
|
||||||
|
reduced_vocab))
|
||||||
|
tokenizer._outlines_vocabulary = vocabulary # type: ignore
|
||||||
|
|
||||||
|
return vocabulary
|
||||||
|
except AttributeError as e:
|
||||||
|
raise ValueError(f"Cannot get the vocabulary of the tokenizer "
|
||||||
|
f"({type(tokenizer)}). The tokenizer should have a "
|
||||||
|
"get_vocab method.") from e
|
||||||
|
|
||||||
|
|
||||||
|
def get_cache_path() -> str:
|
||||||
|
"""Get the context object that contains previously-computed return values"""
|
||||||
|
outlines_cache_dir = os.getenv("OUTLINES_CACHE_DIR")
|
||||||
|
xdg_cache_home = os.getenv("XDG_CACHE_HOME")
|
||||||
|
home_dir = os.path.expanduser("~")
|
||||||
|
|
||||||
|
if outlines_cache_dir:
|
||||||
|
# OUTLINES_CACHE_DIR takes precedence
|
||||||
|
return outlines_cache_dir
|
||||||
|
elif xdg_cache_home:
|
||||||
|
return os.path.join(xdg_cache_home, ".cache", "outlines")
|
||||||
|
# If homedir is "/", we may be inside a container, and thus writing to
|
||||||
|
# root would be problematic, so we fallback to using a tempfile.
|
||||||
|
# Also validate the path exists, since os.path.expanduser does
|
||||||
|
# not garuntee existence.
|
||||||
|
elif os.path.isdir(home_dir) and home_dir != "/":
|
||||||
|
# Default Unix fallback: ~/.cache/outlines
|
||||||
|
return os.path.join(home_dir, ".cache", "outlines")
|
||||||
|
else:
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
# home_dir may be / inside a docker container without existing user
|
||||||
|
tempdir = tempfile.gettempdir()
|
||||||
|
return os.path.join(tempdir, ".cache", "outlines")
|
||||||
|
|
||||||
|
|
||||||
|
def get_cache():
|
||||||
|
"""Get the Cache instance to be used for index caching"""
|
||||||
|
|
||||||
|
cache_dir = get_cache_path()
|
||||||
|
if envs.VLLM_V0_USE_OUTLINES_CACHE:
|
||||||
|
logger.warning("Enabling outlines cache. This is an unbounded on-disk "
|
||||||
|
"cache. It may consume a lot of disk space and should "
|
||||||
|
"not be used with untrusted clients.")
|
||||||
|
cache = Cache(cache_dir, eviction_policy="none", cull_limit=0)
|
||||||
|
outlines_version = importlib.metadata.version("outlines_core")
|
||||||
|
|
||||||
|
cached_version = cache.get('__version__', None)
|
||||||
|
if cached_version != outlines_version:
|
||||||
|
cache.clear()
|
||||||
|
cache.set('__version__', outlines_version)
|
||||||
|
return cache
|
||||||
|
else:
|
||||||
|
return LRUCache(maxsize=128)
|
||||||
|
|||||||
@ -23,6 +23,8 @@ from vllm.v1.engine import EngineCoreRequest
|
|||||||
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
|
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
|
||||||
from vllm.v1.structured_output.backend_guidance import (
|
from vllm.v1.structured_output.backend_guidance import (
|
||||||
validate_guidance_grammar)
|
validate_guidance_grammar)
|
||||||
|
from vllm.v1.structured_output.backend_outlines import (
|
||||||
|
validate_structured_output_request_outlines)
|
||||||
from vllm.v1.structured_output.backend_xgrammar import (
|
from vllm.v1.structured_output.backend_xgrammar import (
|
||||||
validate_xgrammar_grammar)
|
validate_xgrammar_grammar)
|
||||||
|
|
||||||
@ -193,6 +195,9 @@ class Processor:
|
|||||||
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
|
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
|
||||||
# Without tokenizer these are disallowed in grammars.
|
# Without tokenizer these are disallowed in grammars.
|
||||||
validate_guidance_grammar(params, tokenizer=None)
|
validate_guidance_grammar(params, tokenizer=None)
|
||||||
|
elif engine_level_backend == "outlines":
|
||||||
|
# outlines backend
|
||||||
|
validate_structured_output_request_outlines(params)
|
||||||
else:
|
else:
|
||||||
# NOTE: engine_level_backend must be "auto" here, because we have
|
# NOTE: engine_level_backend must be "auto" here, because we have
|
||||||
# checked supported_backends above.
|
# checked supported_backends above.
|
||||||
|
|||||||
@ -88,6 +88,15 @@ class StructuredOutputManager:
|
|||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
)
|
)
|
||||||
|
elif backend == "outlines":
|
||||||
|
from vllm.v1.structured_output.backend_outlines import (
|
||||||
|
OutlinesBackend)
|
||||||
|
|
||||||
|
self.backend = OutlinesBackend(
|
||||||
|
self.vllm_config,
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported structured output backend: {backend}")
|
f"Unsupported structured output backend: {backend}")
|
||||||
|
|||||||
319
vllm/v1/structured_output/backend_outlines.py
Normal file
319
vllm/v1/structured_output/backend_outlines.py
Normal file
@ -0,0 +1,319 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright 2025-present the Outlines developers
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import ast
|
||||||
|
import importlib
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from regex import escape as regex_escape
|
||||||
|
|
||||||
|
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
|
||||||
|
OutlinesVocabulary, get_cache, get_vocabulary)
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
from vllm.utils import LazyLoader
|
||||||
|
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
|
||||||
|
StructuredOutputGrammar,
|
||||||
|
StructuredOutputOptions)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import outlines_core as oc
|
||||||
|
import outlines_core.json_schema as json_schema
|
||||||
|
else:
|
||||||
|
oc = LazyLoader("oc", globals(), "outlines_core")
|
||||||
|
json_schema = LazyLoader("json_schema", globals(),
|
||||||
|
"outlines_core.json_schema")
|
||||||
|
|
||||||
|
# Python 3.11+ sre_parse and sre_constants
|
||||||
|
# are deprecated, so we must import them from re
|
||||||
|
if sys.version_info >= (3, 11):
|
||||||
|
# Hack to get around pre-commit regex module rule
|
||||||
|
# because going through re is the only way to get sre_parse
|
||||||
|
# and sre_constants in Python 3.11+
|
||||||
|
_re = importlib.import_module("re")
|
||||||
|
sre_parse = _re._parser
|
||||||
|
sre_constants = _re._constants
|
||||||
|
else:
|
||||||
|
import sre_constants
|
||||||
|
import sre_parse
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OutlinesBackend(StructuredOutputBackend):
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
self.vocabulary = get_vocabulary(self.tokenizer)
|
||||||
|
self.cache = get_cache()
|
||||||
|
|
||||||
|
def _compile_index(self, regex_string: str,
|
||||||
|
vocabulary: OutlinesVocabulary) -> oc.Index:
|
||||||
|
cache_key = f"{vocabulary._hash}_{regex_string}"
|
||||||
|
if cache_key in self.cache:
|
||||||
|
return self.cache[cache_key]
|
||||||
|
|
||||||
|
index = oc.Index(regex_string, vocabulary.inner)
|
||||||
|
self.cache[cache_key] = index
|
||||||
|
|
||||||
|
return index
|
||||||
|
|
||||||
|
def compile_grammar(self, request_type: StructuredOutputOptions,
|
||||||
|
grammar_spec: str) -> StructuredOutputGrammar:
|
||||||
|
if request_type == StructuredOutputOptions.JSON:
|
||||||
|
regex = json_schema.build_regex_from_schema(grammar_spec)
|
||||||
|
elif request_type == StructuredOutputOptions.REGEX:
|
||||||
|
regex = grammar_spec
|
||||||
|
elif request_type == StructuredOutputOptions.CHOICE:
|
||||||
|
choices = ast.literal_eval(grammar_spec)
|
||||||
|
choices = [regex_escape(c) for c in choices]
|
||||||
|
regex = "(" + "|".join(choices) + ")"
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid request type for Outlines backend ({request_type!s})"
|
||||||
|
)
|
||||||
|
index = self._compile_index(regex, self.vocabulary)
|
||||||
|
max_rollback_tokens = (
|
||||||
|
self.vllm_config.speculative_config.num_speculative_tokens
|
||||||
|
if self.vllm_config.speculative_config is not None else 0)
|
||||||
|
return OutlinesGrammar(vocab_size=self.vocab_size,
|
||||||
|
guide=oc.Guide(
|
||||||
|
index, max_rollback=max_rollback_tokens))
|
||||||
|
|
||||||
|
def allocate_token_bitmask(self, max_num_seqs: int) -> torch.Tensor:
|
||||||
|
return torch.full(
|
||||||
|
(max_num_seqs, (self.vocab_size + 31) // 32),
|
||||||
|
-1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
pin_memory=torch.cuda.is_available(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def destroy(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OutlinesGrammar(StructuredOutputGrammar):
|
||||||
|
|
||||||
|
vocab_size: int
|
||||||
|
guide: oc.Guide = field(hash=False)
|
||||||
|
num_processed_tokens: int = field(default_factory=lambda: 0,
|
||||||
|
repr=False,
|
||||||
|
hash=False,
|
||||||
|
init=False)
|
||||||
|
|
||||||
|
# outlines_core signals done on DFA accept; vLLM expects done after EOS.
|
||||||
|
# We delay the finished flag by one step so EOS can still be emitted.
|
||||||
|
_prev_finished: bool = field(default=False,
|
||||||
|
init=False,
|
||||||
|
repr=False,
|
||||||
|
hash=False)
|
||||||
|
|
||||||
|
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
|
||||||
|
"""Accepts a list of tokens and advances the FSM.
|
||||||
|
|
||||||
|
Returns True if the FSM was advanced successfully.
|
||||||
|
Returns False if the FSM failed to advance.
|
||||||
|
"""
|
||||||
|
if self.guide.accepts_tokens(tokens):
|
||||||
|
# Advance cannot fail because we checked Guide.accepts_tokens()
|
||||||
|
for t in tokens:
|
||||||
|
self.guide.advance(t)
|
||||||
|
self.num_processed_tokens += 1
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def rollback(self, num_tokens: int) -> None:
|
||||||
|
self.guide.rollback_state(num_tokens)
|
||||||
|
self.num_processed_tokens -= num_tokens
|
||||||
|
|
||||||
|
def validate_tokens(self, tokens: list[int]) -> list[int]:
|
||||||
|
accepted: list[int] = []
|
||||||
|
for tok in tokens:
|
||||||
|
accepted.append(tok)
|
||||||
|
if not self.guide.accepts_tokens(accepted):
|
||||||
|
accepted.pop()
|
||||||
|
break
|
||||||
|
return accepted
|
||||||
|
|
||||||
|
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
|
||||||
|
mask = bitmask[idx]
|
||||||
|
self.guide.write_mask_into(mask.data_ptr(), mask.numel(),
|
||||||
|
mask.element_size())
|
||||||
|
|
||||||
|
def is_terminated(self) -> bool:
|
||||||
|
curr = self.guide.is_finished()
|
||||||
|
prev = self._prev_finished
|
||||||
|
self._prev_finished = curr
|
||||||
|
return prev
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.num_processed_tokens = 0
|
||||||
|
self._prev_finished = False
|
||||||
|
self.guide.reset()
|
||||||
|
|
||||||
|
|
||||||
|
def validate_structured_output_request_outlines(params: SamplingParams):
|
||||||
|
if params.guided_decoding is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
gd_params = params.guided_decoding
|
||||||
|
|
||||||
|
if gd_params.regex:
|
||||||
|
validate_regex_is_buildable(gd_params.regex)
|
||||||
|
elif gd_params.json:
|
||||||
|
if isinstance(gd_params.json, str):
|
||||||
|
try:
|
||||||
|
# make sure schema is valid json
|
||||||
|
json.loads(gd_params.json)
|
||||||
|
schema = gd_params.json
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise ValueError("Invalid JSON grammar specification.") from e
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
schema = json.dumps(gd_params.json)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Error serializing guided decoding jsonschema: {e}"
|
||||||
|
) from e
|
||||||
|
pattern = json_schema.build_regex_from_schema(schema)
|
||||||
|
validate_regex_is_buildable(pattern)
|
||||||
|
elif gd_params.choice:
|
||||||
|
choices = [regex_escape(str(choice)) for choice in gd_params.choice]
|
||||||
|
regex = "(" + "|".join(choices) + ")"
|
||||||
|
validate_regex_is_buildable(regex)
|
||||||
|
elif gd_params.grammar:
|
||||||
|
raise ValueError("Outlines guided decoding backend "
|
||||||
|
"does not support grammar specifications")
|
||||||
|
|
||||||
|
|
||||||
|
def _prefix_needs_context(parsed) -> bool:
|
||||||
|
"""Return True if there's a look-around/anchor before any consumer."""
|
||||||
|
|
||||||
|
def subpattern_consumes(parsed) -> bool:
|
||||||
|
"""Return True if subpattern can consume at least one character."""
|
||||||
|
tokens = parsed.data if hasattr(parsed, 'data') else parsed
|
||||||
|
for ttype, tval in tokens:
|
||||||
|
# literal, character class, or dot always consumes
|
||||||
|
if ttype in (sre_parse.LITERAL, sre_parse.IN, sre_parse.ANY):
|
||||||
|
return True
|
||||||
|
# quantified subpattern: check inner pattern
|
||||||
|
elif ttype == sre_parse.MAX_REPEAT:
|
||||||
|
_, mx, sub = tval
|
||||||
|
if mx != 0 and subpattern_consumes(sub):
|
||||||
|
return True
|
||||||
|
# alternation: if any branch consumes, the whole does
|
||||||
|
elif ttype == sre_parse.BRANCH:
|
||||||
|
_, branches = tval
|
||||||
|
if any(subpattern_consumes(br) for br in branches):
|
||||||
|
return True
|
||||||
|
# grouped subpattern: recurse into its contents
|
||||||
|
elif ttype == sre_parse.SUBPATTERN and subpattern_consumes(
|
||||||
|
tval[3]):
|
||||||
|
return True
|
||||||
|
# No consumers, return False
|
||||||
|
return False
|
||||||
|
|
||||||
|
tokens = parsed.data if hasattr(parsed, 'data') else parsed
|
||||||
|
for ttype, tval in tokens:
|
||||||
|
# Direct anchors or look-around
|
||||||
|
if ttype == sre_parse.AT or ttype in (sre_constants.ASSERT,
|
||||||
|
sre_constants.ASSERT_NOT):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Nested subpattern: check
|
||||||
|
if ttype == sre_parse.SUBPATTERN:
|
||||||
|
# tval: (group, add_flags, del_flags, subpattern)
|
||||||
|
if _prefix_needs_context(tval[3]):
|
||||||
|
return True
|
||||||
|
if subpattern_consumes(tval[3]):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# if any branch has a prefix anchor => True,
|
||||||
|
# else if at least one branch consumes => prefix ends => False
|
||||||
|
elif ttype == sre_parse.BRANCH:
|
||||||
|
saw_consumer = False
|
||||||
|
for br in tval[1]:
|
||||||
|
if _prefix_needs_context(br):
|
||||||
|
return True
|
||||||
|
if subpattern_consumes(br):
|
||||||
|
saw_consumer = True
|
||||||
|
if saw_consumer:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Immediate consumer tokens
|
||||||
|
elif ttype in (sre_parse.LITERAL, sre_parse.IN, sre_parse.ANY):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# if subpattern has anchor => True, if it can consume => stop
|
||||||
|
elif ttype == sre_parse.MAX_REPEAT:
|
||||||
|
if _prefix_needs_context(tval[2]):
|
||||||
|
return True
|
||||||
|
if subpattern_consumes(tval[2]):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _check_unsupported(parsed) -> None:
|
||||||
|
"""Check for regex features unsupported by regex-automata"""
|
||||||
|
tokens = parsed.data if hasattr(parsed, 'data') else parsed
|
||||||
|
for ttype, tval in tokens:
|
||||||
|
|
||||||
|
# backreference
|
||||||
|
if ttype in (sre_parse.GROUPREF, sre_parse.GROUPREF_EXISTS):
|
||||||
|
raise ValueError("Backreferences are unsupported.")
|
||||||
|
|
||||||
|
# look-around assertion
|
||||||
|
elif ttype in (sre_constants.ASSERT, sre_constants.ASSERT_NOT):
|
||||||
|
raise ValueError("Look-Around assertion are unsupported.")
|
||||||
|
|
||||||
|
# unicode word boundaries
|
||||||
|
elif ttype == sre_parse.AT:
|
||||||
|
if tval in (sre_constants.AT_BOUNDARY,
|
||||||
|
sre_constants.AT_NON_BOUNDARY):
|
||||||
|
raise ValueError("Unicode word boundaries are unsupported.")
|
||||||
|
|
||||||
|
elif ttype == sre_parse.BRANCH:
|
||||||
|
# tval is (None, branches)
|
||||||
|
for branch in tval[1]:
|
||||||
|
_check_unsupported(branch)
|
||||||
|
|
||||||
|
# tval is (min, max, subpattern)
|
||||||
|
elif ttype == sre_parse.MAX_REPEAT:
|
||||||
|
_check_unsupported(tval[2])
|
||||||
|
|
||||||
|
|
||||||
|
def validate_regex_is_buildable(pattern: str) -> None:
|
||||||
|
"""
|
||||||
|
Validates that the input regex is not using unsupported features
|
||||||
|
of the `regex-automata` crate (outlines_core regex engine) and has a
|
||||||
|
universal start state.
|
||||||
|
definition of universal start state used can be found at:
|
||||||
|
https://docs.rs/regex-automata/latest/regex_automata/dfa/trait.Automaton.html#method.universal_start_state
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
parsed = sre_parse.parse(pattern)
|
||||||
|
|
||||||
|
except sre_constants.error as e:
|
||||||
|
raise ValueError(f"Error parsing regex: {e}") from e
|
||||||
|
|
||||||
|
try:
|
||||||
|
_check_unsupported(parsed)
|
||||||
|
except ValueError as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Regex uses unsupported feature for guided decoding: {e}. "
|
||||||
|
"Only basic matching constructs are supported—lookarounds, "
|
||||||
|
"backreferences, and unicode boundaries are not.") from e
|
||||||
|
|
||||||
|
if _prefix_needs_context(parsed):
|
||||||
|
raise ValueError(
|
||||||
|
"Regex does not have a anchored universal start state"
|
||||||
|
"This means that the Regex uses anchors (^) or look-arounds "
|
||||||
|
"in a way which requires context before any token is matched."
|
||||||
|
"Guided decoding needs regexes that can match without needing "
|
||||||
|
"that context. Try rewriting the pattern without using these "
|
||||||
|
f"constructs. Pattern:\n{pattern}")
|
||||||
Loading…
x
Reference in New Issue
Block a user