mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-12 02:27:12 +08:00
[V1] Add disable-any-whitespace option support for xgrammar (#15316)
Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
parent
2f4bd358f1
commit
eb63ea1e18
@ -57,6 +57,50 @@ def test_guided_json_completion(
|
|||||||
jsonschema.validate(instance=output_json, schema=sample_json_schema)
|
jsonschema.validate(instance=output_json, schema=sample_json_schema)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip_global_cleanup
|
||||||
|
@pytest.mark.parametrize("guided_decoding_backend",
|
||||||
|
GUIDED_DECODING_BACKENDS_V1)
|
||||||
|
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
|
||||||
|
def test_guided_json_completion_disable_any_whitespace(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
sample_json_schema: dict[str, Any],
|
||||||
|
guided_decoding_backend: str,
|
||||||
|
model_name: str,
|
||||||
|
):
|
||||||
|
if guided_decoding_backend != "xgrammar":
|
||||||
|
pytest.skip("disable-any-whitespace is only supported for xgrammar.")
|
||||||
|
guided_decoding_backend = 'xgrammar:disable-any-whitespace'
|
||||||
|
|
||||||
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||||
|
llm = LLM(model=model_name,
|
||||||
|
max_model_len=1024,
|
||||||
|
guided_decoding_backend=guided_decoding_backend)
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=1.0,
|
||||||
|
max_tokens=1000,
|
||||||
|
guided_decoding=GuidedDecodingParams(json=sample_json_schema))
|
||||||
|
outputs = llm.generate(prompts=[
|
||||||
|
f"Give an example JSON for an employee profile "
|
||||||
|
f"that fits this schema: {sample_json_schema}"
|
||||||
|
] * 2,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
use_tqdm=True)
|
||||||
|
|
||||||
|
assert outputs is not None
|
||||||
|
|
||||||
|
for output in outputs:
|
||||||
|
assert output is not None
|
||||||
|
assert isinstance(output, RequestOutput)
|
||||||
|
prompt = output.prompt
|
||||||
|
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
assert generated_text is not None
|
||||||
|
assert "\n" not in generated_text
|
||||||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
output_json = json.loads(generated_text)
|
||||||
|
jsonschema.validate(instance=output_json, schema=sample_json_schema)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip_global_cleanup
|
@pytest.mark.skip_global_cleanup
|
||||||
@pytest.mark.parametrize("guided_decoding_backend",
|
@pytest.mark.parametrize("guided_decoding_backend",
|
||||||
GUIDED_DECODING_BACKENDS_V1)
|
GUIDED_DECODING_BACKENDS_V1)
|
||||||
@ -301,7 +345,6 @@ def test_guided_choice_completion(
|
|||||||
prompts="The best language for type-safe systems programming is ",
|
prompts="The best language for type-safe systems programming is ",
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
use_tqdm=True)
|
use_tqdm=True)
|
||||||
|
|
||||||
assert outputs is not None
|
assert outputs is not None
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
assert output is not None
|
assert output is not None
|
||||||
|
|||||||
@ -1486,7 +1486,9 @@ class EngineArgs:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
# Only support Xgrammar for guided decoding so far.
|
# Only support Xgrammar for guided decoding so far.
|
||||||
SUPPORTED_GUIDED_DECODING = ["xgrammar", "xgrammar:nofallback"]
|
SUPPORTED_GUIDED_DECODING = [
|
||||||
|
"xgrammar", "xgrammar:disable-any-whitespace"
|
||||||
|
]
|
||||||
if self.guided_decoding_backend not in SUPPORTED_GUIDED_DECODING:
|
if self.guided_decoding_backend not in SUPPORTED_GUIDED_DECODING:
|
||||||
_raise_or_fallback(feature_name="--guided-decoding-backend",
|
_raise_or_fallback(feature_name="--guided-decoding-backend",
|
||||||
recommend_to_remove=False)
|
recommend_to_remove=False)
|
||||||
|
|||||||
@ -120,7 +120,7 @@ class Processor:
|
|||||||
if not params.guided_decoding or not self.decoding_config:
|
if not params.guided_decoding or not self.decoding_config:
|
||||||
return
|
return
|
||||||
|
|
||||||
supported_backends = ["xgrammar"]
|
supported_backends = ["xgrammar", "xgrammar:disable-any-whitespace"]
|
||||||
engine_level_backend = self.decoding_config.guided_decoding_backend
|
engine_level_backend = self.decoding_config.guided_decoding_backend
|
||||||
if engine_level_backend not in supported_backends:
|
if engine_level_backend not in supported_backends:
|
||||||
raise ValueError(f"Only {supported_backends} structured output is "
|
raise ValueError(f"Only {supported_backends} structured output is "
|
||||||
|
|||||||
@ -26,6 +26,9 @@ class XgrammarBackend(StructuredOutputBackend):
|
|||||||
|
|
||||||
def __init__(self, vllm_config: VllmConfig):
|
def __init__(self, vllm_config: VllmConfig):
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
|
self.disable_any_whitespace = (
|
||||||
|
"disable-any-whitespace"
|
||||||
|
in vllm_config.decoding_config.guided_decoding_backend)
|
||||||
tokenizer_group = init_tokenizer_from_configs(
|
tokenizer_group = init_tokenizer_from_configs(
|
||||||
model_config=vllm_config.model_config,
|
model_config=vllm_config.model_config,
|
||||||
scheduler_config=vllm_config.scheduler_config,
|
scheduler_config=vllm_config.scheduler_config,
|
||||||
@ -74,8 +77,8 @@ class XgrammarBackend(StructuredOutputBackend):
|
|||||||
def compile_grammar(self, request_type: StructuredOutputOptions,
|
def compile_grammar(self, request_type: StructuredOutputOptions,
|
||||||
grammar_spec: str) -> StructuredOutputGrammar:
|
grammar_spec: str) -> StructuredOutputGrammar:
|
||||||
if request_type == StructuredOutputOptions.JSON:
|
if request_type == StructuredOutputOptions.JSON:
|
||||||
ctx = self.compiler.compile_json_schema(grammar_spec,
|
ctx = self.compiler.compile_json_schema(
|
||||||
any_whitespace=False)
|
grammar_spec, any_whitespace=not self.disable_any_whitespace)
|
||||||
elif request_type == StructuredOutputOptions.JSON_OBJECT:
|
elif request_type == StructuredOutputOptions.JSON_OBJECT:
|
||||||
ctx = self.compiler.compile_builtin_json_grammar()
|
ctx = self.compiler.compile_builtin_json_grammar()
|
||||||
elif request_type == StructuredOutputOptions.GRAMMAR:
|
elif request_type == StructuredOutputOptions.GRAMMAR:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user