mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 11:45:22 +08:00
[V1] Support disable_any_whtespace for guidance backend (#15584)
Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
parent
541d1df486
commit
7329ff5468
@ -6,7 +6,6 @@ import weakref
|
|||||||
|
|
||||||
import jsonschema
|
import jsonschema
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from vllm.distributed import cleanup_dist_env_and_memory
|
from vllm.distributed import cleanup_dist_env_and_memory
|
||||||
from vllm.entrypoints.llm import LLM
|
from vllm.entrypoints.llm import LLM
|
||||||
@ -15,7 +14,10 @@ from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
|||||||
|
|
||||||
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
|
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||||
GUIDED_DECODING_BACKENDS = [
|
GUIDED_DECODING_BACKENDS = [
|
||||||
"outlines", "lm-format-enforcer", "xgrammar", "guidance"
|
"outlines",
|
||||||
|
"lm-format-enforcer",
|
||||||
|
"xgrammar:disable-any-whitespace",
|
||||||
|
"guidance:disable-any-whitespace",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -322,59 +324,9 @@ def test_guided_json_object(llm, guided_decoding_backend: str):
|
|||||||
print(generated_text)
|
print(generated_text)
|
||||||
assert generated_text is not None
|
assert generated_text is not None
|
||||||
|
|
||||||
|
if 'disable-any-whitespace' in guided_decoding_backend:
|
||||||
|
assert "\n" not in generated_text
|
||||||
|
|
||||||
# Parse to verify it is valid JSON
|
# Parse to verify it is valid JSON
|
||||||
parsed_json = json.loads(generated_text)
|
parsed_json = json.loads(generated_text)
|
||||||
assert isinstance(parsed_json, dict)
|
assert isinstance(parsed_json, dict)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip_global_cleanup
|
|
||||||
def test_json_with_any_whitespace_disabled(llm):
|
|
||||||
|
|
||||||
class ResponseSchema(BaseModel):
|
|
||||||
clarifying_question: str
|
|
||||||
cost_per_serving: str
|
|
||||||
calories: str
|
|
||||||
type_dish_ids: str
|
|
||||||
type_meal_ids: str
|
|
||||||
product_ids: list[str]
|
|
||||||
exclude_product_ids: list[str]
|
|
||||||
allergen_ids: list[str]
|
|
||||||
total_cooking_time: str
|
|
||||||
kitchen_ids: str
|
|
||||||
holiday_ids: str
|
|
||||||
|
|
||||||
# Note: Without this setting, the response is sometimes full of `\n`
|
|
||||||
# for some models. This option prevents that.
|
|
||||||
guided_decoding_backend = 'xgrammar:disable-any-whitespace'
|
|
||||||
|
|
||||||
schema = ResponseSchema.model_json_schema()
|
|
||||||
guided_params = GuidedDecodingParams(json=schema,
|
|
||||||
backend=\
|
|
||||||
guided_decoding_backend)
|
|
||||||
sampling_params = SamplingParams(max_tokens=2000,
|
|
||||||
frequency_penalty=0,
|
|
||||||
presence_penalty=-1.1,
|
|
||||||
repetition_penalty=1.3,
|
|
||||||
guided_decoding=guided_params)
|
|
||||||
|
|
||||||
prompt = ("<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You"
|
|
||||||
"are a helpful assistant.<|im_end|>\n<|im_start|>user\nI want a "
|
|
||||||
"quick launch fast with $10.<|im_end|>\n<|im_start|>assistant\n")
|
|
||||||
outputs = llm.generate(prompts=prompt,
|
|
||||||
sampling_params=sampling_params,
|
|
||||||
use_tqdm=True)
|
|
||||||
|
|
||||||
assert outputs is not None
|
|
||||||
|
|
||||||
for output in outputs:
|
|
||||||
assert output is not None
|
|
||||||
assert isinstance(output, RequestOutput)
|
|
||||||
|
|
||||||
generated_text = output.outputs[0].text
|
|
||||||
assert generated_text is not None
|
|
||||||
assert "\n" not in generated_text
|
|
||||||
|
|
||||||
# Parse to verify it is valid JSON
|
|
||||||
parsed_json = json.loads(generated_text)
|
|
||||||
assert isinstance(parsed_json, dict)
|
|
||||||
jsonschema.validate(instance=parsed_json, schema=schema)
|
|
||||||
|
|||||||
@ -15,7 +15,9 @@ from vllm.entrypoints.llm import LLM
|
|||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||||
|
|
||||||
GUIDED_DECODING_BACKENDS_V1 = ["xgrammar", "guidance"]
|
GUIDED_DECODING_BACKENDS_V1 = [
|
||||||
|
"xgrammar:disable-any-whitespace", "guidance:disable-any-whitespace"
|
||||||
|
]
|
||||||
MODELS_TO_TEST = [
|
MODELS_TO_TEST = [
|
||||||
"Qwen/Qwen2.5-1.5B-Instruct", "mistralai/Ministral-8B-Instruct-2410"
|
"Qwen/Qwen2.5-1.5B-Instruct", "mistralai/Ministral-8B-Instruct-2410"
|
||||||
]
|
]
|
||||||
@ -55,50 +57,8 @@ def test_guided_json_completion(
|
|||||||
|
|
||||||
generated_text = output.outputs[0].text
|
generated_text = output.outputs[0].text
|
||||||
assert generated_text is not None
|
assert generated_text is not None
|
||||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
if 'disable-any-whitespace' in guided_decoding_backend:
|
||||||
output_json = json.loads(generated_text)
|
assert "\n" not in generated_text
|
||||||
jsonschema.validate(instance=output_json, schema=sample_json_schema)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip_global_cleanup
|
|
||||||
@pytest.mark.parametrize("guided_decoding_backend",
|
|
||||||
GUIDED_DECODING_BACKENDS_V1)
|
|
||||||
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
|
|
||||||
def test_guided_json_completion_disable_any_whitespace(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
sample_json_schema: dict[str, Any],
|
|
||||||
guided_decoding_backend: str,
|
|
||||||
model_name: str,
|
|
||||||
):
|
|
||||||
if guided_decoding_backend != "xgrammar":
|
|
||||||
pytest.skip("disable-any-whitespace is only supported for xgrammar.")
|
|
||||||
guided_decoding_backend = 'xgrammar:disable-any-whitespace'
|
|
||||||
|
|
||||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
|
||||||
llm = LLM(model=model_name,
|
|
||||||
max_model_len=1024,
|
|
||||||
guided_decoding_backend=guided_decoding_backend)
|
|
||||||
sampling_params = SamplingParams(
|
|
||||||
temperature=1.0,
|
|
||||||
max_tokens=1000,
|
|
||||||
guided_decoding=GuidedDecodingParams(json=sample_json_schema))
|
|
||||||
outputs = llm.generate(prompts=[
|
|
||||||
f"Give an example JSON for an employee profile "
|
|
||||||
f"that fits this schema: {sample_json_schema}"
|
|
||||||
] * 2,
|
|
||||||
sampling_params=sampling_params,
|
|
||||||
use_tqdm=True)
|
|
||||||
|
|
||||||
assert outputs is not None
|
|
||||||
|
|
||||||
for output in outputs:
|
|
||||||
assert output is not None
|
|
||||||
assert isinstance(output, RequestOutput)
|
|
||||||
prompt = output.prompt
|
|
||||||
|
|
||||||
generated_text = output.outputs[0].text
|
|
||||||
assert generated_text is not None
|
|
||||||
assert "\n" not in generated_text
|
|
||||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
output_json = json.loads(generated_text)
|
output_json = json.loads(generated_text)
|
||||||
jsonschema.validate(instance=output_json, schema=sample_json_schema)
|
jsonschema.validate(instance=output_json, schema=sample_json_schema)
|
||||||
@ -142,7 +102,7 @@ def test_guided_json_object(
|
|||||||
# Parse to verify it is valid JSON
|
# Parse to verify it is valid JSON
|
||||||
parsed_json = json.loads(generated_text)
|
parsed_json = json.loads(generated_text)
|
||||||
allowed_types: tuple[type, ...] = (dict, )
|
allowed_types: tuple[type, ...] = (dict, )
|
||||||
if guided_decoding_backend == "xgrammar":
|
if guided_decoding_backend.startswith("xgrammar"):
|
||||||
# TODO - we are currently too permissive with xgrammar and
|
# TODO - we are currently too permissive with xgrammar and
|
||||||
# allow # any valid json (typically comes back as a list or
|
# allow # any valid json (typically comes back as a list or
|
||||||
# object). We can fix this by specifying a jsonschema of
|
# object). We can fix this by specifying a jsonschema of
|
||||||
@ -170,7 +130,7 @@ def test_guided_json_unsupported_schema(
|
|||||||
temperature=1.0,
|
temperature=1.0,
|
||||||
max_tokens=1000,
|
max_tokens=1000,
|
||||||
guided_decoding=GuidedDecodingParams(json=unsupported_json_schema))
|
guided_decoding=GuidedDecodingParams(json=unsupported_json_schema))
|
||||||
if guided_decoding_backend == "xgrammar":
|
if guided_decoding_backend.startswith("xgrammar"):
|
||||||
with pytest.raises(ValueError,
|
with pytest.raises(ValueError,
|
||||||
match="The provided JSON schema contains features "
|
match="The provided JSON schema contains features "
|
||||||
"not supported by xgrammar."):
|
"not supported by xgrammar."):
|
||||||
|
|||||||
@ -1561,7 +1561,8 @@ class EngineArgs:
|
|||||||
|
|
||||||
# Xgrammar and Guidance are supported.
|
# Xgrammar and Guidance are supported.
|
||||||
SUPPORTED_GUIDED_DECODING = [
|
SUPPORTED_GUIDED_DECODING = [
|
||||||
"xgrammar", "xgrammar:disable-any-whitespace", "guidance", "auto"
|
"xgrammar", "xgrammar:disable-any-whitespace", "guidance",
|
||||||
|
"guidance:disable-any-whitespace", "auto"
|
||||||
]
|
]
|
||||||
if self.guided_decoding_backend not in SUPPORTED_GUIDED_DECODING:
|
if self.guided_decoding_backend not in SUPPORTED_GUIDED_DECODING:
|
||||||
_raise_or_fallback(feature_name="--guided-decoding-backend",
|
_raise_or_fallback(feature_name="--guided-decoding-backend",
|
||||||
|
|||||||
@ -18,14 +18,22 @@ def get_local_guidance_guided_decoding_logits_processor(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
grm = ""
|
grm = ""
|
||||||
|
any_whitespace = 'disable-any-whitespace' not in \
|
||||||
|
guided_params.backend_options()
|
||||||
if guided_params.json:
|
if guided_params.json:
|
||||||
grm = llguidance.LLMatcher.grammar_from_json_schema(
|
grm = llguidance.LLMatcher.grammar_from_json_schema(
|
||||||
guided_params.json,
|
guided_params.json,
|
||||||
overrides={"whitespace_pattern": guided_params.whitespace_pattern})
|
overrides={"whitespace_pattern": guided_params.whitespace_pattern},
|
||||||
|
defaults={
|
||||||
|
"whitespace_flexible": any_whitespace,
|
||||||
|
})
|
||||||
elif guided_params.json_object:
|
elif guided_params.json_object:
|
||||||
grm = llguidance.LLMatcher.grammar_from_json_schema(
|
grm = llguidance.LLMatcher.grammar_from_json_schema(
|
||||||
'{"type": "object"}',
|
'{"type": "object"}',
|
||||||
overrides={"whitespace_pattern": guided_params.whitespace_pattern})
|
overrides={"whitespace_pattern": guided_params.whitespace_pattern},
|
||||||
|
defaults={
|
||||||
|
"whitespace_flexible": any_whitespace,
|
||||||
|
})
|
||||||
elif guided_params.regex:
|
elif guided_params.regex:
|
||||||
grm = llguidance.grammar_from("regex", guided_params.regex)
|
grm = llguidance.grammar_from("regex", guided_params.regex)
|
||||||
elif guided_params.choice:
|
elif guided_params.choice:
|
||||||
|
|||||||
@ -121,7 +121,8 @@ class Processor:
|
|||||||
return
|
return
|
||||||
|
|
||||||
supported_backends = [
|
supported_backends = [
|
||||||
"xgrammar", "xgrammar:disable-any-whitespace", "guidance", "auto"
|
"xgrammar", "xgrammar:disable-any-whitespace", "guidance",
|
||||||
|
"guidance:disable-any-whitespace", "auto"
|
||||||
]
|
]
|
||||||
engine_level_backend = self.decoding_config.guided_decoding_backend
|
engine_level_backend = self.decoding_config.guided_decoding_backend
|
||||||
if engine_level_backend not in supported_backends:
|
if engine_level_backend not in supported_backends:
|
||||||
@ -140,11 +141,10 @@ class Processor:
|
|||||||
raise ValueError("Structured output is not supported on TPU.")
|
raise ValueError("Structured output is not supported on TPU.")
|
||||||
|
|
||||||
# Request content validation
|
# Request content validation
|
||||||
|
if engine_level_backend.startswith("xgrammar"):
|
||||||
if engine_level_backend == "xgrammar":
|
|
||||||
# xgrammar with no fallback
|
# xgrammar with no fallback
|
||||||
validate_structured_output_request_xgrammar(params)
|
validate_structured_output_request_xgrammar(params)
|
||||||
params.guided_decoding.backend = "xgrammar"
|
params.guided_decoding.backend = engine_level_backend
|
||||||
elif engine_level_backend == "auto":
|
elif engine_level_backend == "auto":
|
||||||
# "auto" is an opt-in to opinionated behavior where we try to
|
# "auto" is an opt-in to opinionated behavior where we try to
|
||||||
# choose a backend based on request contents. This is not the
|
# choose a backend based on request contents. This is not the
|
||||||
@ -158,12 +158,13 @@ class Processor:
|
|||||||
# are not supported in xgrammar. Fall back to guidance.
|
# are not supported in xgrammar. Fall back to guidance.
|
||||||
params.guided_decoding.backend = "guidance"
|
params.guided_decoding.backend = "guidance"
|
||||||
|
|
||||||
if params.guided_decoding.backend == "guidance":
|
if engine_level_backend.startswith("guidance"):
|
||||||
# TODO ideally we would have the LLTokenizer here as Lark syntax
|
# TODO ideally we would have the LLTokenizer here as Lark syntax
|
||||||
# allows <|special_token|> and similar, see
|
# allows <|special_token|> and similar, see
|
||||||
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
|
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
|
||||||
# Without tokenizer these are disallowed in grammars.
|
# Without tokenizer these are disallowed in grammars.
|
||||||
validate_guidance_grammar(params, tokenizer=None)
|
validate_guidance_grammar(params, tokenizer=None)
|
||||||
|
params.guided_decoding.backend = engine_level_backend
|
||||||
|
|
||||||
def process_inputs(
|
def process_inputs(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -41,6 +41,9 @@ class GuidanceBackend(StructuredOutputBackend):
|
|||||||
tokenizer_group.ping()
|
tokenizer_group.ping()
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
self.vocab_size = vllm_config.model_config.get_vocab_size()
|
self.vocab_size = vllm_config.model_config.get_vocab_size()
|
||||||
|
self.disable_any_whitespace = (
|
||||||
|
"disable-any-whitespace"
|
||||||
|
in vllm_config.decoding_config.guided_decoding_backend)
|
||||||
|
|
||||||
tokenizer = tokenizer_group.get_lora_tokenizer(None)
|
tokenizer = tokenizer_group.get_lora_tokenizer(None)
|
||||||
self.ll_tokenizer = llguidance_hf.from_tokenizer(tokenizer, None)
|
self.ll_tokenizer = llguidance_hf.from_tokenizer(tokenizer, None)
|
||||||
@ -48,7 +51,7 @@ class GuidanceBackend(StructuredOutputBackend):
|
|||||||
def compile_grammar(self, request_type: StructuredOutputOptions,
|
def compile_grammar(self, request_type: StructuredOutputOptions,
|
||||||
grammar_spec: str) -> StructuredOutputGrammar:
|
grammar_spec: str) -> StructuredOutputGrammar:
|
||||||
self.serialized_grammar = serialize_guidance_grammar(
|
self.serialized_grammar = serialize_guidance_grammar(
|
||||||
request_type, grammar_spec)
|
request_type, grammar_spec, self.disable_any_whitespace)
|
||||||
|
|
||||||
ll_matcher = llguidance.LLMatcher(
|
ll_matcher = llguidance.LLMatcher(
|
||||||
self.ll_tokenizer,
|
self.ll_tokenizer,
|
||||||
@ -126,17 +129,19 @@ class GuidanceGrammar(StructuredOutputGrammar):
|
|||||||
|
|
||||||
|
|
||||||
def serialize_guidance_grammar(request_type: StructuredOutputOptions,
|
def serialize_guidance_grammar(request_type: StructuredOutputOptions,
|
||||||
grammar_spec: str) -> str:
|
grammar_spec: str,
|
||||||
|
disable_any_whitespace: bool = False) -> str:
|
||||||
if request_type == StructuredOutputOptions.JSON:
|
if request_type == StructuredOutputOptions.JSON:
|
||||||
# TODO: make whitespace_flexible configurable
|
|
||||||
return llguidance.LLMatcher.grammar_from_json_schema(
|
return llguidance.LLMatcher.grammar_from_json_schema(
|
||||||
grammar_spec, defaults={
|
grammar_spec,
|
||||||
"whitespace_flexible": True,
|
defaults={
|
||||||
|
"whitespace_flexible": not disable_any_whitespace,
|
||||||
})
|
})
|
||||||
elif request_type == StructuredOutputOptions.JSON_OBJECT:
|
elif request_type == StructuredOutputOptions.JSON_OBJECT:
|
||||||
return llguidance.LLMatcher.grammar_from_json_schema(
|
return llguidance.LLMatcher.grammar_from_json_schema(
|
||||||
'{"type": "object"}', defaults={
|
'{"type": "object"}',
|
||||||
"whitespace_flexible": True,
|
defaults={
|
||||||
|
"whitespace_flexible": not disable_any_whitespace,
|
||||||
})
|
})
|
||||||
else:
|
else:
|
||||||
if request_type == StructuredOutputOptions.REGEX:
|
if request_type == StructuredOutputOptions.REGEX:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user