[V1] Adding min tokens/repetition/presence/frequence penalties to V1 sampler (#10681)

Signed-off-by: Sourashis Roy <sroy@roblox.com>
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
sroy745 2024-12-26 02:02:58 -08:00 committed by GitHub
parent 7492a36207
commit dcb1a944d4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 879 additions and 49 deletions

View File

@ -139,3 +139,41 @@ def test_engine_core(monkeypatch):
engine_core.abort_requests([req2.request_id, req0.request_id]) engine_core.abort_requests([req2.request_id, req0.request_id])
assert len(engine_core.scheduler.waiting) == 0 assert len(engine_core.scheduler.waiting) == 0
assert len(engine_core.scheduler.running) == 0 assert len(engine_core.scheduler.running) == 0
def test_engine_core_advanced_sampling(monkeypatch):
"""
A basic end-to-end test to verify that the engine functions correctly
when additional sampling parameters, such as min_tokens and
presence_penalty, are set.
"""
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
"""Setup the EngineCore."""
engine_args = EngineArgs(model=MODEL_NAME)
vllm_config = engine_args.create_engine_config(
usage_context=UsageContext.UNKNOWN_CONTEXT)
executor_class = AsyncLLM._get_executor_cls(vllm_config)
engine_core = EngineCore(vllm_config=vllm_config,
executor_class=executor_class,
usage_context=UsageContext.UNKNOWN_CONTEXT)
"""Test basic request lifecycle."""
# First request.
request: EngineCoreRequest = make_request()
request.sampling_params = SamplingParams(
min_tokens=4,
presence_penalty=1.0,
frequency_penalty=1.0,
repetition_penalty=0.1,
stop_token_ids=[1001, 1002],
)
engine_core.add_request(request)
assert len(engine_core.scheduler.waiting) == 1
assert len(engine_core.scheduler.running) == 0
# Loop through until they are all done.
while len(engine_core.step()) > 0:
pass
assert len(engine_core.scheduler.waiting) == 0
assert len(engine_core.scheduler.running) == 0

View File

View File

@ -0,0 +1,331 @@
from typing import List, Set, Tuple
import numpy as np
import pytest
import torch
from vllm.utils import make_tensor_with_pad
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import Sampler
VOCAB_SIZE = 1024
NUM_OUTPUT_TOKENS = 20
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
MAX_NUM_PROMPT_TOKENS = 64
def _create_fake_logits(batch_size: int, vocab_size: int) -> torch.Tensor:
fake_logits = torch.full((batch_size, vocab_size), 1e-2, dtype=torch.float)
return fake_logits
def _create_penalty_tensor(batch_size: int, penalty_value: float,
device: torch.device) -> torch.Tensor:
return torch.full((batch_size, ),
fill_value=penalty_value,
dtype=torch.float,
device=device)
def _create_prompt_tokens_tensor(
prompt_token_ids: List[List[int]],
vocab_size: int,
device: torch.device,
) -> torch.Tensor:
return make_tensor_with_pad(
prompt_token_ids,
pad=vocab_size,
device=device,
dtype=torch.int64,
pin_memory=False,
)
def _create_default_sampling_metadata(
num_output_tokens: int,
batch_size: int,
vocab_size: int,
device: torch.device,
) -> SamplingMetadata:
output_token_ids: List[List[int]] = []
prompt_token_ids: List[List[int]] = []
for _ in range(batch_size):
output_token_ids.append(
np.random.randint(0, vocab_size, size=num_output_tokens).tolist())
prompt_token_ids.append(
np.random.randint(0,
vocab_size,
size=np.random.randint(
1, MAX_NUM_PROMPT_TOKENS)).tolist())
fake_sampling_metadata = SamplingMetadata(
temperature=torch.full((batch_size, ), 0.0),
all_greedy=True,
all_random=False,
top_p=torch.empty(batch_size, ),
top_k=torch.empty(batch_size, ),
no_top_p=True,
no_top_k=True,
generators={},
max_num_logprobs=VOCAB_SIZE,
prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids,
vocab_size, device),
output_token_ids=output_token_ids,
frequency_penalties=_create_penalty_tensor(batch_size, 0.0, device),
presence_penalties=_create_penalty_tensor(batch_size, 0.0, device),
repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device),
no_penalties=True,
min_tokens=[],
stop_token_ids=[],
)
return fake_sampling_metadata
def _generate_min_token_penalties_and_stop_tokens(
num_output_tokens: int, batch_size: int, vocab_size: int,
batch_indices_for_min_token_penalty: List[int]
) -> Tuple[List[int], List[Set[int]]]:
"""
Generates and returns a list of minimum token penalties (`min_tokens`)
and a corresponding list of stop token IDs (`stop_token_ids`) for each
batch.
If a batch index is included in `batch_indices_for_min_token_penalty`,
a higher `min_tokens` value is assigned (within a randomized range),
and a random set of stop token IDs is created. Otherwise, a lower
`min_tokens` value is assigned, and the stop token IDs set is empty.
"""
stop_token_ids: List[Set[int]] = []
min_tokens: List[int] = []
for index in range(batch_size):
if index in batch_indices_for_min_token_penalty:
min_tokens.append(
np.random.randint(num_output_tokens + 1,
2 * num_output_tokens))
stop_token_ids.append(
set(
np.random.randint(0, vocab_size - 1)
for _ in range(np.random.randint(0, vocab_size))))
else:
min_tokens.append(np.random.randint(0, num_output_tokens))
stop_token_ids.append(set())
return (min_tokens, stop_token_ids)
def _create_weighted_output_token_list(
batch_size: int,
vocab_size: int) -> Tuple[List[List[int]], List[List[int]]]:
"""
Creates an output token list where each token occurs a distinct
number of times.
For each batch, a random subset of token IDs is selected from the
vocabulary. The selected tokens are then added to the output token
list, each with a different frequency.
Returns:
Tuple[List[List[int]], List[List[int]]]:
- The first element is the output token list, where each sublist
corresponds to a batch and contains tokens with weighted
frequencies.
- The second element is a list of distinct token IDs for each
batch, ordered by their frequency in the corresponding output
list.
"""
output_token_ids: List[List[int]] = []
sorted_token_ids_in_output: List[List[int]] = []
for _ in range(batch_size):
distinct_token_ids = np.random.choice(vocab_size,
size=np.random.randint(1, 10),
replace=False).tolist()
sorted_token_ids_in_output.append(distinct_token_ids)
output_token_ids_for_batch = []
for index, token_id in enumerate(distinct_token_ids):
output_token_ids_for_batch.extend(
[token_id for _ in range(index + 1)])
output_token_ids.append(output_token_ids_for_batch)
return (output_token_ids, sorted_token_ids_in_output)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
def test_sampler_min_tokens_penalty(device: str, batch_size: int):
"""
Tests that if the number of output tokens is less than
SamplingParams.min_tokens then we will set the logits for
the stop token ids to -inf.
"""
torch.set_default_device(device)
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
sampling_metadata = _create_default_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
batch_indices_for_min_token_penalty = np.random.randint(
0, batch_size - 1, size=np.random.randint(0, batch_size)).tolist()
min_tokens, stop_token_ids = _generate_min_token_penalties_and_stop_tokens(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE,
batch_indices_for_min_token_penalty)
sampling_metadata.min_tokens = min_tokens
sampling_metadata.stop_token_ids = stop_token_ids
sampler = Sampler()
sampler_output = sampler(fake_logits, sampling_metadata)
for batch_idx in range(batch_size):
for vocab in range(VOCAB_SIZE):
# Verify that the logprobs for stop token ids is set
# to -inf.
logprob_index = torch.where(
sampler_output.logprob_token_ids[batch_idx] ==
vocab)[0].item()
if vocab in stop_token_ids[batch_idx]:
assert sampler_output.logprobs[batch_idx][
logprob_index] == -float("inf")
else:
assert sampler_output.logprobs[batch_idx][
logprob_index] != -float("inf")
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("presence_penalty", [-2.0, 2.0])
def test_sampler_presence_penalty(device: str, batch_size: int,
presence_penalty: float):
"""
Test to verify that if presence penalty is enabled then tokens
are penalized as per their presence in the existing output.
"""
torch.set_default_device(device)
# Create fake logits where each token is assigned the same
# logit value.
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
sampling_metadata = _create_default_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
output_token_ids = sampling_metadata.output_token_ids
sampling_metadata.presence_penalties = _create_penalty_tensor(
batch_size, presence_penalty, torch.device(device))
sampling_metadata.no_penalties = False
sampler = Sampler()
sampler_output = sampler(fake_logits, sampling_metadata)
for batch_idx in range(batch_size):
# The logprobs in the SamplerOutput are arranged in descending order.
# Since all tokens initially have the same logprobs, the non-penalized
# tokens will appear at the beginning, while the penalized tokens
# will appear at the end of the list.
penalized_token_id = sampler_output.logprob_token_ids[batch_idx][
VOCAB_SIZE - 1]
penalized_log_prod = sampler_output.logprobs[batch_idx][VOCAB_SIZE - 1]
non_penalized_token_id = sampler_output.logprob_token_ids[batch_idx][0]
non_penalized_log_prod = sampler_output.logprobs[batch_idx][0]
assert non_penalized_log_prod > penalized_log_prod
if presence_penalty > 0:
# If `presence_penalty` is set to a value greater than 0, it
# indicates a preference for new tokens over those already
# present in the output.
# Verify that the penalized token ID exists in the output, while the
# non-penalized token ID does not.
assert penalized_token_id in output_token_ids[batch_idx]
assert non_penalized_token_id not in output_token_ids[batch_idx]
elif presence_penalty < 0:
# If `presence_penalty` is set to a value less than 0, it indicates
# a preference for existing tokens over new ones. Verify that the
# non-penalized token ID exists in the output, while the penalized
# token ID does not.
assert non_penalized_token_id in output_token_ids[batch_idx]
assert penalized_token_id not in output_token_ids[batch_idx]
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("frequency_penalty", [-2.0, 2.0])
def test_sampler_frequency_penalty(device: str, batch_size: int,
frequency_penalty: float):
"""
Test to verify that if frequency penalty is enabled then tokens are
penalized as per their frequency of occurrence.
"""
torch.set_default_device(device)
# Create fake logits where each token is assigned the same
# logit value.
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
sampling_metadata = _create_default_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
sampling_metadata.frequency_penalties = _create_penalty_tensor(
batch_size, frequency_penalty, torch.device(device))
output_token_ids, sorted_token_ids_in_output = \
_create_weighted_output_token_list(batch_size, VOCAB_SIZE)
sampling_metadata.output_token_ids = output_token_ids
sampling_metadata.no_penalties = False
sampler = Sampler()
sampler_output = sampler(fake_logits, sampling_metadata)
for batch_idx in range(batch_size):
logprobs_token_ids = sampler_output.logprob_token_ids[batch_idx]
non_penalized_token_id = logprobs_token_ids[0]
penalized_token_id = logprobs_token_ids[VOCAB_SIZE - 1]
distinct_sorted_token_ids_in_output = \
sorted_token_ids_in_output[batch_idx]
most_frequent_token_id = distinct_sorted_token_ids_in_output[
len(distinct_sorted_token_ids_in_output) - 1]
if frequency_penalty > 0:
# If `frequency_penalty` is set to > 0, it indicates
# a preference for new tokens over existing ones. Verify that the
# non-penalized token ID is not present in the output, while the
# most penalized token is the one that occurs most frequently in
# the output.
assert non_penalized_token_id \
not in distinct_sorted_token_ids_in_output
assert penalized_token_id == most_frequent_token_id
elif frequency_penalty < 0:
# If `frequency_penalty` is set to < 0, it indicates
# a preference for existing tokens over new ones. Verify that the
# non-penalized token ID is the one that occurs most frequently
# in the output, while the penalized token ID is one that has not
# yet appeared.
assert non_penalized_token_id == most_frequent_token_id
assert penalized_token_id \
not in distinct_sorted_token_ids_in_output
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("repetition_penalty", [0.1, 1.9])
def test_sampler_repetition_penalty(device: str, batch_size: int,
repetition_penalty: float):
"""
Test to verify that when the repetition penalty is enabled, tokens
are penalized based on their presence in the prompt or the existing
output.
"""
torch.set_default_device(device)
# Create fake logits where each token is assigned the same
# logit value.
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
sampling_metadata = _create_default_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
sampling_metadata.repetition_penalties = _create_penalty_tensor(
batch_size, repetition_penalty, torch.device(device))
sampling_metadata.no_penalties = False
sampler = Sampler()
sampler_output = sampler(fake_logits, sampling_metadata)
for batch_idx in range(batch_size):
logprobs_token_ids = sampler_output.logprob_token_ids[batch_idx]
non_penalized_token_id = logprobs_token_ids[0]
penalized_token_id = logprobs_token_ids[VOCAB_SIZE - 1]
prompt_tokens = sampling_metadata.prompt_token_ids[
batch_idx][:].tolist()
output_tokens = sampling_metadata.output_token_ids[batch_idx]
if repetition_penalty > 1.0:
# If `repetition_penalty` > 1.0, verify that the non-penalized
# token ID has not been seen before, while the penalized token ID
# exists either in the prompt or the output.
assert (non_penalized_token_id not in prompt_tokens and \
non_penalized_token_id not in output_tokens)
assert (penalized_token_id in prompt_tokens or \
penalized_token_id in output_tokens)
elif repetition_penalty < 1.0:
# If `repetition_penalty` < 1.0, verify that the penalized
# token ID has not been seen before, while the non-penalized
# token ID exists either in the prompt or the output.
assert (penalized_token_id not in prompt_tokens and \
penalized_token_id not in output_tokens)
assert (non_penalized_token_id in prompt_tokens or \
non_penalized_token_id in output_tokens)

View File

View File

@ -0,0 +1,224 @@
from typing import Dict, List, Set, Tuple
import numpy as np
import pytest
import torch
from vllm.sampling_params import SamplingParams
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
VOCAB_SIZE = 1024
NUM_OUTPUT_TOKENS = 20
MAX_PROMPT_SIZE = 100
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
MAX_NUM_PROMPT_TOKENS = 64
def _remove_requests(
input_batch: InputBatch, batch_size: int,
reqs: List[CachedRequestState]) -> Tuple[Set[str], List[int]]:
"""
Remove some requests randomly from the batch and returns a Tuple
of 1) set of request removed 2) indices of the requests removed
ordered in descending order
"""
num_reqs_to_remove = np.random.randint(0, batch_size)
req_indices_to_remove: Set[int] = set()
for _ in range(num_reqs_to_remove):
req_index_to_remove = np.random.randint(0, batch_size)
req_indices_to_remove.add(req_index_to_remove)
req_indices_to_remove_list = list(req_indices_to_remove)
req_indices_to_remove_list.sort(reverse=True)
req_ids_to_remove: Set[str] = set()
for index in req_indices_to_remove:
input_batch.remove_request(reqs[index].req_id)
req_ids_to_remove.add(reqs[index].req_id)
return (req_ids_to_remove, req_indices_to_remove_list)
def _construct_expected_sampling_metadata(
reqs: List[CachedRequestState], req_ids_retained: Set[int],
req_id_index_in_input_batch: Dict[str, int],
device: torch.device) -> SamplingMetadata:
"""
Constructs and returns the expected SamplingMetadata for this
batch.
"""
num_reqs = len(req_ids_retained)
output_token_ids: List[List[int]] = [list() for _ in range(num_reqs)]
prompt_token_ids: List[List[int]] = [list() for _ in range(num_reqs)]
presence_penalties = [0.0 for _ in range(num_reqs)]
frequency_penalties = [0.0 for _ in range(num_reqs)]
repetition_penalties = [1.0 for _ in range(num_reqs)]
top_k = [0 for _ in range(num_reqs)]
top_p = [0.0 for _ in range(num_reqs)]
temperature = [0.0 for _ in range(num_reqs)]
stop_token_ids: List[Set[int]] = [set() for _ in range(num_reqs)]
min_tokens = [0 for _ in range(num_reqs)]
for req in reqs:
if req.req_id not in req_ids_retained:
continue
index_in_input_batch = req_id_index_in_input_batch[req.req_id]
output_token_ids[index_in_input_batch] = req.output_token_ids
prompt_token_ids[index_in_input_batch] = req.prompt_token_ids
presence_penalties[
index_in_input_batch] = req.sampling_params.presence_penalty
frequency_penalties[
index_in_input_batch] = req.sampling_params.frequency_penalty
repetition_penalties[
index_in_input_batch] = req.sampling_params.repetition_penalty
top_k[index_in_input_batch] = req.sampling_params.top_k
top_p[index_in_input_batch] = req.sampling_params.top_p
temperature[index_in_input_batch] = req.sampling_params.temperature
stop_token_ids[
index_in_input_batch] = req.sampling_params.all_stop_token_ids
min_tokens[index_in_input_batch] = req.sampling_params.min_tokens
return SamplingMetadata(
temperature=torch.tensor(temperature, dtype=torch.float, device=device),
all_greedy=False,
all_random=True,
top_p=torch.tensor(top_p, dtype=torch.float, device=device),
top_k=torch.tensor(top_k, dtype=torch.int, device=device),
no_top_p=all(x == 1.0 for x in top_p),
no_top_k=all(x == 0 for x in top_k),
generators={},
max_num_logprobs=0,
prompt_token_ids= make_tensor_with_pad(
prompt_token_ids,
pad=VOCAB_SIZE,
device=torch.device(device),
dtype=torch.int64,
),
frequency_penalties=torch.tensor(
frequency_penalties, dtype=torch.float,
device=device),
presence_penalties=torch.tensor(
presence_penalties, dtype=torch.float,
device=device),
repetition_penalties=torch.tensor(
repetition_penalties, dtype=torch.float,
device=device),
output_token_ids=output_token_ids,
min_tokens=min_tokens,
stop_token_ids=stop_token_ids,
no_penalties=(all(x ==0 for x in presence_penalties) and \
all(x ==0 for x in frequency_penalties) and \
all(x ==1 for x in repetition_penalties))
)
def _create_sampling_params():
return SamplingParams(top_k=np.random.randint(1, 10),
top_p=np.random.uniform(0.0, 1.0),
presence_penalty=np.random.uniform(-2.0, 2.0),
repetition_penalty=np.random.uniform(0.0, 2.0),
frequency_penalty=np.random.uniform(-2.0, 2.0),
min_tokens=np.random.randint(1, 10),
stop_token_ids=[
np.random.randint(0, VOCAB_SIZE)
for _ in range(np.random.randint(10))
])
def _construct_cached_request_state(req_id_suffix: int):
prompt_token_ids = [
np.random.randint(0, VOCAB_SIZE)
for _ in range(np.random.randint(0, MAX_PROMPT_SIZE))
]
output_token_ids = [
np.random.randint(0, VOCAB_SIZE)
for _ in range(np.random.randint(0, NUM_OUTPUT_TOKENS))
]
return CachedRequestState(req_id=f"req_id_{req_id_suffix}",
prompt_token_ids=prompt_token_ids,
prompt=None,
sampling_params=_create_sampling_params(),
mm_inputs=[],
mm_positions=[],
block_ids=[],
generator=None,
num_computed_tokens=len(output_token_ids),
output_token_ids=output_token_ids)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32, 64])
def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
"""
Tests the logic for managing sampling metadata in the InputBatch.
This test involves adding a set of requests to the InputBatch,
followed by removing a subset of them. Afterward, the batch is compacted,
and the `make_sampling_metadata` method is invoked on the batch. The
output of `make_sampling_metadata` is then compared against the expected
results to ensure correctness.
"""
input_batch: InputBatch = InputBatch(max_num_reqs=batch_size,
max_model_len=1024,
max_num_blocks_per_req=10,
device=torch.device(device),
pin_memory=is_pin_memory_available(),
vocab_size=1024)
reqs: List[CachedRequestState] = []
req_id_reqs = {}
req_id_output_token_ids = {}
# Add requests
for req_index in range(batch_size):
req: CachedRequestState = _construct_cached_request_state(req_index)
input_batch.add_request(req, req_index)
reqs.append(req)
req_id_reqs[req.req_id] = req
req_id_output_token_ids[req.req_id] = req.output_token_ids
# Remove some requests
req_ids_to_remove, req_indices_to_remove = _remove_requests(
input_batch, batch_size, reqs)
req_ids_retained = set(req_id_reqs.keys()) - req_ids_to_remove
# Compact the input batch
input_batch.condense(req_indices_to_remove)
# Generate the sampling metadata
sampling_metadata = input_batch.make_sampling_metadata(
req_id_output_token_ids, skip_copy=False)
# Create expected output.
expected_sampling_metadata = _construct_expected_sampling_metadata(
reqs,
req_ids_retained,
input_batch.req_id_to_index,
device=torch.device(device))
# Assert the actual and expected output.
assert torch.allclose(expected_sampling_metadata.temperature,
sampling_metadata.temperature)
assert torch.allclose(expected_sampling_metadata.top_p,
sampling_metadata.top_p)
assert torch.allclose(expected_sampling_metadata.top_k,
sampling_metadata.top_k)
assert torch.allclose(expected_sampling_metadata.frequency_penalties,
sampling_metadata.frequency_penalties)
assert torch.allclose(expected_sampling_metadata.presence_penalties,
sampling_metadata.presence_penalties)
assert torch.allclose(expected_sampling_metadata.repetition_penalties,
sampling_metadata.repetition_penalties)
assert torch.allclose(expected_sampling_metadata.prompt_token_ids,
sampling_metadata.prompt_token_ids)
assert (expected_sampling_metadata.output_token_ids ==
sampling_metadata.output_token_ids)
assert (
expected_sampling_metadata.min_tokens == sampling_metadata.min_tokens)
assert (expected_sampling_metadata.stop_token_ids ==
sampling_metadata.stop_token_ids)
assert (expected_sampling_metadata.no_penalties ==
sampling_metadata.no_penalties)
assert (expected_sampling_metadata.no_top_p == sampling_metadata.no_top_p)
assert (expected_sampling_metadata.no_top_k == sampling_metadata.no_top_k)

View File

@ -11,6 +11,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import vllm.envs as envs import vllm.envs as envs
from vllm.model_executor.layers.utils import apply_penalties
from vllm.model_executor.sampling_metadata import (SamplingMetadata, from vllm.model_executor.sampling_metadata import (SamplingMetadata,
SamplingTensors, SamplingTensors,
SequenceGroupToSample) SequenceGroupToSample)
@ -258,11 +259,11 @@ class Sampler(nn.Module):
# Apply presence and frequency penalties. # Apply presence and frequency penalties.
if do_penalties: if do_penalties:
logits = _apply_penalties(logits, sampling_tensors.prompt_tokens, logits = apply_penalties(logits, sampling_tensors.prompt_tokens,
sampling_tensors.output_tokens, sampling_tensors.output_tokens,
sampling_tensors.presence_penalties, sampling_tensors.presence_penalties,
sampling_tensors.frequency_penalties, sampling_tensors.frequency_penalties,
sampling_tensors.repetition_penalties) sampling_tensors.repetition_penalties)
# Use float32 to apply temperature scaling. # Use float32 to apply temperature scaling.
# Use in-place division to avoid creating a new tensor. # Use in-place division to avoid creating a new tensor.
@ -336,23 +337,6 @@ class Sampler(nn.Module):
return self.should_modify_greedy_probs_inplace return self.should_modify_greedy_probs_inplace
def _get_bin_counts_and_mask(
tokens: torch.Tensor,
vocab_size: int,
num_seqs: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Compute the bin counts for the tokens.
# vocab_size + 1 for padding.
bin_counts = torch.zeros((num_seqs, vocab_size + 1),
dtype=torch.long,
device=tokens.device)
bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
bin_counts = bin_counts[:, :vocab_size]
mask = bin_counts > 0
return bin_counts, mask
def _apply_min_tokens_penalty( def _apply_min_tokens_penalty(
logits: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
@ -400,29 +384,6 @@ def _apply_min_tokens_penalty(
return logits return logits
def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
output_tokens_tensor: torch.Tensor,
presence_penalties: torch.Tensor,
frequency_penalties: torch.Tensor,
repetition_penalties: torch.Tensor) -> torch.Tensor:
num_seqs, vocab_size = logits.shape
_, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size,
num_seqs)
output_bin_counts, output_mask = _get_bin_counts_and_mask(
output_tokens_tensor, vocab_size, num_seqs)
repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
repetition_penalties[~(prompt_mask | output_mask)] = 1.0
logits = torch.where(logits > 0, logits / repetition_penalties,
logits * repetition_penalties)
# We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
return logits
def _apply_top_k_top_p( def _apply_top_k_top_p(
logits: torch.Tensor, logits: torch.Tensor,
p: torch.Tensor, p: torch.Tensor,

View File

@ -0,0 +1,57 @@
"""Utility methods for model layers."""
from typing import Tuple
import torch
def get_token_bin_counts_and_mask(
tokens: torch.Tensor,
vocab_size: int,
num_seqs: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Compute the bin counts for the tokens.
# vocab_size + 1 for padding.
bin_counts = torch.zeros((num_seqs, vocab_size + 1),
dtype=torch.long,
device=tokens.device)
bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
bin_counts = bin_counts[:, :vocab_size]
mask = bin_counts > 0
return bin_counts, mask
def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
output_tokens_tensor: torch.Tensor,
presence_penalties: torch.Tensor,
frequency_penalties: torch.Tensor,
repetition_penalties: torch.Tensor) -> torch.Tensor:
"""
Applies penalties in place to the logits tensor
logits : The input logits tensor of shape [num_seqs, vocab_size]
prompt_tokens_tensor: A tensor containing the prompt tokens. The prompts
are padded to the maximum prompt length within the batch using
`vocab_size` as the padding value. The value `vocab_size` is used
for padding because it does not correspond to any valid token ID
in the vocabulary.
output_tokens_tensor: The output tokens tensor.
presence_penalties: The presence penalties of shape (num_seqs, )
frequency_penalties: The frequency penalties of shape (num_seqs, )
repetition_penalties: The repetition penalties of shape (num_seqs, )
"""
num_seqs, vocab_size = logits.shape
_, prompt_mask = get_token_bin_counts_and_mask(prompt_tokens_tensor,
vocab_size, num_seqs)
output_bin_counts, output_mask = get_token_bin_counts_and_mask(
output_tokens_tensor, vocab_size, num_seqs)
repetition_penalties = repetition_penalties.unsqueeze_(dim=1).repeat(
1, vocab_size)
logits[logits > 0] /= torch.where(prompt_mask | output_mask,
repetition_penalties, 1.0)[logits > 0]
logits[logits <= 0] *= torch.where(prompt_mask | output_mask,
repetition_penalties, 1.0)[logits <= 0]
# We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
return logits

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict from typing import Dict, List, Optional, Set
import torch import torch
@ -19,3 +19,13 @@ class SamplingMetadata:
generators: Dict[int, torch.Generator] generators: Dict[int, torch.Generator]
max_num_logprobs: int max_num_logprobs: int
no_penalties: bool
prompt_token_ids: Optional[torch.Tensor]
frequency_penalties: torch.Tensor
presence_penalties: torch.Tensor
repetition_penalties: torch.Tensor
output_token_ids: List[List[int]]
min_tokens: List[int]
stop_token_ids: List[Set[int]]

View File

@ -1,9 +1,11 @@
"""A layer that samples the next tokens from the model's outputs.""" """A layer that samples the next tokens from the model's outputs."""
from typing import Dict from typing import Dict, List, Set, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.model_executor.layers.utils import apply_penalties
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.v1.outputs import SamplerOutput from vllm.v1.outputs import SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
@ -17,9 +19,18 @@ class Sampler(nn.Module):
logits: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> SamplerOutput:
_apply_min_token_penalties(logits, sampling_metadata.output_token_ids,
sampling_metadata.stop_token_ids,
sampling_metadata.min_tokens)
if not sampling_metadata.no_penalties:
assert sampling_metadata.prompt_token_ids is not None
_apply_penalties(logits, sampling_metadata.prompt_token_ids,
sampling_metadata.presence_penalties,
sampling_metadata.frequency_penalties,
sampling_metadata.repetition_penalties,
sampling_metadata.output_token_ids)
logits = self.apply_temperature(logits, sampling_metadata.temperature) logits = self.apply_temperature(logits, sampling_metadata.temperature)
logits = self.apply_top_k_top_p(logits, sampling_metadata) logits = self.apply_top_k_top_p(logits, sampling_metadata)
probs = self.get_probs(logits) probs = self.get_probs(logits)
sampled = self.sample(probs, sampling_metadata) sampled = self.sample(probs, sampling_metadata)
# Use int32 to reduce the tensor size. # Use int32 to reduce the tensor size.
@ -157,3 +168,53 @@ def _apply_top_k_top_p(
# Re-sort the probabilities. # Re-sort the probabilities.
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
return logits return logits
def _apply_min_token_penalties(logits: torch.Tensor,
output_token_ids: List[List[int]],
stop_token_ids: List[Set[int]],
min_tokens: List[int]):
"""
Applies minimum token penalty by setting the logits of the stop tokens
to -inf.
"""
min_tokens_logits_to_penalize: List[Tuple[int, int]] = []
for index, min_token in enumerate(min_tokens):
if (len(output_token_ids[index]) < min_token):
for stop_token_id in stop_token_ids[index]:
min_tokens_logits_to_penalize.append((index, stop_token_id))
if min_tokens_logits_to_penalize:
logits[tuple(zip(*min_tokens_logits_to_penalize))] = -float("inf")
def _apply_penalties(logits: torch.Tensor, prompt_token_ids: torch.Tensor,
presence_penalties: torch.Tensor,
frequency_penalties: torch.Tensor,
repetition_penalties: torch.Tensor,
output_token_ids: List[List[int]]):
"""
Applies presence, frequency and repetition penalties to the logits.
"""
_, vocab_size = logits.shape
output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size,
logits.device)
return apply_penalties(logits, prompt_token_ids, output_tokens_t,
presence_penalties, frequency_penalties,
repetition_penalties)
def _convert_to_tensors(output_token_ids: List[List[int]], vocab_size: int,
device: torch.device) -> torch.Tensor:
"""
Convert the different list data structures to tensors.
"""
output_tokens_tensor = make_tensor_with_pad(
output_token_ids,
# Use the value of vocab_size as a pad since we don't have a
# token_id of this value.
pad=vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=is_pin_memory_available(),
)
return output_tokens_tensor.to(device, non_blocking=True)

View File

@ -43,12 +43,14 @@ class InputBatch:
max_num_blocks_per_req: int, max_num_blocks_per_req: int,
device: torch.device, device: torch.device,
pin_memory: bool, pin_memory: bool,
vocab_size: int,
): ):
self.max_num_reqs = max_num_reqs self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len self.max_model_len = max_model_len
self.max_num_blocks_per_req = max_num_blocks_per_req self.max_num_blocks_per_req = max_num_blocks_per_req
self.device = device self.device = device
self.pin_memory = pin_memory self.pin_memory = pin_memory
self.vocab_size = vocab_size
self.req_ids: List[Optional[str]] = [None] * max_num_reqs self.req_ids: List[Optional[str]] = [None] * max_num_reqs
self.req_id_to_index: Dict[str, int] = {} self.req_id_to_index: Dict[str, int] = {}
@ -63,6 +65,7 @@ class InputBatch:
) )
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy() self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32)
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
# Attention-related. # Attention-related.
self.block_table = torch.zeros( self.block_table = torch.zeros(
@ -110,6 +113,50 @@ class InputBatch:
self.top_k_cpu = self.top_k_cpu_tensor.numpy() self.top_k_cpu = self.top_k_cpu_tensor.numpy()
self.top_k_reqs: Set[str] = set() self.top_k_reqs: Set[str] = set()
# Frequency penalty related data structures
self.frequency_penalties = torch.empty((max_num_reqs, ),
dtype=torch.float,
device=device)
self.frequency_penalties_cpu_tensor = torch.empty(
(max_num_reqs, ),
dtype=torch.float,
device="cpu",
pin_memory=pin_memory)
self.frequency_penalties_cpu = \
self.frequency_penalties_cpu_tensor.numpy()
self.frequency_penalties_reqs: Set[str] = set()
# Presence penalty related data structures
self.presence_penalties = torch.empty((max_num_reqs, ),
dtype=torch.float,
device=device)
self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ),
dtype=torch.float,
device="cpu",
pin_memory=pin_memory)
self.presence_penalties_cpu = \
self.presence_penalties_cpu_tensor.numpy()
self.presence_penalties_reqs: Set[str] = set()
# Repetition penalty related data structures
self.repetition_penalties = torch.empty((max_num_reqs, ),
dtype=torch.float,
device=device)
self.repetition_penalties_cpu_tensor = torch.empty(
(max_num_reqs, ),
dtype=torch.float,
device="cpu",
pin_memory=pin_memory)
self.repetition_penalties_cpu = \
self.repetition_penalties_cpu_tensor.numpy()
self.repetition_penalties_reqs: Set[str] = set()
self.min_tokens: List[int] = [0] * max_num_reqs
self.stop_token_ids: List[Set[int]] = [
set() for _ in range(max_num_reqs)
]
self.prompt_token_ids: Optional[torch.Tensor] = None
# req_index -> generator # req_index -> generator
# NOTE(woosuk): The indices of the requests that do not have their own # NOTE(woosuk): The indices of the requests that do not have their own
# generator should not be included in the dictionary. # generator should not be included in the dictionary.
@ -133,6 +180,7 @@ class InputBatch:
# Copy the prompt token ids and output token ids. # Copy the prompt token ids and output token ids.
num_prompt_tokens = len(request.prompt_token_ids) num_prompt_tokens = len(request.prompt_token_ids)
self.num_prompt_tokens[req_index] = num_prompt_tokens
self.token_ids_cpu[ self.token_ids_cpu[
req_index, :num_prompt_tokens] = request.prompt_token_ids req_index, :num_prompt_tokens] = request.prompt_token_ids
start_idx = num_prompt_tokens start_idx = num_prompt_tokens
@ -157,6 +205,20 @@ class InputBatch:
self.top_k_cpu[req_index] = sampling_params.top_k self.top_k_cpu[req_index] = sampling_params.top_k
if sampling_params.top_k > 0: if sampling_params.top_k > 0:
self.top_k_reqs.add(req_id) self.top_k_reqs.add(req_id)
self.frequency_penalties_cpu[req_index] = \
sampling_params.frequency_penalty
if sampling_params.frequency_penalty != 0.0:
self.frequency_penalties_reqs.add(req_id)
self.presence_penalties_cpu[req_index] = \
sampling_params.presence_penalty
if sampling_params.presence_penalty != 0.0:
self.presence_penalties_reqs.add(req_id)
self.repetition_penalties_cpu[req_index] = \
sampling_params.repetition_penalty
if sampling_params.repetition_penalty != 1.0:
self.repetition_penalties_reqs.add(req_id)
self.min_tokens[req_index] = sampling_params.min_tokens
self.stop_token_ids[req_index] = sampling_params.all_stop_token_ids
# NOTE(woosuk): self.generators should not include the requests that # NOTE(woosuk): self.generators should not include the requests that
# do not have their own generator. # do not have their own generator.
@ -179,6 +241,9 @@ class InputBatch:
self.random_reqs.discard(req_id) self.random_reqs.discard(req_id)
self.top_p_reqs.discard(req_id) self.top_p_reqs.discard(req_id)
self.top_k_reqs.discard(req_id) self.top_k_reqs.discard(req_id)
self.frequency_penalties_reqs.discard(req_id)
self.presence_penalties_reqs.discard(req_id)
self.repetition_penalties_reqs.discard(req_id)
self.generators.pop(req_index, None) self.generators.pop(req_index, None)
self.num_logprobs.pop(req_id, None) self.num_logprobs.pop(req_id, None)
self.prompt_logprob_reqs.discard(req_id) self.prompt_logprob_reqs.discard(req_id)
@ -191,6 +256,9 @@ class InputBatch:
self.random_reqs.clear() self.random_reqs.clear()
self.top_p_reqs.clear() self.top_p_reqs.clear()
self.top_k_reqs.clear() self.top_k_reqs.clear()
self.frequency_penalties_reqs.clear()
self.presence_penalties_reqs.clear()
self.repetition_penalties_reqs.clear()
self.generators.clear() self.generators.clear()
self.num_logprobs.clear() self.num_logprobs.clear()
self.prompt_logprob_reqs.clear() self.prompt_logprob_reqs.clear()
@ -224,6 +292,8 @@ class InputBatch:
# block_table_cpu. # block_table_cpu.
self.token_ids_cpu[empty_index] = self.token_ids_cpu[ self.token_ids_cpu[empty_index] = self.token_ids_cpu[
last_req_index] last_req_index]
self.num_prompt_tokens[empty_index] = \
self.num_prompt_tokens[last_req_index]
self.num_computed_tokens_cpu[ self.num_computed_tokens_cpu[
empty_index] = self.num_computed_tokens_cpu[last_req_index] empty_index] = self.num_computed_tokens_cpu[last_req_index]
self.block_table_cpu[empty_index] = self.block_table_cpu[ self.block_table_cpu[empty_index] = self.block_table_cpu[
@ -232,6 +302,15 @@ class InputBatch:
last_req_index] last_req_index]
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
self.frequency_penalties_cpu[empty_index] = \
self.frequency_penalties_cpu[last_req_index]
self.presence_penalties_cpu[empty_index] = \
self.presence_penalties_cpu[last_req_index]
self.repetition_penalties_cpu[empty_index] = \
self.repetition_penalties_cpu[last_req_index]
self.min_tokens[empty_index] = self.min_tokens[last_req_index]
self.stop_token_ids[empty_index] = \
self.stop_token_ids[last_req_index]
generator = self.generators.pop(last_req_index, None) generator = self.generators.pop(last_req_index, None)
if generator is not None: if generator is not None:
self.generators[empty_index] = generator self.generators[empty_index] = generator
@ -241,6 +320,7 @@ class InputBatch:
def make_sampling_metadata( def make_sampling_metadata(
self, self,
req_id_output_token_ids: Dict[str, List[int]],
skip_copy: bool = False, skip_copy: bool = False,
) -> SamplingMetadata: ) -> SamplingMetadata:
if not skip_copy: if not skip_copy:
@ -250,6 +330,37 @@ class InputBatch:
self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True) self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True)
self.top_k[:self.num_reqs].copy_( self.top_k[:self.num_reqs].copy_(
self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True) self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True)
if not self.no_penalties:
# Since syncing these tensors is expensive only copy them
# if necessary i.e. if there are requests which require
# penalties to be applied during sampling.
self.frequency_penalties[:self.num_reqs].copy_(
self.frequency_penalties_cpu_tensor[:self.num_reqs],
non_blocking=True)
self.presence_penalties[:self.num_reqs].copy_(
self.presence_penalties_cpu_tensor[:self.num_reqs],
non_blocking=True)
self.repetition_penalties[:self.num_reqs].copy_(
self.repetition_penalties_cpu_tensor[:self.num_reqs],
non_blocking=True)
# The prompt tokens are used only for applying penalties during
# the sampling process. Hence copy these tensors only when
# there are requests which need penalties to be applied.
self.prompt_token_ids = self._make_prompt_token_ids_tensor()
output_token_ids: List[List[int]] = []
for req_id in self.req_ids[:self.num_reqs]:
assert req_id is not None
# Currently we create a tensor for output_token_ids from scratch
# at each step. However, for the penalties computation what we
# need is stats about the token ids present in the output. This
# stats can be maintained incrementally instead of computing it
# from scratch at each step.
# TODO - Replace this with incremental update to output token
# statistics.
output_token_ids.append(req_id_output_token_ids[req_id])
return SamplingMetadata( return SamplingMetadata(
temperature=self.temperature[:self.num_reqs], temperature=self.temperature[:self.num_reqs],
all_greedy=self.all_greedy, all_greedy=self.all_greedy,
@ -260,8 +371,33 @@ class InputBatch:
no_top_k=self.no_top_k, no_top_k=self.no_top_k,
generators=self.generators, generators=self.generators,
max_num_logprobs=self.max_num_logprobs, max_num_logprobs=self.max_num_logprobs,
prompt_token_ids=self.prompt_token_ids,
frequency_penalties=self.frequency_penalties[:self.num_reqs],
presence_penalties=self.presence_penalties[:self.num_reqs],
repetition_penalties=self.repetition_penalties[:self.num_reqs],
output_token_ids=output_token_ids,
min_tokens=self.min_tokens[:self.num_reqs],
stop_token_ids=self.stop_token_ids[:self.num_reqs],
no_penalties=self.no_penalties,
) )
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
prompt_token_ids_cpu_tensor = torch.empty(
(self.num_reqs, max_prompt_len),
device="cpu",
dtype=torch.int64,
pin_memory=self.pin_memory)
prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
prompt_token_ids[:] = (
self.token_ids_cpu[:self.num_reqs, :max_prompt_len])
# Use the value of vocab_size as a pad since we don't have a
# token_id of this value.
for i in range(self.num_reqs):
prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size
return prompt_token_ids_cpu_tensor.to(device=self.device,
non_blocking=True)
@property @property
def num_reqs(self) -> int: def num_reqs(self) -> int:
return len(self.req_id_to_index) return len(self.req_id_to_index)
@ -282,6 +418,12 @@ class InputBatch:
def no_top_k(self) -> bool: def no_top_k(self) -> bool:
return len(self.top_k_reqs) == 0 return len(self.top_k_reqs) == 0
@property
def no_penalties(self) -> bool:
return (len(self.presence_penalties_reqs) == 0
and len(self.frequency_penalties_reqs) == 0
and len(self.repetition_penalties_reqs) == 0)
@property @property
def max_num_logprobs(self) -> int: def max_num_logprobs(self) -> int:
return max(self.num_logprobs.values()) if self.num_logprobs else 0 return max(self.num_logprobs.values()) if self.num_logprobs else 0

View File

@ -105,6 +105,7 @@ class GPUModelRunner:
max_num_blocks_per_req=self.max_num_blocks_per_req, max_num_blocks_per_req=self.max_num_blocks_per_req,
device=self.device, device=self.device,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
vocab_size=model_config.get_vocab_size(),
) )
self.use_cuda_graph = (self.vllm_config.compilation_config.level self.use_cuda_graph = (self.vllm_config.compilation_config.level
@ -383,7 +384,12 @@ class GPUModelRunner:
or scheduler_output.scheduled_resumed_reqs): or scheduler_output.scheduled_resumed_reqs):
skip_copy = False skip_copy = False
# Create the sampling metadata. # Create the sampling metadata.
sampling_metadata = self.input_batch.make_sampling_metadata(skip_copy) req_id_output_token_ids: Dict[str, List[int]] = \
{req_id: req.output_token_ids \
for req_id, req in self.requests.items()}
sampling_metadata = self.input_batch.make_sampling_metadata(
req_id_output_token_ids, skip_copy)
return sampling_metadata return sampling_metadata
def _execute_encoder(self, scheduler_output: "SchedulerOutput"): def _execute_encoder(self, scheduler_output: "SchedulerOutput"):