mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-24 02:01:27 +08:00
[Bugfix] Fix guided decoding with tokenizer mode mistral (#11046)
This commit is contained in:
parent
866fa4550d
commit
8b79f9e107
@ -224,8 +224,12 @@ steps:
|
||||
mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- vllm/model_executor/layers
|
||||
- vllm/model_executor/guided_decoding
|
||||
- tests/test_logits_processor
|
||||
command: pytest -v -s test_logits_processor.py
|
||||
- tests/model_executor/test_guided_processors
|
||||
commands:
|
||||
- pytest -v -s test_logits_processor.py
|
||||
- pytest -v -s model_executor/test_guided_processors.py
|
||||
|
||||
- label: Speculative decoding tests # 30min
|
||||
source_file_dependencies:
|
||||
|
||||
@ -14,12 +14,13 @@ aiohttp
|
||||
openai >= 1.45.0 # Ensure modern openai package (ensure types module present and max_completion_tokens field support)
|
||||
uvicorn[standard]
|
||||
pydantic >= 2.9 # Required for fastapi >= 0.113.0
|
||||
pillow # Required for image processing
|
||||
prometheus_client >= 0.18.0
|
||||
pillow # Required for image processing
|
||||
prometheus-fastapi-instrumentator >= 7.0.0
|
||||
tiktoken >= 0.6.0 # Required for DBRX tokenizer
|
||||
lm-format-enforcer >= 0.10.9, < 0.11
|
||||
outlines == 0.1.11
|
||||
lark == 1.2.2
|
||||
xgrammar >= 0.1.6; platform_machine == "x86_64"
|
||||
typing_extensions >= 4.10
|
||||
filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317
|
||||
|
||||
@ -1,13 +1,19 @@
|
||||
import pickle
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.model_executor.guided_decoding import (
|
||||
get_guided_decoding_logits_processor)
|
||||
get_guided_decoding_logits_processor,
|
||||
get_local_guided_decoding_logits_processor)
|
||||
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
|
||||
JSONLogitsProcessor, RegexLogitsProcessor)
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
|
||||
MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta'
|
||||
|
||||
|
||||
def test_guided_logits_processors(sample_regex, sample_json_schema):
|
||||
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
|
||||
@ -38,14 +44,29 @@ def test_guided_logits_processors(sample_regex, sample_json_schema):
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("backend",
|
||||
["outlines", "lm-format-enforcer", "xgrammar"])
|
||||
async def test_guided_logits_processor_black_box(backend: str, sample_regex,
|
||||
@pytest.mark.parametrize("is_local", [True, False])
|
||||
async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
|
||||
sample_regex,
|
||||
sample_json_schema):
|
||||
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
|
||||
|
||||
config = ModelConfig(
|
||||
MODEL_NAME,
|
||||
task="generate",
|
||||
tokenizer=MODEL_NAME,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="bfloat16",
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
token_ids = tokenizer.encode(
|
||||
f"Give an example IPv4 address with this regex: {sample_regex}")
|
||||
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
|
||||
regex_lp = await get_guided_decoding_logits_processor(
|
||||
regex_request, tokenizer)
|
||||
|
||||
regex_lp = get_local_guided_decoding_logits_processor(
|
||||
regex_request, tokenizer, config) if is_local else \
|
||||
await get_guided_decoding_logits_processor(
|
||||
regex_request, tokenizer, config)
|
||||
assert regex_lp is not None
|
||||
tensor = torch.rand(32000)
|
||||
original_tensor = torch.clone(tensor)
|
||||
@ -59,7 +80,7 @@ async def test_guided_logits_processor_black_box(backend: str, sample_regex,
|
||||
json_request = GuidedDecodingParams(json=sample_json_schema,
|
||||
backend=backend)
|
||||
json_lp = await get_guided_decoding_logits_processor(
|
||||
json_request, tokenizer)
|
||||
json_request, tokenizer, config)
|
||||
assert json_lp is not None
|
||||
tensor = torch.rand(32000)
|
||||
original_tensor = torch.clone(tensor)
|
||||
@ -84,3 +105,24 @@ def test_multiple_guided_options_not_allowed(sample_json_schema, sample_regex):
|
||||
with pytest.raises(ValueError,
|
||||
match="You can only use one kind of guided"):
|
||||
GuidedDecodingParams(json=sample_json_schema, grammar="test grammar")
|
||||
|
||||
|
||||
def test_pickle_xgrammar_tokenizer_data():
|
||||
|
||||
# TODO: move to another test file for xgrammar
|
||||
try:
|
||||
import xgrammar as xgr
|
||||
except ImportError:
|
||||
pytest.skip("Could not import xgrammar to run test")
|
||||
|
||||
from vllm.model_executor.guided_decoding.xgrammar_decoding import (
|
||||
TokenizerData)
|
||||
tokenizer_data = TokenizerData(vocab_type=xgr.VocabType.RAW)
|
||||
pickled = pickle.dumps(tokenizer_data)
|
||||
|
||||
assert pickled is not None
|
||||
|
||||
depickled: TokenizerData = pickle.loads(pickled)
|
||||
|
||||
assert depickled is not None
|
||||
assert depickled.vocab_type == xgr.VocabType.RAW
|
||||
|
||||
@ -3,17 +3,20 @@
|
||||
Run `pytest tests/models/test_mistral.py`.
|
||||
"""
|
||||
import copy
|
||||
import json
|
||||
|
||||
import jsonschema
|
||||
import jsonschema.exceptions
|
||||
import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( # noqa
|
||||
MistralToolParser)
|
||||
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||
|
||||
from ...utils import check_logprobs_close
|
||||
|
||||
MODELS = [
|
||||
"mistralai/Mistral-7B-Instruct-v0.1",
|
||||
"mistralai/Mistral-7B-Instruct-v0.3",
|
||||
]
|
||||
|
||||
MISTRAL_FORMAT_MODELS = [
|
||||
@ -126,6 +129,45 @@ MSGS = [
|
||||
}
|
||||
]
|
||||
|
||||
SAMPLE_JSON_SCHEMA = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"age": {
|
||||
"type": "integer"
|
||||
},
|
||||
"skills": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"maxLength": 10
|
||||
},
|
||||
"minItems": 3
|
||||
},
|
||||
"work_history": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"company": {
|
||||
"type": "string"
|
||||
},
|
||||
"duration": {
|
||||
"type": "number"
|
||||
},
|
||||
"position": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": ["company", "position"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["name", "age", "skills", "work_history"]
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@ -251,3 +293,43 @@ def test_mistral_function_calling(
|
||||
assert parsed_message.tool_calls[
|
||||
0].function.arguments == '{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}' # noqa
|
||||
assert parsed_message.content is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("guided_backend",
|
||||
["outlines", "lm-format-enforcer", "xgrammar"])
|
||||
def test_mistral_guided_decoding(
|
||||
vllm_runner,
|
||||
model: str,
|
||||
guided_backend: str,
|
||||
) -> None:
|
||||
with vllm_runner(model, dtype='bfloat16',
|
||||
tokenizer_mode="mistral") as vllm_model:
|
||||
|
||||
guided_decoding = GuidedDecodingParams(json=SAMPLE_JSON_SCHEMA,
|
||||
backend=guided_backend)
|
||||
params = SamplingParams(max_tokens=512,
|
||||
temperature=0.7,
|
||||
guided_decoding=guided_decoding)
|
||||
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
}, {
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
f"Give an example JSON for an employee profile that "
|
||||
f"fits this schema: {SAMPLE_JSON_SCHEMA}"
|
||||
}]
|
||||
outputs = vllm_model.model.chat(messages, sampling_params=params)
|
||||
|
||||
generated_text = outputs[0].outputs[0].text
|
||||
json_response = json.loads(generated_text)
|
||||
assert outputs is not None
|
||||
|
||||
try:
|
||||
jsonschema.validate(instance=json_response,
|
||||
schema=SAMPLE_JSON_SCHEMA)
|
||||
except jsonschema.exceptions.ValidationError:
|
||||
pytest.fail("Generated response is not valid with JSON schema")
|
||||
|
||||
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
@ -16,6 +16,7 @@ except ImportError:
|
||||
|
||||
from vllm.model_executor.guided_decoding.xgrammar_utils import (
|
||||
convert_lark_to_gbnf, grammar_is_likely_lark)
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer
|
||||
@ -37,11 +38,21 @@ def get_local_xgrammar_guided_decoding_logits_processor(
|
||||
return XGrammarLogitsProcessor(config)
|
||||
|
||||
|
||||
class TokenizerData(NamedTuple):
|
||||
@dataclass(frozen=True)
|
||||
class TokenizerData:
|
||||
"""Immutable container for cached tokenizer data."""
|
||||
encoded_vocab: list[str]
|
||||
stop_token_ids: list[int] | None
|
||||
backend_str: str
|
||||
encoded_vocab: list[str] = field(default_factory=list)
|
||||
stop_token_ids: list[int] | None = None
|
||||
# These fields are mutually exclusive: `backend_str` is used to create a
|
||||
# TokenizeInfo with `TokenizerInfo.from_huggingface` while `vocab_type` is
|
||||
# used within the constructor of TokenizeInfo
|
||||
backend_str: str | None = None
|
||||
vocab_type: xgr.VocabType | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
# Check for mutual exclusive
|
||||
assert not (self.backend_str and self.vocab_type), \
|
||||
"backend_str and vocab_type are mutual exclusive"
|
||||
|
||||
|
||||
class TokenizerDataCache:
|
||||
@ -68,18 +79,27 @@ class TokenizerDataCache:
|
||||
"get_vocab method.") from e
|
||||
|
||||
stop_token_ids = None
|
||||
backend_str = xgr.VocabType.RAW
|
||||
backend_str = ""
|
||||
vocab_type = xgr.VocabType.RAW
|
||||
|
||||
if stop_token_ids is None and hasattr(
|
||||
tokenizer,
|
||||
"eos_token_id") and tokenizer.eos_token_id is not None:
|
||||
stop_token_ids = [tokenizer.eos_token_id]
|
||||
|
||||
if isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||
backend_str = tokenizer.backend_tokenizer.to_str()
|
||||
if stop_token_ids is None and hasattr(
|
||||
tokenizer,
|
||||
"eos_token_id") and tokenizer.eos_token_id is not None:
|
||||
stop_token_ids = [tokenizer.eos_token_id]
|
||||
vocab_type = None
|
||||
|
||||
elif isinstance(tokenizer, MistralTokenizer):
|
||||
# REF: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
|
||||
vocab_type = xgr.VocabType.BYTE_FALLBACK
|
||||
|
||||
cls._cache[tokenizer_hash] = TokenizerData(
|
||||
encoded_vocab=encoded_vocab,
|
||||
stop_token_ids=stop_token_ids,
|
||||
backend_str=backend_str)
|
||||
backend_str=backend_str,
|
||||
vocab_type=vocab_type)
|
||||
|
||||
return cls._cache[tokenizer_hash]
|
||||
|
||||
@ -98,11 +118,30 @@ class GrammarCompilerCache:
|
||||
cache_key = str(config.tokenizer_hash)
|
||||
|
||||
if cache_key not in cls._cache:
|
||||
assert config.encoded_vocab is not None
|
||||
tokenizer_info = xgr.TokenizerInfo._create_from_handle(
|
||||
xgr_core.TokenizerInfo.from_huggingface(
|
||||
config.encoded_vocab, config.backend_str,
|
||||
config.vocab_size, config.stop_token_ids))
|
||||
assert config.tokenizer_data is not None
|
||||
assert config.tokenizer_data.encoded_vocab is not None
|
||||
|
||||
config_data = config.tokenizer_data
|
||||
|
||||
# In TokenizerDataCache.get_tokenizer_data, a serializable
|
||||
# tokenizer_data is created and cached. This data is used to build
|
||||
# a tokenizer_info and create an xgrammar compiler.
|
||||
# - If tokenizer_data has backend_str set, use
|
||||
# xgr_core.TokenizerInfo.from_huggingface (a C++ bind).
|
||||
# - Otherwise, use the default constructor with vocab_type.
|
||||
# - xgr_core.TokenizerInfo.from_huggingface !=
|
||||
# xgr.TokenizerInfo.from_huggingface.
|
||||
if config_data.backend_str:
|
||||
tokenizer_info = xgr.TokenizerInfo._create_from_handle(
|
||||
xgr_core.TokenizerInfo.from_huggingface(
|
||||
config_data.encoded_vocab, config_data.backend_str,
|
||||
config.vocab_size, config_data.stop_token_ids))
|
||||
else:
|
||||
tokenizer_info = xgr.TokenizerInfo(
|
||||
config_data.encoded_vocab,
|
||||
config_data.vocab_type,
|
||||
vocab_size=config.vocab_size,
|
||||
stop_token_ids=config_data.stop_token_ids)
|
||||
cls._cache[cache_key] = xgr.GrammarCompiler(
|
||||
tokenizer_info, max_threads=config.max_threads)
|
||||
|
||||
@ -118,10 +157,7 @@ class GrammarConfig:
|
||||
grammar_str: str | None = None
|
||||
json_object: bool | None = None
|
||||
max_threads: int = 8
|
||||
# Only populated if tokenizer_hash not in cache
|
||||
encoded_vocab: list[str] | None = None
|
||||
stop_token_ids: list[int] | None = None
|
||||
backend_str: str | None = None
|
||||
tokenizer_data: TokenizerData | None = None
|
||||
|
||||
@classmethod
|
||||
def from_guided_params(cls,
|
||||
@ -132,9 +168,6 @@ class GrammarConfig:
|
||||
|
||||
tokenizer_hash = hash(tokenizer)
|
||||
tokenizer_data = TokenizerDataCache.get_tokenizer_data(tokenizer)
|
||||
encoded_vocab = tokenizer_data.encoded_vocab
|
||||
stop_token_ids = tokenizer_data.stop_token_ids
|
||||
backend_str = tokenizer_data.backend_str
|
||||
|
||||
if guided_params.json:
|
||||
if not isinstance(guided_params.json, str):
|
||||
@ -152,11 +185,9 @@ class GrammarConfig:
|
||||
|
||||
return cls(json_str=json_str,
|
||||
vocab_size=model_config.hf_text_config.vocab_size,
|
||||
encoded_vocab=encoded_vocab,
|
||||
stop_token_ids=stop_token_ids,
|
||||
backend_str=backend_str,
|
||||
tokenizer_hash=tokenizer_hash,
|
||||
max_threads=max_threads)
|
||||
max_threads=max_threads,
|
||||
tokenizer_data=tokenizer_data)
|
||||
elif guided_params.grammar:
|
||||
# XGrammar only supports GBNF grammars, so we must convert Lark
|
||||
if grammar_is_likely_lark(guided_params.grammar):
|
||||
@ -181,19 +212,17 @@ class GrammarConfig:
|
||||
|
||||
return cls(grammar_str=grammar_str,
|
||||
vocab_size=model_config.hf_text_config.vocab_size,
|
||||
encoded_vocab=encoded_vocab,
|
||||
stop_token_ids=stop_token_ids,
|
||||
backend_str=backend_str,
|
||||
tokenizer_hash=tokenizer_hash,
|
||||
max_threads=max_threads)
|
||||
max_threads=max_threads,
|
||||
tokenizer_data=tokenizer_data)
|
||||
elif guided_params.json_object:
|
||||
return cls(json_object=True,
|
||||
vocab_size=model_config.hf_text_config.vocab_size,
|
||||
encoded_vocab=encoded_vocab,
|
||||
stop_token_ids=stop_token_ids,
|
||||
backend_str=backend_str,
|
||||
tokenizer_hash=tokenizer_hash,
|
||||
max_threads=max_threads)
|
||||
return cls(
|
||||
json_object=True,
|
||||
vocab_size=model_config.hf_text_config.vocab_size,
|
||||
tokenizer_hash=tokenizer_hash,
|
||||
max_threads=max_threads,
|
||||
tokenizer_data=tokenizer_data,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Currently only support JSON and EBNF grammar mode for xgrammar"
|
||||
@ -269,10 +298,14 @@ class XGrammarLogitsProcessor:
|
||||
# fill_next_token_bitmask so we move it to the device of scores
|
||||
device_type = scores.device.type
|
||||
if device_type != "cuda":
|
||||
scores = scores.to("cpu")
|
||||
scores = scores.to("cpu").unsqueeze(0)
|
||||
|
||||
# Note: In this method, if the tensors have different dimensions
|
||||
# on CPU device fails, but on GPU it runs without error. Hence the
|
||||
# unsqueeze above for scores, to match the token bitmask shape
|
||||
xgr.apply_token_bitmask_inplace(scores,
|
||||
self.token_bitmask.to(scores.device))
|
||||
if device_type != "cuda":
|
||||
scores = scores.to(device_type)
|
||||
scores = scores.to(device_type).squeeze()
|
||||
|
||||
return scores
|
||||
|
||||
@ -132,7 +132,7 @@ def get_tokenizer(
|
||||
if is_from_mistral_org and tokenizer_mode != "mistral":
|
||||
warnings.warn(
|
||||
'It is strongly recommended to run mistral models with '
|
||||
'`--tokenizer_mode "mistral"` to ensure correct '
|
||||
'`--tokenizer-mode "mistral"` to ensure correct '
|
||||
'encoding and decoding.',
|
||||
FutureWarning,
|
||||
stacklevel=2)
|
||||
|
||||
@ -314,12 +314,15 @@ class MistralTokenizer:
|
||||
|
||||
if regular_tokens:
|
||||
decoded_list.append(
|
||||
self.decode(regular_tokens)) # type: ignore
|
||||
self.tokenizer.decode(regular_tokens)) # type: ignore
|
||||
|
||||
decoded = ''.join(decoded_list)
|
||||
|
||||
return decoded
|
||||
|
||||
# WARN: Outlines logits processors can overwrite this method.
|
||||
# See: guided_decoding/outlines_logits_processors.py::_adapt_tokenizer
|
||||
# for more.
|
||||
def decode(self,
|
||||
ids: Union[List[int], int],
|
||||
skip_special_tokens: bool = True) -> str:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user