mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 12:44:56 +08:00
[V1] Support bad_words in sampler (#13376)
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
9513290032
commit
eb8b5eb183
@ -14,7 +14,7 @@ from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
|||||||
from vllm.utils import (FlexibleArgumentParser, MemorySnapshot,
|
from vllm.utils import (FlexibleArgumentParser, MemorySnapshot,
|
||||||
PlaceholderModule, StoreBoolean, bind_kv_cache,
|
PlaceholderModule, StoreBoolean, bind_kv_cache,
|
||||||
deprecate_kwargs, get_open_port, memory_profiling,
|
deprecate_kwargs, get_open_port, memory_profiling,
|
||||||
merge_async_iterators, supports_kw)
|
merge_async_iterators, supports_kw, swap_dict_values)
|
||||||
|
|
||||||
from .utils import error_on_warning, fork_new_process_for_each_test
|
from .utils import error_on_warning, fork_new_process_for_each_test
|
||||||
|
|
||||||
@ -449,3 +449,26 @@ def test_placeholder_module_error_handling():
|
|||||||
with build_ctx():
|
with build_ctx():
|
||||||
# Test conflict with internal __module attribute
|
# Test conflict with internal __module attribute
|
||||||
_ = placeholder_attr.module
|
_ = placeholder_attr.module
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"obj,key1,key2",
|
||||||
|
[
|
||||||
|
# Tests for both keys exist
|
||||||
|
({1: "a", 2: "b"}, 1, 2),
|
||||||
|
# Tests for one key does not exist
|
||||||
|
({1: "a", 2: "b"}, 1, 3),
|
||||||
|
# Tests for both keys do not exist
|
||||||
|
({1: "a", 2: "b"}, 3, 4),
|
||||||
|
])
|
||||||
|
def test_swap_dict_values(obj, key1, key2):
|
||||||
|
original_obj = obj.copy()
|
||||||
|
swap_dict_values(obj, key1, key2)
|
||||||
|
if key1 in original_obj:
|
||||||
|
assert obj[key2] == original_obj[key1]
|
||||||
|
else:
|
||||||
|
assert key2 not in obj
|
||||||
|
if key2 in original_obj:
|
||||||
|
assert obj[key1] == original_obj[key2]
|
||||||
|
else:
|
||||||
|
assert key1 not in obj
|
||||||
|
|||||||
@ -42,6 +42,7 @@ def create_sampling_metadata(spec_tokens: list[list[int]]) -> SamplingMetadata:
|
|||||||
min_tokens={},
|
min_tokens={},
|
||||||
logit_bias=[None] * batch_size,
|
logit_bias=[None] * batch_size,
|
||||||
allowed_token_ids_mask=None,
|
allowed_token_ids_mask=None,
|
||||||
|
bad_words_token_ids={},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -77,6 +77,49 @@ def _create_allowed_token_ids(
|
|||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def _create_bad_words_token_ids(
|
||||||
|
batch_size: int, vocab_size: int,
|
||||||
|
bad_words_lengths: list[tuple[int]]) -> dict[int, list[list[int]]]:
|
||||||
|
bad_words_token_ids = {}
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
token_ids_single_batch = []
|
||||||
|
for bad_words_length in bad_words_lengths:
|
||||||
|
token_ids = np.random.choice(vocab_size,
|
||||||
|
size=bad_words_length,
|
||||||
|
replace=True).tolist()
|
||||||
|
token_ids_single_batch.append(token_ids)
|
||||||
|
bad_words_token_ids[batch_idx] = token_ids_single_batch
|
||||||
|
if batch_size >= 2:
|
||||||
|
# Test no bad_words for some batch
|
||||||
|
no_bad_words_batch_idx = np.random.choice(batch_size)
|
||||||
|
bad_words_token_ids.pop(no_bad_words_batch_idx, None)
|
||||||
|
return bad_words_token_ids
|
||||||
|
|
||||||
|
|
||||||
|
def _update_output_token_ids_for_bad_words(
|
||||||
|
metadata: SamplingMetadata, vocab_size: int) -> dict[int, list[int]]:
|
||||||
|
bad_words_last_tokens = {}
|
||||||
|
for batch_idx, bad_words_token_ids in metadata.bad_words_token_ids.items():
|
||||||
|
output_token_ids = metadata.output_token_ids[batch_idx]
|
||||||
|
bad_words_last_token: list[int] = []
|
||||||
|
for i, bad_word_token_ids in enumerate(bad_words_token_ids):
|
||||||
|
if len(bad_word_token_ids) == 1:
|
||||||
|
# Single token id always affects logits
|
||||||
|
bad_words_last_token.append(bad_word_token_ids[0])
|
||||||
|
else:
|
||||||
|
prefix_length = len(bad_word_token_ids) - 1
|
||||||
|
has_bad_words = np.random.choice([True, False])
|
||||||
|
if has_bad_words:
|
||||||
|
output_token_ids[-prefix_length:] = bad_word_token_ids[:-1]
|
||||||
|
bad_words_last_token.append(bad_word_token_ids[-1])
|
||||||
|
break # Maximum one update to output_token_ids
|
||||||
|
else: # Make sure no accidental match to bad words
|
||||||
|
output_token_ids[-1] = (bad_word_token_ids[-2] +
|
||||||
|
1) % vocab_size
|
||||||
|
bad_words_last_tokens[batch_idx] = bad_words_last_token
|
||||||
|
return bad_words_last_tokens
|
||||||
|
|
||||||
|
|
||||||
def _create_default_sampling_metadata(
|
def _create_default_sampling_metadata(
|
||||||
num_output_tokens: int,
|
num_output_tokens: int,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
@ -112,6 +155,7 @@ def _create_default_sampling_metadata(
|
|||||||
min_tokens={},
|
min_tokens={},
|
||||||
logit_bias=[None] * batch_size,
|
logit_bias=[None] * batch_size,
|
||||||
allowed_token_ids_mask=None,
|
allowed_token_ids_mask=None,
|
||||||
|
bad_words_token_ids={},
|
||||||
)
|
)
|
||||||
return fake_sampling_metadata
|
return fake_sampling_metadata
|
||||||
|
|
||||||
@ -467,3 +511,35 @@ def test_sampler_allowed_token_ids(device: str, batch_size: int,
|
|||||||
"inf"), f"{batch_idx}, {token_id}"
|
"inf"), f"{batch_idx}, {token_id}"
|
||||||
else:
|
else:
|
||||||
assert logits_for_req[token_id] != -float("inf")
|
assert logits_for_req[token_id] != -float("inf")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
|
@pytest.mark.parametrize("batch_size", [1, 2, 32])
|
||||||
|
@pytest.mark.parametrize("bad_words_lengths", [(1, ), (1, 3), (2, 2)])
|
||||||
|
def test_sampler_bad_words(device: str, batch_size: int,
|
||||||
|
bad_words_lengths: list[tuple[int]]):
|
||||||
|
"""
|
||||||
|
Test to verify that when the bad words restriction is present, tokens
|
||||||
|
are penalized based on their match with the bad words.
|
||||||
|
"""
|
||||||
|
torch.set_default_device(device)
|
||||||
|
# Create fake logits where each token is assigned the same
|
||||||
|
# logit value.
|
||||||
|
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
|
||||||
|
sampling_metadata = _create_default_sampling_metadata(
|
||||||
|
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
|
||||||
|
sampling_metadata.bad_words_token_ids = _create_bad_words_token_ids(
|
||||||
|
batch_size, VOCAB_SIZE, bad_words_lengths)
|
||||||
|
bad_words_last_tokens = _update_output_token_ids_for_bad_words(
|
||||||
|
sampling_metadata, VOCAB_SIZE)
|
||||||
|
sampler = Sampler()
|
||||||
|
logits = sampler.apply_bad_words(fake_logits, sampling_metadata)
|
||||||
|
logits = logits.cpu()
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
logits_for_req = logits[batch_idx]
|
||||||
|
for token_id in range(VOCAB_SIZE):
|
||||||
|
if (batch_idx in bad_words_last_tokens
|
||||||
|
and token_id in bad_words_last_tokens[batch_idx]):
|
||||||
|
assert logits_for_req[token_id] == -float("inf")
|
||||||
|
else:
|
||||||
|
assert logits_for_req[token_id] != -float("inf")
|
||||||
|
|||||||
@ -120,8 +120,22 @@ def test_detokenize_false(model):
|
|||||||
def test_bad_words(model):
|
def test_bad_words(model):
|
||||||
"""Check that we respect bad words."""
|
"""Check that we respect bad words."""
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
output = model.generate(PROMPT, SamplingParams(temperature=0))
|
||||||
_ = model.generate(PROMPT, SamplingParams(bad_words=["Hello"]))
|
split_text = output[0].outputs[0].text.split()
|
||||||
|
|
||||||
|
bad_words_1 = " ".join(split_text[:2])
|
||||||
|
params = SamplingParams(temperature=0, bad_words=[bad_words_1])
|
||||||
|
output = model.generate(PROMPT, params)
|
||||||
|
new_text = output[0].outputs[0].text
|
||||||
|
assert bad_words_1 not in new_text
|
||||||
|
|
||||||
|
bad_words_2 = new_text.split()[-1]
|
||||||
|
params = SamplingParams(temperature=0,
|
||||||
|
bad_words=[bad_words_1, bad_words_2])
|
||||||
|
output = model.generate(PROMPT, params)
|
||||||
|
new_text = output[0].outputs[0].text
|
||||||
|
assert bad_words_1 not in new_text
|
||||||
|
assert bad_words_2 not in new_text
|
||||||
|
|
||||||
|
|
||||||
def test_logits_processor(model):
|
def test_logits_processor(model):
|
||||||
|
|||||||
@ -100,6 +100,7 @@ def _construct_expected_sampling_metadata(
|
|||||||
VOCAB_SIZE,
|
VOCAB_SIZE,
|
||||||
dtype=torch.bool,
|
dtype=torch.bool,
|
||||||
device=device)
|
device=device)
|
||||||
|
bad_words_token_ids = {}
|
||||||
for req in reqs:
|
for req in reqs:
|
||||||
if req.req_id not in req_ids_retained:
|
if req.req_id not in req_ids_retained:
|
||||||
continue
|
continue
|
||||||
@ -123,6 +124,8 @@ def _construct_expected_sampling_metadata(
|
|||||||
if req.sampling_params.allowed_token_ids:
|
if req.sampling_params.allowed_token_ids:
|
||||||
allowed_token_ids_mask[index_in_input_batch][
|
allowed_token_ids_mask[index_in_input_batch][
|
||||||
req.sampling_params.allowed_token_ids] = True
|
req.sampling_params.allowed_token_ids] = True
|
||||||
|
bad_words_token_ids[
|
||||||
|
index_in_input_batch] = req.sampling_params.bad_words_token_ids
|
||||||
|
|
||||||
return SamplingMetadata(
|
return SamplingMetadata(
|
||||||
temperature=torch.tensor(temperature, dtype=torch.float,
|
temperature=torch.tensor(temperature, dtype=torch.float,
|
||||||
@ -159,6 +162,7 @@ def _construct_expected_sampling_metadata(
|
|||||||
and all(x == 1 for x in repetition_penalties)),
|
and all(x == 1 for x in repetition_penalties)),
|
||||||
logit_bias=logit_bias,
|
logit_bias=logit_bias,
|
||||||
allowed_token_ids_mask=allowed_token_ids_mask,
|
allowed_token_ids_mask=allowed_token_ids_mask,
|
||||||
|
bad_words_token_ids=bad_words_token_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -284,6 +288,8 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
|
|||||||
assert torch.allclose(
|
assert torch.allclose(
|
||||||
expected_sampling_metadata.allowed_token_ids_mask,
|
expected_sampling_metadata.allowed_token_ids_mask,
|
||||||
sampling_metadata.allowed_token_ids_mask)
|
sampling_metadata.allowed_token_ids_mask)
|
||||||
|
assert expected_sampling_metadata.bad_words_token_ids == \
|
||||||
|
sampling_metadata.bad_words_token_ids
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
|
|||||||
@ -11,6 +11,8 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.logits_process import LogitsProcessor
|
from vllm.logits_process import LogitsProcessor
|
||||||
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -202,7 +204,6 @@ class SamplingParams(
|
|||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
stop: Optional[Union[str, list[str]]] = None
|
stop: Optional[Union[str, list[str]]] = None
|
||||||
stop_token_ids: Optional[list[int]] = None
|
stop_token_ids: Optional[list[int]] = None
|
||||||
bad_words: Optional[list[str]] = None
|
|
||||||
ignore_eos: bool = False
|
ignore_eos: bool = False
|
||||||
max_tokens: Optional[int] = 16
|
max_tokens: Optional[int] = 16
|
||||||
min_tokens: int = 0
|
min_tokens: int = 0
|
||||||
@ -232,6 +233,10 @@ class SamplingParams(
|
|||||||
allowed_token_ids: Optional[list[int]] = None
|
allowed_token_ids: Optional[list[int]] = None
|
||||||
extra_args: Optional[dict[str, Any]] = None
|
extra_args: Optional[dict[str, Any]] = None
|
||||||
|
|
||||||
|
# Fields used for bad words
|
||||||
|
bad_words: Optional[list[str]] = None
|
||||||
|
_bad_words_token_ids: list[list[int]] = msgspec.field(default_factory=list)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_optional(
|
def from_optional(
|
||||||
n: Optional[int] = 1,
|
n: Optional[int] = 1,
|
||||||
@ -464,6 +469,46 @@ class SamplingParams(
|
|||||||
eos_ids.update(self.stop_token_ids)
|
eos_ids.update(self.stop_token_ids)
|
||||||
self.stop_token_ids = list(eos_ids)
|
self.stop_token_ids = list(eos_ids)
|
||||||
|
|
||||||
|
def update_from_tokenizer(self, tokenizer: AnyTokenizer) -> None:
|
||||||
|
if self.bad_words is None:
|
||||||
|
return
|
||||||
|
for bad_word in self.bad_words:
|
||||||
|
# To prohibit words both at the beginning
|
||||||
|
# and in the middle of text
|
||||||
|
# (related to add_prefix_space tokenizer parameter)
|
||||||
|
for add_prefix_space in [False, True]:
|
||||||
|
prefix = " " if add_prefix_space else ""
|
||||||
|
prompt = prefix + bad_word.lstrip()
|
||||||
|
|
||||||
|
if isinstance(tokenizer, MistralTokenizer):
|
||||||
|
# Mistral tokenizers should not add special tokens
|
||||||
|
prompt_token_ids = tokenizer.encode(text=prompt)
|
||||||
|
else:
|
||||||
|
prompt_token_ids = tokenizer.encode(
|
||||||
|
text=prompt, add_special_tokens=False)
|
||||||
|
|
||||||
|
# If no space at the beginning
|
||||||
|
# or if prefix space produces a new word token
|
||||||
|
if (not add_prefix_space) or (
|
||||||
|
add_prefix_space and prompt_token_ids[0]
|
||||||
|
!= self._bad_words_token_ids[-1][0]
|
||||||
|
and len(prompt_token_ids) == len(
|
||||||
|
self._bad_words_token_ids[-1])):
|
||||||
|
self._bad_words_token_ids.append(prompt_token_ids)
|
||||||
|
|
||||||
|
invalid_token_ids = [
|
||||||
|
token_id for bad_words_token_ids in self._bad_words_token_ids
|
||||||
|
for token_id in bad_words_token_ids
|
||||||
|
if token_id < 0 or token_id > tokenizer.max_token_id
|
||||||
|
]
|
||||||
|
if len(invalid_token_ids) > 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"The model vocabulary size is {tokenizer.max_token_id+1},"
|
||||||
|
f" but the following tokens"
|
||||||
|
f" were specified as bad: {invalid_token_ids}."
|
||||||
|
f" All token id values should be integers satisfying:"
|
||||||
|
f" 0 <= token_id <= {tokenizer.max_token_id}.")
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def sampling_type(self) -> SamplingType:
|
def sampling_type(self) -> SamplingType:
|
||||||
if self.temperature < _SAMPLING_EPS:
|
if self.temperature < _SAMPLING_EPS:
|
||||||
@ -476,6 +521,11 @@ class SamplingParams(
|
|||||||
def all_stop_token_ids(self) -> set[int]:
|
def all_stop_token_ids(self) -> set[int]:
|
||||||
return self._all_stop_token_ids
|
return self._all_stop_token_ids
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bad_words_token_ids(self) -> list[list[int]]:
|
||||||
|
# For internal use only. Backward compatibility not guaranteed
|
||||||
|
return self._bad_words_token_ids
|
||||||
|
|
||||||
def clone(self) -> "SamplingParams":
|
def clone(self) -> "SamplingParams":
|
||||||
"""Deep copy, but maybe not the LogitsProcessor objects.
|
"""Deep copy, but maybe not the LogitsProcessor objects.
|
||||||
|
|
||||||
|
|||||||
@ -2361,3 +2361,19 @@ class LazyLoader(types.ModuleType):
|
|||||||
if self._module is None:
|
if self._module is None:
|
||||||
self._module = self._load()
|
self._module = self._load()
|
||||||
return dir(self._module)
|
return dir(self._module)
|
||||||
|
|
||||||
|
|
||||||
|
def swap_dict_values(obj: dict[_K, _V], key1: _K, key2: _K) -> None:
|
||||||
|
"""
|
||||||
|
Helper function to swap values for two keys
|
||||||
|
"""
|
||||||
|
v1 = obj.get(key1)
|
||||||
|
v2 = obj.get(key2)
|
||||||
|
if v1 is not None:
|
||||||
|
obj[key2] = v1
|
||||||
|
else:
|
||||||
|
obj.pop(key2, None)
|
||||||
|
if v2 is not None:
|
||||||
|
obj[key1] = v2
|
||||||
|
else:
|
||||||
|
obj.pop(key1, None)
|
||||||
|
|||||||
@ -94,9 +94,6 @@ class Processor:
|
|||||||
# Best of not yet supported.
|
# Best of not yet supported.
|
||||||
if params.best_of is not None and params.best_of > 1:
|
if params.best_of is not None and params.best_of > 1:
|
||||||
raise ValueError("VLLM V1 does not yet support best_of.")
|
raise ValueError("VLLM V1 does not yet support best_of.")
|
||||||
# Bad words not yet supported.
|
|
||||||
if params.bad_words:
|
|
||||||
raise ValueError("VLLM V1 does not yet support bad_words.")
|
|
||||||
# Logits processors not supported.
|
# Logits processors not supported.
|
||||||
if params.logits_processors:
|
if params.logits_processors:
|
||||||
raise ValueError("VLLM V1 does not support per request "
|
raise ValueError("VLLM V1 does not support per request "
|
||||||
@ -203,6 +200,8 @@ class Processor:
|
|||||||
sampling_params = params.clone()
|
sampling_params = params.clone()
|
||||||
sampling_params.update_from_generation_config(
|
sampling_params.update_from_generation_config(
|
||||||
self.generation_config_fields, eos_token_id)
|
self.generation_config_fields, eos_token_id)
|
||||||
|
sampling_params.update_from_tokenizer(
|
||||||
|
self.tokenizer.get_lora_tokenizer(lora_request))
|
||||||
|
|
||||||
# Multimodal related.
|
# Multimodal related.
|
||||||
# Compute MM hashes (if enabled)
|
# Compute MM hashes (if enabled)
|
||||||
|
|||||||
@ -38,3 +38,6 @@ class SamplingMetadata:
|
|||||||
# `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size,
|
# `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size,
|
||||||
# vocab size).
|
# vocab size).
|
||||||
allowed_token_ids_mask: Optional[torch.Tensor]
|
allowed_token_ids_mask: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
# req_index -> bad_words_token_ids
|
||||||
|
bad_words_token_ids: dict[int, list[list[int]]]
|
||||||
|
|||||||
38
vllm/v1/sample/ops/bad_words.py
Normal file
38
vllm/v1/sample/ops/bad_words.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
_SMALLEST_LOGIT = float("-inf")
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_bad_words_single_batch(
|
||||||
|
logits: torch.Tensor,
|
||||||
|
bad_words_token_ids: list[list[int]],
|
||||||
|
past_tokens_ids: list[int],
|
||||||
|
) -> None:
|
||||||
|
for bad_word_ids in bad_words_token_ids:
|
||||||
|
if len(bad_word_ids) > len(past_tokens_ids) + 1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
prefix_length = len(bad_word_ids) - 1
|
||||||
|
last_token_id = bad_word_ids[-1]
|
||||||
|
if prefix_length > 0:
|
||||||
|
actual_prefix = past_tokens_ids[-prefix_length:]
|
||||||
|
else:
|
||||||
|
actual_prefix = []
|
||||||
|
expected_prefix = bad_word_ids[:prefix_length]
|
||||||
|
|
||||||
|
assert len(actual_prefix) == len(expected_prefix)
|
||||||
|
|
||||||
|
if actual_prefix == expected_prefix:
|
||||||
|
logits[last_token_id] = _SMALLEST_LOGIT
|
||||||
|
|
||||||
|
|
||||||
|
def apply_bad_words(
|
||||||
|
logits: torch.Tensor,
|
||||||
|
bad_words_token_ids: dict[int, list[list[int]]],
|
||||||
|
past_tokens_ids: list[list[int]],
|
||||||
|
) -> None:
|
||||||
|
for i, bad_words_ids in bad_words_token_ids.items():
|
||||||
|
_apply_bad_words_single_batch(logits[i], bad_words_ids,
|
||||||
|
past_tokens_ids[i])
|
||||||
@ -6,6 +6,7 @@ import torch.nn as nn
|
|||||||
|
|
||||||
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
|
from vllm.v1.sample.ops.bad_words import apply_bad_words
|
||||||
from vllm.v1.sample.ops.penalties import (apply_all_penalties,
|
from vllm.v1.sample.ops.penalties import (apply_all_penalties,
|
||||||
apply_min_token_penalties)
|
apply_min_token_penalties)
|
||||||
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
|
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
|
||||||
@ -38,6 +39,8 @@ class Sampler(nn.Module):
|
|||||||
logits = logits.to(torch.float32)
|
logits = logits.to(torch.float32)
|
||||||
# Apply allowed token ids.
|
# Apply allowed token ids.
|
||||||
logits = self.apply_allowed_token_ids(logits, sampling_metadata)
|
logits = self.apply_allowed_token_ids(logits, sampling_metadata)
|
||||||
|
# Apply bad words exclusion.
|
||||||
|
logits = self.apply_bad_words(logits, sampling_metadata)
|
||||||
# Apply logits bias.
|
# Apply logits bias.
|
||||||
logits = self.apply_logits_bias(logits, sampling_metadata)
|
logits = self.apply_logits_bias(logits, sampling_metadata)
|
||||||
# Apply penalties (e.g., min_tokens, freq_penalties).
|
# Apply penalties (e.g., min_tokens, freq_penalties).
|
||||||
@ -237,3 +240,16 @@ class Sampler(nn.Module):
|
|||||||
logits.masked_fill_(sampling_metadata.allowed_token_ids_mask,
|
logits.masked_fill_(sampling_metadata.allowed_token_ids_mask,
|
||||||
float("-inf"))
|
float("-inf"))
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
def apply_bad_words(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if sampling_metadata.bad_words_token_ids:
|
||||||
|
apply_bad_words(
|
||||||
|
logits,
|
||||||
|
sampling_metadata.bad_words_token_ids,
|
||||||
|
sampling_metadata.output_token_ids,
|
||||||
|
)
|
||||||
|
return logits
|
||||||
|
|||||||
@ -10,6 +10,7 @@ import torch
|
|||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.multimodal import MultiModalKwargs
|
from vllm.multimodal import MultiModalKwargs
|
||||||
from vllm.sampling_params import SamplingParams, SamplingType
|
from vllm.sampling_params import SamplingParams, SamplingType
|
||||||
|
from vllm.utils import swap_dict_values
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.utils import copy_slice
|
from vllm.v1.utils import copy_slice
|
||||||
from vllm.v1.worker.block_table import BlockTable
|
from vllm.v1.worker.block_table import BlockTable
|
||||||
@ -204,6 +205,9 @@ class InputBatch:
|
|||||||
self.allowed_token_ids_mask: Optional[torch.Tensor] = None
|
self.allowed_token_ids_mask: Optional[torch.Tensor] = None
|
||||||
self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None
|
self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
# req_index -> bad_words_token_ids
|
||||||
|
self.bad_words_token_ids: dict[int, list[list[int]]] = {}
|
||||||
|
|
||||||
self.req_output_token_ids: list[Optional[list[int]]] = []
|
self.req_output_token_ids: list[Optional[list[int]]] = []
|
||||||
|
|
||||||
# This is updated each time the batch constituents change.
|
# This is updated each time the batch constituents change.
|
||||||
@ -320,6 +324,9 @@ class InputBatch:
|
|||||||
self.allowed_token_ids_mask_cpu_tensor[req_index][
|
self.allowed_token_ids_mask_cpu_tensor[req_index][
|
||||||
sampling_params.allowed_token_ids] = False
|
sampling_params.allowed_token_ids] = False
|
||||||
|
|
||||||
|
self.bad_words_token_ids[
|
||||||
|
req_index] = sampling_params.bad_words_token_ids
|
||||||
|
|
||||||
# Add request lora ID
|
# Add request lora ID
|
||||||
if request.lora_request:
|
if request.lora_request:
|
||||||
lora_id = request.lora_request.lora_int_id
|
lora_id = request.lora_request.lora_int_id
|
||||||
@ -369,6 +376,7 @@ class InputBatch:
|
|||||||
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
||||||
# False means we don't fill with -inf.
|
# False means we don't fill with -inf.
|
||||||
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
|
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
|
||||||
|
self.bad_words_token_ids.pop(req_index, None)
|
||||||
return req_index
|
return req_index
|
||||||
|
|
||||||
def swap_states(self, i1: int, i2: int) -> None:
|
def swap_states(self, i1: int, i2: int) -> None:
|
||||||
@ -413,27 +421,9 @@ class InputBatch:
|
|||||||
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
|
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
|
||||||
self.token_ids_cpu[i2, ...] = tmp
|
self.token_ids_cpu[i2, ...] = tmp
|
||||||
|
|
||||||
g1 = self.generators.get(i1)
|
swap_dict_values(self.generators, i1, i2)
|
||||||
g2 = self.generators.get(i2)
|
swap_dict_values(self.min_tokens, i1, i2)
|
||||||
if g1 is not None:
|
swap_dict_values(self.bad_words_token_ids, i1, i2)
|
||||||
self.generators[i2] = g1
|
|
||||||
else:
|
|
||||||
self.generators.pop(i2, None)
|
|
||||||
if g2 is not None:
|
|
||||||
self.generators[i1] = g2
|
|
||||||
else:
|
|
||||||
self.generators.pop(i1, None)
|
|
||||||
|
|
||||||
t1 = self.min_tokens.get(i1)
|
|
||||||
t2 = self.min_tokens.get(i2)
|
|
||||||
if t1 is not None:
|
|
||||||
self.min_tokens[i2] = t1
|
|
||||||
else:
|
|
||||||
self.min_tokens.pop(i2, None)
|
|
||||||
if t2 is not None:
|
|
||||||
self.min_tokens[i1] = t2
|
|
||||||
else:
|
|
||||||
self.min_tokens.pop(i1, None)
|
|
||||||
|
|
||||||
self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\
|
self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\
|
||||||
self.request_lora_mapping[i2], self.request_lora_mapping[i1]
|
self.request_lora_mapping[i2], self.request_lora_mapping[i1]
|
||||||
@ -518,6 +508,10 @@ class InputBatch:
|
|||||||
empty_index] = self.allowed_token_ids_mask_cpu_tensor[
|
empty_index] = self.allowed_token_ids_mask_cpu_tensor[
|
||||||
last_req_index]
|
last_req_index]
|
||||||
|
|
||||||
|
bad_words_token_ids = self.bad_words_token_ids.pop(
|
||||||
|
last_req_index, None)
|
||||||
|
if bad_words_token_ids is not None:
|
||||||
|
self.bad_words_token_ids[empty_index] = bad_words_token_ids
|
||||||
# Decrement last_req_index since it is now empty.
|
# Decrement last_req_index since it is now empty.
|
||||||
last_req_index -= 1
|
last_req_index -= 1
|
||||||
|
|
||||||
@ -585,6 +579,7 @@ class InputBatch:
|
|||||||
no_penalties=self.no_penalties,
|
no_penalties=self.no_penalties,
|
||||||
logit_bias=self.logit_bias[:num_reqs],
|
logit_bias=self.logit_bias[:num_reqs],
|
||||||
allowed_token_ids_mask=allowed_token_ids_mask,
|
allowed_token_ids_mask=allowed_token_ids_mask,
|
||||||
|
bad_words_token_ids=self.bad_words_token_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
|
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
|
||||||
|
|||||||
@ -1268,6 +1268,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
min_tokens={},
|
min_tokens={},
|
||||||
logit_bias=[None for _ in range(num_reqs)],
|
logit_bias=[None for _ in range(num_reqs)],
|
||||||
allowed_token_ids_mask=None,
|
allowed_token_ids_mask=None,
|
||||||
|
bad_words_token_ids={},
|
||||||
)
|
)
|
||||||
sampler_output = self.model.sample(logits=logits,
|
sampler_output = self.model.sample(logits=logits,
|
||||||
sampling_metadata=dummy_metadata)
|
sampling_metadata=dummy_metadata)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user