mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 16:25:55 +08:00
[Bugfix] Backend option to disable xgrammar any_whitespace (#12744)
Signed-off-by: Wallas Santos <wallashss@ibm.com> Signed-off-by: Joe Runde <Joseph.Runde@ibm.com> Co-authored-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
parent
d08b285adf
commit
4cb6fa0a9c
@ -6,6 +6,7 @@ import weakref
|
|||||||
|
|
||||||
import jsonschema
|
import jsonschema
|
||||||
import pytest
|
import pytest
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from vllm.distributed import cleanup_dist_env_and_memory
|
from vllm.distributed import cleanup_dist_env_and_memory
|
||||||
from vllm.entrypoints.llm import LLM
|
from vllm.entrypoints.llm import LLM
|
||||||
@ -322,3 +323,56 @@ def test_guided_json_object(llm, guided_decoding_backend: str):
|
|||||||
# Parse to verify it is valid JSON
|
# Parse to verify it is valid JSON
|
||||||
parsed_json = json.loads(generated_text)
|
parsed_json = json.loads(generated_text)
|
||||||
assert isinstance(parsed_json, dict)
|
assert isinstance(parsed_json, dict)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip_global_cleanup
|
||||||
|
def test_json_with_any_whitespace_disabled(llm):
|
||||||
|
|
||||||
|
class ResponseSchema(BaseModel):
|
||||||
|
clarifying_question: str
|
||||||
|
cost_per_serving: str
|
||||||
|
calories: str
|
||||||
|
type_dish_ids: str
|
||||||
|
type_meal_ids: str
|
||||||
|
product_ids: list[str]
|
||||||
|
exclude_product_ids: list[str]
|
||||||
|
allergen_ids: list[str]
|
||||||
|
total_cooking_time: str
|
||||||
|
kitchen_ids: str
|
||||||
|
holiday_ids: str
|
||||||
|
|
||||||
|
# Note: Without this setting, the response is sometimes full of `\n`
|
||||||
|
# for some models. This option prevents that.
|
||||||
|
guided_decoding_backend = 'xgrammar:disable-any-whitespace'
|
||||||
|
|
||||||
|
schema = ResponseSchema.model_json_schema()
|
||||||
|
guided_params = GuidedDecodingParams(json=schema,
|
||||||
|
backend=\
|
||||||
|
guided_decoding_backend)
|
||||||
|
sampling_params = SamplingParams(max_tokens=2000,
|
||||||
|
frequency_penalty=0,
|
||||||
|
presence_penalty=-1.1,
|
||||||
|
repetition_penalty=1.3,
|
||||||
|
guided_decoding=guided_params)
|
||||||
|
|
||||||
|
prompt = ("<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You"
|
||||||
|
"are a helpful assistant.<|im_end|>\n<|im_start|>user\nI want a "
|
||||||
|
"quick launch fast with $10.<|im_end|>\n<|im_start|>assistant\n")
|
||||||
|
outputs = llm.generate(prompts=prompt,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
use_tqdm=True)
|
||||||
|
|
||||||
|
assert outputs is not None
|
||||||
|
|
||||||
|
for output in outputs:
|
||||||
|
assert output is not None
|
||||||
|
assert isinstance(output, RequestOutput)
|
||||||
|
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
assert generated_text is not None
|
||||||
|
assert "\n" not in generated_text
|
||||||
|
|
||||||
|
# Parse to verify it is valid JSON
|
||||||
|
parsed_json = json.loads(generated_text)
|
||||||
|
assert isinstance(parsed_json, dict)
|
||||||
|
jsonschema.validate(instance=parsed_json, schema=schema)
|
||||||
|
|||||||
@ -385,6 +385,7 @@ class EngineArgs:
|
|||||||
'Backend-specific options can be supplied in a comma-separated '
|
'Backend-specific options can be supplied in a comma-separated '
|
||||||
'list following a colon after the backend name. Valid backends and '
|
'list following a colon after the backend name. Valid backends and '
|
||||||
'all available options are: [xgrammar:no-fallback, '
|
'all available options are: [xgrammar:no-fallback, '
|
||||||
|
'xgrammar:disable-any-whitespace, '
|
||||||
'outlines:no-fallback, lm-format-enforcer:no-fallback]')
|
'outlines:no-fallback, lm-format-enforcer:no-fallback]')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--logits-processor-pattern',
|
'--logits-processor-pattern',
|
||||||
|
|||||||
@ -19,6 +19,7 @@ except ImportError:
|
|||||||
xgr_installed = False
|
xgr_installed = False
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.guided_decoding.utils import (convert_lark_to_gbnf,
|
from vllm.model_executor.guided_decoding.utils import (convert_lark_to_gbnf,
|
||||||
grammar_is_likely_lark)
|
grammar_is_likely_lark)
|
||||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||||
@ -29,6 +30,8 @@ if TYPE_CHECKING:
|
|||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.sampling_params import GuidedDecodingParams
|
from vllm.sampling_params import GuidedDecodingParams
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# TODO: passing batch size to max threads here
|
# TODO: passing batch size to max threads here
|
||||||
def get_local_xgrammar_guided_decoding_logits_processor(
|
def get_local_xgrammar_guided_decoding_logits_processor(
|
||||||
@ -161,6 +164,7 @@ class GrammarConfig:
|
|||||||
json_str: str | None = None
|
json_str: str | None = None
|
||||||
grammar_str: str | None = None
|
grammar_str: str | None = None
|
||||||
json_object: bool | None = None
|
json_object: bool | None = None
|
||||||
|
any_whitespace: bool = True
|
||||||
max_threads: int = 8
|
max_threads: int = 8
|
||||||
tokenizer_data: TokenizerData | None = None
|
tokenizer_data: TokenizerData | None = None
|
||||||
|
|
||||||
@ -180,11 +184,33 @@ class GrammarConfig:
|
|||||||
else:
|
else:
|
||||||
json_str = guided_params.json
|
json_str = guided_params.json
|
||||||
|
|
||||||
|
any_whitespace = 'disable-any-whitespace' not in \
|
||||||
|
guided_params.backend_options()
|
||||||
|
|
||||||
|
# Check and log if model with xgrammar and whitespace have history
|
||||||
|
# of runaway generation of whitespaces.
|
||||||
|
# References:
|
||||||
|
# https://github.com/vllm-project/vllm/pull/12744
|
||||||
|
# https://github.com/mlc-ai/xgrammar/issues/212
|
||||||
|
model_with_warn = None
|
||||||
|
|
||||||
|
if 'Mistral' in model_config.model:
|
||||||
|
model_with_warn = 'Mistral'
|
||||||
|
elif 'Qwen' in model_config.model:
|
||||||
|
model_with_warn = 'Qwen'
|
||||||
|
|
||||||
|
if model_with_warn is not None and any_whitespace:
|
||||||
|
msg = (f"{model_with_warn} "
|
||||||
|
f"model detected, consider set "
|
||||||
|
f"`guided_backend=xgrammar:disable-any-whitespace` "
|
||||||
|
f"to prevent runaway generation of whitespaces.")
|
||||||
|
logger.info_once(msg)
|
||||||
# Validate the schema and raise ValueError here if it is invalid.
|
# Validate the schema and raise ValueError here if it is invalid.
|
||||||
# This is to avoid exceptions in model execution, which will crash
|
# This is to avoid exceptions in model execution, which will crash
|
||||||
# the engine worker process.
|
# the engine worker process.
|
||||||
try:
|
try:
|
||||||
xgr.Grammar.from_json_schema(json_str)
|
xgr.Grammar.from_json_schema(json_str,
|
||||||
|
any_whitespace=any_whitespace)
|
||||||
except RuntimeError as err:
|
except RuntimeError as err:
|
||||||
raise ValueError(str(err)) from err
|
raise ValueError(str(err)) from err
|
||||||
|
|
||||||
@ -192,7 +218,8 @@ class GrammarConfig:
|
|||||||
vocab_size=model_config.hf_text_config.vocab_size,
|
vocab_size=model_config.hf_text_config.vocab_size,
|
||||||
tokenizer_hash=tokenizer_hash,
|
tokenizer_hash=tokenizer_hash,
|
||||||
max_threads=max_threads,
|
max_threads=max_threads,
|
||||||
tokenizer_data=tokenizer_data)
|
tokenizer_data=tokenizer_data,
|
||||||
|
any_whitespace=any_whitespace)
|
||||||
elif guided_params.grammar:
|
elif guided_params.grammar:
|
||||||
# XGrammar only supports GBNF grammars, so we must convert Lark
|
# XGrammar only supports GBNF grammars, so we must convert Lark
|
||||||
if grammar_is_likely_lark(guided_params.grammar):
|
if grammar_is_likely_lark(guided_params.grammar):
|
||||||
@ -290,7 +317,10 @@ class XGrammarLogitsProcessor:
|
|||||||
if self.ctx is None:
|
if self.ctx is None:
|
||||||
compiler = GrammarCompilerCache.get_compiler(self.config)
|
compiler = GrammarCompilerCache.get_compiler(self.config)
|
||||||
if self.config.json_str is not None:
|
if self.config.json_str is not None:
|
||||||
self.ctx = compiler.compile_json_schema(self.config.json_str)
|
any_whitespace = self.config.any_whitespace
|
||||||
|
self.ctx = compiler\
|
||||||
|
.compile_json_schema(self.config.json_str,
|
||||||
|
any_whitespace=any_whitespace)
|
||||||
elif self.config.grammar_str is not None:
|
elif self.config.grammar_str is not None:
|
||||||
self.ctx = compiler.compile_grammar(self.config.grammar_str)
|
self.ctx = compiler.compile_grammar(self.config.grammar_str)
|
||||||
elif self.config.json_object:
|
elif self.config.json_object:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user