vllm/vllm/logits_process.py
Russell Bryant e489ad7a21
[Misc] Add SPDX-License-Identifier headers to python source files (#12628)
- **Add SPDX license headers to python source files**
- **Check for SPDX headers using pre-commit**

commit 9d7ef44c3cfb72ca4c32e1c677d99259d10d4745
Author: Russell Bryant <rbryant@redhat.com>
Date:   Fri Jan 31 14:18:24 2025 -0500

    Add SPDX license headers to python source files
    
This commit adds SPDX license headers to python source files as
recommended to
the project by the Linux Foundation. These headers provide a concise way
that is
both human and machine readable for communicating license information
for each
source file. It helps avoid any ambiguity about the license of the code
and can
    also be easily used by tools to help manage license compliance.
    
The Linux Foundation runs license scans against the codebase to help
ensure
    we are in compliance with the licenses of the code we use, including
dependencies. Having these headers in place helps that tool do its job.
    
    More information can be found on the SPDX site:
    
    - https://spdx.dev/learn/handling-license-info/
    
    Signed-off-by: Russell Bryant <rbryant@redhat.com>

commit 5a1cf1cb3b80759131c73f6a9dddebccac039dea
Author: Russell Bryant <rbryant@redhat.com>
Date:   Fri Jan 31 14:36:32 2025 -0500

    Check for SPDX headers using pre-commit
    
    Signed-off-by: Russell Bryant <rbryant@redhat.com>

---------

Signed-off-by: Russell Bryant <rbryant@redhat.com>
2025-02-02 11:58:18 -08:00

122 lines
4.6 KiB
Python

# SPDX-License-Identifier: Apache-2.0
from typing import Callable, List, Tuple, Union
import torch
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
LogitsProcessor = Union[Callable[[List[int], torch.Tensor], torch.Tensor],
Callable[[List[int], List[int], torch.Tensor],
torch.Tensor]]
"""LogitsProcessor is a function that takes a list
of previously generated tokens, the logits tensor
for the next token and, optionally, prompt tokens as a
first argument, and returns a modified tensor of logits
to sample from."""
def get_bad_words_logits_processors(
bad_words: List[str],
tokenizer: AnyTokenizer) -> List[LogitsProcessor]:
bad_words_ids: List[List[int]] = list()
for bad_word in bad_words:
# To prohibit words both at the beginning
# and in the middle of text
# (related to add_prefix_space tokenizer parameter)
for add_prefix_space in [False, True]:
prefix = " " if add_prefix_space else ""
prompt = prefix + bad_word.lstrip()
if isinstance(tokenizer, MistralTokenizer):
# Mistral tokenizers should not add special tokens
prompt_token_ids = tokenizer.encode(prompt=prompt)
else:
prompt_token_ids = tokenizer.encode(text=prompt,
add_special_tokens=False)
# If no space at the beginning
# or if prefix space produces a new word token
if (not add_prefix_space) or (
add_prefix_space
and prompt_token_ids[0] != bad_words_ids[-1][0]
and len(prompt_token_ids) == len(bad_words_ids[-1])):
bad_words_ids.append(prompt_token_ids)
return [NoBadWordsLogitsProcessor(bad_words_ids=bad_words_ids)]
class NoBadWordsLogitsProcessor:
_SMALLEST_LOGIT = float("-inf")
_NEUTRAL_LOGIT = 0.0
def __init__(self, bad_words_ids: List[List[int]]):
self.bad_words_ids = bad_words_ids
self.word_bias: torch.FloatTensor = None
def __call__(
self,
past_tokens_ids: Union[List[int], Tuple[int]],
logits: torch.FloatTensor,
) -> torch.Tensor:
if self.word_bias is None:
self._init_word_bias(logits=logits)
last_token_bias = torch.zeros_like(logits)
for bad_word_ids in self.bad_words_ids:
if len(bad_word_ids) == 1: # 1-token words already processed
continue
if len(bad_word_ids) > len(past_tokens_ids) + 1:
continue
prefix_length = len(bad_word_ids) - 1
last_token_id = bad_word_ids[-1]
actual_prefix = past_tokens_ids[-prefix_length:]
expected_prefix = bad_word_ids[:prefix_length]
assert len(actual_prefix) == len(expected_prefix)
is_match = tuple(actual_prefix) == tuple(expected_prefix)
last_token_bias[last_token_id] += (self._SMALLEST_LOGIT if is_match
else self._NEUTRAL_LOGIT)
logits = logits + self.word_bias + last_token_bias
return logits
def _init_word_bias(self, logits: torch.FloatTensor) -> None:
# Code based on NoBadWordsLogitsProcessor and SequenceBiasLogitsProcessor # noqa: E501
# from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py
vocab_size = logits.shape[-1]
self._check_token_ids_bounds(vocab_size=vocab_size)
self.word_bias = torch.zeros((vocab_size, ),
dtype=torch.float,
device=logits.device)
for bad_word_ids in self.bad_words_ids:
if len(bad_word_ids) == 1:
bad_word_id = bad_word_ids[-1]
self.word_bias[bad_word_id] = self._SMALLEST_LOGIT
def _check_token_ids_bounds(self, vocab_size: int) -> None:
invalid_token_ids = []
for bad_word_ids in self.bad_words_ids:
for token_id in bad_word_ids:
if token_id < 0 or token_id >= vocab_size:
invalid_token_ids.append(token_id)
if len(invalid_token_ids) > 0:
raise ValueError(
f"The model vocabulary size is {vocab_size},"
f" but the following tokens"
f" were specified as bad: {invalid_token_ids}."
f" All token id values should be integers satisfying:"
f" 0 <= token_id < {vocab_size}.")