mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 14:25:01 +08:00
[Bugfix] Guided decoding falls back to outlines when fails to import xgrammar (#12976)
Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
parent
deb6c1c6b4
commit
14ecab5be2
@ -40,6 +40,8 @@ def maybe_backend_fallback(
|
||||
guided_params.backend = "outlines"
|
||||
|
||||
if guided_params.backend == "xgrammar":
|
||||
from vllm.model_executor.guided_decoding.xgrammar_decoding import (
|
||||
xgr_installed)
|
||||
# xgrammar only has x86 wheels for linux, fallback to outlines
|
||||
from vllm.platforms import current_platform
|
||||
if current_platform.get_cpu_architecture() is not CpuArchEnum.X86:
|
||||
@ -77,6 +79,13 @@ def maybe_backend_fallback(
|
||||
"Falling back to use outlines instead.")
|
||||
guided_params.backend = "outlines"
|
||||
|
||||
# If the xgrammar module cannot be imported successfully,
|
||||
# we should still allow users to use guided decoding with a fallback.
|
||||
elif not xgr_installed:
|
||||
logger.warning("xgrammar module cannot be imported successfully. "
|
||||
"Falling back to use outlines instead.")
|
||||
guided_params.backend = "outlines"
|
||||
|
||||
if (guided_params.backend == "outlines"
|
||||
and guided_params.json_object is not None):
|
||||
# outlines doesn't support json_object, fallback to xgrammar
|
||||
|
||||
@ -14,7 +14,9 @@ from transformers import PreTrainedTokenizerFast
|
||||
try:
|
||||
import xgrammar as xgr
|
||||
from xgrammar.base import _core as xgr_core
|
||||
xgr_installed = True
|
||||
except ImportError:
|
||||
xgr_installed = False
|
||||
pass
|
||||
|
||||
from vllm.model_executor.guided_decoding.utils import (convert_lark_to_gbnf,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user