From a6977dbd1531378456725e5cdb151c88a33df52a Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Tue, 29 Apr 2025 20:02:23 +0100 Subject: [PATCH] Simplify (and fix) passing of guided decoding backend options (#17008) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- ...enai_chat_completion_structured_outputs.py | 7 +- tests/entrypoints/llm/test_guided_generate.py | 206 +++++++++++------- .../model_executor/test_guided_processors.py | 15 +- .../llm/test_struct_output_generate.py | 31 +-- tests/v1/test_oracle.py | 3 +- vllm/config.py | 70 +++++- vllm/engine/arg_utils.py | 36 ++- vllm/engine/llm_engine.py | 2 +- vllm/engine/multiprocessing/client.py | 4 +- .../guided_decoding/__init__.py | 24 +- .../guided_decoding/guidance_decoding.py | 5 +- .../guided_decoding/xgrammar_decoding.py | 12 +- vllm/sampling_params.py | 56 +++-- vllm/v1/engine/processor.py | 8 +- vllm/v1/structured_output/__init__.py | 8 +- vllm/v1/structured_output/backend_guidance.py | 25 +-- vllm/v1/structured_output/backend_xgrammar.py | 14 +- 17 files changed, 309 insertions(+), 217 deletions(-) diff --git a/examples/online_serving/openai_chat_completion_structured_outputs.py b/examples/online_serving/openai_chat_completion_structured_outputs.py index f71162e36efd2..9c57af1c158c1 100644 --- a/examples/online_serving/openai_chat_completion_structured_outputs.py +++ b/examples/online_serving/openai_chat_completion_structured_outputs.py @@ -112,8 +112,8 @@ def extra_backend_options_completion(client: OpenAI, model: str): "alan.turing@enigma.com\n") try: - # The no-fallback option forces vLLM to use xgrammar, so when it fails - # you get a 400 with the reason why + # The guided_decoding_disable_fallback option forces vLLM to use + # xgrammar, so when it fails you get a 400 with the reason why completion = client.chat.completions.create( model=model, messages=[{ @@ -123,7 +123,8 @@ def extra_backend_options_completion(client: OpenAI, model: str): extra_body={ "guided_regex": r"\w+@\w+\.com\n", "stop": ["\n"], - "guided_decoding_backend": "xgrammar:no-fallback" + "guided_decoding_backend": "xgrammar", + "guided_decoding_disable_fallback": True, }, ) return completion.choices[0].message.content diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py index ad726fa8ce518..fdbdccd4654c1 100644 --- a/tests/entrypoints/llm/test_guided_generate.py +++ b/tests/entrypoints/llm/test_guided_generate.py @@ -16,10 +16,11 @@ from vllm.sampling_params import GuidedDecodingParams, SamplingParams MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" GUIDED_DECODING_BACKENDS = [ - "outlines", - "lm-format-enforcer", - "xgrammar:disable-any-whitespace", - "guidance:disable-any-whitespace", + # (backend, disable_any_whitespace), + ("outlines", False), + ("lm-format-enforcer", False), + ("xgrammar", True), + ("guidance", True), ] @@ -36,13 +37,17 @@ def llm(): @pytest.mark.skip_global_cleanup -@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) -def test_guided_regex(sample_regex, llm, guided_decoding_backend: str): - sampling_params = SamplingParams(temperature=0.8, - top_p=0.95, - guided_decoding=GuidedDecodingParams( - regex=sample_regex, - backend=guided_decoding_backend)) +@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", + GUIDED_DECODING_BACKENDS) +def test_guided_regex(sample_regex, llm, guided_decoding_backend: str, + disable_any_whitespace: bool): + sampling_params = SamplingParams( + temperature=0.8, + top_p=0.95, + guided_decoding=GuidedDecodingParams( + regex=sample_regex, + backend=guided_decoding_backend, + disable_any_whitespace=disable_any_whitespace)) outputs = llm.generate(prompts=[ f"Give an example IPv4 address with this regex: {sample_regex}" ] * 2, @@ -62,14 +67,18 @@ def test_guided_regex(sample_regex, llm, guided_decoding_backend: str): @pytest.mark.skip_global_cleanup -@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) +@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", + GUIDED_DECODING_BACKENDS) def test_guided_json_completion(sample_json_schema, llm, - guided_decoding_backend: str): - sampling_params = SamplingParams(temperature=1.0, - max_tokens=1000, - guided_decoding=GuidedDecodingParams( - json=sample_json_schema, - backend=guided_decoding_backend)) + guided_decoding_backend: str, + disable_any_whitespace: bool): + sampling_params = SamplingParams( + temperature=1.0, + max_tokens=1000, + guided_decoding=GuidedDecodingParams( + json=sample_json_schema, + backend=guided_decoding_backend, + disable_any_whitespace=disable_any_whitespace)) outputs = llm.generate(prompts=[ f"Give an example JSON for an employee profile " f"that fits this schema: {sample_json_schema}" @@ -92,14 +101,18 @@ def test_guided_json_completion(sample_json_schema, llm, @pytest.mark.skip_global_cleanup -@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) +@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", + GUIDED_DECODING_BACKENDS) def test_guided_complex_json_completion(sample_complex_json_schema, llm, - guided_decoding_backend: str): - sampling_params = SamplingParams(temperature=1.0, - max_tokens=1000, - guided_decoding=GuidedDecodingParams( - json=sample_complex_json_schema, - backend=guided_decoding_backend)) + guided_decoding_backend: str, + disable_any_whitespace: bool): + sampling_params = SamplingParams( + temperature=1.0, + max_tokens=1000, + guided_decoding=GuidedDecodingParams( + json=sample_complex_json_schema, + backend=guided_decoding_backend, + disable_any_whitespace=disable_any_whitespace)) outputs = llm.generate(prompts=[ f"Give an example JSON for an assignment grade " f"that fits this schema: {sample_complex_json_schema}" @@ -123,14 +136,18 @@ def test_guided_complex_json_completion(sample_complex_json_schema, llm, @pytest.mark.skip_global_cleanup -@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) +@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", + GUIDED_DECODING_BACKENDS) def test_guided_definition_json_completion(sample_definition_json_schema, llm, - guided_decoding_backend: str): - sampling_params = SamplingParams(temperature=1.0, - max_tokens=1000, - guided_decoding=GuidedDecodingParams( - json=sample_definition_json_schema, - backend=guided_decoding_backend)) + guided_decoding_backend: str, + disable_any_whitespace: bool): + sampling_params = SamplingParams( + temperature=1.0, + max_tokens=1000, + guided_decoding=GuidedDecodingParams( + json=sample_definition_json_schema, + backend=guided_decoding_backend, + disable_any_whitespace=disable_any_whitespace)) outputs = llm.generate(prompts=[ f"Give an example JSON for solving 8x + 7 = -23 " f"that fits this schema: {sample_definition_json_schema}" @@ -154,14 +171,18 @@ def test_guided_definition_json_completion(sample_definition_json_schema, llm, @pytest.mark.skip_global_cleanup -@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) +@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", + GUIDED_DECODING_BACKENDS) def test_guided_enum_json_completion(sample_enum_json_schema, llm, - guided_decoding_backend: str): - sampling_params = SamplingParams(temperature=1.0, - max_tokens=1000, - guided_decoding=GuidedDecodingParams( - json=sample_enum_json_schema, - backend=guided_decoding_backend)) + guided_decoding_backend: str, + disable_any_whitespace: bool): + sampling_params = SamplingParams( + temperature=1.0, + max_tokens=1000, + guided_decoding=GuidedDecodingParams( + json=sample_enum_json_schema, + backend=guided_decoding_backend, + disable_any_whitespace=disable_any_whitespace)) outputs = llm.generate(prompts=[ "Create a bug report JSON that fits this schema: " f"{sample_enum_json_schema}. Make it for a high priority critical bug." @@ -195,14 +216,18 @@ def test_guided_enum_json_completion(sample_enum_json_schema, llm, @pytest.mark.skip_global_cleanup -@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) +@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", + GUIDED_DECODING_BACKENDS) def test_guided_choice_completion(sample_guided_choice, llm, - guided_decoding_backend: str): - sampling_params = SamplingParams(temperature=0.8, - top_p=0.95, - guided_decoding=GuidedDecodingParams( - choice=sample_guided_choice, - backend=guided_decoding_backend)) + guided_decoding_backend: str, + disable_any_whitespace: bool): + sampling_params = SamplingParams( + temperature=0.8, + top_p=0.95, + guided_decoding=GuidedDecodingParams( + choice=sample_guided_choice, + backend=guided_decoding_backend, + disable_any_whitespace=disable_any_whitespace)) outputs = llm.generate( prompts="The best language for type-safe systems programming is ", sampling_params=sampling_params, @@ -221,15 +246,19 @@ def test_guided_choice_completion(sample_guided_choice, llm, @pytest.mark.skip_global_cleanup -@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) +@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", + GUIDED_DECODING_BACKENDS) def test_guided_grammar(sample_sql_statements, llm, - guided_decoding_backend: str): - sampling_params = SamplingParams(temperature=0.8, - top_p=0.95, - max_tokens=1000, - guided_decoding=GuidedDecodingParams( - grammar=sample_sql_statements, - backend=guided_decoding_backend)) + guided_decoding_backend: str, + disable_any_whitespace: bool): + sampling_params = SamplingParams( + temperature=0.8, + top_p=0.95, + max_tokens=1000, + guided_decoding=GuidedDecodingParams( + grammar=sample_sql_statements, + backend=guided_decoding_backend, + disable_any_whitespace=disable_any_whitespace)) outputs = llm.generate( prompts=("Generate a sql state that select col_1 from " "table_1 where it is equals to 1"), @@ -300,7 +329,8 @@ def test_disable_guided_decoding_fallback(sample_regex, llm): top_p=0.95, guided_decoding=GuidedDecodingParams( json=unsupported_json, - backend="xgrammar:no-fallback")) + backend="xgrammar", + disable_fallback=True)) with pytest.raises( ValueError, @@ -312,14 +342,18 @@ def test_disable_guided_decoding_fallback(sample_regex, llm): @pytest.mark.skip_global_cleanup -@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) -def test_guided_json_object(llm, guided_decoding_backend: str): - sampling_params = SamplingParams(temperature=1.0, - max_tokens=100, - n=2, - guided_decoding=GuidedDecodingParams( - json_object=True, - backend=guided_decoding_backend)) +@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", + GUIDED_DECODING_BACKENDS) +def test_guided_json_object(llm, guided_decoding_backend: str, + disable_any_whitespace: bool): + sampling_params = SamplingParams( + temperature=1.0, + max_tokens=100, + n=2, + guided_decoding=GuidedDecodingParams( + json_object=True, + backend=guided_decoding_backend, + disable_any_whitespace=disable_any_whitespace)) outputs = llm.generate( prompts=("Generate a JSON object with curly braces for a person with " @@ -337,7 +371,7 @@ def test_guided_json_object(llm, guided_decoding_backend: str): print(generated_text) assert generated_text is not None - if 'disable-any-whitespace' in guided_decoding_backend: + if disable_any_whitespace: assert "\n" not in generated_text # Parse to verify it is valid JSON @@ -359,14 +393,18 @@ class CarDescription(BaseModel): @pytest.mark.skip_global_cleanup -@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) -def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str): +@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", + GUIDED_DECODING_BACKENDS) +def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str, + disable_any_whitespace: bool): json_schema = CarDescription.model_json_schema() - sampling_params = SamplingParams(temperature=1.0, - max_tokens=1000, - guided_decoding=GuidedDecodingParams( - json=json_schema, - backend=guided_decoding_backend)) + sampling_params = SamplingParams( + temperature=1.0, + max_tokens=1000, + guided_decoding=GuidedDecodingParams( + json=json_schema, + backend=guided_decoding_backend, + disable_any_whitespace=disable_any_whitespace)) outputs = llm.generate( prompts="Generate a JSON with the brand, model and car_type of" "the most iconic car from the 90's", @@ -387,9 +425,10 @@ def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str): @pytest.mark.skip_global_cleanup -@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) -def test_guided_number_range_json_completion(llm, - guided_decoding_backend: str): +@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", + GUIDED_DECODING_BACKENDS) +def test_guided_number_range_json_completion(llm, guided_decoding_backend: str, + disable_any_whitespace: bool): sample_output_schema = { "type": "object", "properties": { @@ -413,8 +452,10 @@ def test_guided_number_range_json_completion(llm, sampling_params = SamplingParams( temperature=1.0, max_tokens=1000, - guided_decoding=GuidedDecodingParams(json=sample_output_schema, - backend=guided_decoding_backend), + guided_decoding=GuidedDecodingParams( + json=sample_output_schema, + backend=guided_decoding_backend, + disable_any_whitespace=disable_any_whitespace), ) outputs = llm.generate( prompts=[ @@ -466,8 +507,12 @@ def test_guidance_no_additional_properties(llm): "large JSON object with key-value pairs a1=b1, a2=b2, ..., a20=b20" "<|im_end|>\n<|im_start|>assistant\n") - def generate_with_backend(backend): - guided_params = GuidedDecodingParams(json=schema, backend=backend) + def generate_with_backend(backend, disable_additional_properties): + guided_params = GuidedDecodingParams( + json=schema, + backend=backend, + disable_any_whitespace=True, + disable_additional_properties=disable_additional_properties) sampling_params = SamplingParams(temperature=0, max_tokens=256, guided_decoding=guided_params) @@ -481,7 +526,7 @@ def test_guidance_no_additional_properties(llm): jsonschema.validate(instance=parsed_json, schema=schema) return parsed_json - base_generated = generate_with_backend('guidance:disable-any-whitespace') + base_generated = generate_with_backend("guidance", False) assert "a1" in base_generated assert "a2" in base_generated assert "a3" in base_generated @@ -490,8 +535,7 @@ def test_guidance_no_additional_properties(llm): assert "a5" in base_generated assert "a6" in base_generated - generated = generate_with_backend( - 'guidance:no-additional-properties,disable-any-whitespace') + generated = generate_with_backend("guidance", True) assert "a1" in generated assert "a2" in generated assert "a3" in generated diff --git a/tests/model_executor/test_guided_processors.py b/tests/model_executor/test_guided_processors.py index 59da575e37b18..6cd966f84802b 100644 --- a/tests/model_executor/test_guided_processors.py +++ b/tests/model_executor/test_guided_processors.py @@ -202,12 +202,15 @@ def test_multiple_guided_options_not_allowed(sample_json_schema, sample_regex): def test_guided_decoding_backend_options(): """Test backend-specific options""" - params = GuidedDecodingParams( - backend="xgrammar:option-1,option-2,option-3") - assert params.backend_options() == ["option-1", "option-2", "option-3"] - - no_fallback = GuidedDecodingParams(backend="xgrammar:option-1,no-fallback") - assert no_fallback.no_fallback() + with pytest.warns(DeprecationWarning): + guided_decoding_params = GuidedDecodingParams( + backend= + "xgrammar:no-fallback,disable-any-whitespace,no-additional-properties" + ) + assert guided_decoding_params.backend == "xgrammar" + assert guided_decoding_params.disable_fallback + assert guided_decoding_params.disable_any_whitespace + assert guided_decoding_params.disable_additional_properties def test_pickle_xgrammar_tokenizer_data(): diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index 3de4fec9c9019..29ec6088ee8b9 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -17,15 +17,12 @@ from vllm.platforms import current_platform from vllm.sampling_params import GuidedDecodingParams, SamplingParams PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [ - ("mistralai/Ministral-8B-Instruct-2410", "xgrammar:disable-any-whitespace", - "auto"), - ("mistralai/Ministral-8B-Instruct-2410", "guidance:disable-any-whitespace", - "auto"), - ("mistralai/Ministral-8B-Instruct-2410", "xgrammar:disable-any-whitespace", - "mistral"), - ("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar:disable-any-whitespace", "auto"), + ("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto"), + ("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto"), + ("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral"), + ("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto"), #FIXME: This test is flaky on CI thus disabled - #("Qwen/Qwen2.5-1.5B-Instruct", "guidance:disable-any-whitespace", "auto"), + #("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"), ] PARAMS_MODELS_TOKENIZER_MODE = [ @@ -73,6 +70,7 @@ def test_structured_output( enforce_eager=enforce_eager, max_model_len=1024, guided_decoding_backend=guided_decoding_backend, + guided_decoding_disable_any_whitespace=True, tokenizer_mode=tokenizer_mode) # @@ -98,8 +96,7 @@ def test_structured_output( generated_text = output.outputs[0].text assert generated_text is not None - if 'disable-any-whitespace' in guided_decoding_backend: - assert "\n" not in generated_text + assert "\n" not in generated_text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") output_json = json.loads(generated_text) jsonschema.validate(instance=output_json, schema=sample_json_schema) @@ -520,10 +517,11 @@ def test_structured_output_auto_mode( def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("VLLM_USE_V1", "1") - backend = 'guidance:no-additional-properties,disable-any-whitespace' llm = LLM(model="Qwen/Qwen2.5-1.5B-Instruct", max_model_len=1024, - guided_decoding_backend=backend) + guided_decoding_backend="guidance", + guided_decoding_disable_any_whitespace=True, + guided_decoding_disable_additional_properties=True) schema = { 'type': 'object', @@ -548,7 +546,11 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch): "<|im_end|>\n<|im_start|>assistant\n") def generate_with_backend(backend): - guided_params = GuidedDecodingParams(json=schema, backend=backend) + guided_params = GuidedDecodingParams( + json=schema, + backend=backend, + disable_any_whitespace=True, + disable_additional_properties=True) sampling_params = SamplingParams(temperature=0, max_tokens=256, guided_decoding=guided_params) @@ -562,8 +564,7 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch): jsonschema.validate(instance=parsed_json, schema=schema) return parsed_json - generated = generate_with_backend( - 'guidance:no-additional-properties,disable-any-whitespace') + generated = generate_with_backend("guidance") assert "a1" in generated assert "a2" in generated assert "a3" in generated diff --git a/tests/v1/test_oracle.py b/tests/v1/test_oracle.py index 1448641f6a570..94c8ad7c94f64 100644 --- a/tests/v1/test_oracle.py +++ b/tests/v1/test_oracle.py @@ -57,7 +57,8 @@ def test_unsupported_configs(monkeypatch): with pytest.raises(NotImplementedError): AsyncEngineArgs( model=MODEL, - guided_decoding_backend="lm-format-enforcer:no-fallback", + guided_decoding_backend="lm-format-enforcer", + guided_decoding_disable_fallback=True, ).create_engine_config() with pytest.raises(NotImplementedError): diff --git a/vllm/config.py b/vllm/config.py index 8f927835d2d45..abe59734e2d6b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -17,12 +17,14 @@ from dataclasses import (MISSING, dataclass, field, fields, is_dataclass, from importlib.util import find_spec from pathlib import Path from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal, - Optional, Protocol, TypeVar, Union, get_args, get_origin) + Optional, Protocol, TypeVar, Union, cast, get_args, + get_origin) import torch from pydantic import BaseModel, Field, PrivateAttr from torch.distributed import ProcessGroup, ReduceOp from transformers import PretrainedConfig +from typing_extensions import deprecated import vllm.envs as envs from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass @@ -32,7 +34,6 @@ from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS, get_quantization_config) from vllm.model_executor.models import ModelRegistry from vllm.platforms import CpuArchEnum, current_platform -from vllm.sampling_params import GuidedDecodingParams from vllm.tracing import is_otel_available, otel_import_error_traceback from vllm.transformers_utils.config import ( ConfigFormat, get_config, get_hf_image_processor_config, @@ -344,7 +345,7 @@ class ModelConfig: def __init__( self, model: str, - task: Union[TaskOption, Literal["draft"]], + task: Literal[TaskOption, Literal["draft"]], tokenizer: str, tokenizer_mode: str, trust_remote_code: bool, @@ -701,7 +702,7 @@ class ModelConfig: def _resolve_task( self, - task_option: Union[TaskOption, Literal["draft"]], + task_option: Literal[TaskOption, Literal["draft"]], ) -> tuple[set[_ResolvedTask], _ResolvedTask]: if task_option == "draft": return {"draft"}, "draft" @@ -3185,13 +3186,36 @@ GuidedDecodingBackend = Literal[GuidedDecodingBackendV0, class DecodingConfig: """Dataclass which contains the decoding strategy of the engine.""" - guided_decoding_backend: GuidedDecodingBackend = \ - "auto" if envs.VLLM_USE_V1 else "xgrammar" + @property + @deprecated( + "`guided_decoding_backend` is deprecated and has been renamed to " + "`backend`. This will be removed in v0.10.0. Please use the " + "`backend` argument instead.") + def guided_decoding_backend(self) -> GuidedDecodingBackend: + return self.backend + + @guided_decoding_backend.setter + def guided_decoding_backend(self, value: GuidedDecodingBackend): + self.backend = value + + backend: GuidedDecodingBackend = "auto" if envs.VLLM_USE_V1 else "xgrammar" """Which engine will be used for guided decoding (JSON schema / regex etc) by default. With "auto", we will make opinionated choices based on request contents and what the backend libraries currently support, so the behavior is subject to change in each release.""" + disable_fallback: bool = False + """If `True`, vLLM will not fallback to a different backend on error.""" + + disable_any_whitespace: bool = False + """If `True`, the model will not generate any whitespace during guided + decoding. This is only supported for xgrammar and guidance backends.""" + + disable_additional_properties: bool = False + """If `True`, the `guidance` backend will not use `additionalProperties` + in the JSON schema. This is only supported for the `guidance` backend and + is used to better align its behaviour with `outlines` and `xgrammar`.""" + reasoning_backend: Optional[str] = None """Select the reasoning parser depending on the model that you're using. This is used to parse the reasoning content into OpenAI API format. @@ -3217,15 +3241,41 @@ class DecodingConfig: return hash_str def __post_init__(self): - backend = GuidedDecodingParams( - backend=self.guided_decoding_backend).backend_name + if ":" in self.backend: + self._extract_backend_options() + if envs.VLLM_USE_V1: valid_guided_backends = get_args(GuidedDecodingBackendV1) else: valid_guided_backends = get_args(GuidedDecodingBackendV0) - if backend not in valid_guided_backends: - raise ValueError(f"Invalid guided_decoding_backend '{backend}'," + if self.backend not in valid_guided_backends: + raise ValueError(f"Invalid backend '{self.backend}'," f" must be one of {valid_guided_backends}") + if (self.disable_any_whitespace + and self.backend not in ("xgrammar", "guidance")): + raise ValueError("disable_any_whitespace is only supported for " + "xgrammar and guidance backends.") + if (self.disable_additional_properties and self.backend != "guidance"): + raise ValueError("disable_additional_properties is only supported " + "for the guidance backend.") + + @deprecated( + "Passing guided decoding backend options inside backend in the format " + "'backend:...' is deprecated. This will be removed in v0.10.0. Please " + "use the dedicated arguments '--disable-fallback', " + "'--disable-any-whitespace' and '--disable-additional-properties' " + "instead.") + def _extract_backend_options(self): + """Extract backend options from the backend string.""" + backend, options = self.backend.split(":") + self.backend = cast(GuidedDecodingBackend, backend) + options_set = set(options.strip().split(",")) + if "no-fallback" in options_set: + self.disable_fallback = True + if "disable-any-whitespace" in options_set: + self.disable_any_whitespace = True + if "no-additional-properties" in options_set: + self.disable_additional_properties = True @dataclass diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index fe688025f9b1d..be0cd4d3a20da 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -18,9 +18,9 @@ from vllm import version from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig, ConfigFormat, ConfigType, DecodingConfig, Device, DeviceConfig, DistributedExecutorBackend, - GuidedDecodingBackendV1, HfOverrides, - KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig, - ModelConfig, ModelImpl, MultiModalConfig, + GuidedDecodingBackend, GuidedDecodingBackendV1, + HfOverrides, KVTransferConfig, LoadConfig, LoadFormat, + LoRAConfig, ModelConfig, ModelImpl, MultiModalConfig, ObservabilityConfig, ParallelConfig, PoolerConfig, PrefixCachingHashAlgo, PromptAdapterConfig, SchedulerConfig, SchedulerPolicy, SpeculativeConfig, @@ -317,7 +317,12 @@ class EngineArgs: bool] = SchedulerConfig.enable_chunked_prefill disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input - guided_decoding_backend: str = DecodingConfig.guided_decoding_backend + guided_decoding_backend: GuidedDecodingBackend = DecodingConfig.backend + guided_decoding_disable_fallback: bool = DecodingConfig.disable_fallback + guided_decoding_disable_any_whitespace: bool = \ + DecodingConfig.disable_any_whitespace + guided_decoding_disable_additional_properties: bool = \ + DecodingConfig.disable_additional_properties logits_processor_pattern: Optional[str] = None speculative_config: Optional[Dict[str, Any]] = None @@ -498,9 +503,17 @@ class EngineArgs: title="DecodingConfig", description=DecodingConfig.__doc__, ) + guided_decoding_group.add_argument("--guided-decoding-backend", + **guided_decoding_kwargs["backend"]) guided_decoding_group.add_argument( - '--guided-decoding-backend', - **guided_decoding_kwargs["guided_decoding_backend"]) + "--guided-decoding-disable-fallback", + **guided_decoding_kwargs["disable_fallback"]) + guided_decoding_group.add_argument( + "--guided-decoding-disable-any-whitespace", + **guided_decoding_kwargs["disable_any_whitespace"]) + guided_decoding_group.add_argument( + "--guided-decoding-disable-additional-properties", + **guided_decoding_kwargs["disable_additional_properties"]) guided_decoding_group.add_argument( "--reasoning-parser", # This choices is a special case because it's not static @@ -1244,7 +1257,11 @@ class EngineArgs: if self.enable_prompt_adapter else None decoding_config = DecodingConfig( - guided_decoding_backend=self.guided_decoding_backend, + backend=self.guided_decoding_backend, + disable_fallback=self.guided_decoding_disable_fallback, + disable_any_whitespace=self.guided_decoding_disable_any_whitespace, + disable_additional_properties=\ + self.guided_decoding_disable_additional_properties, reasoning_backend=self.reasoning_parser if self.enable_reasoning else None, ) @@ -1335,9 +1352,8 @@ class EngineArgs: recommend_to_remove=True) return False - # remove backend options when doing this check - if self.guided_decoding_backend.split(':')[0] \ - not in get_args(GuidedDecodingBackendV1): + if self.guided_decoding_backend not in get_args( + GuidedDecodingBackendV1): _raise_or_fallback( feature_name= f"--guided-decoding-backend={self.guided_decoding_backend}", diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index c235309906116..38f13d859e589 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -2091,7 +2091,7 @@ class LLMEngine: tokenizer = self.get_tokenizer(lora_request=lora_request) guided_decoding.backend = guided_decoding.backend or \ - self.decoding_config.guided_decoding_backend + self.decoding_config.backend if self.decoding_config.reasoning_backend is not None: logger.debug("Building with reasoning backend %s", diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index eb3ae89394ecc..d23a4c6ed598e 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -615,9 +615,9 @@ class MQLLMEngineClient(EngineClient): build_guided_decoding_logits_processor_async( sampling_params=params, tokenizer=await self.get_tokenizer(lora_request), - default_guided_backend=(self.decoding_config.guided_decoding_backend + default_guided_backend=(self.decoding_config.backend if self.decoding_config - else DecodingConfig.guided_decoding_backend), + else DecodingConfig.backend), model_config=self.model_config, reasoning_backend=self.decoding_config.reasoning_backend, ) diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index 8fdcdcafa9806..4e4d697f49a95 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -26,8 +26,8 @@ def maybe_backend_fallback( def fallback_or_error(guided_params: GuidedDecodingParams, message: str, fallback: str) -> None: """Change the backend to the specified fallback with a warning log, - or raise a ValueError if the `no-fallback` option is specified.""" - if guided_params.no_fallback(): + or raise a ValueError if the `disable_fallback` option is specified.""" + if guided_params.disable_fallback: raise ValueError(message) logger.warning("%s Falling back to use %s instead.", message, fallback) @@ -40,7 +40,7 @@ def maybe_backend_fallback( guided_params.backend = "xgrammar" # lm-format-enforce doesn't support grammar, fallback to xgrammar - if guided_params.backend_name == "lm-format-enforcer": + if guided_params.backend == "lm-format-enforcer": if guided_params.grammar is not None: fallback_or_error( guided_params, @@ -55,7 +55,7 @@ def maybe_backend_fallback( "lm-format-enforcer does not support advanced JSON schema " "features like patterns or numeric ranges.", "outlines") - if guided_params.backend_name == "xgrammar": + if guided_params.backend == "xgrammar": from vllm.model_executor.guided_decoding.xgrammar_decoding import ( xgr_installed) @@ -87,7 +87,7 @@ def maybe_backend_fallback( guided_params, "xgrammar module cannot be imported successfully.", "outlines") - if (guided_params.backend_name == "outlines" + if (guided_params.backend == "outlines" and guided_params.json_object is not None): # outlines doesn't support json_object, fallback to guidance fallback_or_error(guided_params, @@ -111,7 +111,7 @@ async def get_guided_decoding_logits_processor( guided_params = maybe_backend_fallback(guided_params) # CFG grammar not supported by LMFE, so we use outlines instead - if guided_params.backend_name == 'outlines': + if guided_params.backend == 'outlines': # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa get_outlines_guided_decoding_logits_processor) @@ -122,12 +122,12 @@ async def get_guided_decoding_logits_processor( get_local_lm_format_enforcer_guided_decoding_logits_processor) return get_local_lm_format_enforcer_guided_decoding_logits_processor( guided_params, tokenizer) - if guided_params.backend_name == 'xgrammar': + if guided_params.backend == 'xgrammar': from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa get_local_xgrammar_guided_decoding_logits_processor) return get_local_xgrammar_guided_decoding_logits_processor( guided_params, tokenizer, model_config, reasoner) - if guided_params.backend_name == 'guidance': + if guided_params.backend == 'guidance': from vllm.model_executor.guided_decoding.guidance_decoding import ( get_local_guidance_guided_decoding_logits_processor) return get_local_guidance_guided_decoding_logits_processor( @@ -152,23 +152,23 @@ def get_local_guided_decoding_logits_processor( reasoner = reasoner_class(tokenizer) # CFG grammar not supported by LMFE, so we use outlines instead - if guided_params.backend_name == 'outlines': + if guided_params.backend == 'outlines': # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa get_local_outlines_guided_decoding_logits_processor) return get_local_outlines_guided_decoding_logits_processor( guided_params, tokenizer, reasoner) - if guided_params.backend_name == 'lm-format-enforcer': + if guided_params.backend == 'lm-format-enforcer': from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa get_local_lm_format_enforcer_guided_decoding_logits_processor) return get_local_lm_format_enforcer_guided_decoding_logits_processor( guided_params, tokenizer) - if guided_params.backend_name == 'xgrammar': + if guided_params.backend == 'xgrammar': from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa get_local_xgrammar_guided_decoding_logits_processor) return get_local_xgrammar_guided_decoding_logits_processor( guided_params, tokenizer, model_config, reasoner) - if guided_params.backend_name == 'guidance': + if guided_params.backend == 'guidance': from vllm.model_executor.guided_decoding.guidance_decoding import ( get_local_guidance_guided_decoding_logits_processor) return get_local_guidance_guided_decoding_logits_processor( diff --git a/vllm/model_executor/guided_decoding/guidance_decoding.py b/vllm/model_executor/guided_decoding/guidance_decoding.py index 95b7c71107aab..0b1f4762bc730 100644 --- a/vllm/model_executor/guided_decoding/guidance_decoding.py +++ b/vllm/model_executor/guided_decoding/guidance_decoding.py @@ -21,13 +21,12 @@ def get_local_guidance_guided_decoding_logits_processor( """ grm = "" - any_whitespace = 'disable-any-whitespace' not in \ - guided_params.backend_options() + any_whitespace = not guided_params.disable_any_whitespace if (guide_json := guided_params.json) is not None: # Optionally set additionalProperties to False at the top-level # By default, other backends do not allow additional top-level # properties, so this makes guidance more similar to other backends - if 'no-additional-properties' in guided_params.backend_options(): + if guided_params.disable_additional_properties: if not isinstance(guide_json, str): guide_json = json.dumps(guide_json) guide_json = process_for_additional_properties(guide_json) diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py index ff223c3c9b83e..40f722410ab03 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -175,8 +175,7 @@ class GrammarConfig: else: json_str = guided_params.json - any_whitespace = 'disable-any-whitespace' not in \ - guided_params.backend_options() + any_whitespace = not guided_params.disable_any_whitespace # Check and log if model with xgrammar and whitespace have history # of runaway generation of whitespaces. @@ -191,11 +190,10 @@ class GrammarConfig: model_with_warn = 'Qwen' if model_with_warn is not None and any_whitespace: - msg = (f"{model_with_warn} " - f"model detected, consider set " - f"`guided_backend=xgrammar:disable-any-whitespace` " - f"to prevent runaway generation of whitespaces.") - logger.info_once(msg) + logger.info_once( + "%s model detected, consider setting " + "`disable_any_whitespace` to prevent runaway generation " + "of whitespaces.", model_with_warn) # Validate the schema and raise ValueError here if it is invalid. # This is to avoid exceptions in model execution, which will crash # the engine worker process. diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index c430b74a9db9a..511571d05b7ad 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -8,6 +8,7 @@ from typing import Annotated, Any, Optional, Union import msgspec from pydantic import BaseModel +from typing_extensions import deprecated from vllm.logger import init_logger from vllm.logits_process import LogitsProcessor @@ -37,6 +38,10 @@ class GuidedDecodingParams: json_object: Optional[bool] = None """These are other options that can be set""" backend: Optional[str] = None + backend_was_auto: bool = False + disable_fallback: bool = False + disable_any_whitespace: bool = False + disable_additional_properties: bool = False whitespace_pattern: Optional[str] = None structural_tag: Optional[str] = None @@ -68,36 +73,6 @@ class GuidedDecodingParams: structural_tag=structural_tag, ) - @property - def backend_name(self) -> str: - """Return the backend name without any options. - - For example if the backend is "xgrammar:no-fallback", returns "xgrammar" - """ - return (self.backend or "").split(":")[0] - - def backend_options(self) -> list[str]: - """Return the backend options as a list of strings.""" - if not self.backend or ":" not in self.backend: - return [] - return self.backend.split(":")[1].split(",") - - def add_option(self, opt_name: str) -> None: - """Adds an option to the backend options.""" - if not self.backend: - self.backend = f":{opt_name}" - elif ":" not in self.backend: - self.backend += f":{opt_name}" - else: - options = set(self.backend_options()) - options.add(opt_name) - self.backend = f"{self.backend_name}:{','.join(sorted(options))}" - - def no_fallback(self) -> bool: - """Returns True if the "no-fallback" option is supplied for the guided - decoding backend""" - return "no-fallback" in self.backend_options() - def __post_init__(self): """Validate that some fields are mutually exclusive.""" guide_count = sum([ @@ -109,6 +84,27 @@ class GuidedDecodingParams: "You can only use one kind of guided decoding but multiple are " f"specified: {self.__dict__}") + if self.backend is not None and ":" in self.backend: + self._extract_backend_options() + + @deprecated( + "Passing guided decoding backend options inside backend in the format " + "'backend:...' is deprecated. This will be removed in v0.10.0. Please " + "use the dedicated arguments '--disable-fallback', " + "'--disable-any-whitespace' and '--disable-additional-properties' " + "instead.") + def _extract_backend_options(self): + """Extract backend options from the backend string.""" + assert isinstance(self.backend, str) + self.backend, options = self.backend.split(":") + options_set = set(options.strip().split(",")) + if "no-fallback" in options_set: + self.disable_fallback = True + if "disable-any-whitespace" in options_set: + self.disable_any_whitespace = True + if "no-additional-properties" in options_set: + self.disable_additional_properties = True + class RequestOutputKind(Enum): # Return entire output so far in every RequestOutput diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 5c15e8baef2bf..8ae5d01574c29 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -144,7 +144,7 @@ class Processor: if not params.guided_decoding or not self.decoding_config: return - engine_level_backend = self.decoding_config.guided_decoding_backend + engine_level_backend = self.decoding_config.backend if params.guided_decoding.backend: # Request-level backend selection is not supported in V1. # The values may differ if `params` is reused and was set @@ -152,8 +152,8 @@ class Processor: # request. We remember that it was set as a result of `auto` # using the `_auto` option set on the backend in the params. if (params.guided_decoding.backend != engine_level_backend - and not (engine_level_backend == "auto" and "_auto" - in params.guided_decoding.backend_options())): + and not (engine_level_backend == "auto" + and params.guided_decoding.backend_was_auto)): raise ValueError( "Request-level structured output backend selection is no " "longer supported. The request specified " @@ -189,7 +189,7 @@ class Processor: # are not supported in xgrammar. Fall back to guidance. params.guided_decoding.backend = "guidance" # Remember that this backend was set automatically - params.guided_decoding.add_option("_auto") + params.guided_decoding.backend_was_auto = True def process_inputs( self, diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 0fd66c0729602..47ae4c4f03ee9 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -45,17 +45,17 @@ class StructuredOutputManager: # NOTE: We only support a single backend. We do NOT support different # backends on a per-request basis in V1 (for now, anyway...). if self.backend is None: - backend_name = request.sampling_params.guided_decoding.backend_name - if backend_name == "xgrammar": + backend = request.sampling_params.guided_decoding.backend + if backend == "xgrammar": from vllm.v1.structured_output.backend_xgrammar import ( XgrammarBackend) self.backend = XgrammarBackend(self.vllm_config) - elif backend_name == "guidance": + elif backend == "guidance": self.backend = GuidanceBackend(self.vllm_config) else: raise ValueError( - f"Unsupported structured output backend: {backend_name}") + f"Unsupported structured output backend: {backend}") grammar = self.executor.submit(self._async_create_grammar, request) request.structured_output_request.grammar = grammar # type: ignore[assignment] diff --git a/vllm/v1/structured_output/backend_guidance.py b/vllm/v1/structured_output/backend_guidance.py index d4dc5e681e45c..8fb3e56bcb956 100644 --- a/vllm/v1/structured_output/backend_guidance.py +++ b/vllm/v1/structured_output/backend_guidance.py @@ -10,7 +10,7 @@ import torch from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.sampling_params import GuidedDecodingParams, SamplingParams +from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.utils import LazyLoader from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, @@ -65,19 +65,10 @@ class GuidanceBackend(StructuredOutputBackend): self.vllm_config = vllm_config self.vocab_size = vllm_config.model_config.get_vocab_size() - self.disable_any_whitespace = False - self.no_additional_properties = False - backend_options = GuidedDecodingParams( - backend=vllm_config.decoding_config.guided_decoding_backend - ).backend_options() - for option in backend_options: - if option == "disable-any-whitespace": - self.disable_any_whitespace = True - elif option == "no-additional-properties": - self.no_additional_properties = True - else: - raise ValueError( - f"Unsupported option for the guidance backend: {option}") + self.disable_any_whitespace = \ + vllm_config.decoding_config.disable_any_whitespace + self.disable_additional_properties = \ + vllm_config.decoding_config.disable_additional_properties tokenizer = tokenizer_group.get_lora_tokenizer(None) self.ll_tokenizer = llguidance_hf.from_tokenizer( @@ -87,7 +78,7 @@ class GuidanceBackend(StructuredOutputBackend): grammar_spec: str) -> StructuredOutputGrammar: self.serialized_grammar = serialize_guidance_grammar( request_type, grammar_spec, self.disable_any_whitespace, - self.no_additional_properties) + self.disable_additional_properties) ll_matcher = llguidance.LLMatcher( self.ll_tokenizer, @@ -171,11 +162,11 @@ def serialize_guidance_grammar( request_type: StructuredOutputOptions, grammar_spec: Union[str, dict[str, Any]], disable_any_whitespace: bool = False, - no_additional_properties: bool = False, + disable_additional_properties: bool = False, ) -> str: def _process_schema(grammar_spec: Union[str, dict[str, Any]], ) -> str: - if no_additional_properties: + if disable_additional_properties: grammar_spec = process_for_additional_properties(grammar_spec) return llguidance.LLMatcher.grammar_from_json_schema( grammar_spec, diff --git a/vllm/v1/structured_output/backend_xgrammar.py b/vllm/v1/structured_output/backend_xgrammar.py index ecaeb6e4ee806..50a7d1683acd9 100644 --- a/vllm/v1/structured_output/backend_xgrammar.py +++ b/vllm/v1/structured_output/backend_xgrammar.py @@ -9,7 +9,7 @@ import torch import vllm.envs from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.sampling_params import GuidedDecodingParams, SamplingParams +from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from vllm.utils import LazyLoader @@ -37,16 +37,8 @@ class XgrammarBackend(StructuredOutputBackend): scheduler_config=vllm_config.scheduler_config, lora_config=vllm_config.lora_config) # type: ignore[arg-type] - self.disable_any_whitespace = False - backend_options = GuidedDecodingParams( - backend=vllm_config.decoding_config.guided_decoding_backend - ).backend_options() - for option in backend_options: - if option == "disable-any-whitespace": - self.disable_any_whitespace = True - else: - raise ValueError( - f"Unsupported option for the xgrammar backend: {option}") + self.disable_any_whitespace = \ + vllm_config.decoding_config.disable_any_whitespace tokenizer = tokenizer_group.get_lora_tokenizer(None) self.vocab_size = vllm_config.model_config.get_vocab_size()