mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-06 04:22:13 +08:00
[Bugfix] Fix guided decoding with tokenizer mode mistral (#11046)
This commit is contained in:
parent
866fa4550d
commit
8b79f9e107
@ -224,8 +224,12 @@ steps:
|
|||||||
mirror_hardwares: [amd]
|
mirror_hardwares: [amd]
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/model_executor/layers
|
- vllm/model_executor/layers
|
||||||
|
- vllm/model_executor/guided_decoding
|
||||||
- tests/test_logits_processor
|
- tests/test_logits_processor
|
||||||
command: pytest -v -s test_logits_processor.py
|
- tests/model_executor/test_guided_processors
|
||||||
|
commands:
|
||||||
|
- pytest -v -s test_logits_processor.py
|
||||||
|
- pytest -v -s model_executor/test_guided_processors.py
|
||||||
|
|
||||||
- label: Speculative decoding tests # 30min
|
- label: Speculative decoding tests # 30min
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
|
|||||||
@ -14,12 +14,13 @@ aiohttp
|
|||||||
openai >= 1.45.0 # Ensure modern openai package (ensure types module present and max_completion_tokens field support)
|
openai >= 1.45.0 # Ensure modern openai package (ensure types module present and max_completion_tokens field support)
|
||||||
uvicorn[standard]
|
uvicorn[standard]
|
||||||
pydantic >= 2.9 # Required for fastapi >= 0.113.0
|
pydantic >= 2.9 # Required for fastapi >= 0.113.0
|
||||||
pillow # Required for image processing
|
|
||||||
prometheus_client >= 0.18.0
|
prometheus_client >= 0.18.0
|
||||||
|
pillow # Required for image processing
|
||||||
prometheus-fastapi-instrumentator >= 7.0.0
|
prometheus-fastapi-instrumentator >= 7.0.0
|
||||||
tiktoken >= 0.6.0 # Required for DBRX tokenizer
|
tiktoken >= 0.6.0 # Required for DBRX tokenizer
|
||||||
lm-format-enforcer >= 0.10.9, < 0.11
|
lm-format-enforcer >= 0.10.9, < 0.11
|
||||||
outlines == 0.1.11
|
outlines == 0.1.11
|
||||||
|
lark == 1.2.2
|
||||||
xgrammar >= 0.1.6; platform_machine == "x86_64"
|
xgrammar >= 0.1.6; platform_machine == "x86_64"
|
||||||
typing_extensions >= 4.10
|
typing_extensions >= 4.10
|
||||||
filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317
|
filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317
|
||||||
|
|||||||
@ -1,13 +1,19 @@
|
|||||||
|
import pickle
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig
|
||||||
from vllm.model_executor.guided_decoding import (
|
from vllm.model_executor.guided_decoding import (
|
||||||
get_guided_decoding_logits_processor)
|
get_guided_decoding_logits_processor,
|
||||||
|
get_local_guided_decoding_logits_processor)
|
||||||
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
|
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
|
||||||
JSONLogitsProcessor, RegexLogitsProcessor)
|
JSONLogitsProcessor, RegexLogitsProcessor)
|
||||||
from vllm.sampling_params import GuidedDecodingParams
|
from vllm.sampling_params import GuidedDecodingParams
|
||||||
|
|
||||||
|
MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta'
|
||||||
|
|
||||||
|
|
||||||
def test_guided_logits_processors(sample_regex, sample_json_schema):
|
def test_guided_logits_processors(sample_regex, sample_json_schema):
|
||||||
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
|
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
|
||||||
@ -38,14 +44,29 @@ def test_guided_logits_processors(sample_regex, sample_json_schema):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("backend",
|
@pytest.mark.parametrize("backend",
|
||||||
["outlines", "lm-format-enforcer", "xgrammar"])
|
["outlines", "lm-format-enforcer", "xgrammar"])
|
||||||
async def test_guided_logits_processor_black_box(backend: str, sample_regex,
|
@pytest.mark.parametrize("is_local", [True, False])
|
||||||
|
async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
|
||||||
|
sample_regex,
|
||||||
sample_json_schema):
|
sample_json_schema):
|
||||||
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
|
|
||||||
|
config = ModelConfig(
|
||||||
|
MODEL_NAME,
|
||||||
|
task="generate",
|
||||||
|
tokenizer=MODEL_NAME,
|
||||||
|
tokenizer_mode="auto",
|
||||||
|
trust_remote_code=False,
|
||||||
|
seed=0,
|
||||||
|
dtype="bfloat16",
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||||
token_ids = tokenizer.encode(
|
token_ids = tokenizer.encode(
|
||||||
f"Give an example IPv4 address with this regex: {sample_regex}")
|
f"Give an example IPv4 address with this regex: {sample_regex}")
|
||||||
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
|
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
|
||||||
regex_lp = await get_guided_decoding_logits_processor(
|
|
||||||
regex_request, tokenizer)
|
regex_lp = get_local_guided_decoding_logits_processor(
|
||||||
|
regex_request, tokenizer, config) if is_local else \
|
||||||
|
await get_guided_decoding_logits_processor(
|
||||||
|
regex_request, tokenizer, config)
|
||||||
assert regex_lp is not None
|
assert regex_lp is not None
|
||||||
tensor = torch.rand(32000)
|
tensor = torch.rand(32000)
|
||||||
original_tensor = torch.clone(tensor)
|
original_tensor = torch.clone(tensor)
|
||||||
@ -59,7 +80,7 @@ async def test_guided_logits_processor_black_box(backend: str, sample_regex,
|
|||||||
json_request = GuidedDecodingParams(json=sample_json_schema,
|
json_request = GuidedDecodingParams(json=sample_json_schema,
|
||||||
backend=backend)
|
backend=backend)
|
||||||
json_lp = await get_guided_decoding_logits_processor(
|
json_lp = await get_guided_decoding_logits_processor(
|
||||||
json_request, tokenizer)
|
json_request, tokenizer, config)
|
||||||
assert json_lp is not None
|
assert json_lp is not None
|
||||||
tensor = torch.rand(32000)
|
tensor = torch.rand(32000)
|
||||||
original_tensor = torch.clone(tensor)
|
original_tensor = torch.clone(tensor)
|
||||||
@ -84,3 +105,24 @@ def test_multiple_guided_options_not_allowed(sample_json_schema, sample_regex):
|
|||||||
with pytest.raises(ValueError,
|
with pytest.raises(ValueError,
|
||||||
match="You can only use one kind of guided"):
|
match="You can only use one kind of guided"):
|
||||||
GuidedDecodingParams(json=sample_json_schema, grammar="test grammar")
|
GuidedDecodingParams(json=sample_json_schema, grammar="test grammar")
|
||||||
|
|
||||||
|
|
||||||
|
def test_pickle_xgrammar_tokenizer_data():
|
||||||
|
|
||||||
|
# TODO: move to another test file for xgrammar
|
||||||
|
try:
|
||||||
|
import xgrammar as xgr
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Could not import xgrammar to run test")
|
||||||
|
|
||||||
|
from vllm.model_executor.guided_decoding.xgrammar_decoding import (
|
||||||
|
TokenizerData)
|
||||||
|
tokenizer_data = TokenizerData(vocab_type=xgr.VocabType.RAW)
|
||||||
|
pickled = pickle.dumps(tokenizer_data)
|
||||||
|
|
||||||
|
assert pickled is not None
|
||||||
|
|
||||||
|
depickled: TokenizerData = pickle.loads(pickled)
|
||||||
|
|
||||||
|
assert depickled is not None
|
||||||
|
assert depickled.vocab_type == xgr.VocabType.RAW
|
||||||
|
|||||||
@ -3,17 +3,20 @@
|
|||||||
Run `pytest tests/models/test_mistral.py`.
|
Run `pytest tests/models/test_mistral.py`.
|
||||||
"""
|
"""
|
||||||
import copy
|
import copy
|
||||||
|
import json
|
||||||
|
|
||||||
|
import jsonschema
|
||||||
|
import jsonschema.exceptions
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm import SamplingParams
|
|
||||||
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( # noqa
|
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( # noqa
|
||||||
MistralToolParser)
|
MistralToolParser)
|
||||||
|
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||||
|
|
||||||
from ...utils import check_logprobs_close
|
from ...utils import check_logprobs_close
|
||||||
|
|
||||||
MODELS = [
|
MODELS = [
|
||||||
"mistralai/Mistral-7B-Instruct-v0.1",
|
"mistralai/Mistral-7B-Instruct-v0.3",
|
||||||
]
|
]
|
||||||
|
|
||||||
MISTRAL_FORMAT_MODELS = [
|
MISTRAL_FORMAT_MODELS = [
|
||||||
@ -126,6 +129,45 @@ MSGS = [
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
SAMPLE_JSON_SCHEMA = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"age": {
|
||||||
|
"type": "integer"
|
||||||
|
},
|
||||||
|
"skills": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string",
|
||||||
|
"maxLength": 10
|
||||||
|
},
|
||||||
|
"minItems": 3
|
||||||
|
},
|
||||||
|
"work_history": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"company": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"duration": {
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
"position": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["company", "position"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["name", "age", "skills", "work_history"]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||||
@ -251,3 +293,43 @@ def test_mistral_function_calling(
|
|||||||
assert parsed_message.tool_calls[
|
assert parsed_message.tool_calls[
|
||||||
0].function.arguments == '{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}' # noqa
|
0].function.arguments == '{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}' # noqa
|
||||||
assert parsed_message.content is None
|
assert parsed_message.content is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("guided_backend",
|
||||||
|
["outlines", "lm-format-enforcer", "xgrammar"])
|
||||||
|
def test_mistral_guided_decoding(
|
||||||
|
vllm_runner,
|
||||||
|
model: str,
|
||||||
|
guided_backend: str,
|
||||||
|
) -> None:
|
||||||
|
with vllm_runner(model, dtype='bfloat16',
|
||||||
|
tokenizer_mode="mistral") as vllm_model:
|
||||||
|
|
||||||
|
guided_decoding = GuidedDecodingParams(json=SAMPLE_JSON_SCHEMA,
|
||||||
|
backend=guided_backend)
|
||||||
|
params = SamplingParams(max_tokens=512,
|
||||||
|
temperature=0.7,
|
||||||
|
guided_decoding=guided_decoding)
|
||||||
|
|
||||||
|
messages = [{
|
||||||
|
"role": "system",
|
||||||
|
"content": "you are a helpful assistant"
|
||||||
|
}, {
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content":
|
||||||
|
f"Give an example JSON for an employee profile that "
|
||||||
|
f"fits this schema: {SAMPLE_JSON_SCHEMA}"
|
||||||
|
}]
|
||||||
|
outputs = vllm_model.model.chat(messages, sampling_params=params)
|
||||||
|
|
||||||
|
generated_text = outputs[0].outputs[0].text
|
||||||
|
json_response = json.loads(generated_text)
|
||||||
|
assert outputs is not None
|
||||||
|
|
||||||
|
try:
|
||||||
|
jsonschema.validate(instance=json_response,
|
||||||
|
schema=SAMPLE_JSON_SCHEMA)
|
||||||
|
except jsonschema.exceptions.ValidationError:
|
||||||
|
pytest.fail("Generated response is not valid with JSON schema")
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import TYPE_CHECKING, Any, NamedTuple
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import PreTrainedTokenizerFast
|
from transformers import PreTrainedTokenizerFast
|
||||||
@ -16,6 +16,7 @@ except ImportError:
|
|||||||
|
|
||||||
from vllm.model_executor.guided_decoding.xgrammar_utils import (
|
from vllm.model_executor.guided_decoding.xgrammar_utils import (
|
||||||
convert_lark_to_gbnf, grammar_is_likely_lark)
|
convert_lark_to_gbnf, grammar_is_likely_lark)
|
||||||
|
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
@ -37,11 +38,21 @@ def get_local_xgrammar_guided_decoding_logits_processor(
|
|||||||
return XGrammarLogitsProcessor(config)
|
return XGrammarLogitsProcessor(config)
|
||||||
|
|
||||||
|
|
||||||
class TokenizerData(NamedTuple):
|
@dataclass(frozen=True)
|
||||||
|
class TokenizerData:
|
||||||
"""Immutable container for cached tokenizer data."""
|
"""Immutable container for cached tokenizer data."""
|
||||||
encoded_vocab: list[str]
|
encoded_vocab: list[str] = field(default_factory=list)
|
||||||
stop_token_ids: list[int] | None
|
stop_token_ids: list[int] | None = None
|
||||||
backend_str: str
|
# These fields are mutually exclusive: `backend_str` is used to create a
|
||||||
|
# TokenizeInfo with `TokenizerInfo.from_huggingface` while `vocab_type` is
|
||||||
|
# used within the constructor of TokenizeInfo
|
||||||
|
backend_str: str | None = None
|
||||||
|
vocab_type: xgr.VocabType | None = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
# Check for mutual exclusive
|
||||||
|
assert not (self.backend_str and self.vocab_type), \
|
||||||
|
"backend_str and vocab_type are mutual exclusive"
|
||||||
|
|
||||||
|
|
||||||
class TokenizerDataCache:
|
class TokenizerDataCache:
|
||||||
@ -68,18 +79,27 @@ class TokenizerDataCache:
|
|||||||
"get_vocab method.") from e
|
"get_vocab method.") from e
|
||||||
|
|
||||||
stop_token_ids = None
|
stop_token_ids = None
|
||||||
backend_str = xgr.VocabType.RAW
|
backend_str = ""
|
||||||
|
vocab_type = xgr.VocabType.RAW
|
||||||
|
|
||||||
|
if stop_token_ids is None and hasattr(
|
||||||
|
tokenizer,
|
||||||
|
"eos_token_id") and tokenizer.eos_token_id is not None:
|
||||||
|
stop_token_ids = [tokenizer.eos_token_id]
|
||||||
|
|
||||||
if isinstance(tokenizer, PreTrainedTokenizerFast):
|
if isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||||
backend_str = tokenizer.backend_tokenizer.to_str()
|
backend_str = tokenizer.backend_tokenizer.to_str()
|
||||||
if stop_token_ids is None and hasattr(
|
vocab_type = None
|
||||||
tokenizer,
|
|
||||||
"eos_token_id") and tokenizer.eos_token_id is not None:
|
elif isinstance(tokenizer, MistralTokenizer):
|
||||||
stop_token_ids = [tokenizer.eos_token_id]
|
# REF: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
|
||||||
|
vocab_type = xgr.VocabType.BYTE_FALLBACK
|
||||||
|
|
||||||
cls._cache[tokenizer_hash] = TokenizerData(
|
cls._cache[tokenizer_hash] = TokenizerData(
|
||||||
encoded_vocab=encoded_vocab,
|
encoded_vocab=encoded_vocab,
|
||||||
stop_token_ids=stop_token_ids,
|
stop_token_ids=stop_token_ids,
|
||||||
backend_str=backend_str)
|
backend_str=backend_str,
|
||||||
|
vocab_type=vocab_type)
|
||||||
|
|
||||||
return cls._cache[tokenizer_hash]
|
return cls._cache[tokenizer_hash]
|
||||||
|
|
||||||
@ -98,11 +118,30 @@ class GrammarCompilerCache:
|
|||||||
cache_key = str(config.tokenizer_hash)
|
cache_key = str(config.tokenizer_hash)
|
||||||
|
|
||||||
if cache_key not in cls._cache:
|
if cache_key not in cls._cache:
|
||||||
assert config.encoded_vocab is not None
|
assert config.tokenizer_data is not None
|
||||||
tokenizer_info = xgr.TokenizerInfo._create_from_handle(
|
assert config.tokenizer_data.encoded_vocab is not None
|
||||||
xgr_core.TokenizerInfo.from_huggingface(
|
|
||||||
config.encoded_vocab, config.backend_str,
|
config_data = config.tokenizer_data
|
||||||
config.vocab_size, config.stop_token_ids))
|
|
||||||
|
# In TokenizerDataCache.get_tokenizer_data, a serializable
|
||||||
|
# tokenizer_data is created and cached. This data is used to build
|
||||||
|
# a tokenizer_info and create an xgrammar compiler.
|
||||||
|
# - If tokenizer_data has backend_str set, use
|
||||||
|
# xgr_core.TokenizerInfo.from_huggingface (a C++ bind).
|
||||||
|
# - Otherwise, use the default constructor with vocab_type.
|
||||||
|
# - xgr_core.TokenizerInfo.from_huggingface !=
|
||||||
|
# xgr.TokenizerInfo.from_huggingface.
|
||||||
|
if config_data.backend_str:
|
||||||
|
tokenizer_info = xgr.TokenizerInfo._create_from_handle(
|
||||||
|
xgr_core.TokenizerInfo.from_huggingface(
|
||||||
|
config_data.encoded_vocab, config_data.backend_str,
|
||||||
|
config.vocab_size, config_data.stop_token_ids))
|
||||||
|
else:
|
||||||
|
tokenizer_info = xgr.TokenizerInfo(
|
||||||
|
config_data.encoded_vocab,
|
||||||
|
config_data.vocab_type,
|
||||||
|
vocab_size=config.vocab_size,
|
||||||
|
stop_token_ids=config_data.stop_token_ids)
|
||||||
cls._cache[cache_key] = xgr.GrammarCompiler(
|
cls._cache[cache_key] = xgr.GrammarCompiler(
|
||||||
tokenizer_info, max_threads=config.max_threads)
|
tokenizer_info, max_threads=config.max_threads)
|
||||||
|
|
||||||
@ -118,10 +157,7 @@ class GrammarConfig:
|
|||||||
grammar_str: str | None = None
|
grammar_str: str | None = None
|
||||||
json_object: bool | None = None
|
json_object: bool | None = None
|
||||||
max_threads: int = 8
|
max_threads: int = 8
|
||||||
# Only populated if tokenizer_hash not in cache
|
tokenizer_data: TokenizerData | None = None
|
||||||
encoded_vocab: list[str] | None = None
|
|
||||||
stop_token_ids: list[int] | None = None
|
|
||||||
backend_str: str | None = None
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_guided_params(cls,
|
def from_guided_params(cls,
|
||||||
@ -132,9 +168,6 @@ class GrammarConfig:
|
|||||||
|
|
||||||
tokenizer_hash = hash(tokenizer)
|
tokenizer_hash = hash(tokenizer)
|
||||||
tokenizer_data = TokenizerDataCache.get_tokenizer_data(tokenizer)
|
tokenizer_data = TokenizerDataCache.get_tokenizer_data(tokenizer)
|
||||||
encoded_vocab = tokenizer_data.encoded_vocab
|
|
||||||
stop_token_ids = tokenizer_data.stop_token_ids
|
|
||||||
backend_str = tokenizer_data.backend_str
|
|
||||||
|
|
||||||
if guided_params.json:
|
if guided_params.json:
|
||||||
if not isinstance(guided_params.json, str):
|
if not isinstance(guided_params.json, str):
|
||||||
@ -152,11 +185,9 @@ class GrammarConfig:
|
|||||||
|
|
||||||
return cls(json_str=json_str,
|
return cls(json_str=json_str,
|
||||||
vocab_size=model_config.hf_text_config.vocab_size,
|
vocab_size=model_config.hf_text_config.vocab_size,
|
||||||
encoded_vocab=encoded_vocab,
|
|
||||||
stop_token_ids=stop_token_ids,
|
|
||||||
backend_str=backend_str,
|
|
||||||
tokenizer_hash=tokenizer_hash,
|
tokenizer_hash=tokenizer_hash,
|
||||||
max_threads=max_threads)
|
max_threads=max_threads,
|
||||||
|
tokenizer_data=tokenizer_data)
|
||||||
elif guided_params.grammar:
|
elif guided_params.grammar:
|
||||||
# XGrammar only supports GBNF grammars, so we must convert Lark
|
# XGrammar only supports GBNF grammars, so we must convert Lark
|
||||||
if grammar_is_likely_lark(guided_params.grammar):
|
if grammar_is_likely_lark(guided_params.grammar):
|
||||||
@ -181,19 +212,17 @@ class GrammarConfig:
|
|||||||
|
|
||||||
return cls(grammar_str=grammar_str,
|
return cls(grammar_str=grammar_str,
|
||||||
vocab_size=model_config.hf_text_config.vocab_size,
|
vocab_size=model_config.hf_text_config.vocab_size,
|
||||||
encoded_vocab=encoded_vocab,
|
|
||||||
stop_token_ids=stop_token_ids,
|
|
||||||
backend_str=backend_str,
|
|
||||||
tokenizer_hash=tokenizer_hash,
|
tokenizer_hash=tokenizer_hash,
|
||||||
max_threads=max_threads)
|
max_threads=max_threads,
|
||||||
|
tokenizer_data=tokenizer_data)
|
||||||
elif guided_params.json_object:
|
elif guided_params.json_object:
|
||||||
return cls(json_object=True,
|
return cls(
|
||||||
vocab_size=model_config.hf_text_config.vocab_size,
|
json_object=True,
|
||||||
encoded_vocab=encoded_vocab,
|
vocab_size=model_config.hf_text_config.vocab_size,
|
||||||
stop_token_ids=stop_token_ids,
|
tokenizer_hash=tokenizer_hash,
|
||||||
backend_str=backend_str,
|
max_threads=max_threads,
|
||||||
tokenizer_hash=tokenizer_hash,
|
tokenizer_data=tokenizer_data,
|
||||||
max_threads=max_threads)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Currently only support JSON and EBNF grammar mode for xgrammar"
|
"Currently only support JSON and EBNF grammar mode for xgrammar"
|
||||||
@ -269,10 +298,14 @@ class XGrammarLogitsProcessor:
|
|||||||
# fill_next_token_bitmask so we move it to the device of scores
|
# fill_next_token_bitmask so we move it to the device of scores
|
||||||
device_type = scores.device.type
|
device_type = scores.device.type
|
||||||
if device_type != "cuda":
|
if device_type != "cuda":
|
||||||
scores = scores.to("cpu")
|
scores = scores.to("cpu").unsqueeze(0)
|
||||||
|
|
||||||
|
# Note: In this method, if the tensors have different dimensions
|
||||||
|
# on CPU device fails, but on GPU it runs without error. Hence the
|
||||||
|
# unsqueeze above for scores, to match the token bitmask shape
|
||||||
xgr.apply_token_bitmask_inplace(scores,
|
xgr.apply_token_bitmask_inplace(scores,
|
||||||
self.token_bitmask.to(scores.device))
|
self.token_bitmask.to(scores.device))
|
||||||
if device_type != "cuda":
|
if device_type != "cuda":
|
||||||
scores = scores.to(device_type)
|
scores = scores.to(device_type).squeeze()
|
||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|||||||
@ -132,7 +132,7 @@ def get_tokenizer(
|
|||||||
if is_from_mistral_org and tokenizer_mode != "mistral":
|
if is_from_mistral_org and tokenizer_mode != "mistral":
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
'It is strongly recommended to run mistral models with '
|
'It is strongly recommended to run mistral models with '
|
||||||
'`--tokenizer_mode "mistral"` to ensure correct '
|
'`--tokenizer-mode "mistral"` to ensure correct '
|
||||||
'encoding and decoding.',
|
'encoding and decoding.',
|
||||||
FutureWarning,
|
FutureWarning,
|
||||||
stacklevel=2)
|
stacklevel=2)
|
||||||
|
|||||||
@ -314,12 +314,15 @@ class MistralTokenizer:
|
|||||||
|
|
||||||
if regular_tokens:
|
if regular_tokens:
|
||||||
decoded_list.append(
|
decoded_list.append(
|
||||||
self.decode(regular_tokens)) # type: ignore
|
self.tokenizer.decode(regular_tokens)) # type: ignore
|
||||||
|
|
||||||
decoded = ''.join(decoded_list)
|
decoded = ''.join(decoded_list)
|
||||||
|
|
||||||
return decoded
|
return decoded
|
||||||
|
|
||||||
|
# WARN: Outlines logits processors can overwrite this method.
|
||||||
|
# See: guided_decoding/outlines_logits_processors.py::_adapt_tokenizer
|
||||||
|
# for more.
|
||||||
def decode(self,
|
def decode(self,
|
||||||
ids: Union[List[int], int],
|
ids: Union[List[int], int],
|
||||||
skip_special_tokens: bool = True) -> str:
|
skip_special_tokens: bool = True) -> str:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user