mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 07:05:01 +08:00
[Frontend][Core] Update Outlines Integration from FSM to Guide (#4109)
Co-authored-by: Simon Mo <simon.mo@hey.com> Co-authored-by: Breno Faria <breno.faria@intrafind.com>
This commit is contained in:
parent
3a6ae1d33c
commit
7b0a0dfb22
@ -17,6 +17,6 @@ prometheus_client >= 0.18.0
|
|||||||
prometheus-fastapi-instrumentator >= 7.0.0
|
prometheus-fastapi-instrumentator >= 7.0.0
|
||||||
tiktoken >= 0.6.0 # Required for DBRX tokenizer
|
tiktoken >= 0.6.0 # Required for DBRX tokenizer
|
||||||
lm-format-enforcer == 0.10.1
|
lm-format-enforcer == 0.10.1
|
||||||
outlines == 0.0.34 # Requires torch >= 2.1.0
|
outlines >= 0.0.43 # Requires torch >= 2.1.0
|
||||||
typing_extensions
|
typing_extensions
|
||||||
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
|
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
|
||||||
|
|||||||
@ -63,7 +63,6 @@ def test_guided_logits_processors():
|
|||||||
tokenizer,
|
tokenizer,
|
||||||
whitespace_pattern=None)
|
whitespace_pattern=None)
|
||||||
|
|
||||||
regex_LP.init_state()
|
|
||||||
token_ids = tokenizer.encode(
|
token_ids = tokenizer.encode(
|
||||||
f"Give an example IPv4 address with this regex: {TEST_REGEX}")
|
f"Give an example IPv4 address with this regex: {TEST_REGEX}")
|
||||||
tensor = torch.rand(32000)
|
tensor = torch.rand(32000)
|
||||||
@ -72,7 +71,6 @@ def test_guided_logits_processors():
|
|||||||
assert tensor.shape == original_tensor.shape
|
assert tensor.shape == original_tensor.shape
|
||||||
assert not torch.allclose(tensor, original_tensor)
|
assert not torch.allclose(tensor, original_tensor)
|
||||||
|
|
||||||
json_LP.init_state()
|
|
||||||
token_ids = tokenizer.encode(
|
token_ids = tokenizer.encode(
|
||||||
f"Give an employee profile that fits this schema: {TEST_SCHEMA}")
|
f"Give an employee profile that fits this schema: {TEST_SCHEMA}")
|
||||||
tensor = torch.rand(32000)
|
tensor = torch.rand(32000)
|
||||||
|
|||||||
@ -1,8 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
from copy import copy
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import lru_cache
|
|
||||||
from json import dumps as json_dumps
|
from json import dumps as json_dumps
|
||||||
from re import escape as regex_escape
|
from re import escape as regex_escape
|
||||||
from typing import Tuple, Union
|
from typing import Tuple, Union
|
||||||
@ -54,8 +52,10 @@ global_thread_pool = None # used for generating logits processor fsm
|
|||||||
|
|
||||||
|
|
||||||
async def get_outlines_guided_decoding_logits_processor(
|
async def get_outlines_guided_decoding_logits_processor(
|
||||||
request: Union[CompletionRequest, ChatCompletionRequest],
|
request: Union[CompletionRequest,
|
||||||
tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, None]:
|
ChatCompletionRequest], tokenizer: PreTrainedTokenizerBase
|
||||||
|
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
|
||||||
|
None]:
|
||||||
"""
|
"""
|
||||||
Given an OpenAI-compatible request, check for guided decoding parameters
|
Given an OpenAI-compatible request, check for guided decoding parameters
|
||||||
and get the necessary logits processor for the given guide.
|
and get the necessary logits processor for the given guide.
|
||||||
@ -64,7 +64,7 @@ async def get_outlines_guided_decoding_logits_processor(
|
|||||||
"""
|
"""
|
||||||
global global_thread_pool
|
global global_thread_pool
|
||||||
guide, mode = _get_guide_and_mode(request)
|
guide, mode = _get_guide_and_mode(request)
|
||||||
if not guide:
|
if not guide or not mode:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if global_thread_pool is None:
|
if global_thread_pool is None:
|
||||||
@ -72,15 +72,9 @@ async def get_outlines_guided_decoding_logits_processor(
|
|||||||
max_workers=2)
|
max_workers=2)
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
result = await loop.run_in_executor(global_thread_pool,
|
return await loop.run_in_executor(global_thread_pool,
|
||||||
_get_cached_logits_processor, guide,
|
_get_logits_processor, guide, tokenizer,
|
||||||
tokenizer, mode,
|
mode, request.guided_whitespace_pattern)
|
||||||
request.guided_whitespace_pattern)
|
|
||||||
|
|
||||||
logits_processor = copy(result)
|
|
||||||
# reset logits processor's internal state
|
|
||||||
logits_processor.init_state()
|
|
||||||
return logits_processor
|
|
||||||
|
|
||||||
|
|
||||||
def _get_guide_and_mode(
|
def _get_guide_and_mode(
|
||||||
@ -115,11 +109,10 @@ def _get_guide_and_mode(
|
|||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=32)
|
def _get_logits_processor(
|
||||||
def _get_cached_logits_processor(guide: str,
|
guide: str, tokenizer: PreTrainedTokenizerBase, mode: GuidedDecodingMode,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
whitespace_pattern: Union[str, None]
|
||||||
mode: GuidedDecodingMode,
|
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]:
|
||||||
whitespace_pattern: Union[str, None]):
|
|
||||||
if mode == GuidedDecodingMode.JSON:
|
if mode == GuidedDecodingMode.JSON:
|
||||||
return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern)
|
return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern)
|
||||||
elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
|
elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
|
||||||
|
|||||||
@ -21,7 +21,7 @@ from functools import lru_cache
|
|||||||
from typing import Callable, DefaultDict, Dict, List, Union
|
from typing import Callable, DefaultDict, Dict, List, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM
|
from outlines.fsm.guide import CFGGuide, Generate, Guide, RegexGuide, Write
|
||||||
from outlines.fsm.json_schema import build_regex_from_schema
|
from outlines.fsm.json_schema import build_regex_from_schema
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
@ -29,28 +29,32 @@ from transformers import PreTrainedTokenizerBase
|
|||||||
|
|
||||||
class BaseLogitsProcessor:
|
class BaseLogitsProcessor:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, guide: Guide):
|
||||||
# Child class should use initialize in their init.
|
self._guide: Guide = guide
|
||||||
self.fsm: FSM
|
self._fsm_state: DefaultDict[int, int] = defaultdict(int)
|
||||||
|
|
||||||
def init_state(self):
|
|
||||||
"""Initialize the FSM states."""
|
|
||||||
self.fsm_state: DefaultDict[int, int] = defaultdict(int)
|
|
||||||
|
|
||||||
def __call__(self, input_ids: List[int],
|
def __call__(self, input_ids: List[int],
|
||||||
scores: torch.Tensor) -> torch.Tensor:
|
scores: torch.Tensor) -> torch.Tensor:
|
||||||
"""Use the FSM to bias the logits before sampling the next token."""
|
"""Use the FSM to bias the logits before sampling the next token."""
|
||||||
seq_id = hash(tuple(input_ids))
|
seq_id = hash(tuple(input_ids))
|
||||||
|
|
||||||
if len(input_ids) == 0:
|
if len(input_ids) > 0:
|
||||||
self.init_state()
|
|
||||||
else:
|
|
||||||
last_token = input_ids[-1]
|
last_token = input_ids[-1]
|
||||||
last_seq_id = hash(tuple(input_ids[:-1]))
|
last_seq_id = hash(tuple(input_ids[:-1]))
|
||||||
self.fsm_state[seq_id] = self.fsm.next_state(
|
self._fsm_state[seq_id] = self._guide.get_next_state(
|
||||||
self.fsm_state[last_seq_id], last_token)
|
state=self._fsm_state[last_seq_id], token_id=last_token)
|
||||||
|
|
||||||
allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id])
|
instruction = self._guide.get_next_instruction(
|
||||||
|
state=self._fsm_state[seq_id])
|
||||||
|
|
||||||
|
if type(instruction) == Generate:
|
||||||
|
allowed_tokens = instruction.tokens
|
||||||
|
elif type(instruction) == Write:
|
||||||
|
# TODO: support fast forward tokens
|
||||||
|
allowed_tokens = [instruction.tokens[0]]
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f"Unsupported instruction type {type(instruction)}")
|
||||||
|
|
||||||
mask = torch.full((scores.shape[-1], ),
|
mask = torch.full((scores.shape[-1], ),
|
||||||
-math.inf,
|
-math.inf,
|
||||||
@ -62,6 +66,13 @@ class BaseLogitsProcessor:
|
|||||||
|
|
||||||
class RegexLogitsProcessor(BaseLogitsProcessor):
|
class RegexLogitsProcessor(BaseLogitsProcessor):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@lru_cache(maxsize=32)
|
||||||
|
def _get_guide(cls, regex_string: str,
|
||||||
|
tokenizer: PreTrainedTokenizerBase) -> Guide:
|
||||||
|
tokenizer = _adapt_tokenizer(tokenizer)
|
||||||
|
return RegexGuide(regex_string, tokenizer)
|
||||||
|
|
||||||
def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase):
|
def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase):
|
||||||
"""Compile the FSM that drives the regex-structured generation.
|
"""Compile the FSM that drives the regex-structured generation.
|
||||||
|
|
||||||
@ -73,9 +84,8 @@ class RegexLogitsProcessor(BaseLogitsProcessor):
|
|||||||
The model's tokenizer
|
The model's tokenizer
|
||||||
|
|
||||||
"""
|
"""
|
||||||
tokenizer = _adapt_tokenizer(tokenizer)
|
super().__init__(
|
||||||
fsm = RegexFSM(regex_string, tokenizer)
|
RegexLogitsProcessor._get_guide(regex_string, tokenizer))
|
||||||
self.fsm = fsm
|
|
||||||
|
|
||||||
|
|
||||||
class JSONLogitsProcessor(RegexLogitsProcessor):
|
class JSONLogitsProcessor(RegexLogitsProcessor):
|
||||||
@ -115,6 +125,12 @@ class JSONLogitsProcessor(RegexLogitsProcessor):
|
|||||||
|
|
||||||
class CFGLogitsProcessor(BaseLogitsProcessor):
|
class CFGLogitsProcessor(BaseLogitsProcessor):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@lru_cache(maxsize=32)
|
||||||
|
def _get_guide(cls, cfg: str, tokenizer: PreTrainedTokenizerBase) -> Guide:
|
||||||
|
tokenizer = _adapt_tokenizer(tokenizer)
|
||||||
|
return CFGGuide(cfg, tokenizer)
|
||||||
|
|
||||||
def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase):
|
def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase):
|
||||||
"""Compile the FSM that drives the context free grammar generation.
|
"""Compile the FSM that drives the context free grammar generation.
|
||||||
|
|
||||||
@ -126,17 +142,11 @@ class CFGLogitsProcessor(BaseLogitsProcessor):
|
|||||||
The model's tokenizer
|
The model's tokenizer
|
||||||
|
|
||||||
"""
|
"""
|
||||||
tokenizer = _adapt_tokenizer(tokenizer)
|
super().__init__(CFGLogitsProcessor._get_guide(cfg, tokenizer))
|
||||||
fsm = CFGFSM(cfg, tokenizer)
|
self._guide = self._guide.copy()
|
||||||
self.fsm = fsm
|
|
||||||
|
|
||||||
def init_state(self):
|
|
||||||
"""Initialize state with a CFGFSM copy."""
|
|
||||||
super().init_state()
|
|
||||||
self.fsm = self.fsm.copy()
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
@lru_cache(maxsize=32)
|
||||||
def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
|
def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
|
||||||
"""Adapt vLLM's tokenizer to use to compile the FSM.
|
"""Adapt vLLM's tokenizer to use to compile the FSM.
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user