[V1] [Spec Decode] Support random sampling for spec decode (#13933)

Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Lily Liu 2025-03-16 22:00:20 -07:00 committed by GitHub
parent 583a9778e0
commit 8d6cf89526
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 568 additions and 194 deletions

View File

@ -1,37 +1,51 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Optional
import pytest
import torch
import torch.nn.functional as F
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID, RejectionSampler
DEVICE = "cpu"
@pytest.fixture
def sampler():
return RejectionSampler()
def create_logits_tensor(token_ids: list[int],
def create_logits_tensor(token_ids: list[list[int]],
vocab_size: int = 100) -> torch.Tensor:
"""Helper function to create logits tensor that
will produce desired token ids on argmax"""
logits = torch.full((len(token_ids), vocab_size), -100.0).cuda()
for i, token_id in enumerate(token_ids):
logits[i, token_id] = 100.0
num_total_tokens = sum(len(tokens) for tokens in token_ids)
logits = torch.full((num_total_tokens, vocab_size), -100.0, device=DEVICE)
start_loc = 0
for tokens in token_ids:
for j, token_id in enumerate(tokens):
logits[start_loc + j, token_id] = 100.0
start_loc += len(tokens)
return logits
def create_sampling_metadata(spec_tokens: list[list[int]]) -> SamplingMetadata:
batch_size = len(spec_tokens)
def create_sampling_metadata(
all_greedy: bool,
generators: Optional[dict[int, Any]] = None) -> SamplingMetadata:
"""Create a v1 sampling metadata object with all_greedy set
to the given value. Either all greedy or all random sampling
is used.
"""
generators = generators or {}
return SamplingMetadata(
temperature=torch.tensor([]),
all_greedy=True,
all_random=False,
all_greedy=all_greedy,
all_random=not all_greedy,
top_p=None,
top_k=None,
min_p=torch.empty(batch_size, ),
generators={},
min_p=torch.empty(1, ),
generators=generators,
max_num_logprobs=0,
no_penalties=False,
prompt_token_ids=None,
@ -40,129 +54,310 @@ def create_sampling_metadata(spec_tokens: list[list[int]]) -> SamplingMetadata:
repetition_penalties=torch.tensor([]),
output_token_ids=[],
min_tokens={},
logit_bias=[None] * batch_size,
logit_bias=[None],
allowed_token_ids_mask=None,
bad_words_token_ids={},
)
########################### Tests for Greedy Sampling ###################
def test_perfect_match(sampler):
"""Test when output tokens perfectly match speculated tokens"""
spec_tokens = [[1, 2, 3]]
output_tokens = [1, 2, 3, 4] # 4 is the bonus token
output_tokens = [[1, 2, 3, 4]] # 4 is the bonus token
metadata = create_sampling_metadata(spec_tokens)
metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
device=logits.device)
output = sampler(spec_tokens, logits, metadata)
output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata)
expected = torch.tensor([[1, 2, 3, 4]],
dtype=torch.int,
device=logits.device)
assert torch.equal(output.sampled_token_ids, expected)
assert torch.equal(output, expected)
def test_early_mismatch(sampler):
"""Test when there's an early mismatch in tokens"""
spec_tokens = [[1, 2, 3]]
output_tokens = [1, 5, 3, 4] # Mismatch at position 1
output_tokens = [[1, 5, 3, 4]] # Mismatch at position 1
metadata = create_sampling_metadata(spec_tokens)
metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
device=logits.device)
output = sampler(spec_tokens, logits, metadata)
output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata)
expected = torch.tensor([[1, 5, INVALID_TOKEN_ID, INVALID_TOKEN_ID]],
dtype=torch.int,
device=logits.device)
assert torch.equal(output.sampled_token_ids, expected)
assert torch.equal(output, expected)
def test_multiple_sequences(sampler):
"""Test handling multiple sequences of speculated tokens"""
spec_tokens = [[1, 2], [3]]
output_tokens = [1, 2, 5, 3, 4] # Two sequences with bonus tokens 5 and 4
output_tokens = [[1, 2, 5], [3,
4]] # Two sequences with bonus tokens 5 and 4
metadata = create_sampling_metadata(spec_tokens)
metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor(
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device)
output = sampler(spec_tokens, logits, metadata)
output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata)
expected = torch.tensor([[1, 2, 5], [3, 4, INVALID_TOKEN_ID]],
dtype=torch.int,
device=logits.device)
assert torch.equal(output.sampled_token_ids, expected)
assert torch.equal(output, expected)
def test_single_token_sequence(sampler):
"""Test handling sequences with single token"""
spec_tokens = [[1]]
output_tokens = [1, 2] # Single token with bonus token 2
output_tokens = [[1, 2]] # Single token with bonus token 2
metadata = create_sampling_metadata(spec_tokens)
metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
device=logits.device)
output = sampler(spec_tokens, logits, metadata)
output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata)
expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device)
assert torch.equal(output.sampled_token_ids, expected)
assert torch.equal(output, expected)
def test_empty_sequence(sampler):
"""Test handling empty sequence of speculated tokens"""
spec_tokens: list[list[int]] = [[]]
output_tokens = [5] # Just the bonus token
output_tokens = [[5]] # Just the bonus token
metadata = create_sampling_metadata(spec_tokens)
metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
device=logits.device)
output = sampler(spec_tokens, logits, metadata)
output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata)
expected = torch.tensor([[5]], dtype=torch.int, device=logits.device)
assert torch.equal(output.sampled_token_ids, expected)
assert torch.equal(output, expected)
def test_multiple_mismatches(sampler):
"""Test handling multiple sequences with mismatches"""
spec_tokens = [[1, 2, 3], [4, 5, 6]]
output_tokens = [1, 2, 7, 6, 4, 8, 6, 9] # Mismatches in both sequences
output_tokens = [[1, 2, 7, 6], [4, 8, 6,
9]] # Mismatches in both sequences
metadata = create_sampling_metadata(spec_tokens)
metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor(
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device)
output = sampler(spec_tokens, logits, metadata)
output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata)
expected = torch.tensor([[1, 2, 7, INVALID_TOKEN_ID],
[4, 8, INVALID_TOKEN_ID, INVALID_TOKEN_ID]],
dtype=torch.int,
device=logits.device)
assert torch.equal(output.sampled_token_ids, expected)
assert torch.equal(output, expected)
@pytest.mark.parametrize(
"spec_tokens,output_tokens,expected",
[
([[1, 2]], [1, 2, 3], [[1, 2, 3]]), # Perfect match with bonus
([[1]], [2, 3], [[2, INVALID_TOKEN_ID]]), # First mismatch
([[1, 2], [3, 4]], [1, 5, 6, 3, 4, 7], [[1, 5, INVALID_TOKEN_ID],
[3, 4, 7]]), # Mixed matches
([[1, 2]], [[1, 2, 3]], [[1, 2, 3]]), # Perfect match with bonus
([[1]], [[2, 3]], [[2, INVALID_TOKEN_ID]]), # First mismatch
([[1, 2], [3, 4]], [[1, 5, 6], [3, 4, 7]],
[[1, 5, INVALID_TOKEN_ID], [3, 4, 7]]), # Mixed matches
])
def test_parametrized_cases(sampler, spec_tokens, output_tokens, expected):
"""Parametrized test for various matching scenarios"""
metadata = create_sampling_metadata(spec_tokens)
metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([tokens[-1] for tokens in output_tokens],
device=logits.device)
output = sampler(spec_tokens, logits, metadata)
output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata)
expected_tensor = torch.tensor(expected,
dtype=torch.int,
device=logits.device)
assert torch.equal(output.sampled_token_ids, expected_tensor)
assert torch.equal(output, expected_tensor)
def test_logits_shape_handling(sampler):
"""Test handling of different logits tensor shapes"""
spec_tokens = [[1, 2]]
output_tokens = [1, 2, 3]
vocab_size = 1000
########################### Tests for Random Sampling ###################
@pytest.mark.parametrize("k", [1, 3, 5])
@pytest.mark.parametrize("vocab_size", [1000])
@pytest.mark.parametrize("batch_size", [1, 4, 8])
@pytest.mark.parametrize("frac_seeded", [0.0, 0.5])
@pytest.mark.parametrize("n_rep", [20])
def test_deterministic_when_seeded(sampler, k: int, vocab_size: int,
batch_size: int, frac_seeded: float,
n_rep: int):
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_probs = torch.rand(batch_size * (k + 1),
vocab_size,
dtype=torch.float32)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
draft_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64)
metadata = create_sampling_metadata(spec_tokens)
logits = create_logits_tensor(output_tokens, vocab_size)
seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= frac_seeded
output = sampler(spec_tokens, logits, metadata)
expected = torch.tensor([[1, 2, 3]], dtype=torch.int, device=logits.device)
assert torch.equal(output.sampled_token_ids, expected)
assert logits.shape[-1] == vocab_size
results = []
for _ in range(n_rep):
seeded_seqs = {
i: torch.Generator(device=DEVICE).manual_seed(i)
for i in range(batch_size) if seeded_mask[i]
}
sampling_metadata = create_sampling_metadata(all_greedy=False,
generators=seeded_seqs)
rep_result = sampler(draft_token_ids.tolist(), draft_probs,
bonus_token_ids, target_probs, sampling_metadata)
results.append(rep_result)
for i in range(batch_size):
if seeded_mask[i]:
for j in range(1, n_rep):
assert torch.equal(results[j][i], results[0][i])
def test_rejection_sampling_approximates_target_distribution():
"""Verify rejection sampling approximates target distribution,
despite sampling from a potentially distinct draft distribution.
This is done by first creating a random target probability
distribution and a random draft probability distribution. We then
sample token ids from the rejection sampler using these draft
and target distributions. The samples are used to estimate
the output probability distribution, which we expect to approximate
the target distribution.
A basic distance metric is used to determine similarity between
distributions.
We expect that as we increase the number of samples,
the distance between the observed distribution and the target
distribution decreases. To measure this, we compare the distance
of the observed distribution against both the target distribution
and a uniform random distribution. We expect the distance between
the observed distribution and the target distribution to improve
much more than the distance improvement between the observed
distribution and the random distribution.
"""
torch.set_default_device(DEVICE)
vocab_size = 10
k = 2
num_reference_probs = 100
# Prepare draft, target, and reference probability distributions
draft_probs, target_probs = (F.softmax(
torch.rand(vocab_size, dtype=torch.float32),
dim=-1,
) for _ in range(2))
reference_probs = F.softmax(
torch.rand(num_reference_probs, vocab_size, dtype=torch.float32),
dim=-1,
)
sample_sizes = [10, 100, 1_000, 10_000, 100_000]
distance_wrt_reference: list[float] = []
distance_wrt_target: list[float] = []
for num_samples in sample_sizes:
# Sample using rejection sampling.
rej_sample_probs = estimate_rejection_sampling_pdf(
draft_probs, target_probs, k, vocab_size, num_samples)
rej_sample_probs = rej_sample_probs.to(DEVICE)
# Average distance from reference probs.
reference_vs_rejsample_dist = torch.dist(
reference_probs,
rej_sample_probs).item() / reference_probs.shape[0]
target_vs_rejsample_dist = torch.dist(target_probs,
rej_sample_probs).item()
distance_wrt_reference.append(reference_vs_rejsample_dist)
distance_wrt_target.append(target_vs_rejsample_dist)
relative_change_in_distance_wrt_target = get_ratio_first_to_last(
distance_wrt_target)
relative_change_in_distance_wrt_reference = get_ratio_first_to_last(
distance_wrt_reference)
print(f"{num_samples=} {target_vs_rejsample_dist=:.05f} "
f"{reference_vs_rejsample_dist=:.05f}")
print(f"{num_samples=} {relative_change_in_distance_wrt_target=:.02f} "
f"{relative_change_in_distance_wrt_reference=:.02f}")
relative_change_in_distance_wrt_target = get_ratio_first_to_last(
distance_wrt_target)
relative_change_in_distance_wrt_reference = get_ratio_first_to_last(
distance_wrt_reference)
expected_improvement_multiplier = 20
assert (relative_change_in_distance_wrt_target
> relative_change_in_distance_wrt_reference *
expected_improvement_multiplier)
def get_ratio_first_to_last(elements: list[float]) -> float:
return elements[0] / elements[-1]
def estimate_rejection_sampling_pdf(
draft_probs: torch.Tensor,
target_probs: torch.Tensor,
k: int,
vocab_size: int,
num_samples: int,
) -> torch.Tensor:
"""Estimate the probability distribution of the output tokens
using rejection sampling.
Args:
draft_probs: Draft probability distribution.
target_probs: Target probability distribution.
num_samples: Number of samples to draw.
Returns:
Estimated probability distribution of the output tokens.
"""
sampler = RejectionSampler()
# Repeat draft probs num_samples times.
draft_probs = draft_probs.reshape(1, 1,
vocab_size).repeat(num_samples, k, 1)
# Repeat target probs num_samples * (k + 1) times.
target_probs = target_probs.reshape(1, 1, vocab_size).repeat(
num_samples, k + 1, 1).reshape(num_samples * (k + 1), vocab_size)
# Randomly sample draft token ids from draft probs.
draft_token_ids = torch.multinomial(draft_probs[:, 0, :],
num_samples=k,
replacement=True).reshape(
num_samples, k)
# Bonus tokens not used but required.
bonus_token_ids = torch.zeros((1, 1), dtype=torch.int64,
device=DEVICE).repeat(num_samples, 1)
sampling_metadata = create_sampling_metadata(all_greedy=False)
output_token_ids = sampler(draft_token_ids.tolist(), draft_probs,
bonus_token_ids, target_probs,
sampling_metadata)
output_token_ids = output_token_ids[:, :-1].flatten()
hist = torch.histogram(output_token_ids.to(dtype=torch.float,
device="cpu"),
bins=vocab_size,
range=(0, vocab_size),
density=True)
return hist.hist

View File

@ -1,87 +1,89 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from vllm import envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.outputs import SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata
try:
import flashinfer.sampling as fs
is_flashinfer_available = True
except ImportError:
is_flashinfer_available = False
from vllm.v1.spec_decode.utils import random_sample
logger = init_logger(__name__)
INVALID_TOKEN_ID = -1
class RejectionSampler(nn.Module):
"""
The implementation strictly follows the algorithm described in
https://arxiv.org/abs/2211.17192.
However, we want to clarify the terminology used in the implementation:
accepted tokens: tokens that are accepted based on the relationship
between the "raw" draft and target probabilities.
recovered tokens: tokens that are sampled based on the adjusted probability
distribution, which is derived from both the draft and target
probabilities.
bonus tokens:
If all proposed tokens are accepted, the bonus token is added to the
end of the sequence. The bonus token is only sampled from the target
probabilities. We pass in the bonus tokens instead of sampling them
in the rejection sampler to allow for more flexibility in the
sampling process. For example, we can use top_p, top_k sampling for
bonus tokens, while spec decode does not support these sampling
strategies.
output tokens:
Tokens are finally generated with the rejection sampler.
output tokens = accepted tokens + recovered tokens + bonus tokens
"""
def __init__(self):
super().__init__()
if current_platform.is_cuda():
if is_flashinfer_available:
if envs.VLLM_USE_FLASHINFER_SAMPLER is not False:
# FIXME(woosuk): Currently, we have errors when using
# FlashInfer for rejection sampling. As a workaround, we
# disable FlashInfer for rejection sampling by default.
logger.info("Currently, FlashInfer rejection sampler is "
"disabled because of a bug. Falling back to "
"the PyTorch-native implementation of "
"rejection sampling.")
self.forward_method = self.forward_native
# NOTE(woosuk): The V0 sampler doesn't use FlashInfer for
# sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by
# default it is unused). For backward compatibility, we set
# `VLLM_USE_FLASHINFER_SAMPLER` as None by default and
# interpret it differently in V0 and V1 samplers: In V0,
# None means False, while in V1, None means True. This is
# why we use the condition
# `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here.
# logger.info("Using FlashInfer for rejection sampling.")
# self.forward_method = self.flashinfer_sample
else:
logger.warning(
"FlashInfer is available, but it is not enabled. "
"Falling back to the PyTorch-native implementation of "
"rejection sampling. For the best performance, "
"please set VLLM_USE_FLASHINFER_SAMPLER=1.")
self.forward_method = self.forward_native
else:
logger.warning(
"FlashInfer is not available. Falling back to the PyTorch-"
"native implementation of rejection sampling. For the "
"best performance, please install FlashInfer.")
self.forward_method = self.forward_native
else:
self.forward_method = self.forward_native
def forward(self, draft_token_ids: list[list[int]],
target_probs: torch.Tensor,
sampling_metadata: SamplingMetadata) -> SamplerOutput:
if not sampling_metadata.all_greedy:
raise NotImplementedError(
"Currently, only greedy sampling is supported by "
"rejection sampler.")
return self.forward_method(draft_token_ids, target_probs,
sampling_metadata)
def flashinfer_sample(
def forward(
self,
draft_token_ids: list[list[int]],
target_probs: torch.Tensor,
draft_probs: Optional[torch.Tensor],
bonus_token_ids_tensor: torch.Tensor, # [batch_size, 1]
target_probs: torch.Tensor, # [num_total_tokens, vocab_size]
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
) -> torch.Tensor:
'''
Args:
draft_token_ids (List[List[int]]):
A 2D list of token IDs for each request in the batch.
Each request might have different number of draft tokens.
It may also contain empty lists for requests that have
no draft tokens.
draft_probs (Optional[torch.Tensor]):
Probability distribution for the draft tokens. Shape is
[batch_size, max_spec_len, vocab_size]. Can be None if
probabilities are not provided, which is the case for
ngram spec decode.
bonus_token_ids_tensor (torch.Tensor):
A tensor containing bonus tokens. Shape is [batch_size, 1].
Bonus tokens are added to the end of the sequence if all
proposed tokens are accepted. We generate the bonus tokens
outside of the rejection sampler with the default sampling
strategy. It allows for more flexibility in the sampling
process such as top_p, top_k sampling.
target_probs (torch.Tensor):
Target model probability distribution.
Shape is [num_total_tokens, vocab_size]. num_total_tokens
is the total number of tokens from all requests. Here,
probabilities from different requests are flattened into
a single tensor because this is the shape of the output
logits.
sampling_metadata (SamplingMetadata):
Additional metadata needed for sampling, such as temperature,
top-k/top-p parameters, or other relevant information.
Returns:
output_token_ids (torch.Tensor):
A tensor containing the final output token IDs.
'''
# NOTE: The following input preparationg can be moved
# to the model runner with a persistent manner for better
# performance.
sample_lens = [len(x) + 1 for x in draft_token_ids]
# Convert draft token IDs to a tensor, split by sample_lens, then pad.
draft_token_ids = [
torch.tensor(x, dtype=int, device='cpu') for x in draft_token_ids
@ -90,90 +92,171 @@ class RejectionSampler(nn.Module):
batch_first=True,
padding_value=INVALID_TOKEN_ID)
if sampling_metadata.all_greedy:
target_token_ids = target_probs.argmax(dim=-1).view(-1)
target_token_ids = target_token_ids.split(sample_lens)
target_token_ids = pad_sequence(target_token_ids,
batch_first=True,
padding_value=INVALID_TOKEN_ID)
# NOTE: CPU <-> GPU synchronization happens here.
draft_token_ids_tensor = draft_token_ids_tensor.to(target_probs.device)
# Create one-hot tensor for draft token ids.
# This is used for ngram where we don't have draft_probs.
if draft_probs is None and not sampling_metadata.all_greedy:
vocab_size = target_probs.size(-1)
# NOTE: CPU <-> GPU synchronization happens here.
draft_token_ids_tensor = draft_token_ids_tensor.to(
target_probs.device)
draft_probs = _create_greedy_token_probs(draft_token_ids_tensor,
vocab_size,
target_probs.device)
target_probs = _create_greedy_token_probs(target_token_ids,
vocab_size,
target_probs.device)
uniform_samples = torch.zeros(draft_token_ids_tensor.size(0),
draft_token_ids_tensor.size(1) + 1,
device=target_probs.device)
else:
raise NotImplementedError(
"Currently, only greedy sampling is supported by "
"rejection sampler.")
sample_lens = [len(x) + 1 for x in draft_token_ids]
target_probs = _convert_2d_probs(target_probs, sample_lens)
sampled_token_ids, _, _ = fs.chain_speculative_sampling(
draft_probs,
draft_token_ids_tensor,
uniform_samples,
target_probs,
)
return SamplerOutput(sampled_token_ids=sampled_token_ids,
logprobs_tensors=None)
return self.forward_native(draft_token_ids_tensor, draft_probs,
bonus_token_ids_tensor, target_probs,
sampling_metadata)
# TODO: The following method can be optimized for better performance.
def forward_native(
self,
draft_token_ids: list[list[int]],
draft_token_ids_tensor: torch.Tensor,
# [batch_size, max_spec_len, vocab_size]
draft_probs: Optional[torch.Tensor],
bonus_token_ids_tensor: torch.Tensor,
# [batch_size, max_spec_len + 1, vocab_size]
target_probs: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
sample_lens = [len(x) + 1 for x in draft_token_ids]
# Convert draft token IDs to a tensor, split by sample_lens, then pad.
draft_token_ids = [
torch.tensor(x, dtype=int, device='cpu') for x in draft_token_ids
]
draft_token_ids_tensor = pad_sequence(draft_token_ids,
batch_first=True,
padding_value=INVALID_TOKEN_ID)
draft_token_ids_tensor = draft_token_ids_tensor.to(target_probs.device)
) -> torch.Tensor:
# Add 1 to include the 'bonus' token.
if sampling_metadata.all_greedy:
output_token_ids = target_probs.argmax(dim=-1).view(-1)
output_token_ids = output_token_ids.split(sample_lens)
output_token_ids = pad_sequence(output_token_ids,
batch_first=True,
padding_value=INVALID_TOKEN_ID)
# Produce a mask that remains 1 (True) until the first
# mismatch (cumprod turns 0 after a mismatch).
accept_mask = (
output_token_ids[:, :-1] == draft_token_ids_tensor).cumprod(
dim=1)
else:
raise NotImplementedError(
"Currently, only greedy sampling is supported by "
"rejection sampler.")
# Identify valid positions (non-padding).
valid_mask = output_token_ids != INVALID_TOKEN_ID
# Generate mask with bonus token.
generate_mask = torch.cat([
accept_mask,
torch.zeros(accept_mask.size(0), 1, device=accept_mask.device)
],
dim=1).to(torch.bool) & valid_mask
zeros_mask = (generate_mask == 0)
first_zero_idx = zeros_mask.float().argmax(dim=1)
# Figure out which rows actually contain at least one zero.
rows_with_zero = zeros_mask.any(dim=1)
# Use indexing to set the first zero in each of those rows to 1.
generate_mask[rows_with_zero, first_zero_idx[rows_with_zero]] = 1
target_token_ids_tensor = target_probs.argmax(dim=-1)
accept_mask = (target_token_ids_tensor[:, :-1] ==
draft_token_ids_tensor).cumprod(dim=1)
output_token_ids[~generate_mask] = INVALID_TOKEN_ID
return SamplerOutput(sampled_token_ids=output_token_ids,
logprobs_tensors=None)
# Identify valid positions (non-padding).
valid_mask = target_token_ids_tensor != INVALID_TOKEN_ID
# Generate mask with bonus token.
generate_mask = torch.cat([
accept_mask,
torch.zeros(accept_mask.size(0), 1, device=accept_mask.device)
],
dim=1).to(torch.bool) & valid_mask
zeros_mask = (generate_mask == 0)
first_zero_idx = zeros_mask.float().argmax(dim=1)
# Figure out which rows actually contain at least one zero.
rows_with_zero = zeros_mask.any(dim=1)
# Use indexing to set the first zero in each of those rows to 1.
generate_mask[rows_with_zero, first_zero_idx[rows_with_zero]] = 1
output_token_ids = target_token_ids_tensor
output_token_ids[~generate_mask] = INVALID_TOKEN_ID
else:
# Reference: https://arxiv.org/pdf/2211.17192
# 1. Extract the probabilities of the draft tokens.
# [batch_size, max_spec_len]
batch_size = draft_token_ids_tensor.size(0)
max_spec_len = draft_token_ids_tensor.size(1)
invalid_idx = draft_token_ids_tensor == INVALID_TOKEN_ID
draft_token_ids_tensor[invalid_idx] = 0
assert draft_probs is not None
draft_token_probs = draft_probs.gather(
dim=-1, index=draft_token_ids_tensor.unsqueeze(-1)).squeeze(-1)
target_token_probs = target_probs.gather(
dim=-1, index=draft_token_ids_tensor.unsqueeze(-1)).squeeze(-1)
# Force the probabilities of invalid tokens to inf
# so that they are not accepted.
draft_token_probs[invalid_idx] = float('inf')
# 2. Generate uniform samples.
# [batch_size, max_spec_len + 1]
uniform_samples = _create_uniform_samples(
sampling_metadata.generators, batch_size, max_spec_len,
target_probs.device)
# 3. Accept or reject the samples.
# [batch_size, max_spec_len]
# If the draft token probabilities are 0, set them to the smallest
# positive normal value representable by float32.
safe_draft_probs = torch.where(draft_token_probs > 0,
draft_token_probs,
torch.finfo(torch.float32).tiny)
accepted = uniform_samples <= target_token_probs / safe_draft_probs
accept_mask = accepted.cumprod(dim=1)
# Set the token ids to the draft token ids if accepted, otherwise
# set them to INVALID_TOKEN_ID.
accepted_token_ids = (draft_token_ids_tensor * accept_mask +
INVALID_TOKEN_ID * (1 - accept_mask))
# 4. Adjust the distribution for the recovered tokens.
# Clamp the bonus probabilities to the smallest positive normal
# value representable by float32.
bonus_prob = torch.clamp(target_probs[:, :-1, :] - draft_probs,
min=torch.finfo(torch.float32).tiny)
normalized_bonus_prob = bonus_prob / bonus_prob.sum(dim=-1,
keepdim=True)
# 5. Sample recovered token ids.
recovered_token_ids = random_sample(
normalized_bonus_prob,
sampling_metadata.generators).reshape(batch_size, max_spec_len)
# 6. Get the final output token ids.
# output_token_ids = accepted_token_ids +
# recovered_token_ids +
# bonus_token_id
recovered_bonus_token_ids = torch.cat(
[recovered_token_ids, bonus_token_ids_tensor], dim=1)
# Generate mask with bonus tokens.
generate_mask = torch.cat([
accept_mask,
torch.zeros(batch_size, 1, device=accept_mask.device)
],
dim=1).to(torch.bool)
zeros_mask = (generate_mask == 0)
first_zero_idx = zeros_mask.float().argmax(dim=1)
output_token_ids = torch.cat([
accepted_token_ids,
torch.full((batch_size, 1),
fill_value=INVALID_TOKEN_ID,
device=accept_mask.device)
],
dim=1)
output_token_ids[torch.arange(batch_size),
first_zero_idx] = recovered_bonus_token_ids[
torch.arange(batch_size), first_zero_idx]
return output_token_ids
def compute_probs(self, logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
sample_lens: list[int]) -> torch.Tensor:
"""
Compute probability distribution from logits based on sampling metadata.
This function applies temperature scaling to the logits and converts
them to probabilities using softmax. Note that division by
temperature is not performed inplace to preserve the original logits
tensor, which will be used by the original sampler to get bonus tokens.
Args:
logits: Input logits tensor to be converted to probabilities
sampling_metadata: Metadata containing sampling parameters such
as temperature and whether greedy sampling is used
sample_lens: List of sample lengths used for repeating
temperature values
Returns:
torch.Tensor: Probability distribution (softmax of scaled logits)
if non-greedy sampling is used, otherwise returns the
original logits
"""
if sampling_metadata.all_greedy:
return logits
assert sampling_metadata.temperature is not None
# We should optimize the following code as
# it will cause CPU -> GPU synchronization.
temperature = torch.repeat_interleave(
sampling_metadata.temperature,
torch.tensor(sample_lens,
device=sampling_metadata.temperature.device))
temperature = temperature.unsqueeze(dim=1)
logits = logits / temperature
return logits.softmax(dim=-1, dtype=torch.float32)
def _create_greedy_token_probs(
@ -199,3 +282,66 @@ def _create_greedy_token_probs(
src=valid_mask.unsqueeze(-1).float())
return token_probs
def _convert_2d_probs(
probs: torch.Tensor, # [num_total_tokens, vocab_size]
sample_lens: list[int]) -> torch.Tensor:
"""
Converts a 2D tensor of probabilities to a 3D tensor with padding.
[num_total_tokens, vocab_size] ->
[batch_size, max_spec_len + 1, vocab_size]
"""
cumulative_lens = torch.cumsum(torch.tensor(sample_lens,
device=probs.device),
dim=0)
split_indices = cumulative_lens[:-1].tolist() # Exclude last index
# Split into chunks without loops
chunks = torch.tensor_split(probs, split_indices, dim=0)
# Pad all sequences to maximum length
padded_probs = pad_sequence(chunks, batch_first=True, padding_value=0.0)
return padded_probs
def _create_uniform_samples(seeded_seqs: dict[int, torch.Generator],
batch_size: int, k: int,
device: torch.device) -> torch.Tensor:
"""
Generates a batch of uniform random samples, with optional seeding
for specific sequences.
This method creates a tensor of shape `(batch_size, k)` filled
with uniform random values in the range [0, 1). If `seeded_seqs`
is provided, the sequences corresponding to specific indices
will be generated using the provided `torch.Generator` for
reproducibility. The other sequences will be generated without
a seed.
Args:
seeded_seqs : Optional[Dict[int, torch.Generator]]
A dictionary mapping indices in the batch to
`torch.Generator` objects.
batch_size : int
The number of sequences to generate.
k : int
The number of random samples per sequence.
device : torch.device
The device on which to allocate the tensor.
Returns:
uniform_rand : torch.Tensor
A tensor of shape `(batch_size, k)` containing uniform
random values in the range [0, 1).
"""
uniform_rand = torch.rand(batch_size,
k,
dtype=torch.float32,
device=device)
# Apply seeded generators only where needed
if seeded_seqs:
for idx, generator in seeded_seqs.items():
uniform_rand[idx].uniform_(0, 1, generator=generator)
return uniform_rand

View File

@ -119,14 +119,6 @@ class Sampler(nn.Module):
)
return sampled
def compute_probs(self, logits: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
if sampling_metadata.all_greedy:
return logits
# Apply temperature. This is an in-place op changing logits.
logits = self.apply_temperature(logits, sampling_metadata.temperature)
return logits.softmax(dim=-1, dtype=torch.float32)
def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
return logits.log_softmax(dim=-1, dtype=torch.float32)

View File

@ -0,0 +1,22 @@
# SPDX-License-Identifier: Apache-2.0
from vllm.v1.sample.ops.topk_topp_sampler import random_sample # noqa
from vllm.v1.worker.gpu_input_batch import InputBatch
def is_spec_decode_supported(req_id: str, input_batch: InputBatch) -> bool:
if req_id in input_batch.top_k_reqs or req_id in input_batch.top_p_reqs:
# Spec decode doesn't support top_p/top_k sampling.
return False
elif req_id in input_batch.min_p_reqs:
# Spec decode doesn't support min_p sampling.
return False
elif (req_id in input_batch.frequency_penalties_reqs
or req_id in input_batch.presence_penalties_reqs
or req_id in input_batch.repetition_penalties_reqs):
# Spec decode doesn't support penalties.
return False
elif req_id in input_batch.num_logprobs:
# Spec decode doesn't support logprobs.
return False
return True

View File

@ -37,6 +37,7 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID, RejectionSampler
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.spec_decode.utils import is_spec_decode_supported
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
@ -1020,15 +1021,26 @@ class GPUModelRunner(LoRAModelRunnerMixin):
sampling_metadata=sampling_metadata,
)
else:
target_probs = self.model.sampler.compute_probs(
logits, sampling_metadata)
draft_token_ids = [
scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
for req_id in self.input_batch.req_ids
]
sampler_output = self.rejection_sampler(draft_token_ids,
target_probs,
sampling_metadata)
sample_lens = [len(tokens) + 1 for tokens in draft_token_ids]
recover_logits_idx = np.cumsum(sample_lens) - 1
target_probs = self.rejection_sampler.compute_probs(
logits, sampling_metadata, sample_lens)
sampler_output = self.model.sample(
logits=logits[recover_logits_idx, :],
sampling_metadata=sampling_metadata,
)
bonus_token_ids = sampler_output.sampled_token_ids
output_token_ids = self.rejection_sampler(
draft_token_ids,
None, # draft_probs
bonus_token_ids,
target_probs,
sampling_metadata)
sampler_output.sampled_token_ids = output_token_ids
# TODO(woosuk): The following loop can be slow since it iterates over
# the requests one by one. Optimize.
@ -1075,7 +1087,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
spec_token_ids = None
else:
spec_token_ids = self.generate_draft_token_ids(
valid_sampled_token_ids)
valid_sampled_token_ids, sampling_metadata)
return ModelRunnerOutput(
req_ids=self.input_batch.req_ids,
@ -1089,6 +1101,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def generate_draft_token_ids(
self,
sampled_token_ids: list[list[int]],
sampling_metadata: SamplingMetadata,
) -> list[list[int]]:
# TODO(woosuk): Optimize.
draft_token_ids: list[list[int]] = []
@ -1099,6 +1112,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
draft_token_ids.append([])
continue
# Skip requests that require top-p, top-k, etc.
req_id = self.input_batch.req_ids[i]
if not is_spec_decode_supported(req_id, self.input_batch):
draft_token_ids.append([])
continue
# Add sampled_token_ids to token_ids_cpu.
start_idx = self.input_batch.num_tokens_no_spec[i]
end_idx = start_idx + num_sampled_ids