mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 15:05:45 +08:00
Add guided decoding for OpenAI API server (#2819)
Co-authored-by: br3no <breno@veltefaria.de> Co-authored-by: simon-mo <simon.mo@hey.com>
This commit is contained in:
parent
29a8d6a554
commit
703e42ee4b
@ -12,4 +12,5 @@ pydantic >= 2.0 # Required for OpenAI server.
|
|||||||
prometheus_client >= 0.18.0
|
prometheus_client >= 0.18.0
|
||||||
pynvml == 11.5.0
|
pynvml == 11.5.0
|
||||||
triton >= 2.1.0
|
triton >= 2.1.0
|
||||||
|
outlines >= 0.0.27
|
||||||
cupy-cuda12x == 12.1.0 # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead.
|
cupy-cuda12x == 12.1.0 # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead.
|
||||||
|
|||||||
75
tests/entrypoints/test_guided_processors.py
Normal file
75
tests/entrypoints/test_guided_processors.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
# This unit test should be moved to a new
|
||||||
|
# tests/test_guided_decoding directory.
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.guided_logits_processors import (RegexLogitsProcessor,
|
||||||
|
JSONLogitsProcessor)
|
||||||
|
|
||||||
|
TEST_SCHEMA = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"age": {
|
||||||
|
"type": "integer"
|
||||||
|
},
|
||||||
|
"skills": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string",
|
||||||
|
"maxLength": 10
|
||||||
|
},
|
||||||
|
"minItems": 3
|
||||||
|
},
|
||||||
|
"work history": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"company": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"duration": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"position": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["company", "position"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["name", "age", "skills", "work history"]
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_REGEX = r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" + \
|
||||||
|
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)"
|
||||||
|
|
||||||
|
|
||||||
|
def test_guided_logits_processors():
|
||||||
|
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
|
||||||
|
regex_LP = RegexLogitsProcessor(TEST_REGEX, tokenizer)
|
||||||
|
json_LP = JSONLogitsProcessor(TEST_SCHEMA, tokenizer)
|
||||||
|
|
||||||
|
regex_LP.init_state()
|
||||||
|
token_ids = tokenizer.encode(
|
||||||
|
f"Give an example IPv4 address with this regex: {TEST_REGEX}")
|
||||||
|
tensor = torch.rand(32000)
|
||||||
|
original_tensor = torch.clone(tensor)
|
||||||
|
regex_LP(token_ids, tensor)
|
||||||
|
assert tensor.shape == original_tensor.shape
|
||||||
|
assert not torch.allclose(tensor, original_tensor)
|
||||||
|
|
||||||
|
json_LP.init_state()
|
||||||
|
token_ids = tokenizer.encode(
|
||||||
|
f"Give an employee profile that fits this schema: {TEST_SCHEMA}")
|
||||||
|
tensor = torch.rand(32000)
|
||||||
|
original_tensor = torch.clone(tensor)
|
||||||
|
json_LP(token_ids, tensor)
|
||||||
|
assert tensor.shape == original_tensor.shape
|
||||||
|
assert not torch.allclose(tensor, original_tensor)
|
||||||
@ -9,12 +9,64 @@ import ray # using Ray for overall ease of process management, parallel request
|
|||||||
import openai # use the official client for correctness check
|
import openai # use the official client for correctness check
|
||||||
from huggingface_hub import snapshot_download # downloading lora to test lora requests
|
from huggingface_hub import snapshot_download # downloading lora to test lora requests
|
||||||
|
|
||||||
|
# imports for guided decoding tests
|
||||||
|
import json
|
||||||
|
import jsonschema
|
||||||
|
import re
|
||||||
|
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
|
MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
|
||||||
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" # any model with a chat template should work here
|
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" # any model with a chat template should work here
|
||||||
LORA_NAME = "typeof/zephyr-7b-beta-lora" # technically this needs Mistral-7B-v0.1 as base, but we're not testing generation quality here
|
LORA_NAME = "typeof/zephyr-7b-beta-lora" # technically this needs Mistral-7B-v0.1 as base, but we're not testing generation quality here
|
||||||
|
|
||||||
|
TEST_SCHEMA = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"age": {
|
||||||
|
"type": "integer"
|
||||||
|
},
|
||||||
|
"skills": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string",
|
||||||
|
"maxLength": 10
|
||||||
|
},
|
||||||
|
"minItems": 3
|
||||||
|
},
|
||||||
|
"work history": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"company": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"duration": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"position": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["company", "position"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["name", "age", "skills", "work history"]
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_REGEX = r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" + \
|
||||||
|
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)"
|
||||||
|
|
||||||
|
TEST_CHOICE = [
|
||||||
|
"Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript", "Ruby",
|
||||||
|
"Swift", "Kotlin"
|
||||||
|
]
|
||||||
|
|
||||||
pytestmark = pytest.mark.asyncio
|
pytestmark = pytest.mark.asyncio
|
||||||
|
|
||||||
|
|
||||||
@ -325,6 +377,7 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI):
|
|||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
logit_bias={str(token_id): 100},
|
logit_bias={str(token_id): 100},
|
||||||
|
seed=42,
|
||||||
)
|
)
|
||||||
assert completion.choices[0].text is not None and len(
|
assert completion.choices[0].text is not None and len(
|
||||||
completion.choices[0].text) >= 5
|
completion.choices[0].text) >= 5
|
||||||
@ -358,5 +411,189 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI):
|
|||||||
assert first_response != completion.choices[0].text
|
assert first_response != completion.choices[0].text
|
||||||
|
|
||||||
|
|
||||||
|
async def test_guided_json_completion(server, client: openai.AsyncOpenAI):
|
||||||
|
completion = await client.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
prompt=
|
||||||
|
f"Give an example JSON for an employee profile that fits this schema: {TEST_SCHEMA}",
|
||||||
|
n=3,
|
||||||
|
temperature=1.0,
|
||||||
|
max_tokens=500,
|
||||||
|
extra_body=dict(guided_json=TEST_SCHEMA))
|
||||||
|
|
||||||
|
assert completion.id is not None
|
||||||
|
assert completion.choices is not None and len(completion.choices) == 3
|
||||||
|
for i in range(3):
|
||||||
|
assert completion.choices[i].text is not None
|
||||||
|
output_json = json.loads(completion.choices[i].text)
|
||||||
|
jsonschema.validate(instance=output_json, schema=TEST_SCHEMA)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
|
||||||
|
messages = [{
|
||||||
|
"role": "system",
|
||||||
|
"content": "you are a helpful assistant"
|
||||||
|
}, {
|
||||||
|
"role": "user",
|
||||||
|
"content": "Give an example JSON for an employee profile that " + \
|
||||||
|
f"fits this schema: {TEST_SCHEMA}"
|
||||||
|
}]
|
||||||
|
chat_completion = await client.chat.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=500,
|
||||||
|
extra_body=dict(guided_json=TEST_SCHEMA))
|
||||||
|
message = chat_completion.choices[0].message
|
||||||
|
assert message.content is not None
|
||||||
|
json1 = json.loads(message.content)
|
||||||
|
jsonschema.validate(instance=json1, schema=TEST_SCHEMA)
|
||||||
|
|
||||||
|
messages.append({"role": "assistant", "content": message.content})
|
||||||
|
messages.append({
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content":
|
||||||
|
"Give me another one with a different name and age"
|
||||||
|
})
|
||||||
|
chat_completion = await client.chat.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=500,
|
||||||
|
extra_body=dict(guided_json=TEST_SCHEMA))
|
||||||
|
message = chat_completion.choices[0].message
|
||||||
|
assert message.content is not None
|
||||||
|
json2 = json.loads(message.content)
|
||||||
|
jsonschema.validate(instance=json2, schema=TEST_SCHEMA)
|
||||||
|
assert json1["name"] != json2["name"]
|
||||||
|
assert json1["age"] != json2["age"]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_guided_regex_completion(server, client: openai.AsyncOpenAI):
|
||||||
|
completion = await client.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
prompt=f"Give an example IPv4 address with this regex: {TEST_REGEX}",
|
||||||
|
n=3,
|
||||||
|
temperature=1.0,
|
||||||
|
max_tokens=20,
|
||||||
|
extra_body=dict(guided_regex=TEST_REGEX))
|
||||||
|
|
||||||
|
assert completion.id is not None
|
||||||
|
assert completion.choices is not None and len(completion.choices) == 3
|
||||||
|
for i in range(3):
|
||||||
|
assert completion.choices[i].text is not None
|
||||||
|
assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_guided_regex_chat(server, client: openai.AsyncOpenAI):
|
||||||
|
messages = [{
|
||||||
|
"role": "system",
|
||||||
|
"content": "you are a helpful assistant"
|
||||||
|
}, {
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content":
|
||||||
|
f"Give an example IP address with this regex: {TEST_REGEX}"
|
||||||
|
}]
|
||||||
|
chat_completion = await client.chat.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=20,
|
||||||
|
extra_body=dict(guided_regex=TEST_REGEX))
|
||||||
|
ip1 = chat_completion.choices[0].message.content
|
||||||
|
assert ip1 is not None
|
||||||
|
assert re.fullmatch(TEST_REGEX, ip1) is not None
|
||||||
|
|
||||||
|
messages.append({"role": "assistant", "content": ip1})
|
||||||
|
messages.append({"role": "user", "content": "Give me a different one"})
|
||||||
|
chat_completion = await client.chat.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=20,
|
||||||
|
extra_body=dict(guided_regex=TEST_REGEX))
|
||||||
|
ip2 = chat_completion.choices[0].message.content
|
||||||
|
assert ip2 is not None
|
||||||
|
assert re.fullmatch(TEST_REGEX, ip2) is not None
|
||||||
|
assert ip1 != ip2
|
||||||
|
|
||||||
|
|
||||||
|
async def test_guided_choice_completion(server, client: openai.AsyncOpenAI):
|
||||||
|
completion = await client.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
prompt="The best language for type-safe systems programming is ",
|
||||||
|
n=2,
|
||||||
|
temperature=1.0,
|
||||||
|
max_tokens=10,
|
||||||
|
extra_body=dict(guided_choice=TEST_CHOICE))
|
||||||
|
|
||||||
|
assert completion.id is not None
|
||||||
|
assert completion.choices is not None and len(completion.choices) == 2
|
||||||
|
for i in range(2):
|
||||||
|
assert completion.choices[i].text in TEST_CHOICE
|
||||||
|
|
||||||
|
|
||||||
|
async def test_guided_choice_chat(server, client: openai.AsyncOpenAI):
|
||||||
|
messages = [{
|
||||||
|
"role": "system",
|
||||||
|
"content": "you are a helpful assistant"
|
||||||
|
}, {
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content":
|
||||||
|
"The best language for type-safe systems programming is "
|
||||||
|
}]
|
||||||
|
chat_completion = await client.chat.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=10,
|
||||||
|
extra_body=dict(guided_choice=TEST_CHOICE))
|
||||||
|
choice1 = chat_completion.choices[0].message.content
|
||||||
|
assert choice1 in TEST_CHOICE
|
||||||
|
|
||||||
|
messages.append({"role": "assistant", "content": choice1})
|
||||||
|
messages.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": "I disagree, pick another one"
|
||||||
|
})
|
||||||
|
chat_completion = await client.chat.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=10,
|
||||||
|
extra_body=dict(guided_choice=TEST_CHOICE))
|
||||||
|
choice2 = chat_completion.choices[0].message.content
|
||||||
|
assert choice2 in TEST_CHOICE
|
||||||
|
assert choice1 != choice2
|
||||||
|
|
||||||
|
|
||||||
|
async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI):
|
||||||
|
with pytest.raises(openai.BadRequestError):
|
||||||
|
_ = await client.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
prompt="Give an example JSON that fits this schema: 42",
|
||||||
|
extra_body=dict(guided_json=42))
|
||||||
|
|
||||||
|
messages = [{
|
||||||
|
"role": "system",
|
||||||
|
"content": "you are a helpful assistant"
|
||||||
|
}, {
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content":
|
||||||
|
"The best language for type-safe systems programming is "
|
||||||
|
}]
|
||||||
|
with pytest.raises(openai.BadRequestError):
|
||||||
|
_ = await client.chat.completions.create(model=MODEL_NAME,
|
||||||
|
messages=messages,
|
||||||
|
extra_body=dict(guided_regex={
|
||||||
|
1: "Python",
|
||||||
|
2: "C++"
|
||||||
|
}))
|
||||||
|
|
||||||
|
with pytest.raises(openai.BadRequestError):
|
||||||
|
_ = await client.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
prompt="Give an example string that fits this regex",
|
||||||
|
extra_body=dict(guided_regex=TEST_REGEX, guided_json=TEST_SCHEMA))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pytest.main([__file__])
|
pytest.main([__file__])
|
||||||
|
|||||||
@ -333,6 +333,9 @@ class AsyncLLMEngine:
|
|||||||
return (self.background_loop is not None
|
return (self.background_loop is not None
|
||||||
and not self.background_loop.done())
|
and not self.background_loop.done())
|
||||||
|
|
||||||
|
def get_tokenizer(self):
|
||||||
|
return self.engine.tokenizer.tokenizer
|
||||||
|
|
||||||
def start_background_loop(self) -> None:
|
def start_background_loop(self) -> None:
|
||||||
"""Start the background loop."""
|
"""Start the background loop."""
|
||||||
if self.is_running:
|
if self.is_running:
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
import time
|
import time
|
||||||
from typing import Dict, List, Literal, Optional, Union
|
from typing import Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
from vllm.utils import random_uuid
|
from vllm.utils import random_uuid
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
@ -86,6 +86,9 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
min_p: Optional[float] = 0.0
|
min_p: Optional[float] = 0.0
|
||||||
include_stop_str_in_output: Optional[bool] = False
|
include_stop_str_in_output: Optional[bool] = False
|
||||||
length_penalty: Optional[float] = 1.0
|
length_penalty: Optional[float] = 1.0
|
||||||
|
guided_json: Optional[Union[str, dict, BaseModel]] = None
|
||||||
|
guided_regex: Optional[str] = None
|
||||||
|
guided_choice: Optional[List[str]] = None
|
||||||
|
|
||||||
def to_sampling_params(self) -> SamplingParams:
|
def to_sampling_params(self) -> SamplingParams:
|
||||||
if self.logprobs and not self.top_logprobs:
|
if self.logprobs and not self.top_logprobs:
|
||||||
@ -131,6 +134,20 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
logits_processors=logits_processors,
|
logits_processors=logits_processors,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_guided_decoding_count(cls, data):
|
||||||
|
guide_count = sum([
|
||||||
|
"guided_json" in data and data["guided_json"] is not None,
|
||||||
|
"guided_regex" in data and data["guided_regex"] is not None,
|
||||||
|
"guided_choice" in data and data["guided_choice"] is not None
|
||||||
|
])
|
||||||
|
if guide_count > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"You can only use one kind of guided decoding "
|
||||||
|
"('guided_json', 'guided_regex' or 'guided_choice').")
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class CompletionRequest(BaseModel):
|
class CompletionRequest(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
@ -163,6 +180,9 @@ class CompletionRequest(BaseModel):
|
|||||||
min_p: Optional[float] = 0.0
|
min_p: Optional[float] = 0.0
|
||||||
include_stop_str_in_output: Optional[bool] = False
|
include_stop_str_in_output: Optional[bool] = False
|
||||||
length_penalty: Optional[float] = 1.0
|
length_penalty: Optional[float] = 1.0
|
||||||
|
guided_json: Optional[Union[str, dict, BaseModel]] = None
|
||||||
|
guided_regex: Optional[str] = None
|
||||||
|
guided_choice: Optional[List[str]] = None
|
||||||
|
|
||||||
def to_sampling_params(self):
|
def to_sampling_params(self):
|
||||||
echo_without_generation = self.echo and self.max_tokens == 0
|
echo_without_generation = self.echo and self.max_tokens == 0
|
||||||
@ -207,6 +227,20 @@ class CompletionRequest(BaseModel):
|
|||||||
logits_processors=logits_processors,
|
logits_processors=logits_processors,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_guided_decoding_count(cls, data):
|
||||||
|
guide_count = sum([
|
||||||
|
"guided_json" in data and data["guided_json"] is not None,
|
||||||
|
"guided_regex" in data and data["guided_regex"] is not None,
|
||||||
|
"guided_choice" in data and data["guided_choice"] is not None
|
||||||
|
])
|
||||||
|
if guide_count > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"You can only use one kind of guided decoding "
|
||||||
|
"('guided_json', 'guided_regex' or 'guided_choice').")
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class LogProbs(BaseModel):
|
class LogProbs(BaseModel):
|
||||||
text_offset: List[int] = Field(default_factory=list)
|
text_offset: List[int] = Field(default_factory=list)
|
||||||
|
|||||||
@ -12,6 +12,7 @@ from vllm.entrypoints.openai.protocol import (
|
|||||||
UsageInfo)
|
UsageInfo)
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA
|
from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA
|
||||||
|
from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -62,6 +63,14 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
prompt=prompt)
|
prompt=prompt)
|
||||||
sampling_params = request.to_sampling_params()
|
sampling_params = request.to_sampling_params()
|
||||||
lora_request = self._maybe_get_lora(request)
|
lora_request = self._maybe_get_lora(request)
|
||||||
|
guided_decode_logits_processor = (
|
||||||
|
await get_guided_decoding_logits_processor(
|
||||||
|
request, self.engine.get_tokenizer()))
|
||||||
|
if guided_decode_logits_processor:
|
||||||
|
if sampling_params.logits_processors is None:
|
||||||
|
sampling_params.logits_processors = []
|
||||||
|
sampling_params.logits_processors.append(
|
||||||
|
guided_decode_logits_processor)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
|
|||||||
@ -16,6 +16,7 @@ from .protocol import (
|
|||||||
)
|
)
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA
|
from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA
|
||||||
|
from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -286,6 +287,14 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
try:
|
try:
|
||||||
sampling_params = request.to_sampling_params()
|
sampling_params = request.to_sampling_params()
|
||||||
lora_request = self._maybe_get_lora(request)
|
lora_request = self._maybe_get_lora(request)
|
||||||
|
guided_decode_logit_processor = (
|
||||||
|
await get_guided_decoding_logits_processor(
|
||||||
|
request, self.engine.get_tokenizer()))
|
||||||
|
if guided_decode_logit_processor is not None:
|
||||||
|
if sampling_params.logits_processors is None:
|
||||||
|
sampling_params.logits_processors = []
|
||||||
|
sampling_params.logits_processors.append(
|
||||||
|
guided_decode_logit_processor)
|
||||||
prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
|
prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
|
||||||
|
|
||||||
for i, prompt in enumerate(prompts):
|
for i, prompt in enumerate(prompts):
|
||||||
|
|||||||
99
vllm/model_executor/guided_decoding.py
Normal file
99
vllm/model_executor/guided_decoding.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
import asyncio
|
||||||
|
import concurrent.futures
|
||||||
|
from copy import copy
|
||||||
|
from enum import Enum
|
||||||
|
from functools import lru_cache
|
||||||
|
from json import dumps as json_dumps
|
||||||
|
from re import escape as regex_escape
|
||||||
|
from typing import Union, Tuple
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from vllm.entrypoints.openai.protocol import CompletionRequest, ChatCompletionRequest
|
||||||
|
from vllm.model_executor.guided_logits_processors import JSONLogitsProcessor, RegexLogitsProcessor
|
||||||
|
|
||||||
|
|
||||||
|
class GuidedDecodingMode(Enum):
|
||||||
|
JSON = "json"
|
||||||
|
REGEX = "regex"
|
||||||
|
CHOICE = "choice"
|
||||||
|
|
||||||
|
|
||||||
|
global_thread_pool = None # used for generating logits processor fsm
|
||||||
|
|
||||||
|
|
||||||
|
async def get_guided_decoding_logits_processor(
|
||||||
|
request: Union[CompletionRequest, ChatCompletionRequest],
|
||||||
|
tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]:
|
||||||
|
"""
|
||||||
|
Given an OpenAI-compatible request, check for guided decoding parameters
|
||||||
|
and get the necessary logits processor for the given guide.
|
||||||
|
We cache logit processors by (guide, tokenizer), and on cache hit
|
||||||
|
we make a shallow copy to reuse the same underlying FSM.
|
||||||
|
"""
|
||||||
|
global global_thread_pool
|
||||||
|
guide, mode = _get_guide_and_mode(request)
|
||||||
|
if not guide:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if global_thread_pool is None:
|
||||||
|
global_thread_pool = concurrent.futures.ThreadPoolExecutor(
|
||||||
|
max_workers=2)
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
|
result = await loop.run_in_executor(global_thread_pool,
|
||||||
|
_get_cached_logits_processor, guide,
|
||||||
|
tokenizer, mode)
|
||||||
|
|
||||||
|
logits_processor = copy(result)
|
||||||
|
# reset logits processor's internal state
|
||||||
|
logits_processor.init_state()
|
||||||
|
return logits_processor
|
||||||
|
|
||||||
|
|
||||||
|
def _get_guide_and_mode(
|
||||||
|
request: Union[CompletionRequest, ChatCompletionRequest]
|
||||||
|
) -> Tuple[str, GuidedDecodingMode]:
|
||||||
|
|
||||||
|
if request.guided_json:
|
||||||
|
if not isinstance(request.guided_json, (str, dict, BaseModel)):
|
||||||
|
raise TypeError("JSON schema must be str, dict, or BaseModel")
|
||||||
|
|
||||||
|
json = request.guided_json
|
||||||
|
if isinstance(json, dict):
|
||||||
|
# turn dict into hashable string
|
||||||
|
json = json_dumps(json, sort_keys=True)
|
||||||
|
elif isinstance(json, BaseModel):
|
||||||
|
# use pydantic signature so that different model classes
|
||||||
|
# with the same fields will get hashed the same
|
||||||
|
json = str(json.__signature__)
|
||||||
|
return json, GuidedDecodingMode.JSON
|
||||||
|
|
||||||
|
elif request.guided_regex:
|
||||||
|
if not isinstance(request.guided_regex, str):
|
||||||
|
raise TypeError("Regex must be string")
|
||||||
|
return request.guided_regex, GuidedDecodingMode.REGEX
|
||||||
|
|
||||||
|
elif request.guided_choice:
|
||||||
|
if not isinstance(request.guided_choice, list):
|
||||||
|
raise TypeError("Choices must be a list")
|
||||||
|
|
||||||
|
# choice just uses regex
|
||||||
|
choices = [
|
||||||
|
regex_escape(str(choice)) for choice in request.guided_choice
|
||||||
|
]
|
||||||
|
choices_regex = "(" + "|".join(choices) + ")"
|
||||||
|
return choices_regex, GuidedDecodingMode.CHOICE
|
||||||
|
|
||||||
|
else:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=32)
|
||||||
|
def _get_cached_logits_processor(guide: str, tokenizer,
|
||||||
|
mode: GuidedDecodingMode):
|
||||||
|
if mode == GuidedDecodingMode.JSON:
|
||||||
|
return JSONLogitsProcessor(guide, tokenizer)
|
||||||
|
elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
|
||||||
|
return RegexLogitsProcessor(guide, tokenizer)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown guided decoding mode {mode}")
|
||||||
129
vllm/model_executor/guided_logits_processors.py
Normal file
129
vllm/model_executor/guided_logits_processors.py
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
# Copyright 2024- the Outlines developers
|
||||||
|
# This file is adapted from
|
||||||
|
# https://github.com/outlines-dev/outlines/blob/main/outlines/serve/vllm.py
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Union, DefaultDict, Dict, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from outlines.fsm.fsm import RegexFSM
|
||||||
|
from outlines.fsm.json_schema import build_regex_from_schema
|
||||||
|
|
||||||
|
|
||||||
|
class RegexLogitsProcessor:
|
||||||
|
|
||||||
|
def __init__(self, regex_string: str, tokenizer):
|
||||||
|
"""Compile the FSM that drives the regex-structured generation.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
regex_string
|
||||||
|
A string that represents a regular expression
|
||||||
|
tokenizer
|
||||||
|
The model's tokenizer
|
||||||
|
|
||||||
|
"""
|
||||||
|
tokenizer = self.adapt_tokenizer(tokenizer)
|
||||||
|
fsm = RegexFSM(regex_string, tokenizer)
|
||||||
|
self.fsm = fsm
|
||||||
|
|
||||||
|
def init_state(self):
|
||||||
|
"""Initialize the FSM states."""
|
||||||
|
self.fsm_state: DefaultDict[int, int] = defaultdict(int)
|
||||||
|
|
||||||
|
def __call__(self, input_ids: List[int],
|
||||||
|
scores: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Use the FSM to bias the logits before sampling the next token."""
|
||||||
|
|
||||||
|
seq_id = hash(tuple(input_ids))
|
||||||
|
|
||||||
|
if len(input_ids) == 0:
|
||||||
|
self.init_state()
|
||||||
|
else:
|
||||||
|
last_token = input_ids[-1]
|
||||||
|
last_seq_id = hash(tuple(input_ids[:-1]))
|
||||||
|
self.fsm_state[seq_id] = self.fsm.next_state(
|
||||||
|
self.fsm_state[last_seq_id], last_token)
|
||||||
|
|
||||||
|
allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id])
|
||||||
|
|
||||||
|
mask = torch.full((scores.shape[-1], ),
|
||||||
|
-math.inf,
|
||||||
|
device=scores.device)
|
||||||
|
mask[allowed_tokens] = 0
|
||||||
|
scores.add_(mask)
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
def adapt_tokenizer(self, tokenizer):
|
||||||
|
"""Adapt vLLM's tokenizer to use to compile the FSM.
|
||||||
|
|
||||||
|
The API of Outlines tokenizers is slightly different to that of
|
||||||
|
`transformers`. In addition we need to handle the missing spaces to
|
||||||
|
Llama's tokenizer to be able to compile FSMs for this model.
|
||||||
|
|
||||||
|
"""
|
||||||
|
tokenizer.vocabulary = tokenizer.get_vocab()
|
||||||
|
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
|
||||||
|
|
||||||
|
def convert_token_to_string(token: str) -> str:
|
||||||
|
from transformers.file_utils import SPIECE_UNDERLINE
|
||||||
|
|
||||||
|
string = tokenizer.convert_tokens_to_string([token])
|
||||||
|
|
||||||
|
# A hack to handle missing spaces to HF's Llama tokenizers
|
||||||
|
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
|
||||||
|
return " " + string
|
||||||
|
|
||||||
|
return string
|
||||||
|
|
||||||
|
tokenizer.convert_token_to_string = convert_token_to_string
|
||||||
|
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class JSONLogitsProcessor(RegexLogitsProcessor):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
schema: Union[str, Dict, BaseModel],
|
||||||
|
tokenizer,
|
||||||
|
whitespace_pattern: Optional[str] = None):
|
||||||
|
"""Compile the FSM that drives the JSON-guided generation.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
schema
|
||||||
|
A JSON schema that encodes the structure we want the model to generate
|
||||||
|
tokenizer
|
||||||
|
The model's tokenizer
|
||||||
|
whitespace_pattern
|
||||||
|
Pattern to use for JSON syntactic whitespace (doesn't impact string literals)
|
||||||
|
Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"`
|
||||||
|
"""
|
||||||
|
if isinstance(schema, type(BaseModel)):
|
||||||
|
schema_str = json.dumps(schema.model_json_schema())
|
||||||
|
elif isinstance(schema, Dict):
|
||||||
|
schema_str = json.dumps(schema)
|
||||||
|
elif isinstance(schema, str):
|
||||||
|
schema_str = schema
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot parse schema {schema}. The schema must be either " +
|
||||||
|
"a Pydantic object, a dictionary or a string that contains the JSON "
|
||||||
|
+ "Schema specification")
|
||||||
|
regex_string = build_regex_from_schema(schema_str, whitespace_pattern)
|
||||||
|
super().__init__(regex_string, tokenizer)
|
||||||
Loading…
x
Reference in New Issue
Block a user