mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:55:01 +08:00
Support logit bias for OpenAI API (#3027)
This commit is contained in:
parent
4bd18ec0c7
commit
e0ade06d63
@ -9,6 +9,8 @@ import ray # using Ray for overall ease of process management, parallel request
|
||||
import openai # use the official client for correctness check
|
||||
from huggingface_hub import snapshot_download # downloading lora to test lora requests
|
||||
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
|
||||
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" # any model with a chat template should work here
|
||||
LORA_NAME = "typeof/zephyr-7b-beta-lora" # technically this needs Mistral-7B-v0.1 as base, but we're not testing generation quality here
|
||||
@ -310,5 +312,51 @@ async def test_batch_completions(server, client: openai.AsyncOpenAI,
|
||||
assert texts[0] == texts[1]
|
||||
|
||||
|
||||
async def test_logits_bias(server, client: openai.AsyncOpenAI):
|
||||
prompt = "Hello, my name is"
|
||||
max_tokens = 5
|
||||
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
|
||||
|
||||
# Test exclusive selection
|
||||
token_id = 1000
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
temperature=0.0,
|
||||
logit_bias={str(token_id): 100},
|
||||
)
|
||||
assert completion.choices[0].text is not None and len(
|
||||
completion.choices[0].text) >= 5
|
||||
response_tokens = tokenizer(completion.choices[0].text,
|
||||
add_special_tokens=False)["input_ids"]
|
||||
expected_tokens = tokenizer(tokenizer.decode([token_id] * 5),
|
||||
add_special_tokens=False)["input_ids"]
|
||||
assert all([
|
||||
response == expected
|
||||
for response, expected in zip(response_tokens, expected_tokens)
|
||||
])
|
||||
|
||||
# Test ban
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
temperature=0.0,
|
||||
)
|
||||
response_tokens = tokenizer(completion.choices[0].text,
|
||||
add_special_tokens=False)["input_ids"]
|
||||
first_response = completion.choices[0].text
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
temperature=0.0,
|
||||
logit_bias={str(token): -100
|
||||
for token in response_tokens},
|
||||
)
|
||||
assert first_response != completion.choices[0].text
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
|
||||
@ -8,6 +8,8 @@ from pydantic import BaseModel, Field
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
object: str = "error"
|
||||
@ -88,6 +90,21 @@ class ChatCompletionRequest(BaseModel):
|
||||
def to_sampling_params(self) -> SamplingParams:
|
||||
if self.logprobs and not self.top_logprobs:
|
||||
raise ValueError("Top logprobs must be set when logprobs is.")
|
||||
|
||||
logits_processors = None
|
||||
if self.logit_bias:
|
||||
|
||||
def logit_bias_logits_processor(
|
||||
token_ids: List[int],
|
||||
logits: torch.Tensor) -> torch.Tensor:
|
||||
for token_id, bias in self.logit_bias.items():
|
||||
# Clamp the bias between -100 and 100 per OpenAI API spec
|
||||
bias = min(100, max(-100, bias))
|
||||
logits[int(token_id)] += bias
|
||||
return logits
|
||||
|
||||
logits_processors = [logit_bias_logits_processor]
|
||||
|
||||
return SamplingParams(
|
||||
n=self.n,
|
||||
presence_penalty=self.presence_penalty,
|
||||
@ -111,6 +128,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
spaces_between_special_tokens=self.spaces_between_special_tokens,
|
||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||
length_penalty=self.length_penalty,
|
||||
logits_processors=logits_processors,
|
||||
)
|
||||
|
||||
|
||||
@ -149,6 +167,20 @@ class CompletionRequest(BaseModel):
|
||||
def to_sampling_params(self):
|
||||
echo_without_generation = self.echo and self.max_tokens == 0
|
||||
|
||||
logits_processors = None
|
||||
if self.logit_bias:
|
||||
|
||||
def logit_bias_logits_processor(
|
||||
token_ids: List[int],
|
||||
logits: torch.Tensor) -> torch.Tensor:
|
||||
for token_id, bias in self.logit_bias.items():
|
||||
# Clamp the bias between -100 and 100 per OpenAI API spec
|
||||
bias = min(100, max(-100, bias))
|
||||
logits[int(token_id)] += bias
|
||||
return logits
|
||||
|
||||
logits_processors = [logit_bias_logits_processor]
|
||||
|
||||
return SamplingParams(
|
||||
n=self.n,
|
||||
best_of=self.best_of,
|
||||
@ -172,6 +204,7 @@ class CompletionRequest(BaseModel):
|
||||
spaces_between_special_tokens=(self.spaces_between_special_tokens),
|
||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||
length_penalty=self.length_penalty,
|
||||
logits_processors=logits_processors,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -39,19 +39,13 @@ class OpenAIServingChat(OpenAIServing):
|
||||
See https://platform.openai.com/docs/api-reference/chat/create
|
||||
for the API specification. This API mimics the OpenAI ChatCompletion API.
|
||||
|
||||
NOTE: Currently we do not support the following features:
|
||||
NOTE: Currently we do not support the following feature:
|
||||
- function_call (Users should implement this by themselves)
|
||||
- logit_bias (to be supported by vLLM engine)
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
if request.logit_bias is not None and len(request.logit_bias) > 0:
|
||||
# TODO: support logit_bias in vLLM engine.
|
||||
return self.create_error_response(
|
||||
"logit_bias is not currently supported")
|
||||
|
||||
try:
|
||||
prompt = self.tokenizer.apply_chat_template(
|
||||
conversation=request.messages,
|
||||
|
||||
@ -264,10 +264,9 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
See https://platform.openai.com/docs/api-reference/completions/create
|
||||
for the API specification. This API mimics the OpenAI Completion API.
|
||||
|
||||
NOTE: Currently we do not support the following features:
|
||||
NOTE: Currently we do not support the following feature:
|
||||
- suffix (the language models we currently support do not support
|
||||
suffix)
|
||||
- logit_bias (to be supported by vLLM engine)
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
@ -277,9 +276,6 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
if request.suffix is not None:
|
||||
return self.create_error_response(
|
||||
"suffix is not currently supported")
|
||||
if request.logit_bias is not None and len(request.logit_bias) > 0:
|
||||
return self.create_error_response(
|
||||
"logit_bias is not currently supported")
|
||||
|
||||
model_name = request.model
|
||||
request_id = f"cmpl-{random_uuid()}"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user