mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-24 22:34:32 +08:00
[Frontend] Add backend-specific options for guided decoding (#13505)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
parent
6a417b8600
commit
bfbc0b32c6
@ -16,7 +16,7 @@ The following parameters are supported, which must be added as extra parameters:
|
||||
- `guided_json`: the output will follow the JSON schema.
|
||||
- `guided_grammar`: the output will follow the context free grammar.
|
||||
- `guided_whitespace_pattern`: used to override the default whitespace pattern for guided json decoding.
|
||||
- `guided_decoding_backend`: used to select the guided decoding backend to use.
|
||||
- `guided_decoding_backend`: used to select the guided decoding backend to use. Additional backend-specific options can be supplied in a comma separated list following a colon after the backend name. For example `"xgrammar:no-fallback"` will not allow vLLM to fallback to a different backend on error.
|
||||
|
||||
You can see the complete list of supported parameters on the [OpenAI-Compatible Server](#openai-compatible-server)page.
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from openai import OpenAI
|
||||
from openai import BadRequestError, OpenAI
|
||||
from pydantic import BaseModel
|
||||
|
||||
client = OpenAI(
|
||||
@ -94,3 +94,26 @@ completion = client.chat.completions.create(
|
||||
extra_body={"guided_grammar": simplified_sql_grammar},
|
||||
)
|
||||
print(completion.choices[0].message.content)
|
||||
|
||||
# Extra backend options
|
||||
prompt = ("Generate an email address for Alan Turing, who works in Enigma."
|
||||
"End in .com and new line. Example result:"
|
||||
"alan.turing@enigma.com\n")
|
||||
|
||||
try:
|
||||
# The no-fallback option forces vLLM to use xgrammar, so when it fails
|
||||
# you get a 400 with the reason why
|
||||
completion = client.chat.completions.create(
|
||||
model="Qwen/Qwen2.5-3B-Instruct",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}],
|
||||
extra_body={
|
||||
"guided_regex": "\w+@\w+\.com\n",
|
||||
"stop": ["\n"],
|
||||
"guided_decoding_backend": "xgrammar:no-fallback"
|
||||
},
|
||||
)
|
||||
except BadRequestError as e:
|
||||
print("This error is expected:", e)
|
||||
|
||||
@ -280,6 +280,22 @@ def test_validation_against_both_guided_decoding_options(sample_regex, llm):
|
||||
guided_options_request=dict(guided_regex=sample_regex))
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_disable_guided_decoding_fallback(sample_regex, llm):
|
||||
sampling_params = SamplingParams(temperature=0.8,
|
||||
top_p=0.95,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
regex=sample_regex,
|
||||
backend="xgrammar:no-fallback"))
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="xgrammar does not support regex guided decoding"):
|
||||
llm.generate(prompts="This should fail",
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
|
||||
def test_guided_json_object(llm, guided_decoding_backend: str):
|
||||
|
||||
@ -109,6 +109,16 @@ def test_multiple_guided_options_not_allowed(sample_json_schema, sample_regex):
|
||||
GuidedDecodingParams(json=sample_json_schema, grammar="test grammar")
|
||||
|
||||
|
||||
def test_guided_decoding_backend_options():
|
||||
"""Test backend-specific options"""
|
||||
params = GuidedDecodingParams(
|
||||
backend="xgrammar:option-1,option-2,option-3")
|
||||
assert params.backend_options() == ["option-1", "option-2", "option-3"]
|
||||
|
||||
no_fallback = GuidedDecodingParams(backend="xgrammar:option-1,no-fallback")
|
||||
assert no_fallback.no_fallback()
|
||||
|
||||
|
||||
def test_pickle_xgrammar_tokenizer_data():
|
||||
|
||||
# TODO: move to another test file for xgrammar
|
||||
|
||||
@ -25,6 +25,7 @@ from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
|
||||
get_quantization_config)
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.platforms import CpuArchEnum
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
from vllm.tracing import is_otel_available, otel_import_error_traceback
|
||||
from vllm.transformers_utils.config import (
|
||||
ConfigFormat, get_config, get_hf_image_processor_config,
|
||||
@ -2631,7 +2632,9 @@ class DecodingConfig:
|
||||
|
||||
def __post_init__(self):
|
||||
valid_guided_backends = ['outlines', 'lm-format-enforcer', 'xgrammar']
|
||||
backend = self.guided_decoding_backend
|
||||
|
||||
backend = GuidedDecodingParams(
|
||||
backend=self.guided_decoding_backend).backend_name
|
||||
if backend not in valid_guided_backends:
|
||||
raise ValueError(f"Invalid guided_decoding_backend '{backend},"
|
||||
f"must be one of {valid_guided_backends}")
|
||||
|
||||
@ -372,14 +372,17 @@ class EngineArgs:
|
||||
'--guided-decoding-backend',
|
||||
type=str,
|
||||
default='xgrammar',
|
||||
choices=['outlines', 'lm-format-enforcer', 'xgrammar'],
|
||||
help='Which engine will be used for guided decoding'
|
||||
' (JSON schema / regex etc) by default. Currently support '
|
||||
'https://github.com/outlines-dev/outlines, '
|
||||
'https://github.com/mlc-ai/xgrammar, and '
|
||||
'https://github.com/noamgat/lm-format-enforcer.'
|
||||
' Can be overridden per request via guided_decoding_backend'
|
||||
' parameter.')
|
||||
' parameter.\n'
|
||||
'Backend-sepcific options can be supplied in a comma-separated '
|
||||
'list following a colon after the backend name. Valid backends and '
|
||||
'all available options are: [xgrammar:no-fallback, '
|
||||
'outlines:no-fallback, lm-format-enforcer:no-fallback]')
|
||||
parser.add_argument(
|
||||
'--logits-processor-pattern',
|
||||
type=nullable_str,
|
||||
|
||||
@ -22,47 +22,56 @@ logger = init_logger(__name__)
|
||||
|
||||
def maybe_backend_fallback(
|
||||
guided_params: GuidedDecodingParams) -> GuidedDecodingParams:
|
||||
|
||||
def fallback_or_error(guided_params: GuidedDecodingParams, message: str,
|
||||
fallback: str) -> None:
|
||||
"""Change the backend to the specified fallback with a warning log,
|
||||
or raise a ValueError if the `no-fallback` option is specified."""
|
||||
if guided_params.no_fallback():
|
||||
raise ValueError(message)
|
||||
|
||||
logger.warning("%s Falling back to use %s instead.", message, fallback)
|
||||
guided_params.backend = fallback
|
||||
|
||||
# lm-format-enforce doesn't support grammar, fallback to xgrammar
|
||||
if guided_params.backend == "lm-format-enforcer":
|
||||
if guided_params.backend_name == "lm-format-enforcer":
|
||||
if guided_params.grammar is not None:
|
||||
logger.warning(
|
||||
"lm-format-enforcer does not support grammar guided decoding. "
|
||||
"Falling back to use xgrammar instead.")
|
||||
guided_params.backend = "xgrammar"
|
||||
fallback_or_error(
|
||||
guided_params,
|
||||
"lm-format-enforcer does not support grammar guided decoding.",
|
||||
"xgrammar")
|
||||
|
||||
# lm-format-enforcer doesn't support some JSON schema features
|
||||
elif (guided_params.json is not None
|
||||
and has_lmf_unsupported_json_features(guided_params.json)):
|
||||
logger.warning(
|
||||
fallback_or_error(
|
||||
guided_params,
|
||||
"lm-format-enforcer does not support advanced JSON schema "
|
||||
"features like patterns or numeric ranges. "
|
||||
"Falling back to use outlines instead.")
|
||||
guided_params.backend = "outlines"
|
||||
"features like patterns or numeric ranges.", "outlines")
|
||||
|
||||
if guided_params.backend == "xgrammar":
|
||||
if guided_params.backend_name == "xgrammar":
|
||||
from vllm.model_executor.guided_decoding.xgrammar_decoding import (
|
||||
xgr_installed)
|
||||
# xgrammar only has x86 wheels for linux, fallback to outlines
|
||||
from vllm.platforms import current_platform
|
||||
if current_platform.get_cpu_architecture() is not CpuArchEnum.X86:
|
||||
logger.warning("xgrammar is only supported on x86 CPUs. "
|
||||
"Falling back to use outlines instead.")
|
||||
guided_params.backend = "outlines"
|
||||
fallback_or_error(guided_params,
|
||||
"xgrammar is only supported on x86 CPUs.",
|
||||
"outlines")
|
||||
|
||||
# xgrammar doesn't support regex, fallback to outlines
|
||||
if guided_params.regex is not None:
|
||||
logger.warning("xgrammar does not support regex guided decoding. "
|
||||
"Falling back to use outlines instead.")
|
||||
guided_params.backend = "outlines"
|
||||
fallback_or_error(
|
||||
guided_params,
|
||||
"xgrammar does not support regex guided decoding.", "outlines")
|
||||
|
||||
# xgrammar doesn't support some JSON schema features
|
||||
elif (guided_params.json is not None
|
||||
and has_xgrammar_unsupported_json_features(guided_params.json)):
|
||||
logger.warning(
|
||||
fallback_or_error(
|
||||
guided_params,
|
||||
"xgrammar does not support advanced JSON schema features like "
|
||||
"patterns or numeric ranges. "
|
||||
"Falling back to use outlines instead.")
|
||||
guided_params.backend = "outlines"
|
||||
"enums, patterns or numeric ranges.", "outlines")
|
||||
|
||||
# xgrammar only supports GBNF grammars, so we must convert Lark.
|
||||
# We must check if the grammar is likely Lark and if that
|
||||
@ -72,25 +81,23 @@ def maybe_backend_fallback(
|
||||
try:
|
||||
convert_lark_to_gbnf(guided_params.grammar)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
fallback_or_error(
|
||||
guided_params,
|
||||
"xgrammar does not support Lark grammars and the "
|
||||
"grammar failed to convert to GBNF. "
|
||||
"Falling back to use outlines instead.")
|
||||
guided_params.backend = "outlines"
|
||||
"grammar failed to convert to GBNF.", "outlines")
|
||||
|
||||
# If the xgrammar module cannot be imported successfully,
|
||||
# we should still allow users to use guided decoding with a fallback.
|
||||
elif not xgr_installed:
|
||||
logger.warning("xgrammar module cannot be imported successfully. "
|
||||
"Falling back to use outlines instead.")
|
||||
guided_params.backend = "outlines"
|
||||
fallback_or_error(
|
||||
guided_params,
|
||||
"xgrammar module cannot be imported successfully.", "outlines")
|
||||
|
||||
if (guided_params.backend == "outlines"
|
||||
if (guided_params.backend_name == "outlines"
|
||||
and guided_params.json_object is not None):
|
||||
# outlines doesn't support json_object, fallback to xgrammar
|
||||
logger.warning("outlines does not support json_object. "
|
||||
"Falling back to use xgrammar instead.")
|
||||
guided_params.backend = "xgrammar"
|
||||
fallback_or_error(guided_params,
|
||||
"outlines does not support json_object.", "xgrammar")
|
||||
|
||||
return guided_params
|
||||
|
||||
@ -100,18 +107,18 @@ async def get_guided_decoding_logits_processor(
|
||||
model_config: ModelConfig) -> LogitsProcessor | None:
|
||||
guided_params = maybe_backend_fallback(guided_params)
|
||||
# CFG grammar not supported by LMFE, so we use outlines instead
|
||||
if guided_params.backend == 'outlines':
|
||||
if guided_params.backend_name == 'outlines':
|
||||
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
|
||||
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
|
||||
get_outlines_guided_decoding_logits_processor)
|
||||
return await get_outlines_guided_decoding_logits_processor(
|
||||
guided_params, tokenizer)
|
||||
if guided_params.backend == 'lm-format-enforcer':
|
||||
if guided_params.backend_name == 'lm-format-enforcer':
|
||||
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
|
||||
get_local_lm_format_enforcer_guided_decoding_logits_processor)
|
||||
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
|
||||
guided_params, tokenizer)
|
||||
if guided_params.backend == 'xgrammar':
|
||||
if guided_params.backend_name == 'xgrammar':
|
||||
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
|
||||
get_local_xgrammar_guided_decoding_logits_processor)
|
||||
return get_local_xgrammar_guided_decoding_logits_processor(
|
||||
@ -127,18 +134,18 @@ def get_local_guided_decoding_logits_processor(
|
||||
model_config: ModelConfig) -> LogitsProcessor | None:
|
||||
guided_params = maybe_backend_fallback(guided_params)
|
||||
# CFG grammar not supported by LMFE, so we use outlines instead
|
||||
if guided_params.backend == 'outlines':
|
||||
if guided_params.backend_name == 'outlines':
|
||||
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
|
||||
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
|
||||
get_local_outlines_guided_decoding_logits_processor)
|
||||
return get_local_outlines_guided_decoding_logits_processor(
|
||||
guided_params, tokenizer)
|
||||
if guided_params.backend == 'lm-format-enforcer':
|
||||
if guided_params.backend_name == 'lm-format-enforcer':
|
||||
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
|
||||
get_local_lm_format_enforcer_guided_decoding_logits_processor)
|
||||
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
|
||||
guided_params, tokenizer)
|
||||
if guided_params.backend == 'xgrammar':
|
||||
if guided_params.backend_name == 'xgrammar':
|
||||
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
|
||||
get_local_xgrammar_guided_decoding_logits_processor)
|
||||
return get_local_xgrammar_guided_decoding_logits_processor(
|
||||
|
||||
@ -64,6 +64,25 @@ class GuidedDecodingParams:
|
||||
whitespace_pattern=whitespace_pattern,
|
||||
)
|
||||
|
||||
@property
|
||||
def backend_name(self) -> str:
|
||||
"""Return the backend name without any options.
|
||||
|
||||
For example if the backend is "xgrammar:no-fallback", returns "xgrammar"
|
||||
"""
|
||||
return (self.backend or "").split(":")[0]
|
||||
|
||||
def backend_options(self) -> List[str]:
|
||||
"""Return the backend options as a list of strings."""
|
||||
if not self.backend or ":" not in self.backend:
|
||||
return []
|
||||
return self.backend.split(":")[1].split(",")
|
||||
|
||||
def no_fallback(self) -> bool:
|
||||
"""Returns True if the "no-fallback" option is supplied for the guided
|
||||
decoding backend"""
|
||||
return "no-fallback" in self.backend_options()
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate that some fields are mutually exclusive."""
|
||||
guide_count = sum([
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user