From 7329ff5468eceaf17f4b193ae3ef0b43c7bf38d6 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Fri, 28 Mar 2025 11:46:45 -0400 Subject: [PATCH] [V1] Support disable_any_whtespace for guidance backend (#15584) Signed-off-by: Russell Bryant --- tests/entrypoints/llm/test_guided_generate.py | 62 +++---------------- .../llm/test_struct_output_generate.py | 54 +++------------- vllm/engine/arg_utils.py | 3 +- .../guided_decoding/guidance_decoding.py | 12 +++- vllm/v1/engine/processor.py | 11 ++-- vllm/v1/structured_output/backend_guidance.py | 19 +++--- 6 files changed, 44 insertions(+), 117 deletions(-) diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py index 5f1a91cb2b19..3f275e0b2ec7 100644 --- a/tests/entrypoints/llm/test_guided_generate.py +++ b/tests/entrypoints/llm/test_guided_generate.py @@ -6,7 +6,6 @@ import weakref import jsonschema import pytest -from pydantic import BaseModel from vllm.distributed import cleanup_dist_env_and_memory from vllm.entrypoints.llm import LLM @@ -15,7 +14,10 @@ from vllm.sampling_params import GuidedDecodingParams, SamplingParams MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" GUIDED_DECODING_BACKENDS = [ - "outlines", "lm-format-enforcer", "xgrammar", "guidance" + "outlines", + "lm-format-enforcer", + "xgrammar:disable-any-whitespace", + "guidance:disable-any-whitespace", ] @@ -322,59 +324,9 @@ def test_guided_json_object(llm, guided_decoding_backend: str): print(generated_text) assert generated_text is not None + if 'disable-any-whitespace' in guided_decoding_backend: + assert "\n" not in generated_text + # Parse to verify it is valid JSON parsed_json = json.loads(generated_text) assert isinstance(parsed_json, dict) - - -@pytest.mark.skip_global_cleanup -def test_json_with_any_whitespace_disabled(llm): - - class ResponseSchema(BaseModel): - clarifying_question: str - cost_per_serving: str - calories: str - type_dish_ids: str - type_meal_ids: str - product_ids: list[str] - exclude_product_ids: list[str] - allergen_ids: list[str] - total_cooking_time: str - kitchen_ids: str - holiday_ids: str - - # Note: Without this setting, the response is sometimes full of `\n` - # for some models. This option prevents that. - guided_decoding_backend = 'xgrammar:disable-any-whitespace' - - schema = ResponseSchema.model_json_schema() - guided_params = GuidedDecodingParams(json=schema, - backend=\ - guided_decoding_backend) - sampling_params = SamplingParams(max_tokens=2000, - frequency_penalty=0, - presence_penalty=-1.1, - repetition_penalty=1.3, - guided_decoding=guided_params) - - prompt = ("<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You" - "are a helpful assistant.<|im_end|>\n<|im_start|>user\nI want a " - "quick launch fast with $10.<|im_end|>\n<|im_start|>assistant\n") - outputs = llm.generate(prompts=prompt, - sampling_params=sampling_params, - use_tqdm=True) - - assert outputs is not None - - for output in outputs: - assert output is not None - assert isinstance(output, RequestOutput) - - generated_text = output.outputs[0].text - assert generated_text is not None - assert "\n" not in generated_text - - # Parse to verify it is valid JSON - parsed_json = json.loads(generated_text) - assert isinstance(parsed_json, dict) - jsonschema.validate(instance=parsed_json, schema=schema) diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index 00fa47575b6a..c9fa03a1ae1f 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -15,7 +15,9 @@ from vllm.entrypoints.llm import LLM from vllm.outputs import RequestOutput from vllm.sampling_params import GuidedDecodingParams, SamplingParams -GUIDED_DECODING_BACKENDS_V1 = ["xgrammar", "guidance"] +GUIDED_DECODING_BACKENDS_V1 = [ + "xgrammar:disable-any-whitespace", "guidance:disable-any-whitespace" +] MODELS_TO_TEST = [ "Qwen/Qwen2.5-1.5B-Instruct", "mistralai/Ministral-8B-Instruct-2410" ] @@ -55,50 +57,8 @@ def test_guided_json_completion( generated_text = output.outputs[0].text assert generated_text is not None - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - output_json = json.loads(generated_text) - jsonschema.validate(instance=output_json, schema=sample_json_schema) - - -@pytest.mark.skip_global_cleanup -@pytest.mark.parametrize("guided_decoding_backend", - GUIDED_DECODING_BACKENDS_V1) -@pytest.mark.parametrize("model_name", MODELS_TO_TEST) -def test_guided_json_completion_disable_any_whitespace( - monkeypatch: pytest.MonkeyPatch, - sample_json_schema: dict[str, Any], - guided_decoding_backend: str, - model_name: str, -): - if guided_decoding_backend != "xgrammar": - pytest.skip("disable-any-whitespace is only supported for xgrammar.") - guided_decoding_backend = 'xgrammar:disable-any-whitespace' - - monkeypatch.setenv("VLLM_USE_V1", "1") - llm = LLM(model=model_name, - max_model_len=1024, - guided_decoding_backend=guided_decoding_backend) - sampling_params = SamplingParams( - temperature=1.0, - max_tokens=1000, - guided_decoding=GuidedDecodingParams(json=sample_json_schema)) - outputs = llm.generate(prompts=[ - f"Give an example JSON for an employee profile " - f"that fits this schema: {sample_json_schema}" - ] * 2, - sampling_params=sampling_params, - use_tqdm=True) - - assert outputs is not None - - for output in outputs: - assert output is not None - assert isinstance(output, RequestOutput) - prompt = output.prompt - - generated_text = output.outputs[0].text - assert generated_text is not None - assert "\n" not in generated_text + if 'disable-any-whitespace' in guided_decoding_backend: + assert "\n" not in generated_text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") output_json = json.loads(generated_text) jsonschema.validate(instance=output_json, schema=sample_json_schema) @@ -142,7 +102,7 @@ def test_guided_json_object( # Parse to verify it is valid JSON parsed_json = json.loads(generated_text) allowed_types: tuple[type, ...] = (dict, ) - if guided_decoding_backend == "xgrammar": + if guided_decoding_backend.startswith("xgrammar"): # TODO - we are currently too permissive with xgrammar and # allow # any valid json (typically comes back as a list or # object). We can fix this by specifying a jsonschema of @@ -170,7 +130,7 @@ def test_guided_json_unsupported_schema( temperature=1.0, max_tokens=1000, guided_decoding=GuidedDecodingParams(json=unsupported_json_schema)) - if guided_decoding_backend == "xgrammar": + if guided_decoding_backend.startswith("xgrammar"): with pytest.raises(ValueError, match="The provided JSON schema contains features " "not supported by xgrammar."): diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index a416fa8aa08e..6f498af36a40 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1561,7 +1561,8 @@ class EngineArgs: # Xgrammar and Guidance are supported. SUPPORTED_GUIDED_DECODING = [ - "xgrammar", "xgrammar:disable-any-whitespace", "guidance", "auto" + "xgrammar", "xgrammar:disable-any-whitespace", "guidance", + "guidance:disable-any-whitespace", "auto" ] if self.guided_decoding_backend not in SUPPORTED_GUIDED_DECODING: _raise_or_fallback(feature_name="--guided-decoding-backend", diff --git a/vllm/model_executor/guided_decoding/guidance_decoding.py b/vllm/model_executor/guided_decoding/guidance_decoding.py index d8675a14030d..f19ebcbe420e 100644 --- a/vllm/model_executor/guided_decoding/guidance_decoding.py +++ b/vllm/model_executor/guided_decoding/guidance_decoding.py @@ -18,14 +18,22 @@ def get_local_guidance_guided_decoding_logits_processor( """ grm = "" + any_whitespace = 'disable-any-whitespace' not in \ + guided_params.backend_options() if guided_params.json: grm = llguidance.LLMatcher.grammar_from_json_schema( guided_params.json, - overrides={"whitespace_pattern": guided_params.whitespace_pattern}) + overrides={"whitespace_pattern": guided_params.whitespace_pattern}, + defaults={ + "whitespace_flexible": any_whitespace, + }) elif guided_params.json_object: grm = llguidance.LLMatcher.grammar_from_json_schema( '{"type": "object"}', - overrides={"whitespace_pattern": guided_params.whitespace_pattern}) + overrides={"whitespace_pattern": guided_params.whitespace_pattern}, + defaults={ + "whitespace_flexible": any_whitespace, + }) elif guided_params.regex: grm = llguidance.grammar_from("regex", guided_params.regex) elif guided_params.choice: diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 24762d214c34..dbaf0abaea18 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -121,7 +121,8 @@ class Processor: return supported_backends = [ - "xgrammar", "xgrammar:disable-any-whitespace", "guidance", "auto" + "xgrammar", "xgrammar:disable-any-whitespace", "guidance", + "guidance:disable-any-whitespace", "auto" ] engine_level_backend = self.decoding_config.guided_decoding_backend if engine_level_backend not in supported_backends: @@ -140,11 +141,10 @@ class Processor: raise ValueError("Structured output is not supported on TPU.") # Request content validation - - if engine_level_backend == "xgrammar": + if engine_level_backend.startswith("xgrammar"): # xgrammar with no fallback validate_structured_output_request_xgrammar(params) - params.guided_decoding.backend = "xgrammar" + params.guided_decoding.backend = engine_level_backend elif engine_level_backend == "auto": # "auto" is an opt-in to opinionated behavior where we try to # choose a backend based on request contents. This is not the @@ -158,12 +158,13 @@ class Processor: # are not supported in xgrammar. Fall back to guidance. params.guided_decoding.backend = "guidance" - if params.guided_decoding.backend == "guidance": + if engine_level_backend.startswith("guidance"): # TODO ideally we would have the LLTokenizer here as Lark syntax # allows <|special_token|> and similar, see # https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens # Without tokenizer these are disallowed in grammars. validate_guidance_grammar(params, tokenizer=None) + params.guided_decoding.backend = engine_level_backend def process_inputs( self, diff --git a/vllm/v1/structured_output/backend_guidance.py b/vllm/v1/structured_output/backend_guidance.py index 1e274ad0ae62..a7ba71016949 100644 --- a/vllm/v1/structured_output/backend_guidance.py +++ b/vllm/v1/structured_output/backend_guidance.py @@ -41,6 +41,9 @@ class GuidanceBackend(StructuredOutputBackend): tokenizer_group.ping() self.vllm_config = vllm_config self.vocab_size = vllm_config.model_config.get_vocab_size() + self.disable_any_whitespace = ( + "disable-any-whitespace" + in vllm_config.decoding_config.guided_decoding_backend) tokenizer = tokenizer_group.get_lora_tokenizer(None) self.ll_tokenizer = llguidance_hf.from_tokenizer(tokenizer, None) @@ -48,7 +51,7 @@ class GuidanceBackend(StructuredOutputBackend): def compile_grammar(self, request_type: StructuredOutputOptions, grammar_spec: str) -> StructuredOutputGrammar: self.serialized_grammar = serialize_guidance_grammar( - request_type, grammar_spec) + request_type, grammar_spec, self.disable_any_whitespace) ll_matcher = llguidance.LLMatcher( self.ll_tokenizer, @@ -126,17 +129,19 @@ class GuidanceGrammar(StructuredOutputGrammar): def serialize_guidance_grammar(request_type: StructuredOutputOptions, - grammar_spec: str) -> str: + grammar_spec: str, + disable_any_whitespace: bool = False) -> str: if request_type == StructuredOutputOptions.JSON: - # TODO: make whitespace_flexible configurable return llguidance.LLMatcher.grammar_from_json_schema( - grammar_spec, defaults={ - "whitespace_flexible": True, + grammar_spec, + defaults={ + "whitespace_flexible": not disable_any_whitespace, }) elif request_type == StructuredOutputOptions.JSON_OBJECT: return llguidance.LLMatcher.grammar_from_json_schema( - '{"type": "object"}', defaults={ - "whitespace_flexible": True, + '{"type": "object"}', + defaults={ + "whitespace_flexible": not disable_any_whitespace, }) else: if request_type == StructuredOutputOptions.REGEX: