mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:25:01 +08:00
[Bugfix] Backend option to disable xgrammar any_whitespace (#12744)
Signed-off-by: Wallas Santos <wallashss@ibm.com> Signed-off-by: Joe Runde <Joseph.Runde@ibm.com> Co-authored-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
parent
d08b285adf
commit
4cb6fa0a9c
@ -6,6 +6,7 @@ import weakref
|
||||
|
||||
import jsonschema
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.entrypoints.llm import LLM
|
||||
@ -322,3 +323,56 @@ def test_guided_json_object(llm, guided_decoding_backend: str):
|
||||
# Parse to verify it is valid JSON
|
||||
parsed_json = json.loads(generated_text)
|
||||
assert isinstance(parsed_json, dict)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_json_with_any_whitespace_disabled(llm):
|
||||
|
||||
class ResponseSchema(BaseModel):
|
||||
clarifying_question: str
|
||||
cost_per_serving: str
|
||||
calories: str
|
||||
type_dish_ids: str
|
||||
type_meal_ids: str
|
||||
product_ids: list[str]
|
||||
exclude_product_ids: list[str]
|
||||
allergen_ids: list[str]
|
||||
total_cooking_time: str
|
||||
kitchen_ids: str
|
||||
holiday_ids: str
|
||||
|
||||
# Note: Without this setting, the response is sometimes full of `\n`
|
||||
# for some models. This option prevents that.
|
||||
guided_decoding_backend = 'xgrammar:disable-any-whitespace'
|
||||
|
||||
schema = ResponseSchema.model_json_schema()
|
||||
guided_params = GuidedDecodingParams(json=schema,
|
||||
backend=\
|
||||
guided_decoding_backend)
|
||||
sampling_params = SamplingParams(max_tokens=2000,
|
||||
frequency_penalty=0,
|
||||
presence_penalty=-1.1,
|
||||
repetition_penalty=1.3,
|
||||
guided_decoding=guided_params)
|
||||
|
||||
prompt = ("<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You"
|
||||
"are a helpful assistant.<|im_end|>\n<|im_start|>user\nI want a "
|
||||
"quick launch fast with $10.<|im_end|>\n<|im_start|>assistant\n")
|
||||
outputs = llm.generate(prompts=prompt,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
|
||||
assert outputs is not None
|
||||
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
assert "\n" not in generated_text
|
||||
|
||||
# Parse to verify it is valid JSON
|
||||
parsed_json = json.loads(generated_text)
|
||||
assert isinstance(parsed_json, dict)
|
||||
jsonschema.validate(instance=parsed_json, schema=schema)
|
||||
|
||||
@ -385,6 +385,7 @@ class EngineArgs:
|
||||
'Backend-specific options can be supplied in a comma-separated '
|
||||
'list following a colon after the backend name. Valid backends and '
|
||||
'all available options are: [xgrammar:no-fallback, '
|
||||
'xgrammar:disable-any-whitespace, '
|
||||
'outlines:no-fallback, lm-format-enforcer:no-fallback]')
|
||||
parser.add_argument(
|
||||
'--logits-processor-pattern',
|
||||
|
||||
@ -19,6 +19,7 @@ except ImportError:
|
||||
xgr_installed = False
|
||||
pass
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.guided_decoding.utils import (convert_lark_to_gbnf,
|
||||
grammar_is_likely_lark)
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
@ -29,6 +30,8 @@ if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# TODO: passing batch size to max threads here
|
||||
def get_local_xgrammar_guided_decoding_logits_processor(
|
||||
@ -161,6 +164,7 @@ class GrammarConfig:
|
||||
json_str: str | None = None
|
||||
grammar_str: str | None = None
|
||||
json_object: bool | None = None
|
||||
any_whitespace: bool = True
|
||||
max_threads: int = 8
|
||||
tokenizer_data: TokenizerData | None = None
|
||||
|
||||
@ -180,11 +184,33 @@ class GrammarConfig:
|
||||
else:
|
||||
json_str = guided_params.json
|
||||
|
||||
any_whitespace = 'disable-any-whitespace' not in \
|
||||
guided_params.backend_options()
|
||||
|
||||
# Check and log if model with xgrammar and whitespace have history
|
||||
# of runaway generation of whitespaces.
|
||||
# References:
|
||||
# https://github.com/vllm-project/vllm/pull/12744
|
||||
# https://github.com/mlc-ai/xgrammar/issues/212
|
||||
model_with_warn = None
|
||||
|
||||
if 'Mistral' in model_config.model:
|
||||
model_with_warn = 'Mistral'
|
||||
elif 'Qwen' in model_config.model:
|
||||
model_with_warn = 'Qwen'
|
||||
|
||||
if model_with_warn is not None and any_whitespace:
|
||||
msg = (f"{model_with_warn} "
|
||||
f"model detected, consider set "
|
||||
f"`guided_backend=xgrammar:disable-any-whitespace` "
|
||||
f"to prevent runaway generation of whitespaces.")
|
||||
logger.info_once(msg)
|
||||
# Validate the schema and raise ValueError here if it is invalid.
|
||||
# This is to avoid exceptions in model execution, which will crash
|
||||
# the engine worker process.
|
||||
try:
|
||||
xgr.Grammar.from_json_schema(json_str)
|
||||
xgr.Grammar.from_json_schema(json_str,
|
||||
any_whitespace=any_whitespace)
|
||||
except RuntimeError as err:
|
||||
raise ValueError(str(err)) from err
|
||||
|
||||
@ -192,7 +218,8 @@ class GrammarConfig:
|
||||
vocab_size=model_config.hf_text_config.vocab_size,
|
||||
tokenizer_hash=tokenizer_hash,
|
||||
max_threads=max_threads,
|
||||
tokenizer_data=tokenizer_data)
|
||||
tokenizer_data=tokenizer_data,
|
||||
any_whitespace=any_whitespace)
|
||||
elif guided_params.grammar:
|
||||
# XGrammar only supports GBNF grammars, so we must convert Lark
|
||||
if grammar_is_likely_lark(guided_params.grammar):
|
||||
@ -290,7 +317,10 @@ class XGrammarLogitsProcessor:
|
||||
if self.ctx is None:
|
||||
compiler = GrammarCompilerCache.get_compiler(self.config)
|
||||
if self.config.json_str is not None:
|
||||
self.ctx = compiler.compile_json_schema(self.config.json_str)
|
||||
any_whitespace = self.config.any_whitespace
|
||||
self.ctx = compiler\
|
||||
.compile_json_schema(self.config.json_str,
|
||||
any_whitespace=any_whitespace)
|
||||
elif self.config.grammar_str is not None:
|
||||
self.ctx = compiler.compile_grammar(self.config.grammar_str)
|
||||
elif self.config.json_object:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user