mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:25:44 +08:00
[Frontend] New allowed_token_ids decoding request parameter (#6753)
This commit is contained in:
parent
9a7e2d0534
commit
9f69d8245a
@ -541,6 +541,28 @@ async def test_logits_bias(client: openai.AsyncOpenAI):
|
||||
assert first_response != completion.choices[0].text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allowed_token_ids(client: openai.AsyncOpenAI):
|
||||
prompt = "Hello, my name is"
|
||||
max_tokens = 1
|
||||
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
|
||||
|
||||
# Test exclusive selection
|
||||
allowed_ids = [21555, 21557, 21558]
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
temperature=0.0,
|
||||
seed=42,
|
||||
extra_body=dict(allowed_token_ids=allowed_ids),
|
||||
logprobs=1,
|
||||
)
|
||||
response_tokens = completion.choices[0].logprobs.tokens
|
||||
assert len(response_tokens) == 1
|
||||
assert tokenizer.convert_tokens_to_ids(response_tokens)[0] in allowed_ids
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
|
||||
74
vllm/entrypoints/openai/logits_processors.py
Normal file
74
vllm/entrypoints/openai/logits_processors.py
Normal file
@ -0,0 +1,74 @@
|
||||
from functools import lru_cache
|
||||
from typing import Dict, FrozenSet, Iterable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.sampling_params import LogitsProcessor
|
||||
|
||||
|
||||
class AllowedTokenIdsLogitsProcessor:
|
||||
"""Logits processor for constraining generated tokens to a
|
||||
specific set of token ids."""
|
||||
|
||||
def __init__(self, allowed_ids: Iterable[int]):
|
||||
self.allowed_ids: Optional[List[int]] = list(allowed_ids)
|
||||
self.mask: Optional[torch.Tensor] = None
|
||||
|
||||
def __call__(self, token_ids: List[int],
|
||||
logits: torch.Tensor) -> torch.Tensor:
|
||||
if self.mask is None:
|
||||
self.mask = torch.ones((logits.shape[-1], ),
|
||||
dtype=torch.bool,
|
||||
device=logits.device)
|
||||
self.mask[self.allowed_ids] = False
|
||||
self.allowed_ids = None
|
||||
logits.masked_fill_(self.mask, float("-inf"))
|
||||
return logits
|
||||
|
||||
|
||||
@lru_cache(maxsize=32)
|
||||
def _get_allowed_token_ids_logits_processor(
|
||||
allowed_token_ids: FrozenSet[int],
|
||||
vocab_size: int,
|
||||
) -> LogitsProcessor:
|
||||
if not allowed_token_ids:
|
||||
raise ValueError("Empty allowed_token_ids provided")
|
||||
if not all(0 <= tid < vocab_size for tid in allowed_token_ids):
|
||||
raise ValueError("allowed_token_ids contains "
|
||||
"out-of-vocab token id")
|
||||
return AllowedTokenIdsLogitsProcessor(allowed_token_ids)
|
||||
|
||||
|
||||
def get_logits_processors(
|
||||
logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]],
|
||||
allowed_token_ids: Optional[List[int]],
|
||||
tokenizer: PreTrainedTokenizer) -> List[LogitsProcessor]:
|
||||
logits_processors = []
|
||||
if logit_bias:
|
||||
try:
|
||||
# Convert token_id to integer
|
||||
# Clamp the bias between -100 and 100 per OpenAI API spec
|
||||
clamped_logit_bias: Dict[int, float] = {
|
||||
int(token_id): min(100.0, max(-100.0, bias))
|
||||
for token_id, bias in logit_bias.items()
|
||||
}
|
||||
except ValueError as exc:
|
||||
raise ValueError(
|
||||
"Found token_id in logit_bias that is not "
|
||||
"an integer or string representing an integer") from exc
|
||||
|
||||
def logit_bias_logits_processor(token_ids: List[int],
|
||||
logits: torch.Tensor) -> torch.Tensor:
|
||||
for token_id, bias in clamped_logit_bias.items():
|
||||
logits[token_id] += bias
|
||||
return logits
|
||||
|
||||
logits_processors.append(logit_bias_logits_processor)
|
||||
|
||||
if allowed_token_ids is not None:
|
||||
logits_processors.append(
|
||||
_get_allowed_token_ids_logits_processor(
|
||||
frozenset(allowed_token_ids), tokenizer.vocab_size))
|
||||
|
||||
return logits_processors
|
||||
@ -5,9 +5,11 @@ from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from transformers import PreTrainedTokenizer
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
from vllm.entrypoints.openai.logits_processors import get_logits_processors
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import random_uuid
|
||||
@ -213,30 +215,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
|
||||
# doc: end-chat-completion-extra-params
|
||||
|
||||
def to_sampling_params(self) -> SamplingParams:
|
||||
def to_sampling_params(self,
|
||||
tokenizer: PreTrainedTokenizer) -> SamplingParams:
|
||||
# We now allow logprobs being true without top_logrobs.
|
||||
|
||||
logits_processors = None
|
||||
if self.logit_bias:
|
||||
logit_bias: Dict[int, float] = {}
|
||||
try:
|
||||
for token_id, bias in self.logit_bias.items():
|
||||
# Convert token_id to integer before we add to LLMEngine
|
||||
# Clamp the bias between -100 and 100 per OpenAI API spec
|
||||
logit_bias[int(token_id)] = min(100, max(-100, bias))
|
||||
except ValueError as exc:
|
||||
raise ValueError(f"Found token_id `{token_id}` in logit_bias "
|
||||
f"but token_id must be an integer or string "
|
||||
f"representing an integer") from exc
|
||||
|
||||
def logit_bias_logits_processor(
|
||||
token_ids: List[int],
|
||||
logits: torch.Tensor) -> torch.Tensor:
|
||||
for token_id, bias in logit_bias.items():
|
||||
logits[token_id] += bias
|
||||
return logits
|
||||
|
||||
logits_processors = [logit_bias_logits_processor]
|
||||
logits_processors = get_logits_processors(
|
||||
logit_bias=self.logit_bias,
|
||||
allowed_token_ids=None,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
return SamplingParams(
|
||||
n=self.n,
|
||||
@ -358,6 +345,7 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
skip_special_tokens: bool = True
|
||||
spaces_between_special_tokens: bool = True
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||
allowed_token_ids: Optional[List[int]] = None
|
||||
# doc: end-completion-sampling-params
|
||||
|
||||
# doc: begin-completion-extra-params
|
||||
@ -407,30 +395,14 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
|
||||
# doc: end-completion-extra-params
|
||||
|
||||
def to_sampling_params(self):
|
||||
def to_sampling_params(self, tokenizer: PreTrainedTokenizer):
|
||||
echo_without_generation = self.echo and self.max_tokens == 0
|
||||
|
||||
logits_processors = None
|
||||
if self.logit_bias:
|
||||
logit_bias: Dict[int, float] = {}
|
||||
try:
|
||||
for token_id, bias in self.logit_bias.items():
|
||||
# Convert token_id to integer
|
||||
# Clamp the bias between -100 and 100 per OpenAI API spec
|
||||
logit_bias[int(token_id)] = min(100, max(-100, bias))
|
||||
except ValueError as exc:
|
||||
raise ValueError(f"Found token_id `{token_id}` in logit_bias "
|
||||
f"but token_id must be an integer or string "
|
||||
f"representing an integer") from exc
|
||||
|
||||
def logit_bias_logits_processor(
|
||||
token_ids: List[int],
|
||||
logits: torch.Tensor) -> torch.Tensor:
|
||||
for token_id, bias in logit_bias.items():
|
||||
logits[token_id] += bias
|
||||
return logits
|
||||
|
||||
logits_processors = [logit_bias_logits_processor]
|
||||
logits_processors = get_logits_processors(
|
||||
logit_bias=self.logit_bias,
|
||||
allowed_token_ids=self.allowed_token_ids,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
return SamplingParams(
|
||||
n=self.n,
|
||||
|
||||
@ -134,7 +134,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
request_id = f"chat-{random_uuid()}"
|
||||
try:
|
||||
sampling_params = request.to_sampling_params()
|
||||
sampling_params = request.to_sampling_params(tokenizer)
|
||||
decoding_config = await self.engine.get_decoding_config()
|
||||
guided_decoding_backend = request.guided_decoding_backend \
|
||||
or decoding_config.guided_decoding_backend
|
||||
|
||||
@ -95,7 +95,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
tokenizer = await self.engine.get_tokenizer(lora_request)
|
||||
|
||||
sampling_params = request.to_sampling_params()
|
||||
sampling_params = request.to_sampling_params(tokenizer)
|
||||
decoding_config = await self.engine.get_decoding_config()
|
||||
guided_decoding_backend = request.guided_decoding_backend \
|
||||
or decoding_config.guided_decoding_backend
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user