Support logit bias for OpenAI API (#3027)

This commit is contained in:
Dylan Hawk 2024-02-26 19:51:53 -08:00 committed by GitHub
parent 4bd18ec0c7
commit e0ade06d63
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 83 additions and 12 deletions

View File

@ -9,6 +9,8 @@ import ray # using Ray for overall ease of process management, parallel request
import openai # use the official client for correctness check
from huggingface_hub import snapshot_download # downloading lora to test lora requests
from vllm.transformers_utils.tokenizer import get_tokenizer
MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" # any model with a chat template should work here
LORA_NAME = "typeof/zephyr-7b-beta-lora" # technically this needs Mistral-7B-v0.1 as base, but we're not testing generation quality here
@ -310,5 +312,51 @@ async def test_batch_completions(server, client: openai.AsyncOpenAI,
assert texts[0] == texts[1]
async def test_logits_bias(server, client: openai.AsyncOpenAI):
prompt = "Hello, my name is"
max_tokens = 5
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
# Test exclusive selection
token_id = 1000
completion = await client.completions.create(
model=MODEL_NAME,
prompt=prompt,
max_tokens=max_tokens,
temperature=0.0,
logit_bias={str(token_id): 100},
)
assert completion.choices[0].text is not None and len(
completion.choices[0].text) >= 5
response_tokens = tokenizer(completion.choices[0].text,
add_special_tokens=False)["input_ids"]
expected_tokens = tokenizer(tokenizer.decode([token_id] * 5),
add_special_tokens=False)["input_ids"]
assert all([
response == expected
for response, expected in zip(response_tokens, expected_tokens)
])
# Test ban
completion = await client.completions.create(
model=MODEL_NAME,
prompt=prompt,
max_tokens=max_tokens,
temperature=0.0,
)
response_tokens = tokenizer(completion.choices[0].text,
add_special_tokens=False)["input_ids"]
first_response = completion.choices[0].text
completion = await client.completions.create(
model=MODEL_NAME,
prompt=prompt,
max_tokens=max_tokens,
temperature=0.0,
logit_bias={str(token): -100
for token in response_tokens},
)
assert first_response != completion.choices[0].text
if __name__ == "__main__":
pytest.main([__file__])

View File

@ -8,6 +8,8 @@ from pydantic import BaseModel, Field
from vllm.utils import random_uuid
from vllm.sampling_params import SamplingParams
import torch
class ErrorResponse(BaseModel):
object: str = "error"
@ -88,6 +90,21 @@ class ChatCompletionRequest(BaseModel):
def to_sampling_params(self) -> SamplingParams:
if self.logprobs and not self.top_logprobs:
raise ValueError("Top logprobs must be set when logprobs is.")
logits_processors = None
if self.logit_bias:
def logit_bias_logits_processor(
token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor:
for token_id, bias in self.logit_bias.items():
# Clamp the bias between -100 and 100 per OpenAI API spec
bias = min(100, max(-100, bias))
logits[int(token_id)] += bias
return logits
logits_processors = [logit_bias_logits_processor]
return SamplingParams(
n=self.n,
presence_penalty=self.presence_penalty,
@ -111,6 +128,7 @@ class ChatCompletionRequest(BaseModel):
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty,
logits_processors=logits_processors,
)
@ -149,6 +167,20 @@ class CompletionRequest(BaseModel):
def to_sampling_params(self):
echo_without_generation = self.echo and self.max_tokens == 0
logits_processors = None
if self.logit_bias:
def logit_bias_logits_processor(
token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor:
for token_id, bias in self.logit_bias.items():
# Clamp the bias between -100 and 100 per OpenAI API spec
bias = min(100, max(-100, bias))
logits[int(token_id)] += bias
return logits
logits_processors = [logit_bias_logits_processor]
return SamplingParams(
n=self.n,
best_of=self.best_of,
@ -172,6 +204,7 @@ class CompletionRequest(BaseModel):
spaces_between_special_tokens=(self.spaces_between_special_tokens),
include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty,
logits_processors=logits_processors,
)

View File

@ -39,19 +39,13 @@ class OpenAIServingChat(OpenAIServing):
See https://platform.openai.com/docs/api-reference/chat/create
for the API specification. This API mimics the OpenAI ChatCompletion API.
NOTE: Currently we do not support the following features:
NOTE: Currently we do not support the following feature:
- function_call (Users should implement this by themselves)
- logit_bias (to be supported by vLLM engine)
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
if request.logit_bias is not None and len(request.logit_bias) > 0:
# TODO: support logit_bias in vLLM engine.
return self.create_error_response(
"logit_bias is not currently supported")
try:
prompt = self.tokenizer.apply_chat_template(
conversation=request.messages,

View File

@ -264,10 +264,9 @@ class OpenAIServingCompletion(OpenAIServing):
See https://platform.openai.com/docs/api-reference/completions/create
for the API specification. This API mimics the OpenAI Completion API.
NOTE: Currently we do not support the following features:
NOTE: Currently we do not support the following feature:
- suffix (the language models we currently support do not support
suffix)
- logit_bias (to be supported by vLLM engine)
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
@ -277,9 +276,6 @@ class OpenAIServingCompletion(OpenAIServing):
if request.suffix is not None:
return self.create_error_response(
"suffix is not currently supported")
if request.logit_bias is not None and len(request.logit_bias) > 0:
return self.create_error_response(
"logit_bias is not currently supported")
model_name = request.model
request_id = f"cmpl-{random_uuid()}"