mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-09 08:49:08 +08:00
Frontend: Adding LM Format Enforcer support to V1 engine (#22564)
Signed-off-by: Noam Gat <noamgat@gmail.com> Co-authored-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
parent
504d914314
commit
39971db3aa
@ -18,7 +18,7 @@ prometheus_client >= 0.18.0
|
|||||||
pillow # Required for image processing
|
pillow # Required for image processing
|
||||||
prometheus-fastapi-instrumentator >= 7.0.0
|
prometheus-fastapi-instrumentator >= 7.0.0
|
||||||
tiktoken >= 0.6.0 # Required for DBRX tokenizer
|
tiktoken >= 0.6.0 # Required for DBRX tokenizer
|
||||||
lm-format-enforcer >= 0.10.11, < 0.11
|
lm-format-enforcer == 0.11.3
|
||||||
llguidance >= 0.7.11, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64"
|
llguidance >= 0.7.11, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64"
|
||||||
outlines_core == 0.2.10 ; platform_machine != "s390x"
|
outlines_core == 0.2.10 ; platform_machine != "s390x"
|
||||||
outlines == 0.1.11 ; platform_machine == "s390x"
|
outlines == 0.1.11 ; platform_machine == "s390x"
|
||||||
|
|||||||
@ -41,8 +41,11 @@ EAGLE_SPEC_CONFIG = {
|
|||||||
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [
|
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [
|
||||||
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto", None),
|
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto", None),
|
||||||
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None),
|
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None),
|
||||||
|
("mistralai/Ministral-8B-Instruct-2410", "lm-format-enforcer", "auto",
|
||||||
|
None),
|
||||||
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None),
|
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None),
|
||||||
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None),
|
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None),
|
||||||
|
("Qwen/Qwen2.5-1.5B-Instruct", "lm-format-enforcer", "auto", None),
|
||||||
("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", None),
|
("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", None),
|
||||||
("mistralai/Ministral-8B-Instruct-2410", "outlines", "mistral", None),
|
("mistralai/Ministral-8B-Instruct-2410", "outlines", "mistral", None),
|
||||||
("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto",
|
("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto",
|
||||||
@ -148,7 +151,8 @@ def test_structured_output(
|
|||||||
|
|
||||||
generated_text = output.outputs[0].text
|
generated_text = output.outputs[0].text
|
||||||
assert generated_text is not None
|
assert generated_text is not None
|
||||||
assert "\n" not in generated_text
|
if guided_decoding_backend != 'lm-format-enforcer':
|
||||||
|
assert "\n" not in generated_text
|
||||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
output_json = json.loads(generated_text)
|
output_json = json.loads(generated_text)
|
||||||
jsonschema.validate(instance=output_json, schema=sample_json_schema)
|
jsonschema.validate(instance=output_json, schema=sample_json_schema)
|
||||||
@ -225,7 +229,7 @@ def test_structured_output(
|
|||||||
parsed_json = json.loads(generated_text)
|
parsed_json = json.loads(generated_text)
|
||||||
assert isinstance(parsed_json, dict)
|
assert isinstance(parsed_json, dict)
|
||||||
|
|
||||||
if guided_decoding_backend != "outlines":
|
if guided_decoding_backend not in ["outlines", "lm-format-enforcer"]:
|
||||||
#
|
#
|
||||||
# Test 4: Generate SQL statement using EBNF grammar
|
# Test 4: Generate SQL statement using EBNF grammar
|
||||||
#
|
#
|
||||||
@ -439,7 +443,7 @@ def test_structured_output(
|
|||||||
output_json = json.loads(generated_text)
|
output_json = json.loads(generated_text)
|
||||||
jsonschema.validate(instance=output_json, schema=json_schema)
|
jsonschema.validate(instance=output_json, schema=json_schema)
|
||||||
|
|
||||||
if guided_decoding_backend != "outlines":
|
if guided_decoding_backend not in ["outlines", "lm-format-enforcer"]:
|
||||||
#
|
#
|
||||||
# Test 11: Generate structured output using structural_tag format
|
# Test 11: Generate structured output using structural_tag format
|
||||||
#
|
#
|
||||||
|
|||||||
@ -3057,7 +3057,8 @@ def get_served_model_name(model: str,
|
|||||||
return served_model_name
|
return served_model_name
|
||||||
|
|
||||||
|
|
||||||
GuidedDecodingBackend = Literal["auto", "xgrammar", "guidance", "outlines"]
|
GuidedDecodingBackend = Literal["auto", "xgrammar", "guidance", "outlines",
|
||||||
|
"lm-format-enforcer"]
|
||||||
|
|
||||||
|
|
||||||
@config
|
@config
|
||||||
|
|||||||
@ -21,6 +21,8 @@ from vllm.v1.engine import EngineCoreRequest
|
|||||||
from vllm.v1.engine.mm_input_cache import MultiModalInputCacheClient
|
from vllm.v1.engine.mm_input_cache import MultiModalInputCacheClient
|
||||||
from vllm.v1.structured_output.backend_guidance import (
|
from vllm.v1.structured_output.backend_guidance import (
|
||||||
validate_guidance_grammar)
|
validate_guidance_grammar)
|
||||||
|
from vllm.v1.structured_output.backend_lm_format_enforcer import (
|
||||||
|
validate_structured_output_request_lm_format_enforcer)
|
||||||
from vllm.v1.structured_output.backend_outlines import (
|
from vllm.v1.structured_output.backend_outlines import (
|
||||||
validate_structured_output_request_outlines)
|
validate_structured_output_request_outlines)
|
||||||
from vllm.v1.structured_output.backend_xgrammar import (
|
from vllm.v1.structured_output.backend_xgrammar import (
|
||||||
@ -200,6 +202,9 @@ class Processor:
|
|||||||
elif engine_level_backend == "outlines":
|
elif engine_level_backend == "outlines":
|
||||||
# outlines backend
|
# outlines backend
|
||||||
validate_structured_output_request_outlines(params)
|
validate_structured_output_request_outlines(params)
|
||||||
|
elif engine_level_backend == "lm-format-enforcer":
|
||||||
|
# lm format enforcer backend
|
||||||
|
validate_structured_output_request_lm_format_enforcer(params)
|
||||||
else:
|
else:
|
||||||
# NOTE: engine_level_backend must be "auto" here, because we have
|
# NOTE: engine_level_backend must be "auto" here, because we have
|
||||||
# checked supported_backends above.
|
# checked supported_backends above.
|
||||||
|
|||||||
@ -108,6 +108,14 @@ class StructuredOutputManager:
|
|||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
)
|
)
|
||||||
|
elif backend == "lm-format-enforcer":
|
||||||
|
from vllm.v1.structured_output.backend_lm_format_enforcer import ( # noqa: E501
|
||||||
|
LMFormatEnforcerBackend)
|
||||||
|
self.backend = LMFormatEnforcerBackend(
|
||||||
|
self.vllm_config,
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported structured output backend: {backend}")
|
f"Unsupported structured output backend: {backend}")
|
||||||
|
|||||||
167
vllm/v1/structured_output/backend_lm_format_enforcer.py
Normal file
167
vllm/v1/structured_output/backend_lm_format_enforcer.py
Normal file
@ -0,0 +1,167 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import ast
|
||||||
|
import json
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
from vllm.utils import LazyLoader
|
||||||
|
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
|
||||||
|
StructuredOutputGrammar,
|
||||||
|
StructuredOutputOptions)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import lmformatenforcer
|
||||||
|
import lmformatenforcer.integrations.vllm as lmfe_vllm
|
||||||
|
else:
|
||||||
|
lmformatenforcer = LazyLoader("lmformatenforcer", globals(),
|
||||||
|
"lmformatenforcer")
|
||||||
|
lmfe_vllm = LazyLoader("lmformatenforcer.integrations.vllm", globals(),
|
||||||
|
"lmformatenforcer.integrations.vllm")
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def _cached_build_vllm_token_enforcer_tokenizer_data(
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
vocab_size: int) -> lmfe_vllm.TokenEnforcerTokenizerData:
|
||||||
|
return lmfe_vllm.build_vllm_token_enforcer_tokenizer_data(
|
||||||
|
tokenizer, use_bitmask=True, vocab_size=vocab_size)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LMFormatEnforcerGrammar(StructuredOutputGrammar):
|
||||||
|
token_enforcer: lmformatenforcer.TokenEnforcer
|
||||||
|
current_tokens_prefix: list[int] = field(default_factory=list)
|
||||||
|
|
||||||
|
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
|
||||||
|
original_len = len(self.current_tokens_prefix)
|
||||||
|
for token in tokens:
|
||||||
|
if not self.token_enforcer.get_allowed_tokens(
|
||||||
|
self.current_tokens_prefix).is_token_allowed(token):
|
||||||
|
# Rollback partial updates to ensure atomicity.
|
||||||
|
del self.current_tokens_prefix[original_len:]
|
||||||
|
return False
|
||||||
|
self.current_tokens_prefix.append(token)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def validate_tokens(self, tokens: list[int]) -> list[int]:
|
||||||
|
for prefix_length in range(len(tokens)):
|
||||||
|
prefix = tokens[:prefix_length]
|
||||||
|
next_token = tokens[prefix_length]
|
||||||
|
if not self.token_enforcer.get_allowed_tokens(
|
||||||
|
self.current_tokens_prefix +
|
||||||
|
prefix).is_token_allowed(next_token):
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
return tokens[:prefix_length]
|
||||||
|
|
||||||
|
def rollback(self, num_tokens: int) -> None:
|
||||||
|
self.current_tokens_prefix = self.current_tokens_prefix[:-num_tokens]
|
||||||
|
|
||||||
|
def fill_bitmask(self, bitmask: torch.Tensor, batch_index: int) -> None:
|
||||||
|
allowed_tokens = self.token_enforcer.get_allowed_tokens(
|
||||||
|
self.current_tokens_prefix)
|
||||||
|
bitmask[batch_index] = allowed_tokens.allowed_tokens
|
||||||
|
|
||||||
|
def is_terminated(self) -> bool:
|
||||||
|
# We are considered terminated if the prefix ends with eos_token_id
|
||||||
|
return_value = len(
|
||||||
|
self.current_tokens_prefix) > 0 and self.current_tokens_prefix[
|
||||||
|
-1] == self.token_enforcer.eos_token_id
|
||||||
|
return return_value
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.current_tokens_prefix = []
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LMFormatEnforcerBackend(StructuredOutputBackend):
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
self.tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data(
|
||||||
|
self.tokenizer, self.vocab_size)
|
||||||
|
|
||||||
|
def compile_grammar(self, request_type: StructuredOutputOptions,
|
||||||
|
grammar_spec: str) -> StructuredOutputGrammar:
|
||||||
|
character_level_parser: lmformatenforcer.CharacterLevelParser
|
||||||
|
if request_type == StructuredOutputOptions.JSON:
|
||||||
|
spec_dict = json.loads(grammar_spec)
|
||||||
|
character_level_parser = lmformatenforcer.JsonSchemaParser(
|
||||||
|
spec_dict)
|
||||||
|
elif request_type == StructuredOutputOptions.JSON_OBJECT:
|
||||||
|
character_level_parser = lmformatenforcer.JsonSchemaParser(None)
|
||||||
|
elif request_type == StructuredOutputOptions.REGEX:
|
||||||
|
character_level_parser = lmformatenforcer.RegexParser(grammar_spec)
|
||||||
|
elif request_type == StructuredOutputOptions.CHOICE:
|
||||||
|
choices = ast.literal_eval(grammar_spec)
|
||||||
|
character_level_parser = lmformatenforcer.UnionParser(
|
||||||
|
[lmformatenforcer.StringParser(choice) for choice in choices])
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid request type for LM Format Enforcer backend"
|
||||||
|
f"({request_type!s})")
|
||||||
|
max_rollback_tokens = (
|
||||||
|
self.vllm_config.speculative_config.num_speculative_tokens
|
||||||
|
if self.vllm_config.speculative_config is not None else 0)
|
||||||
|
|
||||||
|
if max_rollback_tokens > 0:
|
||||||
|
raise ValueError(
|
||||||
|
"LM Format Enforcer backend does not support speculative tokens"
|
||||||
|
)
|
||||||
|
|
||||||
|
token_enforcer = lmformatenforcer.TokenEnforcer(
|
||||||
|
tokenizer_data=self.tokenizer_data,
|
||||||
|
parser=character_level_parser,
|
||||||
|
)
|
||||||
|
return LMFormatEnforcerGrammar(token_enforcer)
|
||||||
|
|
||||||
|
def allocate_token_bitmask(self, max_num_seqs: int) -> torch.Tensor:
|
||||||
|
return torch.full(
|
||||||
|
(max_num_seqs, (self.vocab_size + 31) // 32),
|
||||||
|
-1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
pin_memory=torch.cuda.is_available(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def destroy(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def validate_structured_output_request_lm_format_enforcer(
|
||||||
|
params: SamplingParams):
|
||||||
|
if params.guided_decoding is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
gd_params = params.guided_decoding
|
||||||
|
|
||||||
|
if gd_params.regex:
|
||||||
|
return
|
||||||
|
elif gd_params.json:
|
||||||
|
if isinstance(gd_params.json, str):
|
||||||
|
try:
|
||||||
|
# make sure schema is valid json
|
||||||
|
json.loads(gd_params.json)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise ValueError("Invalid JSON grammar specification.") from e
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
json.dumps(gd_params.json)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Error serializing guided decoding jsonschema: {e}"
|
||||||
|
) from e
|
||||||
|
return
|
||||||
|
elif gd_params.choice:
|
||||||
|
return
|
||||||
|
elif gd_params.grammar:
|
||||||
|
raise ValueError("LM Format Enforcer guided decoding backend "
|
||||||
|
"does not support grammar specifications")
|
||||||
Loading…
x
Reference in New Issue
Block a user