[Frontend][Core] Update Outlines Integration from FSM to Guide (#4109)

Co-authored-by: Simon Mo <simon.mo@hey.com>
Co-authored-by: Breno Faria <breno.faria@intrafind.com>
This commit is contained in:
Breno Faria 2024-06-06 01:49:12 +02:00 committed by GitHub
parent 3a6ae1d33c
commit 7b0a0dfb22
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 49 additions and 48 deletions

View File

@ -17,6 +17,6 @@ prometheus_client >= 0.18.0
prometheus-fastapi-instrumentator >= 7.0.0 prometheus-fastapi-instrumentator >= 7.0.0
tiktoken >= 0.6.0 # Required for DBRX tokenizer tiktoken >= 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer == 0.10.1 lm-format-enforcer == 0.10.1
outlines == 0.0.34 # Requires torch >= 2.1.0 outlines >= 0.0.43 # Requires torch >= 2.1.0
typing_extensions typing_extensions
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4

View File

@ -63,7 +63,6 @@ def test_guided_logits_processors():
tokenizer, tokenizer,
whitespace_pattern=None) whitespace_pattern=None)
regex_LP.init_state()
token_ids = tokenizer.encode( token_ids = tokenizer.encode(
f"Give an example IPv4 address with this regex: {TEST_REGEX}") f"Give an example IPv4 address with this regex: {TEST_REGEX}")
tensor = torch.rand(32000) tensor = torch.rand(32000)
@ -72,7 +71,6 @@ def test_guided_logits_processors():
assert tensor.shape == original_tensor.shape assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor) assert not torch.allclose(tensor, original_tensor)
json_LP.init_state()
token_ids = tokenizer.encode( token_ids = tokenizer.encode(
f"Give an employee profile that fits this schema: {TEST_SCHEMA}") f"Give an employee profile that fits this schema: {TEST_SCHEMA}")
tensor = torch.rand(32000) tensor = torch.rand(32000)

View File

@ -1,8 +1,6 @@
import asyncio import asyncio
import concurrent.futures import concurrent.futures
from copy import copy
from enum import Enum from enum import Enum
from functools import lru_cache
from json import dumps as json_dumps from json import dumps as json_dumps
from re import escape as regex_escape from re import escape as regex_escape
from typing import Tuple, Union from typing import Tuple, Union
@ -54,8 +52,10 @@ global_thread_pool = None # used for generating logits processor fsm
async def get_outlines_guided_decoding_logits_processor( async def get_outlines_guided_decoding_logits_processor(
request: Union[CompletionRequest, ChatCompletionRequest], request: Union[CompletionRequest,
tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, None]: ChatCompletionRequest], tokenizer: PreTrainedTokenizerBase
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
None]:
""" """
Given an OpenAI-compatible request, check for guided decoding parameters Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide. and get the necessary logits processor for the given guide.
@ -64,7 +64,7 @@ async def get_outlines_guided_decoding_logits_processor(
""" """
global global_thread_pool global global_thread_pool
guide, mode = _get_guide_and_mode(request) guide, mode = _get_guide_and_mode(request)
if not guide: if not guide or not mode:
return None return None
if global_thread_pool is None: if global_thread_pool is None:
@ -72,15 +72,9 @@ async def get_outlines_guided_decoding_logits_processor(
max_workers=2) max_workers=2)
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
result = await loop.run_in_executor(global_thread_pool, return await loop.run_in_executor(global_thread_pool,
_get_cached_logits_processor, guide, _get_logits_processor, guide, tokenizer,
tokenizer, mode, mode, request.guided_whitespace_pattern)
request.guided_whitespace_pattern)
logits_processor = copy(result)
# reset logits processor's internal state
logits_processor.init_state()
return logits_processor
def _get_guide_and_mode( def _get_guide_and_mode(
@ -115,11 +109,10 @@ def _get_guide_and_mode(
return None, None return None, None
@lru_cache(maxsize=32) def _get_logits_processor(
def _get_cached_logits_processor(guide: str, guide: str, tokenizer: PreTrainedTokenizerBase, mode: GuidedDecodingMode,
tokenizer: PreTrainedTokenizerBase, whitespace_pattern: Union[str, None]
mode: GuidedDecodingMode, ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]:
whitespace_pattern: Union[str, None]):
if mode == GuidedDecodingMode.JSON: if mode == GuidedDecodingMode.JSON:
return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern) return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern)
elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE: elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:

View File

@ -21,7 +21,7 @@ from functools import lru_cache
from typing import Callable, DefaultDict, Dict, List, Union from typing import Callable, DefaultDict, Dict, List, Union
import torch import torch
from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM from outlines.fsm.guide import CFGGuide, Generate, Guide, RegexGuide, Write
from outlines.fsm.json_schema import build_regex_from_schema from outlines.fsm.json_schema import build_regex_from_schema
from pydantic import BaseModel from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
@ -29,28 +29,32 @@ from transformers import PreTrainedTokenizerBase
class BaseLogitsProcessor: class BaseLogitsProcessor:
def __init__(self): def __init__(self, guide: Guide):
# Child class should use initialize in their init. self._guide: Guide = guide
self.fsm: FSM self._fsm_state: DefaultDict[int, int] = defaultdict(int)
def init_state(self):
"""Initialize the FSM states."""
self.fsm_state: DefaultDict[int, int] = defaultdict(int)
def __call__(self, input_ids: List[int], def __call__(self, input_ids: List[int],
scores: torch.Tensor) -> torch.Tensor: scores: torch.Tensor) -> torch.Tensor:
"""Use the FSM to bias the logits before sampling the next token.""" """Use the FSM to bias the logits before sampling the next token."""
seq_id = hash(tuple(input_ids)) seq_id = hash(tuple(input_ids))
if len(input_ids) == 0: if len(input_ids) > 0:
self.init_state()
else:
last_token = input_ids[-1] last_token = input_ids[-1]
last_seq_id = hash(tuple(input_ids[:-1])) last_seq_id = hash(tuple(input_ids[:-1]))
self.fsm_state[seq_id] = self.fsm.next_state( self._fsm_state[seq_id] = self._guide.get_next_state(
self.fsm_state[last_seq_id], last_token) state=self._fsm_state[last_seq_id], token_id=last_token)
allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id]) instruction = self._guide.get_next_instruction(
state=self._fsm_state[seq_id])
if type(instruction) == Generate:
allowed_tokens = instruction.tokens
elif type(instruction) == Write:
# TODO: support fast forward tokens
allowed_tokens = [instruction.tokens[0]]
else:
raise TypeError(
f"Unsupported instruction type {type(instruction)}")
mask = torch.full((scores.shape[-1], ), mask = torch.full((scores.shape[-1], ),
-math.inf, -math.inf,
@ -62,6 +66,13 @@ class BaseLogitsProcessor:
class RegexLogitsProcessor(BaseLogitsProcessor): class RegexLogitsProcessor(BaseLogitsProcessor):
@classmethod
@lru_cache(maxsize=32)
def _get_guide(cls, regex_string: str,
tokenizer: PreTrainedTokenizerBase) -> Guide:
tokenizer = _adapt_tokenizer(tokenizer)
return RegexGuide(regex_string, tokenizer)
def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase): def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase):
"""Compile the FSM that drives the regex-structured generation. """Compile the FSM that drives the regex-structured generation.
@ -73,9 +84,8 @@ class RegexLogitsProcessor(BaseLogitsProcessor):
The model's tokenizer The model's tokenizer
""" """
tokenizer = _adapt_tokenizer(tokenizer) super().__init__(
fsm = RegexFSM(regex_string, tokenizer) RegexLogitsProcessor._get_guide(regex_string, tokenizer))
self.fsm = fsm
class JSONLogitsProcessor(RegexLogitsProcessor): class JSONLogitsProcessor(RegexLogitsProcessor):
@ -115,6 +125,12 @@ class JSONLogitsProcessor(RegexLogitsProcessor):
class CFGLogitsProcessor(BaseLogitsProcessor): class CFGLogitsProcessor(BaseLogitsProcessor):
@classmethod
@lru_cache(maxsize=32)
def _get_guide(cls, cfg: str, tokenizer: PreTrainedTokenizerBase) -> Guide:
tokenizer = _adapt_tokenizer(tokenizer)
return CFGGuide(cfg, tokenizer)
def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase): def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase):
"""Compile the FSM that drives the context free grammar generation. """Compile the FSM that drives the context free grammar generation.
@ -126,17 +142,11 @@ class CFGLogitsProcessor(BaseLogitsProcessor):
The model's tokenizer The model's tokenizer
""" """
tokenizer = _adapt_tokenizer(tokenizer) super().__init__(CFGLogitsProcessor._get_guide(cfg, tokenizer))
fsm = CFGFSM(cfg, tokenizer) self._guide = self._guide.copy()
self.fsm = fsm
def init_state(self):
"""Initialize state with a CFGFSM copy."""
super().init_state()
self.fsm = self.fsm.copy()
@lru_cache @lru_cache(maxsize=32)
def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase): def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
"""Adapt vLLM's tokenizer to use to compile the FSM. """Adapt vLLM's tokenizer to use to compile the FSM.