diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index f5be8dca05f1d..230dd83834202 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -200,6 +200,7 @@ steps: - pytest -v -s v1/core - pytest -v -s v1/entrypoints - pytest -v -s v1/engine + - pytest -v -s v1/entrypoints - pytest -v -s v1/sample - pytest -v -s v1/worker - pytest -v -s v1/structured_output diff --git a/requirements/common.txt b/requirements/common.txt index 8d9108687a2b3..d08ef253828b1 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -20,7 +20,7 @@ tiktoken >= 0.6.0 # Required for DBRX tokenizer lm-format-enforcer >= 0.10.11, < 0.11 outlines == 0.1.11 lark == 1.2.2 -xgrammar == 0.1.15; platform_machine == "x86_64" or platform_machine == "aarch64" +xgrammar == 0.1.16; platform_machine == "x86_64" or platform_machine == "aarch64" typing_extensions >= 4.10 filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317 partial-json-parser # used for parsing partial JSON outputs diff --git a/tests/model_executor/test_guided_processors.py b/tests/model_executor/test_guided_processors.py index 531c3a8c13b2a..85a53a178ca75 100644 --- a/tests/model_executor/test_guided_processors.py +++ b/tests/model_executor/test_guided_processors.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +import json import pickle import pytest @@ -208,8 +209,6 @@ def test_guided_decoding_backend_options(): def test_pickle_xgrammar_tokenizer_data(): - - # TODO: move to another test file for xgrammar try: import xgrammar as xgr except ImportError: @@ -217,7 +216,11 @@ def test_pickle_xgrammar_tokenizer_data(): from vllm.model_executor.guided_decoding.xgrammar_decoding import ( TokenizerData) - tokenizer_data = TokenizerData(vocab_type=xgr.VocabType.RAW) + tokenizer_data = TokenizerData( + metadata= + '{"vocab_type":2,"vocab_size":151665,"add_prefix_space":false,"stop_token_ids":[151645]}', + encoded_vocab=['!', '"', '#', '$', '%'], + ) pickled = pickle.dumps(tokenizer_data) assert pickled is not None @@ -225,4 +228,5 @@ def test_pickle_xgrammar_tokenizer_data(): depickled: TokenizerData = pickle.loads(pickled) assert depickled is not None - assert depickled.vocab_type == xgr.VocabType.RAW + assert json.loads( + depickled.metadata)['vocab_type'] == xgr.VocabType.BYTE_LEVEL.value diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index 98983fa05b83f..b4eb475c23baa 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -18,9 +18,6 @@ MODELS_TO_TEST = [ "Qwen/Qwen2.5-1.5B-Instruct", "mistralai/Ministral-8B-Instruct-2410" ] -# Undo after https://github.com/vllm-project/vllm/pull/14868 -pytest.skip(allow_module_level=True) - @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend", diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index 6b9a855eeccce..c21df044d48f6 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -9,7 +9,6 @@ from vllm.model_executor.guided_decoding.reasoner import get_reasoner from vllm.model_executor.guided_decoding.utils import ( convert_lark_to_gbnf, grammar_is_likely_lark, has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features) -from vllm.platforms import CpuArchEnum if TYPE_CHECKING: from transformers import PreTrainedTokenizer @@ -26,7 +25,7 @@ def maybe_backend_fallback( def fallback_or_error(guided_params: GuidedDecodingParams, message: str, fallback: str) -> None: - """Change the backend to the specified fallback with a warning log, + """Change the backend to the specified fallback with a warning log, or raise a ValueError if the `no-fallback` option is specified.""" if guided_params.no_fallback(): raise ValueError(message) @@ -53,19 +52,12 @@ def maybe_backend_fallback( if guided_params.backend_name == "xgrammar": from vllm.model_executor.guided_decoding.xgrammar_decoding import ( xgr_installed) - # xgrammar only has x86 wheels for linux, fallback to outlines - from vllm.platforms import current_platform - if current_platform.get_cpu_architecture() is not CpuArchEnum.X86: - fallback_or_error(guided_params, - "xgrammar is only supported on x86 CPUs.", - "outlines") # xgrammar doesn't support regex, fallback to outlines if guided_params.regex is not None: fallback_or_error( guided_params, "xgrammar does not support regex guided decoding.", "outlines") - # xgrammar doesn't support some JSON schema features elif (guided_params.json is not None and has_xgrammar_unsupported_json_features(guided_params.json)): diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py index 9405ef93e145e..bc156223953e0 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -9,13 +9,11 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, List import torch -from transformers import PreTrainedTokenizerFast from vllm.logger import init_logger try: import xgrammar as xgr - from xgrammar.base import _core as xgr_core xgr_installed = True except ImportError: xgr_installed = False @@ -35,7 +33,6 @@ if TYPE_CHECKING: logger = init_logger(__name__) -# TODO: passing batch size to max threads here def get_local_xgrammar_guided_decoding_logits_processor( guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer, @@ -52,18 +49,8 @@ def get_local_xgrammar_guided_decoding_logits_processor( @dataclass(frozen=True) class TokenizerData: """Immutable container for cached tokenizer data.""" + metadata: str encoded_vocab: list[str] = field(default_factory=list) - stop_token_ids: list[int] | None = None - # These fields are mutually exclusive: `backend_str` is used to create a - # TokenizeInfo with `TokenizerInfo.from_huggingface` while `vocab_type` is - # used within the constructor of TokenizeInfo - backend_str: str | None = None - vocab_type: xgr.VocabType | None = None - - def __post_init__(self): - # Check for mutual exclusive - assert not (self.backend_str and self.vocab_type), \ - "backend_str and vocab_type are mutual exclusive" class TokenizerDataCache: @@ -71,46 +58,52 @@ class TokenizerDataCache: _cache: dict[int, TokenizerData] = {} @classmethod - def get_tokenizer_data(cls, - tokenizer: PreTrainedTokenizer) -> TokenizerData: - tokenizer_hash = hash(tokenizer) + def get_tokenizer_data( + cls, + tokenizer: PreTrainedTokenizer, + /, + *, + tokenizer_hash: int, + vocab_size: int, + ) -> TokenizerData: if tokenizer_hash not in cls._cache: - # Vendored from xgrammar logic since we cannot pickle the tokenizer - # https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98 # noqa: E501 + tokenizer_info = xgr.TokenizerInfo.from_huggingface( + tokenizer, + # NOTE: We will need to use lm_head's vocab_size + # to determine correct special_token_ids for this tokenizer. + # See https://github.com/mlc-ai/xgrammar/commit/70c959fb6d9cea75aae33c414763cd0602022d92 # noqa: E501 + vocab_size=vocab_size, + ) + metadata = json.loads(tokenizer_info.dump_metadata()) + + # Vendored from xgrammar logic to get encoded_vocab + # https://github.com/mlc-ai/xgrammar/blob/989222175c2a30fb7987d8bcce35bec1bf6817f2/python/xgrammar/tokenizer_info.py#L127 # noqa: E501 try: - encoded_vocab = [ - token for token, _ in sorted(tokenizer.get_vocab().items(), - key=lambda x: x[1]) - ] + vocab_dict = tokenizer.get_vocab() except AttributeError as e: raise ValueError( f"Cannot get the vocabulary of the tokenizer " f"{type(tokenizer)}. The tokenizer should have a " "get_vocab method.") from e - stop_token_ids = None - backend_str = "" - vocab_type = xgr.VocabType.RAW + # maintain tokenizer's indexing + encoded_vocab = [""] * tokenizer_info.vocab_size + for token, idx in vocab_dict.items(): + if idx < tokenizer_info.vocab_size: + encoded_vocab[idx] = token - if stop_token_ids is None and hasattr( - tokenizer, - "eos_token_id") and tokenizer.eos_token_id is not None: - stop_token_ids = [tokenizer.eos_token_id] - - if isinstance(tokenizer, PreTrainedTokenizerFast): - backend_str = tokenizer.backend_tokenizer.to_str() - vocab_type = None - - elif isinstance(tokenizer, MistralTokenizer): + if isinstance(tokenizer, MistralTokenizer): # REF: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501 - vocab_type = xgr.VocabType.BYTE_FALLBACK + metadata.update({ + "vocab_type": xgr.VocabType.BYTE_FALLBACK, + "add_prefix_space": True + }) cls._cache[tokenizer_hash] = TokenizerData( encoded_vocab=encoded_vocab, - stop_token_ids=stop_token_ids, - backend_str=backend_str, - vocab_type=vocab_type) + metadata=json.dumps(metadata), + ) return cls._cache[tokenizer_hash] @@ -129,30 +122,15 @@ class GrammarCompilerCache: cache_key = str(config.tokenizer_hash) if cache_key not in cls._cache: - assert config.tokenizer_data is not None - assert config.tokenizer_data.encoded_vocab is not None - config_data = config.tokenizer_data # In TokenizerDataCache.get_tokenizer_data, a serializable # tokenizer_data is created and cached. This data is used to build # a tokenizer_info and create an xgrammar compiler. - # - If tokenizer_data has backend_str set, use - # xgr_core.TokenizerInfo.from_huggingface (a C++ bind). - # - Otherwise, use the default constructor with vocab_type. - # - xgr_core.TokenizerInfo.from_huggingface != - # xgr.TokenizerInfo.from_huggingface. - if config_data.backend_str: - tokenizer_info = xgr.TokenizerInfo._create_from_handle( - xgr_core.TokenizerInfo.from_huggingface( - config_data.encoded_vocab, config_data.backend_str, - config.vocab_size, config_data.stop_token_ids)) - else: - tokenizer_info = xgr.TokenizerInfo( - config_data.encoded_vocab, - config_data.vocab_type, - vocab_size=config.vocab_size, - stop_token_ids=config_data.stop_token_ids) + tokenizer_info = xgr.TokenizerInfo.from_vocab_and_metadata( + encoded_vocab=config_data.encoded_vocab, + metadata=config_data.metadata, + ) cls._cache[cache_key] = xgr.GrammarCompiler( tokenizer_info, max_threads=config.max_threads) @@ -163,13 +141,12 @@ class GrammarCompilerCache: class GrammarConfig: """Serializable configuration for grammar compilation""" tokenizer_hash: int - vocab_size: int + tokenizer_data: TokenizerData json_str: str | None = None grammar_str: str | None = None json_object: bool | None = None any_whitespace: bool = True max_threads: int = 8 - tokenizer_data: TokenizerData | None = None @classmethod def from_guided_params(cls, @@ -179,7 +156,11 @@ class GrammarConfig: max_threads: int = 8) -> GrammarConfig: tokenizer_hash = hash(tokenizer) - tokenizer_data = TokenizerDataCache.get_tokenizer_data(tokenizer) + tokenizer_data = TokenizerDataCache.get_tokenizer_data( + tokenizer, + tokenizer_hash=tokenizer_hash, + vocab_size=model_config.hf_text_config.vocab_size, + ) if guided_params.json: if not isinstance(guided_params.json, str): @@ -218,7 +199,6 @@ class GrammarConfig: raise ValueError(str(err)) from err return cls(json_str=json_str, - vocab_size=model_config.hf_text_config.vocab_size, tokenizer_hash=tokenizer_hash, max_threads=max_threads, tokenizer_data=tokenizer_data, @@ -246,14 +226,12 @@ class GrammarConfig: raise ValueError(str(err)) from err return cls(grammar_str=grammar_str, - vocab_size=model_config.hf_text_config.vocab_size, tokenizer_hash=tokenizer_hash, max_threads=max_threads, tokenizer_data=tokenizer_data) elif guided_params.json_object: return cls( json_object=True, - vocab_size=model_config.hf_text_config.vocab_size, tokenizer_hash=tokenizer_hash, max_threads=max_threads, tokenizer_data=tokenizer_data, @@ -267,7 +245,6 @@ class GrammarConfig: return cls( grammar_str=choice_str, - vocab_size=model_config.hf_text_config.vocab_size, tokenizer_hash=tokenizer_hash, max_threads=max_threads, tokenizer_data=tokenizer_data, @@ -291,6 +268,13 @@ class GrammarConfig: grammar = ('root ::= ' + ' | '.join(f'"{c}"' for c in escaped_choices)) return grammar + @staticmethod + def tokenizer_info(tokenizer_data: TokenizerData) -> xgr.TokenizerInfo: + return xgr.TokenizerInfo.from_vocab_and_metadata( + encoded_vocab=tokenizer_data.encoded_vocab, + metadata=tokenizer_data.metadata, + ) + @dataclass class XGrammarLogitsProcessor: @@ -299,11 +283,16 @@ class XGrammarLogitsProcessor: reasoner: Reasoner | None = None ctx: xgr.CompiledGrammar | None = None + tokenizer_info: xgr.TokenizerInfo = None # type: ignore[assignment] token_bitmask: torch.Tensor = None # type: ignore[assignment] matchers: list[xgr.GrammarMatcher] = field(default_factory=list) batch_size: int = field(default=1) prefilled: bool = field(default=False) + def __post_init__(self): + self.tokenizer_info = self.config.tokenizer_info( + self.config.tokenizer_data) + def __getstate__(self) -> dict[str, Any]: return {'config': self.config, 'reasoner': self.reasoner} @@ -311,6 +300,8 @@ class XGrammarLogitsProcessor: self.config = state['config'] self.reasoner = state['reasoner'] + self.tokenizer_info = GrammarConfig.tokenizer_info( + self.config.tokenizer_data) self.ctx = None self.matchers = [] self.batch_size = 1 @@ -352,7 +343,7 @@ class XGrammarLogitsProcessor: xgr.GrammarMatcher(self.ctx) for _ in range(self.batch_size) ] self.token_bitmask = xgr.allocate_token_bitmask( - self.batch_size, self.config.vocab_size) + self.batch_size, self.tokenizer_info.vocab_size) if not self.prefilled: # Have not sampled a token yet diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 77bafdee85ce2..5ed7b832aac54 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -40,7 +40,7 @@ class StructuredOutputManager: tokenizer_group.ping() tokenizer = tokenizer_group.get_lora_tokenizer(None) - self.vocab_size = len(tokenizer.get_vocab()) + self.vocab_size = self.vllm_config.model_config.get_vocab_size() if isinstance(tokenizer, MistralTokenizer): # NOTE: ideally, xgrammar should handle this accordingly. # refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98