mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-28 13:47:15 +08:00
[V1][Core] Support MistralTokenizer for Structured Output (#14625)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
parent
80e78d02ac
commit
77a318bd01
@ -1,7 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import jsonschema
|
||||
import pytest
|
||||
@ -10,17 +13,27 @@ from vllm.entrypoints.llm import LLM
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
GUIDED_DECODING_BACKENDS_V1 = ["xgrammar"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_name():
|
||||
return [
|
||||
"Qwen/Qwen2.5-1.5B-Instruct", "mistralai/Ministral-8B-Instruct-2410"
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
GUIDED_DECODING_BACKENDS_V1)
|
||||
def test_guided_json_completion(monkeypatch, sample_json_schema,
|
||||
guided_decoding_backend: str):
|
||||
def test_guided_json_completion(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sample_json_schema: dict[str, Any],
|
||||
guided_decoding_backend: str,
|
||||
model_name: str,
|
||||
):
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
llm = LLM(model=MODEL_NAME, max_model_len=1024)
|
||||
llm = LLM(model=model_name, max_model_len=1024)
|
||||
sampling_params = SamplingParams(temperature=1.0,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
@ -50,9 +63,13 @@ def test_guided_json_completion(monkeypatch, sample_json_schema,
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
GUIDED_DECODING_BACKENDS_V1)
|
||||
def test_guided_json_object(monkeypatch, guided_decoding_backend: str):
|
||||
def test_guided_json_object(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
guided_decoding_backend: str,
|
||||
model_name: str,
|
||||
):
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
llm = LLM(model=MODEL_NAME, max_model_len=1024)
|
||||
llm = LLM(model=model_name, max_model_len=1024)
|
||||
sampling_params = SamplingParams(temperature=1.0,
|
||||
max_tokens=100,
|
||||
n=2,
|
||||
@ -84,10 +101,14 @@ def test_guided_json_object(monkeypatch, guided_decoding_backend: str):
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
GUIDED_DECODING_BACKENDS_V1)
|
||||
def test_guided_json_unsupported_schema(monkeypatch, unsupported_json_schema,
|
||||
guided_decoding_backend: str):
|
||||
def test_guided_json_unsupported_schema(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
unsupported_json_schema: dict[str, Any],
|
||||
guided_decoding_backend: str,
|
||||
model_name: str,
|
||||
):
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
llm = LLM(model=MODEL_NAME, max_model_len=1024)
|
||||
llm = LLM(model=model_name, max_model_len=1024)
|
||||
sampling_params = SamplingParams(temperature=1.0,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
@ -107,10 +128,14 @@ def test_guided_json_unsupported_schema(monkeypatch, unsupported_json_schema,
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
GUIDED_DECODING_BACKENDS_V1)
|
||||
def test_guided_grammar_ebnf(monkeypatch, sample_sql_ebnf,
|
||||
guided_decoding_backend: str):
|
||||
def test_guided_grammar_ebnf(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sample_sql_ebnf: str,
|
||||
guided_decoding_backend: str,
|
||||
model_name: str,
|
||||
):
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
llm = LLM(model=MODEL_NAME, max_model_len=1024)
|
||||
llm = LLM(model=model_name, max_model_len=1024)
|
||||
sampling_params = SamplingParams(temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=1000,
|
||||
@ -145,10 +170,14 @@ def test_guided_grammar_ebnf(monkeypatch, sample_sql_ebnf,
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
GUIDED_DECODING_BACKENDS_V1)
|
||||
def test_guided_grammar_lark(monkeypatch, sample_sql_lark,
|
||||
guided_decoding_backend: str):
|
||||
def test_guided_grammar_lark(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sample_sql_lark: str,
|
||||
guided_decoding_backend: str,
|
||||
model_name: str,
|
||||
):
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
llm = LLM(model=MODEL_NAME, max_model_len=1024)
|
||||
llm = LLM(model=model_name, max_model_len=1024)
|
||||
sampling_params = SamplingParams(temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=1000,
|
||||
@ -188,10 +217,13 @@ def test_guided_grammar_lark(monkeypatch, sample_sql_lark,
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
GUIDED_DECODING_BACKENDS_V1)
|
||||
def test_guided_grammar_ebnf_invalid(monkeypatch,
|
||||
guided_decoding_backend: str):
|
||||
def test_guided_grammar_ebnf_invalid(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
guided_decoding_backend: str,
|
||||
model_name: str,
|
||||
):
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
llm = LLM(model=MODEL_NAME, max_model_len=1024)
|
||||
llm = LLM(model=model_name, max_model_len=1024)
|
||||
sampling_params = SamplingParams(temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=1000,
|
||||
@ -212,9 +244,14 @@ def test_guided_grammar_ebnf_invalid(monkeypatch,
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
GUIDED_DECODING_BACKENDS_V1)
|
||||
def test_guided_regex(monkeypatch, sample_regex, guided_decoding_backend: str):
|
||||
def test_guided_regex(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sample_regex: str,
|
||||
guided_decoding_backend: str,
|
||||
model_name: str,
|
||||
):
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
llm = LLM(model=MODEL_NAME, max_model_len=1024)
|
||||
llm = LLM(model=model_name, max_model_len=1024)
|
||||
sampling_params = SamplingParams(temperature=0.8,
|
||||
top_p=0.95,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
@ -243,10 +280,14 @@ def test_guided_regex(monkeypatch, sample_regex, guided_decoding_backend: str):
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
GUIDED_DECODING_BACKENDS_V1)
|
||||
def test_guided_choice_completion(monkeypatch, sample_guided_choice,
|
||||
guided_decoding_backend: str):
|
||||
def test_guided_choice_completion(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sample_guided_choice: str,
|
||||
guided_decoding_backend: str,
|
||||
model_name: str,
|
||||
):
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
llm = LLM(model=MODEL_NAME, max_model_len=1024)
|
||||
llm = LLM(model=model_name, max_model_len=1024)
|
||||
sampling_params = SamplingParams(temperature=0.8,
|
||||
top_p=0.95,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
|
||||
@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Optional
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.utils import LazyLoader
|
||||
from vllm.v1.structured_output.grammar import Grammar, StructuredOutputOptions
|
||||
|
||||
@ -40,8 +41,40 @@ class StructuredOutputManager:
|
||||
tokenizer_group.ping()
|
||||
|
||||
tokenizer = tokenizer_group.get_lora_tokenizer(None)
|
||||
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
|
||||
tokenizer, vocab_size=self.vocab_size)
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
# NOTE: ideally, xgrammar should handle this accordingly.
|
||||
# refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98
|
||||
try:
|
||||
encoded_vocab = [
|
||||
token for token, _ in sorted(
|
||||
tokenizer.get_vocab().items(),
|
||||
key=lambda x: x[1],
|
||||
)
|
||||
]
|
||||
stop_token_ids = None
|
||||
if hasattr(
|
||||
tokenizer,
|
||||
"eos_token_id",
|
||||
) and tokenizer.eos_token_id is not None:
|
||||
stop_token_ids = [tokenizer.eos_token_id]
|
||||
except AttributeError as e:
|
||||
raise ValueError(
|
||||
f"Cannot get the vocabulary of the tokenizer "
|
||||
f"{type(tokenizer)}. The tokenizer should have a "
|
||||
"get_vocab method.") from e
|
||||
tokenizer_info = xgr.TokenizerInfo(
|
||||
encoded_vocab=encoded_vocab,
|
||||
# NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
|
||||
vocab_type=xgr.VocabType.BYTE_FALLBACK,
|
||||
vocab_size=self.vocab_size,
|
||||
stop_token_ids=stop_token_ids,
|
||||
add_prefix_space=True,
|
||||
)
|
||||
else:
|
||||
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
|
||||
tokenizer,
|
||||
vocab_size=self.vocab_size,
|
||||
)
|
||||
self.compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8)
|
||||
|
||||
# The default max_workers if not specified is the number of CPUs * 5,
|
||||
@ -51,7 +84,9 @@ class StructuredOutputManager:
|
||||
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
|
||||
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
self._grammar_bitmask = xgr.allocate_token_bitmask(
|
||||
self.vllm_config.scheduler_config.max_num_seqs, self.vocab_size)
|
||||
self.vllm_config.scheduler_config.max_num_seqs,
|
||||
self.vocab_size,
|
||||
)
|
||||
|
||||
self.init_complete = True
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user