mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-26 08:09:37 +08:00
- **Add SPDX license headers to python source files**
- **Check for SPDX headers using pre-commit**
commit 9d7ef44c3cfb72ca4c32e1c677d99259d10d4745
Author: Russell Bryant <rbryant@redhat.com>
Date: Fri Jan 31 14:18:24 2025 -0500
Add SPDX license headers to python source files
This commit adds SPDX license headers to python source files as
recommended to
the project by the Linux Foundation. These headers provide a concise way
that is
both human and machine readable for communicating license information
for each
source file. It helps avoid any ambiguity about the license of the code
and can
also be easily used by tools to help manage license compliance.
The Linux Foundation runs license scans against the codebase to help
ensure
we are in compliance with the licenses of the code we use, including
dependencies. Having these headers in place helps that tool do its job.
More information can be found on the SPDX site:
- https://spdx.dev/learn/handling-license-info/
Signed-off-by: Russell Bryant <rbryant@redhat.com>
commit 5a1cf1cb3b80759131c73f6a9dddebccac039dea
Author: Russell Bryant <rbryant@redhat.com>
Date: Fri Jan 31 14:36:32 2025 -0500
Check for SPDX headers using pre-commit
Signed-off-by: Russell Bryant <rbryant@redhat.com>
---------
Signed-off-by: Russell Bryant <rbryant@redhat.com>
145 lines
5.1 KiB
Python
145 lines
5.1 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import asyncio
|
|
import concurrent.futures
|
|
import os
|
|
from enum import Enum
|
|
from json import dumps as json_dumps
|
|
from re import escape as regex_escape
|
|
from typing import Tuple, Union
|
|
|
|
from transformers import PreTrainedTokenizerBase
|
|
|
|
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
|
|
CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
|
|
from vllm.sampling_params import GuidedDecodingParams
|
|
|
|
|
|
class GuidedDecodingMode(Enum):
|
|
JSON = "json"
|
|
REGEX = "regex"
|
|
CHOICE = "choice"
|
|
GRAMMAR = "grammar"
|
|
|
|
|
|
# https://github.com/outlines-dev/outlines/blob/main/outlines/grammars/json.lark
|
|
# the main difference is that we changed the start: value to
|
|
# start: object | array, so we are denying scalar values as the root of the
|
|
# JSON. Starting with scalars as the root seems to cause llama to generate
|
|
# without stop.
|
|
JSON_GRAMMAR = r"""
|
|
?start: object | array
|
|
|
|
?value: object
|
|
| array
|
|
| UNESCAPED_STRING
|
|
| SIGNED_NUMBER -> number
|
|
| "true" -> true
|
|
| "false" -> false
|
|
| "null" -> null
|
|
|
|
array : "[" [value ("," value)*] "]"
|
|
object : "{" [pair ("," pair)*] "}"
|
|
pair : UNESCAPED_STRING ":" value
|
|
|
|
%import common.UNESCAPED_STRING
|
|
%import common.SIGNED_NUMBER
|
|
%import common.WS
|
|
|
|
%ignore WS
|
|
"""
|
|
|
|
global_thread_pool = None # used for generating logits processor fsm
|
|
|
|
# It's not yet clear that using more provides a benefit, and it could
|
|
# potentially starve other processes on the machine. We'll cap this for now and
|
|
# adjust later if testing proves it to help overcome a bottleneck.
|
|
_MAX_THREADPOOL_WORKERS = 16
|
|
|
|
|
|
async def get_outlines_guided_decoding_logits_processor(
|
|
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase
|
|
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
|
|
None]:
|
|
"""
|
|
Given an OpenAI-compatible request, check for guided decoding parameters
|
|
and get the necessary logits processor for the given guide.
|
|
We cache logit processors by (guide, tokenizer), and on cache hit
|
|
we make a shallow copy to reuse the same underlying FSM.
|
|
"""
|
|
global global_thread_pool
|
|
guide, mode = _get_guide_and_mode(guided_params)
|
|
if not guide or not mode:
|
|
return None
|
|
|
|
if global_thread_pool is None:
|
|
max_workers = os.cpu_count() or 2
|
|
if max_workers > _MAX_THREADPOOL_WORKERS:
|
|
max_workers = _MAX_THREADPOOL_WORKERS
|
|
global_thread_pool = concurrent.futures.ThreadPoolExecutor(
|
|
max_workers=max_workers)
|
|
loop = asyncio.get_running_loop()
|
|
|
|
return await loop.run_in_executor(global_thread_pool,
|
|
_get_logits_processor, guide, tokenizer,
|
|
mode, guided_params.whitespace_pattern)
|
|
|
|
|
|
def get_local_outlines_guided_decoding_logits_processor(
|
|
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase
|
|
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
|
|
None]:
|
|
"""
|
|
Given an OpenAI-compatible request, check for guided decoding parameters
|
|
and get the necessary logits processor for the given guide.
|
|
We cache logit processors by (guide, tokenizer), and on cache hit
|
|
we make a shallow copy to reuse the same underlying FSM.
|
|
"""
|
|
guide, mode = _get_guide_and_mode(guided_params)
|
|
if not guide or not mode:
|
|
return None
|
|
|
|
return _get_logits_processor(guide, tokenizer, mode,
|
|
guided_params.whitespace_pattern)
|
|
|
|
|
|
def _get_guide_and_mode(
|
|
guided_params: GuidedDecodingParams
|
|
) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]:
|
|
if guided_params.json:
|
|
if isinstance(guided_params.json, dict):
|
|
# turn dict into hashable string
|
|
json = json_dumps(guided_params.json)
|
|
else:
|
|
json = guided_params.json
|
|
return json, GuidedDecodingMode.JSON
|
|
elif guided_params.regex:
|
|
return guided_params.regex, GuidedDecodingMode.REGEX
|
|
elif guided_params.choice:
|
|
# choice just uses regex
|
|
choices = [
|
|
regex_escape(str(choice)) for choice in guided_params.choice
|
|
]
|
|
choices_regex = "(" + "|".join(choices) + ")"
|
|
return choices_regex, GuidedDecodingMode.CHOICE
|
|
elif guided_params.grammar:
|
|
return guided_params.grammar, GuidedDecodingMode.GRAMMAR
|
|
elif guided_params.json_object:
|
|
return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR
|
|
else:
|
|
return None, None
|
|
|
|
|
|
def _get_logits_processor(
|
|
guide: str, tokenizer: PreTrainedTokenizerBase, mode: GuidedDecodingMode,
|
|
whitespace_pattern: Union[str, None]
|
|
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]:
|
|
if mode == GuidedDecodingMode.JSON:
|
|
return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern)
|
|
elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
|
|
return RegexLogitsProcessor(guide, tokenizer)
|
|
elif mode == GuidedDecodingMode.GRAMMAR:
|
|
return CFGLogitsProcessor(guide, tokenizer)
|
|
else:
|
|
raise ValueError(f"Unknown guided decoding mode {mode}")
|