mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-01 07:28:41 +08:00
[Speculative Decoding 1/2 ] Add typical acceptance sampling as one of the sampling techniques in the verifier (#5131)
This commit is contained in:
parent
26e1188e51
commit
fa9e385229
464
tests/samplers/test_typical_acceptance_sampler.py
Normal file
464
tests/samplers/test_typical_acceptance_sampler.py
Normal file
@ -0,0 +1,464 @@
|
||||
"""Tests for rejection sampling."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.typical_acceptance_sampler import (
|
||||
TypicalAcceptanceSampler)
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
|
||||
CUDA_DEVICES = [f"cuda:{i}" for i in range(1)]
|
||||
|
||||
|
||||
def get_zero_temperature_prob_dist(batch_size, k, vocab_size):
|
||||
"""
|
||||
Generates a fake temperature zero probability distribution.
|
||||
Returns:
|
||||
1. A fake temperature zero probability distribution of shape
|
||||
[batch_size, k, vocab_size]
|
||||
2. Tensor of shape [batch_size, k] containing the token ids
|
||||
of the probability 1.0 tokens at each position.
|
||||
"""
|
||||
# Simulate temperature 0 probability distribution for target probabilities
|
||||
# and create target probabilities such that only 1 token id has
|
||||
# probability 1.0
|
||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
probs = torch.rand(batch_size, k, vocab_size)
|
||||
_, zero_temperature_token_ids = torch.max(probs, dim=-1)
|
||||
# set the probability of the tokens with ids in zero_temperature_token_ids
|
||||
# to 1 and the rest to 0.
|
||||
target_probs = torch.zeros_like(probs).scatter_(
|
||||
-1, zero_temperature_token_ids.unsqueeze(-1), 1.0)
|
||||
return target_probs, zero_temperature_token_ids
|
||||
|
||||
|
||||
def get_draft_token_ids(batch_size: int, k: int, vocab_size: int,
|
||||
token_ids_to_exclude: torch.Tensor):
|
||||
"""
|
||||
Returns a tensor of shape [batch_size, k] of fake draft token ids
|
||||
drawn randomly from a vocab of size vocab_size. We however ensure
|
||||
that token_ids from token_ids_to_exclude are excluded at the
|
||||
corresponding positions.
|
||||
"""
|
||||
draft_token_ids = torch.empty(batch_size, k, dtype=torch.long)
|
||||
for i in range(batch_size):
|
||||
for j in range(k):
|
||||
# Generate a random token ID excluding token_ids_to_exclude[i, j]
|
||||
while True:
|
||||
token_id = torch.randint(0, vocab_size, (1, )).item()
|
||||
if token_id != token_ids_to_exclude[i, j]:
|
||||
draft_token_ids[i, j] = token_id
|
||||
break
|
||||
return draft_token_ids
|
||||
|
||||
|
||||
@pytest.mark.parametrize("k", list(range(1, 6)))
|
||||
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
|
||||
@pytest.mark.parametrize("batch_size", list(range(1, 32)))
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
|
||||
device: str):
|
||||
"""
|
||||
Tests that the TypicalAcceptancSampler forward succeeds for
|
||||
different combinations of k, vocab_size, batch_size and num devices.
|
||||
"""
|
||||
torch.set_default_device(device)
|
||||
typical_acceptance_sampler = TypicalAcceptanceSampler()
|
||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
draft_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64)
|
||||
# Verify that sampling succeeds for all cases.
|
||||
typical_acceptance_sampler(target_probs, bonus_token_ids, draft_token_ids)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"])
|
||||
@pytest.mark.parametrize("which_token_ids",
|
||||
["bonus_token_ids", "draft_token_ids"])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
|
||||
which_token_ids: str, device: str):
|
||||
"""
|
||||
Tests that we throw an exception of the token ids fall outside
|
||||
the bound of the provided vocabulary.
|
||||
"""
|
||||
k = 3
|
||||
batch_size = 5
|
||||
vocab_size = 30_000
|
||||
torch.set_default_device(device)
|
||||
typical_acceptance_sampler = TypicalAcceptanceSampler(strict_mode=True)
|
||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
draft_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64)
|
||||
# Verify that appropriate exceptions are thrown for out
|
||||
# of bound vocabs.
|
||||
oob_token_ids = None
|
||||
if which_token_ids == "bonus_token_ids":
|
||||
oob_token_ids = bonus_token_ids
|
||||
elif which_token_ids == "draft_token_ids":
|
||||
oob_token_ids = draft_token_ids
|
||||
else:
|
||||
raise AssertionError()
|
||||
|
||||
if above_or_below_vocab_range == "above":
|
||||
rogue_token_id = vocab_size + 1
|
||||
elif above_or_below_vocab_range == "below":
|
||||
rogue_token_id = -1
|
||||
else:
|
||||
raise AssertionError()
|
||||
|
||||
oob_token_ids[0][0] = rogue_token_id
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
typical_acceptance_sampler(target_probs, bonus_token_ids,
|
||||
draft_token_ids)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", list(range(10)))
|
||||
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_uniform_target_distribution_accepts_all_tokens(
|
||||
seed: int, disable_bonus_tokens: bool, device: str):
|
||||
"""
|
||||
Test the TypicalAcceptanceSampler with a uniform target probability
|
||||
distribution.
|
||||
|
||||
This test verifies that when provided with a uniform target probability
|
||||
distribution, the TypicalAcceptanceSampler accepts all draft tokens. The
|
||||
entropy of the uniform target distribution being high should lead to all
|
||||
draft tokens being accepted. The test also ensures that the behavior
|
||||
regarding bonus tokens is consistent with the `disable_bonus_tokens`
|
||||
flag.
|
||||
"""
|
||||
set_random_seed(seed)
|
||||
k = 3
|
||||
batch_size = 5
|
||||
vocab_size = 30_000
|
||||
torch.set_default_device(device)
|
||||
typical_acceptance_sampler = TypicalAcceptanceSampler(
|
||||
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
draft_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
output_token_ids = typical_acceptance_sampler(target_probs,
|
||||
bonus_token_ids,
|
||||
draft_token_ids)
|
||||
# We are using a uniform target probability distribution.
|
||||
# For a uniform distribution the entropy is very high and it
|
||||
# should lead to all draft tokens being accepted. Verify that.
|
||||
assert output_token_ids.shape[0] == batch_size
|
||||
assert output_token_ids.shape[1] == (k + 1)
|
||||
if disable_bonus_tokens:
|
||||
assert torch.all(output_token_ids[:, -1] == -1)
|
||||
else:
|
||||
assert torch.all(output_token_ids[:, -1] == bonus_token_ids.squeeze())
|
||||
|
||||
assert torch.all(output_token_ids[:, :k] == draft_token_ids)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", list(range(10)))
|
||||
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_temperature_zero_target_distribution(seed: int,
|
||||
disable_bonus_tokens: bool,
|
||||
device: str):
|
||||
"""
|
||||
Test the TypicalAcceptanceSampler with a zero-temperature target
|
||||
probability distribution.
|
||||
|
||||
This test verifies that when using a zero-temperature target probability
|
||||
distribution, where only one token has a probability of 1.0, the
|
||||
TypicalAcceptanceSampler correctly rejects all draft tokens that do not
|
||||
match this probability. Additionally, it ensures that when all draft
|
||||
tokens are rejected, the sampler falls back to greedy sampling to select a
|
||||
single token from the target distribution.
|
||||
"""
|
||||
set_random_seed(seed)
|
||||
k = 3
|
||||
batch_size = 5
|
||||
vocab_size = 30_000
|
||||
torch.set_default_device(device)
|
||||
|
||||
typical_acceptance_sampler = TypicalAcceptanceSampler(
|
||||
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
||||
# Simulate temperature 0 probability distribution for target probabilities
|
||||
# and create target probabilities such that only 1 token id has
|
||||
# probability 1.0
|
||||
target_probs, zero_temperature_token_ids = get_zero_temperature_prob_dist(
|
||||
batch_size, k, vocab_size)
|
||||
# Populate draft_token_ids such that they exclude the token_ids
|
||||
# with probability = 1.0
|
||||
draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size,
|
||||
zero_temperature_token_ids)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
# The target probaility distribution is a temperature zero distribution
|
||||
# with zero entroy. Since our draft token ids don't match the probability
|
||||
# 1.0 tokens in the target distribution we will reject all of them and
|
||||
# fallback to the greedy sampling for selecting 1 token for each sequence.
|
||||
# Verify the same.
|
||||
output_token_ids = typical_acceptance_sampler(target_probs,
|
||||
bonus_token_ids,
|
||||
draft_token_ids)
|
||||
assert output_token_ids.shape[0] == batch_size
|
||||
assert output_token_ids.shape[1] == (k + 1)
|
||||
assert torch.all(output_token_ids[:, -1] == -1)
|
||||
assert torch.all(output_token_ids[:, 0] == zero_temperature_token_ids[:,
|
||||
0])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", list(range(10)))
|
||||
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
|
||||
device: str):
|
||||
"""
|
||||
Test the TypicalAcceptanceSampler with a mixed target probability
|
||||
distribution.
|
||||
|
||||
This test ensures that the TypicalAcceptanceSampler handles a mixed
|
||||
target probability distribution correctly. Specifically, it uses a
|
||||
zero-temperature distribution for some sequences and a uniform
|
||||
distribution for others. The test verifies that:
|
||||
|
||||
- For sequences with a zero-temperature distribution, only the token
|
||||
with a probability of 1.0 is accepted, and all other tokens are rejected.
|
||||
- For sequences with a uniform distribution, all draft tokens are
|
||||
accepted.
|
||||
- When `disable_bonus_tokens` is False, the bonus tokens are also accepted
|
||||
for sequences with a uniform distribution.
|
||||
"""
|
||||
set_random_seed(seed)
|
||||
k = 3
|
||||
batch_size = 4
|
||||
vocab_size = 30_000
|
||||
torch.set_default_device(device)
|
||||
typical_acceptance_sampler = TypicalAcceptanceSampler(
|
||||
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
||||
# For sequences 0 and 2 set the distribution to a temperature
|
||||
# zero distribution. For sequences 1 and 3 set it to a uniform
|
||||
# distribution.
|
||||
target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist(
|
||||
batch_size, k, vocab_size))
|
||||
draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size,
|
||||
zero_temperature_token_ids)
|
||||
uniform_probs = torch.rand(2, k, vocab_size, dtype=torch.float32)
|
||||
target_probs[[1, 3]] = uniform_probs
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
output_token_ids = typical_acceptance_sampler(target_probs,
|
||||
bonus_token_ids,
|
||||
draft_token_ids)
|
||||
# verify the shape of output_token_ids
|
||||
assert output_token_ids.shape[0] == batch_size
|
||||
assert output_token_ids.shape[1] == (k + 1)
|
||||
# For sequences 0 and 2 verify that only 1 token is accepted
|
||||
# which is the token with probability 1.0 in the target distribution
|
||||
# at position 0.
|
||||
assert torch.all(output_token_ids[[0, 2], 1:] == -1)
|
||||
assert (torch.all(output_token_ids[[0, 2],
|
||||
0] == zero_temperature_token_ids[[0, 2],
|
||||
0]))
|
||||
# For sequences 1 and 3 verify that all tokens are accepted since the
|
||||
# target probability distribution is uniform. In addition verify that
|
||||
# if disable_bonus_tokens is false then we also accept the bonus tokens.
|
||||
assert torch.all(
|
||||
output_token_ids[[1, 3], :-1] == draft_token_ids[[1, 3], :])
|
||||
if disable_bonus_tokens:
|
||||
assert torch.all(output_token_ids[[1, 3], -1] == -1)
|
||||
else:
|
||||
assert torch.all(output_token_ids[[1, 3], -1] != -1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", list(range(10)))
|
||||
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
|
||||
device: str):
|
||||
"""
|
||||
Test the TypicalAcceptanceSampler's behavior when only a subset of draft
|
||||
tokens should be accepted.
|
||||
|
||||
This test verifies that the TypicalAcceptanceSampler correctly accepts or
|
||||
rejects draft tokens based on a zero-temperature target probability
|
||||
distribution. Specifically, it ensures that:
|
||||
|
||||
- When all draft tokens match tokens with a probability of 1.0 in the
|
||||
target distribution, all draft tokens are accepted.
|
||||
- When only some draft tokens match tokens with a probability of 1.0 in
|
||||
the target distribution, only those matching tokens are accepted, and the
|
||||
rest are rejected.
|
||||
"""
|
||||
set_random_seed(seed)
|
||||
k = 5
|
||||
batch_size = 1
|
||||
vocab_size = 30_000
|
||||
torch.set_default_device(device)
|
||||
typical_acceptance_sampler = TypicalAcceptanceSampler(
|
||||
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
||||
# Create a temperature zero target probability distribution and ensure
|
||||
# all draft token ids correspond to the tokens with 1.0 probability.
|
||||
# Verify that all of them are accepted.
|
||||
target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist(
|
||||
batch_size, k, vocab_size))
|
||||
draft_token_ids = zero_temperature_token_ids
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
output_token_ids = typical_acceptance_sampler(target_probs,
|
||||
bonus_token_ids,
|
||||
draft_token_ids)
|
||||
assert output_token_ids.shape[0] == batch_size
|
||||
assert output_token_ids.shape[1] == (k + 1)
|
||||
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
|
||||
if disable_bonus_tokens:
|
||||
assert torch.all(output_token_ids[:, -1] == -1)
|
||||
else:
|
||||
assert torch.all(output_token_ids[:, -1] == bonus_token_ids)
|
||||
# Next only keep the first 2 draft tokens same as the zero temperature
|
||||
# tokens. For the remaining 3 choose some other tokens. In the
|
||||
# response we will expect the first 2 tokens to be the same as the
|
||||
# draft tokens and the rest as -1
|
||||
draft_token_ids_to_replace = get_draft_token_ids(
|
||||
batch_size, k, vocab_size, zero_temperature_token_ids)
|
||||
draft_token_ids = torch.cat(
|
||||
(draft_token_ids[:, :2], draft_token_ids_to_replace[:, -3:]), dim=1)
|
||||
output_token_ids = typical_acceptance_sampler(target_probs,
|
||||
bonus_token_ids,
|
||||
draft_token_ids)
|
||||
assert output_token_ids.shape[0] == batch_size
|
||||
assert output_token_ids.shape[1] == (k + 1)
|
||||
assert torch.all(output_token_ids[:, :2] == draft_token_ids[:, :2])
|
||||
assert torch.all(output_token_ids[:, -3:] == -1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", list(range(1)))
|
||||
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_accept_tokens_set_non_default_posteriors(seed: int,
|
||||
disable_bonus_tokens: bool,
|
||||
device: str):
|
||||
"""
|
||||
Test the TypicalAcceptanceSampler with custom posterior thresholds and
|
||||
alpha values. This test verifies that by modifying the posterior
|
||||
thresholds and alpha values we can change the acceptance behavior of the
|
||||
sampler.
|
||||
"""
|
||||
set_random_seed(seed)
|
||||
k = 5
|
||||
batch_size = 1
|
||||
vocab_size = 30_000
|
||||
torch.set_default_device(device)
|
||||
typical_acceptance_sampler = TypicalAcceptanceSampler(
|
||||
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
||||
# Simulate temperature 0 probability distribution for target
|
||||
# probabilities and create target probabilities such that only 1 token
|
||||
# id has probability 1.0 and others have a very low probability of
|
||||
# 0.00001. Populate draft_token_ids such that they exclude the token_ids
|
||||
# with probability = 1.0. Without any changes to the posterior thresholds
|
||||
# none of the draft tokens are accepted.
|
||||
target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist(
|
||||
batch_size, k, vocab_size))
|
||||
target_probs[target_probs == 0] = 0.00001
|
||||
draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size,
|
||||
zero_temperature_token_ids)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
output_token_ids = typical_acceptance_sampler(target_probs,
|
||||
bonus_token_ids,
|
||||
draft_token_ids)
|
||||
assert output_token_ids.shape[0] == batch_size
|
||||
assert output_token_ids.shape[1] == (k + 1)
|
||||
assert torch.all(output_token_ids[:, 1:-1] == -1)
|
||||
|
||||
# Change the posterior threshold values to 0.0 so that we will
|
||||
# now accept even draft tokens with very low probability in the
|
||||
# target distribution. Simulate and verify the same.
|
||||
typical_acceptance_sampler = TypicalAcceptanceSampler(
|
||||
strict_mode=True,
|
||||
disable_bonus_tokens=disable_bonus_tokens,
|
||||
posterior_threshold=0.0,
|
||||
posterior_alpha=0.0)
|
||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
||||
output_token_ids = typical_acceptance_sampler(target_probs,
|
||||
bonus_token_ids,
|
||||
draft_token_ids)
|
||||
assert output_token_ids.shape[0] == batch_size
|
||||
assert output_token_ids.shape[1] == (k + 1)
|
||||
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
|
||||
if disable_bonus_tokens:
|
||||
assert torch.all(output_token_ids[:, -1] == -1)
|
||||
else:
|
||||
assert torch.all(output_token_ids[:, -1] == bonus_token_ids)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", list(range(10)))
|
||||
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_replacement_token_ids(seed: int, disable_bonus_tokens: bool,
|
||||
device: str):
|
||||
"""
|
||||
Test the TypicalAcceptanceSampler's method for generating
|
||||
replacement token IDs.
|
||||
|
||||
This test verifies that the `_replacement_token_ids` method of the
|
||||
TypicalAcceptanceSampler correctly identifies the token IDs to be used
|
||||
as replacements based on the target probability distribution.
|
||||
Specifically, it ensures that the method correctly identifies the
|
||||
tokens with the highest probability for each sequence in the batch.
|
||||
"""
|
||||
set_random_seed(seed)
|
||||
k = 10
|
||||
batch_size = 5
|
||||
vocab_size = 30_000
|
||||
torch.set_default_device(device)
|
||||
typical_acceptance_sampler = TypicalAcceptanceSampler(
|
||||
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
expected_replacement_tokens = -torch.ones(
|
||||
(batch_size, k), dtype=torch.long)
|
||||
expected_replacement_tokens[:, 0] = torch.argmax(target_probs[:, 0, :],
|
||||
dim=1)
|
||||
actual_replacement_tokens = (
|
||||
typical_acceptance_sampler._replacement_token_ids(target_probs))
|
||||
assert torch.all(expected_replacement_tokens == actual_replacement_tokens)
|
||||
@ -1,12 +1,15 @@
|
||||
from functools import cached_property
|
||||
from typing import Optional, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.jit
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.model_executor.layers.spec_decode_base_sampler import (
|
||||
SpecDecodeBaseSampler)
|
||||
|
||||
class RejectionSampler(nn.Module):
|
||||
|
||||
class RejectionSampler(SpecDecodeBaseSampler, nn.Module):
|
||||
"""Apply modified rejection sampling as described in "Accelerating Large
|
||||
Language Model Decoding with Speculative Sampling"
|
||||
https://arxiv.org/pdf/2302.01318.pdf.
|
||||
@ -22,39 +25,11 @@ class RejectionSampler(nn.Module):
|
||||
Require when bonus tokens will cause corrupt KV cache for
|
||||
proposal methods that require KV cache.
|
||||
strict_mode: Whether or not to perform shape/device/dtype checks
|
||||
during sampling. This catches correctness issues but adds
|
||||
nontrivial latency.
|
||||
during sampling. This catches correctness issues but adds
|
||||
nontrivial latency.
|
||||
"""
|
||||
super().__init__()
|
||||
self._disable_bonus_tokens = disable_bonus_tokens
|
||||
self._strict_mode = strict_mode
|
||||
|
||||
# NOTE: A "bonus token" is accepted iff all proposal tokens are
|
||||
# accepted. There is always only one possible bonus token. We store this
|
||||
# value in a variable for readability.
|
||||
self._num_bonus_tokens = 1
|
||||
|
||||
self.num_accepted_tokens: Optional[torch.Tensor] = None
|
||||
self.num_emitted_tokens: Optional[torch.Tensor] = None
|
||||
self.num_draft_tokens: int = 0
|
||||
|
||||
def init_gpu_tensors(self, rank: int) -> None:
|
||||
assert self.num_accepted_tokens is None
|
||||
device = f"cuda:{rank}"
|
||||
self.num_accepted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
self.num_emitted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
@property
|
||||
def probs_dtype(self):
|
||||
return torch.float32
|
||||
|
||||
@property
|
||||
def token_id_dtype(self):
|
||||
return torch.int64
|
||||
SpecDecodeBaseSampler.__init__(self, disable_bonus_tokens, strict_mode)
|
||||
nn.Module.__init__(self)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -100,15 +75,8 @@ class RejectionSampler(nn.Module):
|
||||
# Only perform shape/dtype/device checking in strict mode, as it adds
|
||||
# overhead.
|
||||
if self._strict_mode:
|
||||
self._raise_if_incorrect_shape(target_probs, bonus_token_ids,
|
||||
self._raise_if_incorrect_input(target_probs, bonus_token_ids,
|
||||
draft_probs, draft_token_ids)
|
||||
self._raise_if_incorrect_dtype(target_probs, bonus_token_ids,
|
||||
draft_probs, draft_token_ids)
|
||||
self._raise_if_inconsistent_device(target_probs, bonus_token_ids,
|
||||
draft_probs, draft_token_ids)
|
||||
self._raise_if_out_of_bounds_vocab(target_probs.shape[-1],
|
||||
bonus_token_ids,
|
||||
draft_token_ids)
|
||||
|
||||
accepted, recovered_token_ids = self._batch_modified_rejection_sampling(
|
||||
target_probs,
|
||||
@ -272,128 +240,6 @@ class RejectionSampler(nn.Module):
|
||||
"""
|
||||
return torch.finfo(self.probs_dtype).tiny
|
||||
|
||||
def _create_output(
|
||||
self,
|
||||
accepted: torch.Tensor, # [batch_size, k]
|
||||
recovered_token_ids: torch.Tensor, # [batch_size, k]
|
||||
draft_token_ids: torch.Tensor, # [batch_size, k]
|
||||
bonus_token_ids: torch.Tensor, # [batch_size]
|
||||
) -> torch.Tensor:
|
||||
"""Format output. Returns a matrix of token ids. When
|
||||
a token is rejected via rejection sampling, all subsequent
|
||||
token ids are set to -1 for the sequence.
|
||||
|
||||
shape = [batch_size, k + num_bonus_tokens]
|
||||
"""
|
||||
bonus_token_ids = bonus_token_ids.squeeze()
|
||||
batch_size, k = recovered_token_ids.shape
|
||||
|
||||
# Determine the index of the first False value for each row.
|
||||
limits = (accepted == 0).max(1).indices
|
||||
limits[~(accepted == 0).any(1)] = k
|
||||
|
||||
# Create masks using the indices.
|
||||
indices = torch.arange(k, device=accepted.device).unsqueeze(0)
|
||||
accepted_mask = indices < limits.unsqueeze(1)
|
||||
after_false_mask = indices == limits.unsqueeze(1)
|
||||
|
||||
# Create an extended output tensor
|
||||
output_with_bonus_tokens = -torch.ones(
|
||||
(batch_size, k + self._num_bonus_tokens),
|
||||
dtype=self.token_id_dtype,
|
||||
device=accepted.device)
|
||||
output = output_with_bonus_tokens[:, :k]
|
||||
|
||||
# Fill in the first k columns of the output tensor using masks and data
|
||||
# tensors.
|
||||
torch.where(accepted_mask,
|
||||
draft_token_ids,
|
||||
-torch.ones_like(draft_token_ids),
|
||||
out=output)
|
||||
|
||||
# Fill the last column.
|
||||
# We check output directly as accepted may have True values inconsistent
|
||||
# with causal acceptance.
|
||||
output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,
|
||||
bonus_token_ids, -1)
|
||||
|
||||
# We disable bonus tokens because it causes corrupt KV cache for
|
||||
# proposal methods that require KV cache. We can fix it by "prefilling"
|
||||
# the bonus token in the proposer. The following issue tracks the fix.
|
||||
# https://github.com/vllm-project/vllm/issues/4212
|
||||
if self._disable_bonus_tokens:
|
||||
output_with_bonus_tokens[:, -1] = -1
|
||||
|
||||
# Fill the recovered token ids.
|
||||
output.mul_(~after_false_mask).add_(
|
||||
recovered_token_ids.mul(after_false_mask))
|
||||
|
||||
self.num_accepted_tokens += accepted.sum()
|
||||
self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum()
|
||||
self.num_draft_tokens += batch_size * k
|
||||
|
||||
return output_with_bonus_tokens
|
||||
|
||||
def _raise_if_incorrect_shape(
|
||||
self,
|
||||
target_probs: torch.Tensor,
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_probs: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
) -> None:
|
||||
(target_batch_size, num_target_probs,
|
||||
target_vocab_size) = target_probs.shape
|
||||
bonus_batch_size, num_bonus_tokens = bonus_token_ids.shape
|
||||
draft_batch_size, num_draft_probs, draft_vocab_size = draft_probs.shape
|
||||
draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape
|
||||
|
||||
assert draft_batch_size == target_batch_size
|
||||
assert num_draft_probs == num_target_probs
|
||||
assert (draft_vocab_size == target_vocab_size
|
||||
), f"{draft_vocab_size=} {target_vocab_size=}"
|
||||
|
||||
assert draft_token_ids_batch_size == draft_batch_size
|
||||
assert num_draft_token_ids == num_draft_probs
|
||||
|
||||
assert bonus_batch_size == target_batch_size
|
||||
assert num_bonus_tokens == self._num_bonus_tokens
|
||||
|
||||
def _raise_if_incorrect_dtype(
|
||||
self,
|
||||
target_probs: torch.Tensor,
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_probs: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
) -> None:
|
||||
assert all(probs.dtype == self.probs_dtype
|
||||
for probs in [target_probs, draft_probs])
|
||||
assert all(token_ids.dtype == self.token_id_dtype
|
||||
for token_ids in [bonus_token_ids, draft_token_ids])
|
||||
|
||||
def _raise_if_inconsistent_device(
|
||||
self,
|
||||
target_probs: torch.Tensor,
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_probs: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
) -> None:
|
||||
devices = [
|
||||
t.device for t in
|
||||
[target_probs, bonus_token_ids, draft_probs, draft_token_ids]
|
||||
]
|
||||
assert all([devices[0] == device for device in devices])
|
||||
|
||||
def _raise_if_out_of_bounds_vocab(
|
||||
self,
|
||||
vocab_size: int,
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
) -> None:
|
||||
assert torch.all(bonus_token_ids < vocab_size)
|
||||
assert torch.all(bonus_token_ids >= 0)
|
||||
assert torch.all(draft_token_ids < vocab_size)
|
||||
assert torch.all(draft_token_ids >= 0)
|
||||
|
||||
|
||||
# torch.multinomial forces a GPU<->CPU sync.
|
||||
# Therefore, we use an optimized implementation instead that skips the sync.
|
||||
|
||||
206
vllm/model_executor/layers/spec_decode_base_sampler.py
Normal file
206
vllm/model_executor/layers/spec_decode_base_sampler.py
Normal file
@ -0,0 +1,206 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class SpecDecodeBaseSampler():
|
||||
"""Base class for samplers used for Speculative Decoding verification
|
||||
step.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
disable_bonus_tokens: bool = True,
|
||||
strict_mode: bool = False):
|
||||
"""Base class constructor.
|
||||
Args:
|
||||
disable_bonus_tokens: Whether or not to disable the bonus token.
|
||||
Require when bonus tokens will cause corrupt KV cache for
|
||||
proposal methods that require KV cache.
|
||||
strict_mode: Whether or not to perform shape/device/dtype checks
|
||||
during sampling. This catches correctness issues but adds
|
||||
nontrivial latency.
|
||||
"""
|
||||
super().__init__()
|
||||
self._disable_bonus_tokens = disable_bonus_tokens
|
||||
self._strict_mode = strict_mode
|
||||
|
||||
# NOTE: A "bonus token" is accepted iff all proposal tokens are
|
||||
# accepted. There is always only one possible bonus token. We store this
|
||||
# value in a variable for readability.
|
||||
self._num_bonus_tokens = 1
|
||||
|
||||
self.num_accepted_tokens: Optional[torch.Tensor] = None
|
||||
self.num_emitted_tokens: Optional[torch.Tensor] = None
|
||||
self.num_draft_tokens: int = 0
|
||||
|
||||
def init_gpu_tensors(self, rank: int) -> None:
|
||||
assert self.num_accepted_tokens is None
|
||||
device = f"cuda:{rank}"
|
||||
self.num_accepted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
self.num_emitted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
@property
|
||||
def probs_dtype(self):
|
||||
return torch.float32
|
||||
|
||||
@property
|
||||
def token_id_dtype(self):
|
||||
return torch.int64
|
||||
|
||||
def _create_output(
|
||||
self,
|
||||
accepted: torch.Tensor, # [batch_size, k]
|
||||
substitute_token_ids: torch.Tensor, # [batch_size, k]
|
||||
draft_token_ids: torch.Tensor, # [batch_size, k]
|
||||
bonus_token_ids: torch.Tensor, # [batch_size]
|
||||
) -> torch.Tensor:
|
||||
"""Format output. Returns a matrix of token ids. When
|
||||
a token is rejected via sampling, all subsequent token ids are
|
||||
set to -1 for the sequence.
|
||||
|
||||
Args:
|
||||
accepted: A boolean tensor indicating if the corresponding
|
||||
draft token in draft_token_ids should be accepted or not.
|
||||
substitute_token_ids: A tensor of token_ids that can be used
|
||||
as substitutes for the draft token ids if the proposed token
|
||||
is rejected.
|
||||
draft_token_ids: A tensor of token ids speculated by the
|
||||
draft model.
|
||||
bonus_token_ids: Token ids to use as the bonus token if
|
||||
all the draft tokens are accepted.
|
||||
Returns:
|
||||
A tensor containing the accepted token ids. The shape of the
|
||||
tensor is [batch_size, k + num_bonus_tokens]
|
||||
"""
|
||||
batch_size, k = substitute_token_ids.shape
|
||||
bonus_token_ids = bonus_token_ids.squeeze()
|
||||
# Determine the index of the first False value for each row.
|
||||
limits = (accepted == 0).max(1).indices
|
||||
limits[~(accepted == 0).any(1)] = k
|
||||
|
||||
# Create masks using the indices.
|
||||
indices = torch.arange(k, device=accepted.device).unsqueeze(0)
|
||||
accepted_mask = indices < limits.unsqueeze(1)
|
||||
after_false_mask = indices == limits.unsqueeze(1)
|
||||
|
||||
# Create an extended output tensor
|
||||
output_with_bonus_tokens = -torch.ones(
|
||||
(batch_size, k + self._num_bonus_tokens),
|
||||
dtype=self.token_id_dtype,
|
||||
device=accepted.device)
|
||||
output = output_with_bonus_tokens[:, :k]
|
||||
|
||||
# Fill in the first k columns of the output tensor using masks and data
|
||||
# tensors.
|
||||
output[:, :k] = torch.where(accepted_mask, draft_token_ids,
|
||||
-torch.ones_like(draft_token_ids))
|
||||
|
||||
# Fill the last column.
|
||||
# We check output directly as accepted may have True values inconsistent
|
||||
# with causal acceptance.
|
||||
output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,
|
||||
bonus_token_ids, -1)
|
||||
|
||||
# We disable bonus tokens because it causes corrupt KV cache for
|
||||
# proposal methods that require KV cache. We can fix it by "prefilling"
|
||||
# the bonus token in the proposer. The following issue tracks the fix.
|
||||
# https://github.com/vllm-project/vllm/issues/4212
|
||||
if self._disable_bonus_tokens:
|
||||
output_with_bonus_tokens[:, -1] = -1
|
||||
|
||||
# Fill the recovered token ids.
|
||||
output.mul_(~after_false_mask).add_(
|
||||
substitute_token_ids.mul(after_false_mask))
|
||||
|
||||
self.num_accepted_tokens += accepted.sum()
|
||||
self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum()
|
||||
self.num_draft_tokens += batch_size * k
|
||||
|
||||
return output_with_bonus_tokens
|
||||
|
||||
def _raise_if_incorrect_input(
|
||||
self,
|
||||
target_probs: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_probs: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
self._raise_if_incorrect_shape(target_probs, draft_token_ids,
|
||||
bonus_token_ids, draft_probs)
|
||||
self._raise_if_incorrect_dtype(target_probs, draft_token_ids,
|
||||
bonus_token_ids, draft_probs)
|
||||
self._raise_if_inconsistent_device(target_probs, draft_token_ids,
|
||||
bonus_token_ids, draft_probs)
|
||||
self._raise_if_out_of_bounds_vocab(target_probs.shape[-1],
|
||||
draft_token_ids, bonus_token_ids)
|
||||
|
||||
def _raise_if_incorrect_shape(
|
||||
self,
|
||||
target_probs: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_probs: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
(target_batch_size, num_target_probs,
|
||||
target_vocab_size) = target_probs.shape
|
||||
|
||||
# validate the shape of draft token ids.
|
||||
draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape
|
||||
assert draft_token_ids_batch_size == target_batch_size
|
||||
assert num_draft_token_ids == num_target_probs
|
||||
|
||||
# validate the shape of bonus token ids
|
||||
bonus_batch_size, num_bonus_tokens = bonus_token_ids.shape
|
||||
assert bonus_batch_size == target_batch_size
|
||||
assert num_bonus_tokens == self._num_bonus_tokens
|
||||
|
||||
# validate the shape of draft probs if it is set
|
||||
if draft_probs is not None:
|
||||
(draft_batch_size, num_draft_probs,
|
||||
draft_vocab_size) = draft_probs.shape
|
||||
assert draft_batch_size == target_batch_size
|
||||
assert num_draft_probs == num_target_probs
|
||||
assert (draft_vocab_size == target_vocab_size
|
||||
), f"{draft_vocab_size=} {target_vocab_size=}"
|
||||
|
||||
def _raise_if_incorrect_dtype(
|
||||
self,
|
||||
target_probs: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_probs: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
assert target_probs.dtype == self.probs_dtype
|
||||
assert draft_token_ids.dtype == self.token_id_dtype
|
||||
assert bonus_token_ids.dtype == self.token_id_dtype
|
||||
if draft_probs is not None:
|
||||
assert draft_probs.dtype == self.probs_dtype
|
||||
|
||||
def _raise_if_inconsistent_device(
|
||||
self,
|
||||
target_probs: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_probs: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
devices = [
|
||||
t.device for t in
|
||||
[target_probs, bonus_token_ids, draft_probs, draft_token_ids]
|
||||
if t is not None
|
||||
]
|
||||
assert all([devices[0] == device for device in devices])
|
||||
|
||||
def _raise_if_out_of_bounds_vocab(
|
||||
self,
|
||||
vocab_size: int,
|
||||
draft_token_ids: torch.Tensor,
|
||||
bonus_token_ids: torch.Tensor,
|
||||
) -> None:
|
||||
assert torch.all(bonus_token_ids < vocab_size)
|
||||
assert torch.all(bonus_token_ids >= 0)
|
||||
assert torch.all(draft_token_ids < vocab_size)
|
||||
assert torch.all(draft_token_ids >= 0)
|
||||
186
vllm/model_executor/layers/typical_acceptance_sampler.py
Normal file
186
vllm/model_executor/layers/typical_acceptance_sampler.py
Normal file
@ -0,0 +1,186 @@
|
||||
import torch
|
||||
import torch.jit
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.model_executor.layers.spec_decode_base_sampler import (
|
||||
SpecDecodeBaseSampler)
|
||||
|
||||
|
||||
class TypicalAcceptanceSampler(SpecDecodeBaseSampler, nn.Module):
|
||||
"""Apply typical acceptance sampling as described in section 3.3.1 in
|
||||
"MEDUSA: Simple LLM Inference Acceleration Framework with
|
||||
Multiple Decoding Heads"
|
||||
https://arxiv.org/pdf/2401.10774
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
disable_bonus_tokens: bool = False,
|
||||
strict_mode: bool = False,
|
||||
posterior_threshold: float = 0.09,
|
||||
posterior_alpha: float = 0.3,
|
||||
):
|
||||
"""Create a Typical Acceptance Sampler.
|
||||
|
||||
Args:
|
||||
disable_bonus_tokens: Whether or not to disable the bonus token.
|
||||
Require when bonus tokens will cause corrupt KV cache for
|
||||
proposal methods that require KV cache.
|
||||
strict_mode: Whether or not to perform shape/device/dtype checks
|
||||
during sampling. This catches correctness issues but adds
|
||||
nontrivial latency.
|
||||
posterior_threshold : A threshold value that sets a lower bound
|
||||
on the posterior probability of a token in target model for it
|
||||
to be accepted. Default is 0.09
|
||||
posterior_alpha : A scaling factor for the entropy-based
|
||||
threshold in typical acceptance sampling. Typically defaults to
|
||||
sqrt of posterior_threshold and is set to 0.3.
|
||||
"""
|
||||
SpecDecodeBaseSampler.__init__(
|
||||
self,
|
||||
disable_bonus_tokens=disable_bonus_tokens,
|
||||
strict_mode=strict_mode)
|
||||
nn.Module.__init__(self)
|
||||
self._posterior_threshold = posterior_threshold
|
||||
self._posterior_alpha = posterior_alpha
|
||||
|
||||
def forward(
|
||||
self,
|
||||
target_probs: torch.Tensor,
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Sample token ids using typical acceptance sampling. This accepts
|
||||
or rejects tokens proposed by the draft model using the probability
|
||||
of each token according to the draft and target models.
|
||||
|
||||
In the worst case where all draft tokens are rejected, it is guaranteed
|
||||
one token will be emitted.
|
||||
|
||||
In the case where all draft tokens are accepted, the bonus token will be
|
||||
accepted conditioned on self._disable_bonus_tokens being false.
|
||||
|
||||
Args:
|
||||
target_probs: The probability distribution over token ids given
|
||||
context according to the target model.
|
||||
shape = [batch_size, num_speculative_tokens, vocab_size]
|
||||
|
||||
bonus_token_ids: The "bonus" token ids that are accepted iff all
|
||||
speculative tokens in a sequence are accepted.
|
||||
shape = [batch_size, num_bonus_tokens]
|
||||
|
||||
draft_token_ids: The token ids that were sampled from the draft
|
||||
probabilities.
|
||||
shape = [batch_size, num_speculative_tokens]
|
||||
|
||||
Returns:
|
||||
output_token_ids: The token ids sampled via rejection sampling,
|
||||
or -1 if unable to sample a token because the previous token
|
||||
was rejected.
|
||||
shape = [batch_size, num_speculative_tokens + num_bonus_tokens]
|
||||
"""
|
||||
# Only perform shape/dtype/device checking in strict mode, as it adds
|
||||
# overhead.
|
||||
if self._strict_mode:
|
||||
self._raise_if_incorrect_input(target_probs, draft_token_ids,
|
||||
bonus_token_ids)
|
||||
accepted = self._evaluate_accepted_tokens(target_probs,
|
||||
draft_token_ids)
|
||||
recovered_token_ids = self._replacement_token_ids(target_probs)
|
||||
output_token_ids = self._create_output(accepted, recovered_token_ids,
|
||||
draft_token_ids,
|
||||
bonus_token_ids)
|
||||
return output_token_ids
|
||||
|
||||
def _evaluate_accepted_tokens(self, target_probs, draft_token_ids):
|
||||
r"""
|
||||
Evaluates and returns a mask of accepted tokens based on the
|
||||
posterior probabilities.
|
||||
|
||||
Parameters:
|
||||
----------
|
||||
target_probs : torch.Tensor
|
||||
A tensor of shape (batch_size, k, vocab_size) representing
|
||||
the probabilities of each token in the vocabulary for each
|
||||
position in the proposed sequence. This is the distribution
|
||||
generated by the target model.
|
||||
draft_token_ids : torch.Tensor
|
||||
A tensor of shape (batch_size, k) representing the proposed
|
||||
token ids.
|
||||
|
||||
A draft token_id x_{n+k} is accepted if it satisfies the
|
||||
following condition
|
||||
|
||||
.. math::
|
||||
p_{\text{original}}(x_{n+k} | x_1, x_2, \dots, x_{n+k-1}) >
|
||||
\min \left( \epsilon, \delta * \exp \left(
|
||||
-H(p_{\text{original}}(
|
||||
\cdot | x_1, x_2, \ldots, x_{n+k-1})) \right) \right)
|
||||
|
||||
where :math:`p_{\text{original}}` corresponds to target_probs
|
||||
and :math:`\epsilon` and :math:`\delta` correspond to hyperparameters
|
||||
specified using self._posterior_threshold and self._posterior_alpha
|
||||
|
||||
This method computes the posterior probabilities for the given
|
||||
draft token ids based on the provided target probabilities. It
|
||||
calculates the entropy of the posterior distribution and determines
|
||||
a dynamic threshold for each token position using the provided
|
||||
posterior_threshold and posterior_alpha values. The method then
|
||||
returns a boolean mask indicating which tokens can be accepted.
|
||||
|
||||
Returns:
|
||||
-------
|
||||
torch.Tensor
|
||||
A boolean tensor of shape (batch_size, k) where each element
|
||||
indicates whether the corresponding draft token has been accepted
|
||||
or rejected. True indicates acceptance and false indicates
|
||||
rejection.
|
||||
|
||||
"""
|
||||
device = target_probs.device
|
||||
candidates_prob = torch.gather(
|
||||
target_probs, dim=-1,
|
||||
index=draft_token_ids.unsqueeze(-1)).squeeze(-1)
|
||||
# A small constant added to prevent computing the logarithm of zero,
|
||||
# which can lead to undefined values.
|
||||
epsilon = 1e-5
|
||||
posterior_entropy = -torch.sum(
|
||||
target_probs * torch.log(target_probs + epsilon), dim=-1)
|
||||
threshold = torch.minimum(
|
||||
torch.ones_like(posterior_entropy, device=device) *
|
||||
self._posterior_threshold,
|
||||
torch.exp(-posterior_entropy) * self._posterior_alpha,
|
||||
)
|
||||
accepted_mask = candidates_prob > threshold
|
||||
return accepted_mask
|
||||
|
||||
def _replacement_token_ids(self, target_probs):
|
||||
"""
|
||||
Generate one replacement token ID for each sequence based on target
|
||||
probabilities. The replacement token is used as the fallback option
|
||||
if typical acceptance sampling does not accept any draft tokens for
|
||||
that particular sequence.
|
||||
|
||||
This method computes the token IDs to be replaced by selecting the
|
||||
token with the highest probability for each sequence in the first
|
||||
position. The rest of the output is filled with -1.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
target_probs : torch.Tensor
|
||||
A tensor of shape (batch_size, k, vocab_size) containing
|
||||
the target probability distribution
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
A tensor of shape (batch_size, k) with the replacement
|
||||
token IDs. Only the first column is set, and the rest of the
|
||||
columns are filled with -1.
|
||||
"""
|
||||
max_indices = torch.argmax(target_probs[:, 0, :], dim=1)
|
||||
output = -torch.ones((target_probs.shape[0], target_probs.shape[1]),
|
||||
dtype=self.token_id_dtype,
|
||||
device=target_probs.device)
|
||||
output[:, 0] = max_indices
|
||||
return output
|
||||
Loading…
x
Reference in New Issue
Block a user