[Speculative Decoding 1/2 ] Add typical acceptance sampling as one of the sampling techniques in the verifier (#5131)

This commit is contained in:
sroy745 2024-06-17 19:29:09 -07:00 committed by GitHub
parent 26e1188e51
commit fa9e385229
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 866 additions and 164 deletions

View File

@ -0,0 +1,464 @@
"""Tests for rejection sampling."""
import pytest
import torch
from vllm.model_executor.layers.typical_acceptance_sampler import (
TypicalAcceptanceSampler)
from vllm.model_executor.utils import set_random_seed
CUDA_DEVICES = [f"cuda:{i}" for i in range(1)]
def get_zero_temperature_prob_dist(batch_size, k, vocab_size):
"""
Generates a fake temperature zero probability distribution.
Returns:
1. A fake temperature zero probability distribution of shape
[batch_size, k, vocab_size]
2. Tensor of shape [batch_size, k] containing the token ids
of the probability 1.0 tokens at each position.
"""
# Simulate temperature 0 probability distribution for target probabilities
# and create target probabilities such that only 1 token id has
# probability 1.0
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
probs = torch.rand(batch_size, k, vocab_size)
_, zero_temperature_token_ids = torch.max(probs, dim=-1)
# set the probability of the tokens with ids in zero_temperature_token_ids
# to 1 and the rest to 0.
target_probs = torch.zeros_like(probs).scatter_(
-1, zero_temperature_token_ids.unsqueeze(-1), 1.0)
return target_probs, zero_temperature_token_ids
def get_draft_token_ids(batch_size: int, k: int, vocab_size: int,
token_ids_to_exclude: torch.Tensor):
"""
Returns a tensor of shape [batch_size, k] of fake draft token ids
drawn randomly from a vocab of size vocab_size. We however ensure
that token_ids from token_ids_to_exclude are excluded at the
corresponding positions.
"""
draft_token_ids = torch.empty(batch_size, k, dtype=torch.long)
for i in range(batch_size):
for j in range(k):
# Generate a random token ID excluding token_ids_to_exclude[i, j]
while True:
token_id = torch.randint(0, vocab_size, (1, )).item()
if token_id != token_ids_to_exclude[i, j]:
draft_token_ids[i, j] = token_id
break
return draft_token_ids
@pytest.mark.parametrize("k", list(range(1, 6)))
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
@pytest.mark.parametrize("batch_size", list(range(1, 32)))
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
device: str):
"""
Tests that the TypicalAcceptancSampler forward succeeds for
different combinations of k, vocab_size, batch_size and num devices.
"""
torch.set_default_device(device)
typical_acceptance_sampler = TypicalAcceptanceSampler()
typical_acceptance_sampler.init_gpu_tensors(rank=0)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
draft_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64)
# Verify that sampling succeeds for all cases.
typical_acceptance_sampler(target_probs, bonus_token_ids, draft_token_ids)
@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"])
@pytest.mark.parametrize("which_token_ids",
["bonus_token_ids", "draft_token_ids"])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
which_token_ids: str, device: str):
"""
Tests that we throw an exception of the token ids fall outside
the bound of the provided vocabulary.
"""
k = 3
batch_size = 5
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = TypicalAcceptanceSampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
draft_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64)
# Verify that appropriate exceptions are thrown for out
# of bound vocabs.
oob_token_ids = None
if which_token_ids == "bonus_token_ids":
oob_token_ids = bonus_token_ids
elif which_token_ids == "draft_token_ids":
oob_token_ids = draft_token_ids
else:
raise AssertionError()
if above_or_below_vocab_range == "above":
rogue_token_id = vocab_size + 1
elif above_or_below_vocab_range == "below":
rogue_token_id = -1
else:
raise AssertionError()
oob_token_ids[0][0] = rogue_token_id
with pytest.raises(AssertionError):
typical_acceptance_sampler(target_probs, bonus_token_ids,
draft_token_ids)
@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_uniform_target_distribution_accepts_all_tokens(
seed: int, disable_bonus_tokens: bool, device: str):
"""
Test the TypicalAcceptanceSampler with a uniform target probability
distribution.
This test verifies that when provided with a uniform target probability
distribution, the TypicalAcceptanceSampler accepts all draft tokens. The
entropy of the uniform target distribution being high should lead to all
draft tokens being accepted. The test also ensures that the behavior
regarding bonus tokens is consistent with the `disable_bonus_tokens`
flag.
"""
set_random_seed(seed)
k = 3
batch_size = 5
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = TypicalAcceptanceSampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
draft_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
output_token_ids = typical_acceptance_sampler(target_probs,
bonus_token_ids,
draft_token_ids)
# We are using a uniform target probability distribution.
# For a uniform distribution the entropy is very high and it
# should lead to all draft tokens being accepted. Verify that.
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
if disable_bonus_tokens:
assert torch.all(output_token_ids[:, -1] == -1)
else:
assert torch.all(output_token_ids[:, -1] == bonus_token_ids.squeeze())
assert torch.all(output_token_ids[:, :k] == draft_token_ids)
@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_temperature_zero_target_distribution(seed: int,
disable_bonus_tokens: bool,
device: str):
"""
Test the TypicalAcceptanceSampler with a zero-temperature target
probability distribution.
This test verifies that when using a zero-temperature target probability
distribution, where only one token has a probability of 1.0, the
TypicalAcceptanceSampler correctly rejects all draft tokens that do not
match this probability. Additionally, it ensures that when all draft
tokens are rejected, the sampler falls back to greedy sampling to select a
single token from the target distribution.
"""
set_random_seed(seed)
k = 3
batch_size = 5
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = TypicalAcceptanceSampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
# Simulate temperature 0 probability distribution for target probabilities
# and create target probabilities such that only 1 token id has
# probability 1.0
target_probs, zero_temperature_token_ids = get_zero_temperature_prob_dist(
batch_size, k, vocab_size)
# Populate draft_token_ids such that they exclude the token_ids
# with probability = 1.0
draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size,
zero_temperature_token_ids)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
# The target probaility distribution is a temperature zero distribution
# with zero entroy. Since our draft token ids don't match the probability
# 1.0 tokens in the target distribution we will reject all of them and
# fallback to the greedy sampling for selecting 1 token for each sequence.
# Verify the same.
output_token_ids = typical_acceptance_sampler(target_probs,
bonus_token_ids,
draft_token_ids)
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, -1] == -1)
assert torch.all(output_token_ids[:, 0] == zero_temperature_token_ids[:,
0])
@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
device: str):
"""
Test the TypicalAcceptanceSampler with a mixed target probability
distribution.
This test ensures that the TypicalAcceptanceSampler handles a mixed
target probability distribution correctly. Specifically, it uses a
zero-temperature distribution for some sequences and a uniform
distribution for others. The test verifies that:
- For sequences with a zero-temperature distribution, only the token
with a probability of 1.0 is accepted, and all other tokens are rejected.
- For sequences with a uniform distribution, all draft tokens are
accepted.
- When `disable_bonus_tokens` is False, the bonus tokens are also accepted
for sequences with a uniform distribution.
"""
set_random_seed(seed)
k = 3
batch_size = 4
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = TypicalAcceptanceSampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
# For sequences 0 and 2 set the distribution to a temperature
# zero distribution. For sequences 1 and 3 set it to a uniform
# distribution.
target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist(
batch_size, k, vocab_size))
draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size,
zero_temperature_token_ids)
uniform_probs = torch.rand(2, k, vocab_size, dtype=torch.float32)
target_probs[[1, 3]] = uniform_probs
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
output_token_ids = typical_acceptance_sampler(target_probs,
bonus_token_ids,
draft_token_ids)
# verify the shape of output_token_ids
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
# For sequences 0 and 2 verify that only 1 token is accepted
# which is the token with probability 1.0 in the target distribution
# at position 0.
assert torch.all(output_token_ids[[0, 2], 1:] == -1)
assert (torch.all(output_token_ids[[0, 2],
0] == zero_temperature_token_ids[[0, 2],
0]))
# For sequences 1 and 3 verify that all tokens are accepted since the
# target probability distribution is uniform. In addition verify that
# if disable_bonus_tokens is false then we also accept the bonus tokens.
assert torch.all(
output_token_ids[[1, 3], :-1] == draft_token_ids[[1, 3], :])
if disable_bonus_tokens:
assert torch.all(output_token_ids[[1, 3], -1] == -1)
else:
assert torch.all(output_token_ids[[1, 3], -1] != -1)
@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
device: str):
"""
Test the TypicalAcceptanceSampler's behavior when only a subset of draft
tokens should be accepted.
This test verifies that the TypicalAcceptanceSampler correctly accepts or
rejects draft tokens based on a zero-temperature target probability
distribution. Specifically, it ensures that:
- When all draft tokens match tokens with a probability of 1.0 in the
target distribution, all draft tokens are accepted.
- When only some draft tokens match tokens with a probability of 1.0 in
the target distribution, only those matching tokens are accepted, and the
rest are rejected.
"""
set_random_seed(seed)
k = 5
batch_size = 1
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = TypicalAcceptanceSampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
# Create a temperature zero target probability distribution and ensure
# all draft token ids correspond to the tokens with 1.0 probability.
# Verify that all of them are accepted.
target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist(
batch_size, k, vocab_size))
draft_token_ids = zero_temperature_token_ids
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
output_token_ids = typical_acceptance_sampler(target_probs,
bonus_token_ids,
draft_token_ids)
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
if disable_bonus_tokens:
assert torch.all(output_token_ids[:, -1] == -1)
else:
assert torch.all(output_token_ids[:, -1] == bonus_token_ids)
# Next only keep the first 2 draft tokens same as the zero temperature
# tokens. For the remaining 3 choose some other tokens. In the
# response we will expect the first 2 tokens to be the same as the
# draft tokens and the rest as -1
draft_token_ids_to_replace = get_draft_token_ids(
batch_size, k, vocab_size, zero_temperature_token_ids)
draft_token_ids = torch.cat(
(draft_token_ids[:, :2], draft_token_ids_to_replace[:, -3:]), dim=1)
output_token_ids = typical_acceptance_sampler(target_probs,
bonus_token_ids,
draft_token_ids)
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, :2] == draft_token_ids[:, :2])
assert torch.all(output_token_ids[:, -3:] == -1)
@pytest.mark.parametrize("seed", list(range(1)))
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_accept_tokens_set_non_default_posteriors(seed: int,
disable_bonus_tokens: bool,
device: str):
"""
Test the TypicalAcceptanceSampler with custom posterior thresholds and
alpha values. This test verifies that by modifying the posterior
thresholds and alpha values we can change the acceptance behavior of the
sampler.
"""
set_random_seed(seed)
k = 5
batch_size = 1
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = TypicalAcceptanceSampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
# Simulate temperature 0 probability distribution for target
# probabilities and create target probabilities such that only 1 token
# id has probability 1.0 and others have a very low probability of
# 0.00001. Populate draft_token_ids such that they exclude the token_ids
# with probability = 1.0. Without any changes to the posterior thresholds
# none of the draft tokens are accepted.
target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist(
batch_size, k, vocab_size))
target_probs[target_probs == 0] = 0.00001
draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size,
zero_temperature_token_ids)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
output_token_ids = typical_acceptance_sampler(target_probs,
bonus_token_ids,
draft_token_ids)
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, 1:-1] == -1)
# Change the posterior threshold values to 0.0 so that we will
# now accept even draft tokens with very low probability in the
# target distribution. Simulate and verify the same.
typical_acceptance_sampler = TypicalAcceptanceSampler(
strict_mode=True,
disable_bonus_tokens=disable_bonus_tokens,
posterior_threshold=0.0,
posterior_alpha=0.0)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
output_token_ids = typical_acceptance_sampler(target_probs,
bonus_token_ids,
draft_token_ids)
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
if disable_bonus_tokens:
assert torch.all(output_token_ids[:, -1] == -1)
else:
assert torch.all(output_token_ids[:, -1] == bonus_token_ids)
@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_replacement_token_ids(seed: int, disable_bonus_tokens: bool,
device: str):
"""
Test the TypicalAcceptanceSampler's method for generating
replacement token IDs.
This test verifies that the `_replacement_token_ids` method of the
TypicalAcceptanceSampler correctly identifies the token IDs to be used
as replacements based on the target probability distribution.
Specifically, it ensures that the method correctly identifies the
tokens with the highest probability for each sequence in the batch.
"""
set_random_seed(seed)
k = 10
batch_size = 5
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = TypicalAcceptanceSampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
expected_replacement_tokens = -torch.ones(
(batch_size, k), dtype=torch.long)
expected_replacement_tokens[:, 0] = torch.argmax(target_probs[:, 0, :],
dim=1)
actual_replacement_tokens = (
typical_acceptance_sampler._replacement_token_ids(target_probs))
assert torch.all(expected_replacement_tokens == actual_replacement_tokens)

View File

@ -1,12 +1,15 @@
from functools import cached_property
from typing import Optional, Tuple
from typing import Tuple
import torch
import torch.jit
import torch.nn as nn
from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeBaseSampler)
class RejectionSampler(nn.Module):
class RejectionSampler(SpecDecodeBaseSampler, nn.Module):
"""Apply modified rejection sampling as described in "Accelerating Large
Language Model Decoding with Speculative Sampling"
https://arxiv.org/pdf/2302.01318.pdf.
@ -22,39 +25,11 @@ class RejectionSampler(nn.Module):
Require when bonus tokens will cause corrupt KV cache for
proposal methods that require KV cache.
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
nontrivial latency.
during sampling. This catches correctness issues but adds
nontrivial latency.
"""
super().__init__()
self._disable_bonus_tokens = disable_bonus_tokens
self._strict_mode = strict_mode
# NOTE: A "bonus token" is accepted iff all proposal tokens are
# accepted. There is always only one possible bonus token. We store this
# value in a variable for readability.
self._num_bonus_tokens = 1
self.num_accepted_tokens: Optional[torch.Tensor] = None
self.num_emitted_tokens: Optional[torch.Tensor] = None
self.num_draft_tokens: int = 0
def init_gpu_tensors(self, rank: int) -> None:
assert self.num_accepted_tokens is None
device = f"cuda:{rank}"
self.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device=device)
self.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device=device)
@property
def probs_dtype(self):
return torch.float32
@property
def token_id_dtype(self):
return torch.int64
SpecDecodeBaseSampler.__init__(self, disable_bonus_tokens, strict_mode)
nn.Module.__init__(self)
def forward(
self,
@ -100,15 +75,8 @@ class RejectionSampler(nn.Module):
# Only perform shape/dtype/device checking in strict mode, as it adds
# overhead.
if self._strict_mode:
self._raise_if_incorrect_shape(target_probs, bonus_token_ids,
self._raise_if_incorrect_input(target_probs, bonus_token_ids,
draft_probs, draft_token_ids)
self._raise_if_incorrect_dtype(target_probs, bonus_token_ids,
draft_probs, draft_token_ids)
self._raise_if_inconsistent_device(target_probs, bonus_token_ids,
draft_probs, draft_token_ids)
self._raise_if_out_of_bounds_vocab(target_probs.shape[-1],
bonus_token_ids,
draft_token_ids)
accepted, recovered_token_ids = self._batch_modified_rejection_sampling(
target_probs,
@ -272,128 +240,6 @@ class RejectionSampler(nn.Module):
"""
return torch.finfo(self.probs_dtype).tiny
def _create_output(
self,
accepted: torch.Tensor, # [batch_size, k]
recovered_token_ids: torch.Tensor, # [batch_size, k]
draft_token_ids: torch.Tensor, # [batch_size, k]
bonus_token_ids: torch.Tensor, # [batch_size]
) -> torch.Tensor:
"""Format output. Returns a matrix of token ids. When
a token is rejected via rejection sampling, all subsequent
token ids are set to -1 for the sequence.
shape = [batch_size, k + num_bonus_tokens]
"""
bonus_token_ids = bonus_token_ids.squeeze()
batch_size, k = recovered_token_ids.shape
# Determine the index of the first False value for each row.
limits = (accepted == 0).max(1).indices
limits[~(accepted == 0).any(1)] = k
# Create masks using the indices.
indices = torch.arange(k, device=accepted.device).unsqueeze(0)
accepted_mask = indices < limits.unsqueeze(1)
after_false_mask = indices == limits.unsqueeze(1)
# Create an extended output tensor
output_with_bonus_tokens = -torch.ones(
(batch_size, k + self._num_bonus_tokens),
dtype=self.token_id_dtype,
device=accepted.device)
output = output_with_bonus_tokens[:, :k]
# Fill in the first k columns of the output tensor using masks and data
# tensors.
torch.where(accepted_mask,
draft_token_ids,
-torch.ones_like(draft_token_ids),
out=output)
# Fill the last column.
# We check output directly as accepted may have True values inconsistent
# with causal acceptance.
output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,
bonus_token_ids, -1)
# We disable bonus tokens because it causes corrupt KV cache for
# proposal methods that require KV cache. We can fix it by "prefilling"
# the bonus token in the proposer. The following issue tracks the fix.
# https://github.com/vllm-project/vllm/issues/4212
if self._disable_bonus_tokens:
output_with_bonus_tokens[:, -1] = -1
# Fill the recovered token ids.
output.mul_(~after_false_mask).add_(
recovered_token_ids.mul(after_false_mask))
self.num_accepted_tokens += accepted.sum()
self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum()
self.num_draft_tokens += batch_size * k
return output_with_bonus_tokens
def _raise_if_incorrect_shape(
self,
target_probs: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
) -> None:
(target_batch_size, num_target_probs,
target_vocab_size) = target_probs.shape
bonus_batch_size, num_bonus_tokens = bonus_token_ids.shape
draft_batch_size, num_draft_probs, draft_vocab_size = draft_probs.shape
draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape
assert draft_batch_size == target_batch_size
assert num_draft_probs == num_target_probs
assert (draft_vocab_size == target_vocab_size
), f"{draft_vocab_size=} {target_vocab_size=}"
assert draft_token_ids_batch_size == draft_batch_size
assert num_draft_token_ids == num_draft_probs
assert bonus_batch_size == target_batch_size
assert num_bonus_tokens == self._num_bonus_tokens
def _raise_if_incorrect_dtype(
self,
target_probs: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
) -> None:
assert all(probs.dtype == self.probs_dtype
for probs in [target_probs, draft_probs])
assert all(token_ids.dtype == self.token_id_dtype
for token_ids in [bonus_token_ids, draft_token_ids])
def _raise_if_inconsistent_device(
self,
target_probs: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
) -> None:
devices = [
t.device for t in
[target_probs, bonus_token_ids, draft_probs, draft_token_ids]
]
assert all([devices[0] == device for device in devices])
def _raise_if_out_of_bounds_vocab(
self,
vocab_size: int,
bonus_token_ids: torch.Tensor,
draft_token_ids: torch.Tensor,
) -> None:
assert torch.all(bonus_token_ids < vocab_size)
assert torch.all(bonus_token_ids >= 0)
assert torch.all(draft_token_ids < vocab_size)
assert torch.all(draft_token_ids >= 0)
# torch.multinomial forces a GPU<->CPU sync.
# Therefore, we use an optimized implementation instead that skips the sync.

View File

@ -0,0 +1,206 @@
from typing import Optional
import torch
class SpecDecodeBaseSampler():
"""Base class for samplers used for Speculative Decoding verification
step.
"""
def __init__(self,
disable_bonus_tokens: bool = True,
strict_mode: bool = False):
"""Base class constructor.
Args:
disable_bonus_tokens: Whether or not to disable the bonus token.
Require when bonus tokens will cause corrupt KV cache for
proposal methods that require KV cache.
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
nontrivial latency.
"""
super().__init__()
self._disable_bonus_tokens = disable_bonus_tokens
self._strict_mode = strict_mode
# NOTE: A "bonus token" is accepted iff all proposal tokens are
# accepted. There is always only one possible bonus token. We store this
# value in a variable for readability.
self._num_bonus_tokens = 1
self.num_accepted_tokens: Optional[torch.Tensor] = None
self.num_emitted_tokens: Optional[torch.Tensor] = None
self.num_draft_tokens: int = 0
def init_gpu_tensors(self, rank: int) -> None:
assert self.num_accepted_tokens is None
device = f"cuda:{rank}"
self.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device=device)
self.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device=device)
@property
def probs_dtype(self):
return torch.float32
@property
def token_id_dtype(self):
return torch.int64
def _create_output(
self,
accepted: torch.Tensor, # [batch_size, k]
substitute_token_ids: torch.Tensor, # [batch_size, k]
draft_token_ids: torch.Tensor, # [batch_size, k]
bonus_token_ids: torch.Tensor, # [batch_size]
) -> torch.Tensor:
"""Format output. Returns a matrix of token ids. When
a token is rejected via sampling, all subsequent token ids are
set to -1 for the sequence.
Args:
accepted: A boolean tensor indicating if the corresponding
draft token in draft_token_ids should be accepted or not.
substitute_token_ids: A tensor of token_ids that can be used
as substitutes for the draft token ids if the proposed token
is rejected.
draft_token_ids: A tensor of token ids speculated by the
draft model.
bonus_token_ids: Token ids to use as the bonus token if
all the draft tokens are accepted.
Returns:
A tensor containing the accepted token ids. The shape of the
tensor is [batch_size, k + num_bonus_tokens]
"""
batch_size, k = substitute_token_ids.shape
bonus_token_ids = bonus_token_ids.squeeze()
# Determine the index of the first False value for each row.
limits = (accepted == 0).max(1).indices
limits[~(accepted == 0).any(1)] = k
# Create masks using the indices.
indices = torch.arange(k, device=accepted.device).unsqueeze(0)
accepted_mask = indices < limits.unsqueeze(1)
after_false_mask = indices == limits.unsqueeze(1)
# Create an extended output tensor
output_with_bonus_tokens = -torch.ones(
(batch_size, k + self._num_bonus_tokens),
dtype=self.token_id_dtype,
device=accepted.device)
output = output_with_bonus_tokens[:, :k]
# Fill in the first k columns of the output tensor using masks and data
# tensors.
output[:, :k] = torch.where(accepted_mask, draft_token_ids,
-torch.ones_like(draft_token_ids))
# Fill the last column.
# We check output directly as accepted may have True values inconsistent
# with causal acceptance.
output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,
bonus_token_ids, -1)
# We disable bonus tokens because it causes corrupt KV cache for
# proposal methods that require KV cache. We can fix it by "prefilling"
# the bonus token in the proposer. The following issue tracks the fix.
# https://github.com/vllm-project/vllm/issues/4212
if self._disable_bonus_tokens:
output_with_bonus_tokens[:, -1] = -1
# Fill the recovered token ids.
output.mul_(~after_false_mask).add_(
substitute_token_ids.mul(after_false_mask))
self.num_accepted_tokens += accepted.sum()
self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum()
self.num_draft_tokens += batch_size * k
return output_with_bonus_tokens
def _raise_if_incorrect_input(
self,
target_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: Optional[torch.Tensor] = None,
) -> None:
self._raise_if_incorrect_shape(target_probs, draft_token_ids,
bonus_token_ids, draft_probs)
self._raise_if_incorrect_dtype(target_probs, draft_token_ids,
bonus_token_ids, draft_probs)
self._raise_if_inconsistent_device(target_probs, draft_token_ids,
bonus_token_ids, draft_probs)
self._raise_if_out_of_bounds_vocab(target_probs.shape[-1],
draft_token_ids, bonus_token_ids)
def _raise_if_incorrect_shape(
self,
target_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: Optional[torch.Tensor] = None,
) -> None:
(target_batch_size, num_target_probs,
target_vocab_size) = target_probs.shape
# validate the shape of draft token ids.
draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape
assert draft_token_ids_batch_size == target_batch_size
assert num_draft_token_ids == num_target_probs
# validate the shape of bonus token ids
bonus_batch_size, num_bonus_tokens = bonus_token_ids.shape
assert bonus_batch_size == target_batch_size
assert num_bonus_tokens == self._num_bonus_tokens
# validate the shape of draft probs if it is set
if draft_probs is not None:
(draft_batch_size, num_draft_probs,
draft_vocab_size) = draft_probs.shape
assert draft_batch_size == target_batch_size
assert num_draft_probs == num_target_probs
assert (draft_vocab_size == target_vocab_size
), f"{draft_vocab_size=} {target_vocab_size=}"
def _raise_if_incorrect_dtype(
self,
target_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: Optional[torch.Tensor] = None,
) -> None:
assert target_probs.dtype == self.probs_dtype
assert draft_token_ids.dtype == self.token_id_dtype
assert bonus_token_ids.dtype == self.token_id_dtype
if draft_probs is not None:
assert draft_probs.dtype == self.probs_dtype
def _raise_if_inconsistent_device(
self,
target_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: Optional[torch.Tensor] = None,
) -> None:
devices = [
t.device for t in
[target_probs, bonus_token_ids, draft_probs, draft_token_ids]
if t is not None
]
assert all([devices[0] == device for device in devices])
def _raise_if_out_of_bounds_vocab(
self,
vocab_size: int,
draft_token_ids: torch.Tensor,
bonus_token_ids: torch.Tensor,
) -> None:
assert torch.all(bonus_token_ids < vocab_size)
assert torch.all(bonus_token_ids >= 0)
assert torch.all(draft_token_ids < vocab_size)
assert torch.all(draft_token_ids >= 0)

View File

@ -0,0 +1,186 @@
import torch
import torch.jit
import torch.nn as nn
from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeBaseSampler)
class TypicalAcceptanceSampler(SpecDecodeBaseSampler, nn.Module):
"""Apply typical acceptance sampling as described in section 3.3.1 in
"MEDUSA: Simple LLM Inference Acceleration Framework with
Multiple Decoding Heads"
https://arxiv.org/pdf/2401.10774
"""
def __init__(
self,
disable_bonus_tokens: bool = False,
strict_mode: bool = False,
posterior_threshold: float = 0.09,
posterior_alpha: float = 0.3,
):
"""Create a Typical Acceptance Sampler.
Args:
disable_bonus_tokens: Whether or not to disable the bonus token.
Require when bonus tokens will cause corrupt KV cache for
proposal methods that require KV cache.
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
nontrivial latency.
posterior_threshold : A threshold value that sets a lower bound
on the posterior probability of a token in target model for it
to be accepted. Default is 0.09
posterior_alpha : A scaling factor for the entropy-based
threshold in typical acceptance sampling. Typically defaults to
sqrt of posterior_threshold and is set to 0.3.
"""
SpecDecodeBaseSampler.__init__(
self,
disable_bonus_tokens=disable_bonus_tokens,
strict_mode=strict_mode)
nn.Module.__init__(self)
self._posterior_threshold = posterior_threshold
self._posterior_alpha = posterior_alpha
def forward(
self,
target_probs: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_token_ids: torch.Tensor,
) -> torch.Tensor:
"""Sample token ids using typical acceptance sampling. This accepts
or rejects tokens proposed by the draft model using the probability
of each token according to the draft and target models.
In the worst case where all draft tokens are rejected, it is guaranteed
one token will be emitted.
In the case where all draft tokens are accepted, the bonus token will be
accepted conditioned on self._disable_bonus_tokens being false.
Args:
target_probs: The probability distribution over token ids given
context according to the target model.
shape = [batch_size, num_speculative_tokens, vocab_size]
bonus_token_ids: The "bonus" token ids that are accepted iff all
speculative tokens in a sequence are accepted.
shape = [batch_size, num_bonus_tokens]
draft_token_ids: The token ids that were sampled from the draft
probabilities.
shape = [batch_size, num_speculative_tokens]
Returns:
output_token_ids: The token ids sampled via rejection sampling,
or -1 if unable to sample a token because the previous token
was rejected.
shape = [batch_size, num_speculative_tokens + num_bonus_tokens]
"""
# Only perform shape/dtype/device checking in strict mode, as it adds
# overhead.
if self._strict_mode:
self._raise_if_incorrect_input(target_probs, draft_token_ids,
bonus_token_ids)
accepted = self._evaluate_accepted_tokens(target_probs,
draft_token_ids)
recovered_token_ids = self._replacement_token_ids(target_probs)
output_token_ids = self._create_output(accepted, recovered_token_ids,
draft_token_ids,
bonus_token_ids)
return output_token_ids
def _evaluate_accepted_tokens(self, target_probs, draft_token_ids):
r"""
Evaluates and returns a mask of accepted tokens based on the
posterior probabilities.
Parameters:
----------
target_probs : torch.Tensor
A tensor of shape (batch_size, k, vocab_size) representing
the probabilities of each token in the vocabulary for each
position in the proposed sequence. This is the distribution
generated by the target model.
draft_token_ids : torch.Tensor
A tensor of shape (batch_size, k) representing the proposed
token ids.
A draft token_id x_{n+k} is accepted if it satisfies the
following condition
.. math::
p_{\text{original}}(x_{n+k} | x_1, x_2, \dots, x_{n+k-1}) >
\min \left( \epsilon, \delta * \exp \left(
-H(p_{\text{original}}(
\cdot | x_1, x_2, \ldots, x_{n+k-1})) \right) \right)
where :math:`p_{\text{original}}` corresponds to target_probs
and :math:`\epsilon` and :math:`\delta` correspond to hyperparameters
specified using self._posterior_threshold and self._posterior_alpha
This method computes the posterior probabilities for the given
draft token ids based on the provided target probabilities. It
calculates the entropy of the posterior distribution and determines
a dynamic threshold for each token position using the provided
posterior_threshold and posterior_alpha values. The method then
returns a boolean mask indicating which tokens can be accepted.
Returns:
-------
torch.Tensor
A boolean tensor of shape (batch_size, k) where each element
indicates whether the corresponding draft token has been accepted
or rejected. True indicates acceptance and false indicates
rejection.
"""
device = target_probs.device
candidates_prob = torch.gather(
target_probs, dim=-1,
index=draft_token_ids.unsqueeze(-1)).squeeze(-1)
# A small constant added to prevent computing the logarithm of zero,
# which can lead to undefined values.
epsilon = 1e-5
posterior_entropy = -torch.sum(
target_probs * torch.log(target_probs + epsilon), dim=-1)
threshold = torch.minimum(
torch.ones_like(posterior_entropy, device=device) *
self._posterior_threshold,
torch.exp(-posterior_entropy) * self._posterior_alpha,
)
accepted_mask = candidates_prob > threshold
return accepted_mask
def _replacement_token_ids(self, target_probs):
"""
Generate one replacement token ID for each sequence based on target
probabilities. The replacement token is used as the fallback option
if typical acceptance sampling does not accept any draft tokens for
that particular sequence.
This method computes the token IDs to be replaced by selecting the
token with the highest probability for each sequence in the first
position. The rest of the output is filled with -1.
Parameters
----------
target_probs : torch.Tensor
A tensor of shape (batch_size, k, vocab_size) containing
the target probability distribution
Returns
-------
torch.Tensor
A tensor of shape (batch_size, k) with the replacement
token IDs. Only the first column is set, and the rest of the
columns are filled with -1.
"""
max_indices = torch.argmax(target_probs[:, 0, :], dim=1)
output = -torch.ones((target_probs.shape[0], target_probs.shape[1]),
dtype=self.token_id_dtype,
device=target_probs.device)
output[:, 0] = max_indices
return output