diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index bddd224548c89..b99fb6a778295 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -1,7 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + import json import re +from typing import Any import jsonschema import pytest @@ -10,17 +13,27 @@ from vllm.entrypoints.llm import LLM from vllm.outputs import RequestOutput from vllm.sampling_params import GuidedDecodingParams, SamplingParams -MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" GUIDED_DECODING_BACKENDS_V1 = ["xgrammar"] +@pytest.fixture +def model_name(): + return [ + "Qwen/Qwen2.5-1.5B-Instruct", "mistralai/Ministral-8B-Instruct-2410" + ] + + @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS_V1) -def test_guided_json_completion(monkeypatch, sample_json_schema, - guided_decoding_backend: str): +def test_guided_json_completion( + monkeypatch: pytest.MonkeyPatch, + sample_json_schema: dict[str, Any], + guided_decoding_backend: str, + model_name: str, +): monkeypatch.setenv("VLLM_USE_V1", "1") - llm = LLM(model=MODEL_NAME, max_model_len=1024) + llm = LLM(model=model_name, max_model_len=1024) sampling_params = SamplingParams(temperature=1.0, max_tokens=1000, guided_decoding=GuidedDecodingParams( @@ -50,9 +63,13 @@ def test_guided_json_completion(monkeypatch, sample_json_schema, @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS_V1) -def test_guided_json_object(monkeypatch, guided_decoding_backend: str): +def test_guided_json_object( + monkeypatch: pytest.MonkeyPatch, + guided_decoding_backend: str, + model_name: str, +): monkeypatch.setenv("VLLM_USE_V1", "1") - llm = LLM(model=MODEL_NAME, max_model_len=1024) + llm = LLM(model=model_name, max_model_len=1024) sampling_params = SamplingParams(temperature=1.0, max_tokens=100, n=2, @@ -84,10 +101,14 @@ def test_guided_json_object(monkeypatch, guided_decoding_backend: str): @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS_V1) -def test_guided_json_unsupported_schema(monkeypatch, unsupported_json_schema, - guided_decoding_backend: str): +def test_guided_json_unsupported_schema( + monkeypatch: pytest.MonkeyPatch, + unsupported_json_schema: dict[str, Any], + guided_decoding_backend: str, + model_name: str, +): monkeypatch.setenv("VLLM_USE_V1", "1") - llm = LLM(model=MODEL_NAME, max_model_len=1024) + llm = LLM(model=model_name, max_model_len=1024) sampling_params = SamplingParams(temperature=1.0, max_tokens=1000, guided_decoding=GuidedDecodingParams( @@ -107,10 +128,14 @@ def test_guided_json_unsupported_schema(monkeypatch, unsupported_json_schema, @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS_V1) -def test_guided_grammar_ebnf(monkeypatch, sample_sql_ebnf, - guided_decoding_backend: str): +def test_guided_grammar_ebnf( + monkeypatch: pytest.MonkeyPatch, + sample_sql_ebnf: str, + guided_decoding_backend: str, + model_name: str, +): monkeypatch.setenv("VLLM_USE_V1", "1") - llm = LLM(model=MODEL_NAME, max_model_len=1024) + llm = LLM(model=model_name, max_model_len=1024) sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=1000, @@ -145,10 +170,14 @@ def test_guided_grammar_ebnf(monkeypatch, sample_sql_ebnf, @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS_V1) -def test_guided_grammar_lark(monkeypatch, sample_sql_lark, - guided_decoding_backend: str): +def test_guided_grammar_lark( + monkeypatch: pytest.MonkeyPatch, + sample_sql_lark: str, + guided_decoding_backend: str, + model_name: str, +): monkeypatch.setenv("VLLM_USE_V1", "1") - llm = LLM(model=MODEL_NAME, max_model_len=1024) + llm = LLM(model=model_name, max_model_len=1024) sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=1000, @@ -188,10 +217,13 @@ def test_guided_grammar_lark(monkeypatch, sample_sql_lark, @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS_V1) -def test_guided_grammar_ebnf_invalid(monkeypatch, - guided_decoding_backend: str): +def test_guided_grammar_ebnf_invalid( + monkeypatch: pytest.MonkeyPatch, + guided_decoding_backend: str, + model_name: str, +): monkeypatch.setenv("VLLM_USE_V1", "1") - llm = LLM(model=MODEL_NAME, max_model_len=1024) + llm = LLM(model=model_name, max_model_len=1024) sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=1000, @@ -212,9 +244,14 @@ def test_guided_grammar_ebnf_invalid(monkeypatch, @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS_V1) -def test_guided_regex(monkeypatch, sample_regex, guided_decoding_backend: str): +def test_guided_regex( + monkeypatch: pytest.MonkeyPatch, + sample_regex: str, + guided_decoding_backend: str, + model_name: str, +): monkeypatch.setenv("VLLM_USE_V1", "1") - llm = LLM(model=MODEL_NAME, max_model_len=1024) + llm = LLM(model=model_name, max_model_len=1024) sampling_params = SamplingParams(temperature=0.8, top_p=0.95, guided_decoding=GuidedDecodingParams( @@ -243,10 +280,14 @@ def test_guided_regex(monkeypatch, sample_regex, guided_decoding_backend: str): @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS_V1) -def test_guided_choice_completion(monkeypatch, sample_guided_choice, - guided_decoding_backend: str): +def test_guided_choice_completion( + monkeypatch: pytest.MonkeyPatch, + sample_guided_choice: str, + guided_decoding_backend: str, + model_name: str, +): monkeypatch.setenv("VLLM_USE_V1", "1") - llm = LLM(model=MODEL_NAME, max_model_len=1024) + llm = LLM(model=model_name, max_model_len=1024) sampling_params = SamplingParams(temperature=0.8, top_p=0.95, guided_decoding=GuidedDecodingParams( diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 3428b8522e506..45fec1122cce3 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Optional from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs +from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from vllm.utils import LazyLoader from vllm.v1.structured_output.grammar import Grammar, StructuredOutputOptions @@ -40,8 +41,40 @@ class StructuredOutputManager: tokenizer_group.ping() tokenizer = tokenizer_group.get_lora_tokenizer(None) - tokenizer_info = xgr.TokenizerInfo.from_huggingface( - tokenizer, vocab_size=self.vocab_size) + if isinstance(tokenizer, MistralTokenizer): + # NOTE: ideally, xgrammar should handle this accordingly. + # refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98 + try: + encoded_vocab = [ + token for token, _ in sorted( + tokenizer.get_vocab().items(), + key=lambda x: x[1], + ) + ] + stop_token_ids = None + if hasattr( + tokenizer, + "eos_token_id", + ) and tokenizer.eos_token_id is not None: + stop_token_ids = [tokenizer.eos_token_id] + except AttributeError as e: + raise ValueError( + f"Cannot get the vocabulary of the tokenizer " + f"{type(tokenizer)}. The tokenizer should have a " + "get_vocab method.") from e + tokenizer_info = xgr.TokenizerInfo( + encoded_vocab=encoded_vocab, + # NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501 + vocab_type=xgr.VocabType.BYTE_FALLBACK, + vocab_size=self.vocab_size, + stop_token_ids=stop_token_ids, + add_prefix_space=True, + ) + else: + tokenizer_info = xgr.TokenizerInfo.from_huggingface( + tokenizer, + vocab_size=self.vocab_size, + ) self.compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8) # The default max_workers if not specified is the number of CPUs * 5, @@ -51,7 +84,9 @@ class StructuredOutputManager: max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2) self.executor = ThreadPoolExecutor(max_workers=max_workers) self._grammar_bitmask = xgr.allocate_token_bitmask( - self.vllm_config.scheduler_config.max_num_seqs, self.vocab_size) + self.vllm_config.scheduler_config.max_num_seqs, + self.vocab_size, + ) self.init_complete = True