[V1][Spec Decode] Change Spec Decode Rejection Sampling API (#13729)

This commit is contained in:
Lily Liu 2025-02-25 18:14:48 -08:00 committed by GitHub
parent 9ba28043b5
commit 5629f26df7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 102 additions and 109 deletions

View File

@ -29,7 +29,6 @@ def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata:
temperature=torch.tensor([]),
all_greedy=True,
all_random=False,
spec_token_ids=spec_tokens,
top_p=None,
top_k=None,
min_p=torch.empty(batch_size, ),
@ -55,7 +54,7 @@ def test_perfect_match(sampler):
metadata = create_sampling_metadata(spec_tokens)
logits = create_logits_tensor(output_tokens)
output = sampler(logits, metadata)
output = sampler(spec_tokens, logits, metadata)
expected = torch.tensor([[1, 2, 3, 4]],
dtype=torch.int,
device=logits.device)
@ -70,7 +69,7 @@ def test_early_mismatch(sampler):
metadata = create_sampling_metadata(spec_tokens)
logits = create_logits_tensor(output_tokens)
output = sampler(logits, metadata)
output = sampler(spec_tokens, logits, metadata)
expected = torch.tensor([[1, 5, INVALID_TOKEN_ID, INVALID_TOKEN_ID]],
dtype=torch.int,
device=logits.device)
@ -85,7 +84,7 @@ def test_multiple_sequences(sampler):
metadata = create_sampling_metadata(spec_tokens)
logits = create_logits_tensor(output_tokens)
output = sampler(logits, metadata)
output = sampler(spec_tokens, logits, metadata)
expected = torch.tensor([[1, 2, 5], [3, 4, INVALID_TOKEN_ID]],
dtype=torch.int,
device=logits.device)
@ -100,7 +99,7 @@ def test_single_token_sequence(sampler):
metadata = create_sampling_metadata(spec_tokens)
logits = create_logits_tensor(output_tokens)
output = sampler(logits, metadata)
output = sampler(spec_tokens, logits, metadata)
expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device)
assert torch.equal(output.sampled_token_ids, expected)
@ -113,7 +112,7 @@ def test_empty_sequence(sampler):
metadata = create_sampling_metadata(spec_tokens)
logits = create_logits_tensor(output_tokens)
output = sampler(logits, metadata)
output = sampler(spec_tokens, logits, metadata)
expected = torch.tensor([[5]], dtype=torch.int, device=logits.device)
assert torch.equal(output.sampled_token_ids, expected)
@ -126,7 +125,7 @@ def test_multiple_mismatches(sampler):
metadata = create_sampling_metadata(spec_tokens)
logits = create_logits_tensor(output_tokens)
output = sampler(logits, metadata)
output = sampler(spec_tokens, logits, metadata)
expected = torch.tensor([[1, 2, 7, INVALID_TOKEN_ID],
[4, 8, INVALID_TOKEN_ID, INVALID_TOKEN_ID]],
dtype=torch.int,
@ -147,7 +146,7 @@ def test_parametrized_cases(sampler, spec_tokens, output_tokens, expected):
metadata = create_sampling_metadata(spec_tokens)
logits = create_logits_tensor(output_tokens)
output = sampler(logits, metadata)
output = sampler(spec_tokens, logits, metadata)
expected_tensor = torch.tensor(expected,
dtype=torch.int,
device=logits.device)
@ -163,7 +162,7 @@ def test_logits_shape_handling(sampler):
metadata = create_sampling_metadata(spec_tokens)
logits = create_logits_tensor(output_tokens, vocab_size)
output = sampler(logits, metadata)
output = sampler(spec_tokens, logits, metadata)
expected = torch.tensor([[1, 2, 3]], dtype=torch.int, device=logits.device)
assert torch.equal(output.sampled_token_ids, expected)
assert logits.shape[-1] == vocab_size

View File

@ -105,7 +105,6 @@ def _create_default_sampling_metadata(
prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids,
vocab_size, device),
output_token_ids=output_token_ids,
spec_token_ids=None,
frequency_penalties=_create_penalty_tensor(batch_size, 0.0, device),
presence_penalties=_create_penalty_tensor(batch_size, 0.0, device),
repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device),

View File

@ -123,7 +123,6 @@ def _construct_expected_sampling_metadata(
dtype=torch.float,
device=device),
output_token_ids=output_token_ids,
spec_token_ids=None,
min_tokens=min_tokens,
no_penalties=(all(x == 0 for x in presence_penalties)
and all(x == 0 for x in frequency_penalties)

View File

@ -13,9 +13,6 @@ class SamplingMetadata:
all_greedy: bool
all_random: bool
# None when there are no speculated tokens.
spec_token_ids: Optional[List[List[int]]]
top_p: Optional[torch.Tensor]
top_k: Optional[torch.Tensor]
min_p: Optional[torch.Tensor]

View File

@ -1,4 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import List
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
@ -52,62 +54,62 @@ class RejectionSampler(nn.Module):
else:
self.forward_method = self.forward_native
def forward(self, logits: torch.Tensor,
def forward(self, draft_token_ids: List[List[int]],
target_probs: torch.Tensor,
sampling_metadata: SamplingMetadata) -> SamplerOutput:
if not sampling_metadata.all_greedy:
raise NotImplementedError(
"Currently, only greedy sampling is supported by "
"rejection sampler.")
return self.forward_method(logits, sampling_metadata)
return self.forward_method(draft_token_ids, target_probs,
sampling_metadata)
def flashinfer_sample(
self,
logits: torch.Tensor,
draft_token_ids: List[List[int]],
target_probs: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
# NOTE: The following input preparationg can be moved
# to the model runner with a persistent manner for better
# performance.
assert sampling_metadata.spec_token_ids is not None
spec_token_ids = sampling_metadata.spec_token_ids
max_spec_len = max(len(s) for s in spec_token_ids)
batch_size = len(spec_token_ids)
draft_token_ids = torch.full((batch_size, max_spec_len),
INVALID_TOKEN_ID,
device="cpu",
dtype=torch.long)
sample_lens = [len(x) + 1 for x in draft_token_ids]
# Convert draft token IDs to a tensor, split by sample_lens, then pad.
draft_token_ids = [
torch.tensor(x, dtype=int, device='cpu') for x in draft_token_ids
]
draft_token_ids_tensor = pad_sequence(draft_token_ids,
batch_first=True,
padding_value=INVALID_TOKEN_ID)
target_token_ids = torch.full((batch_size, max_spec_len + 1),
fill_value=INVALID_TOKEN_ID,
device=logits.device,
dtype=torch.long)
if sampling_metadata.all_greedy:
target_token_ids = target_probs.argmax(dim=-1).view(-1)
target_token_ids = target_token_ids.split(sample_lens)
target_token_ids = pad_sequence(target_token_ids,
batch_first=True,
padding_value=INVALID_TOKEN_ID)
# TODO: Vectorize the following loop for better performance.
start_loc = 0
for i in range(batch_size):
num_spec_tokens = len(spec_token_ids[i])
draft_token_ids[i, :num_spec_tokens] = torch.tensor(
spec_token_ids[i], device="cpu", dtype=torch.long)
end_loc = start_loc + num_spec_tokens + 1
# Assume greedy sampling.
target_token_ids[i, :num_spec_tokens + 1] = torch.argmax(
logits[start_loc:end_loc], dim=-1)
start_loc = end_loc
vocab_size = logits.size(-1)
# NOTE: CPU <-> GPU synchronization happens here.
draft_token_ids = draft_token_ids.to(logits.device)
draft_probs = _create_greedy_token_probs(draft_token_ids, vocab_size,
logits.device)
target_probs = _create_greedy_token_probs(target_token_ids, vocab_size,
logits.device)
uniform_samples = torch.zeros(batch_size,
max_spec_len + 1,
device=logits.device)
vocab_size = target_probs.size(-1)
# NOTE: CPU <-> GPU synchronization happens here.
draft_token_ids_tensor = draft_token_ids_tensor.to(
target_probs.device)
draft_probs = _create_greedy_token_probs(draft_token_ids_tensor,
vocab_size,
target_probs.device)
target_probs = _create_greedy_token_probs(target_token_ids,
vocab_size,
target_probs.device)
uniform_samples = torch.zeros(draft_token_ids_tensor.size(0),
draft_token_ids_tensor.size(1) + 1,
device=target_probs.device)
else:
raise NotImplementedError(
"Currently, only greedy sampling is supported by "
"rejection sampler.")
sampled_token_ids, _, _ = fs.chain_speculative_sampling(
draft_probs,
draft_token_ids,
draft_token_ids_tensor,
uniform_samples,
target_probs,
)
@ -117,35 +119,35 @@ class RejectionSampler(nn.Module):
# TODO: The following method can be optimized for better performance.
def forward_native(
self,
logits: torch.Tensor,
draft_token_ids: List[List[int]],
target_probs: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
assert sampling_metadata.spec_token_ids is not None
spec_lens = [len(x) for x in sampling_metadata.spec_token_ids]
# Add 1 to include the 'bonus' token.
sample_lens = [x + 1 for x in spec_lens]
output_token_ids = logits.argmax(dim=-1).view(-1)
output_token_ids = output_token_ids.split(sample_lens)
output_token_ids = pad_sequence(output_token_ids,
batch_first=True,
padding_value=INVALID_TOKEN_ID)
# Convert spec token IDs to a tensor, split by sample_lens, then pad.
spec_token_ids = [
torch.tensor(x,
dtype=output_token_ids.dtype,
device=output_token_ids.device)
for x in sampling_metadata.spec_token_ids
sample_lens = [len(x) + 1 for x in draft_token_ids]
# Convert draft token IDs to a tensor, split by sample_lens, then pad.
draft_token_ids = [
torch.tensor(x, dtype=int, device='cpu') for x in draft_token_ids
]
spec_token_ids = pad_sequence(spec_token_ids,
batch_first=True,
padding_value=INVALID_TOKEN_ID)
# Produce a mask that remains 1 (True) until the first
# mismatch (cumprod turns 0 after a mismatch).
accept_mask = (output_token_ids[:, :-1] == spec_token_ids).cumprod(
dim=1)
draft_token_ids_tensor = pad_sequence(draft_token_ids,
batch_first=True,
padding_value=INVALID_TOKEN_ID)
draft_token_ids_tensor = draft_token_ids_tensor.to(target_probs.device)
# Add 1 to include the 'bonus' token.
if sampling_metadata.all_greedy:
output_token_ids = target_probs.argmax(dim=-1).view(-1)
output_token_ids = output_token_ids.split(sample_lens)
output_token_ids = pad_sequence(output_token_ids,
batch_first=True,
padding_value=INVALID_TOKEN_ID)
# Produce a mask that remains 1 (True) until the first
# mismatch (cumprod turns 0 after a mismatch).
accept_mask = (
output_token_ids[:, :-1] == draft_token_ids_tensor).cumprod(
dim=1)
else:
raise NotImplementedError(
"Currently, only greedy sampling is supported by "
"rejection sampler.")
# Identify valid positions (non-padding).
valid_mask = output_token_ids != INVALID_TOKEN_ID
# Generate mask with bonus token.

View File

@ -9,7 +9,6 @@ from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.penalties import (apply_all_penalties,
apply_min_token_penalties)
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
from vllm.v1.sample.rejection_sampler import RejectionSampler
_SAMPLING_EPS = 1e-5
@ -19,22 +18,12 @@ class Sampler(nn.Module):
def __init__(self):
super().__init__()
self.topk_topp_sampler = TopKTopPSampler()
self.rejection_sampler = RejectionSampler()
def forward(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
if sampling_metadata.spec_token_ids:
if sampling_metadata.max_num_logprobs:
raise NotImplementedError(
"Rejection sampling does not support logprobs.")
return self.rejection_sampler(
logits,
sampling_metadata,
)
# NOTE(woosuk): Use the original logits (before any penalties or
# temperature scaling) for the top-k logprobs.
# This is different from the V0 sampler, which uses the logits that
@ -127,6 +116,14 @@ class Sampler(nn.Module):
)
return sampled
def compute_probs(self, logits: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
if sampling_metadata.all_greedy:
return logits
# Apply temperature. This is an in-place op changing logits.
logits = self.apply_temperature(logits, sampling_metadata.temperature)
return logits.softmax(dim=-1, dtype=torch.float32)
def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
return logits.log_softmax(dim=-1, dtype=torch.float32)

View File

@ -490,23 +490,12 @@ class InputBatch:
presence_penalties=self.presence_penalties[:num_reqs],
repetition_penalties=self.repetition_penalties[:num_reqs],
output_token_ids=cast(List[List[int]], self.req_output_token_ids),
spec_token_ids=None,
min_tokens=self.min_tokens,
no_penalties=self.no_penalties,
logit_bias=self.logit_bias[:num_reqs],
allowed_token_ids_mask=allowed_token_ids_mask,
)
def get_sampling_metadata(
self,
req_id_to_spec_token_ids: Dict[str, List[int]],
) -> SamplingMetadata:
# Set the new spec token ids in the cached sampling metadata.
self.sampling_metadata.spec_token_ids = [
req_id_to_spec_token_ids.get(req_id, []) for req_id in self.req_ids
] if req_id_to_spec_token_ids else None
return self.sampling_metadata
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
prompt_token_ids_cpu_tensor = torch.empty(

View File

@ -32,7 +32,7 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID, RejectionSampler
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
@ -122,7 +122,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.use_spec_decode = False
if self.speculative_config:
self.use_spec_decode = True
self.rejection_sampler = RejectionSampler()
# TODO: find a better way to check if we are using ngram.
assert self.speculative_config.ngram_prompt_lookup_min, \
"Currently, only ngram spec decode is supported in V1."
@ -951,12 +951,24 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logits = self.model.compute_logits(sample_hidden_states, None)
# Sample the next token and get logprobs if needed.
sampling_metadata = self.input_batch.get_sampling_metadata(
scheduler_output.scheduled_spec_decode_tokens)
sampler_output = self.model.sample(
logits=logits,
sampling_metadata=sampling_metadata,
)
sampling_metadata = self.input_batch.sampling_metadata
if not self.use_spec_decode:
sampler_output = self.model.sample(
logits=logits,
sampling_metadata=sampling_metadata,
)
else:
target_probs = self.model.sampler.compute_probs(
logits, sampling_metadata)
scheduled_request_ids = scheduler_output.num_scheduled_tokens.keys(
)
draft_token_ids = [
scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
for req_id in scheduled_request_ids
]
sampler_output = self.rejection_sampler(draft_token_ids,
target_probs,
sampling_metadata)
# TODO(woosuk): The following loop can be slow since it iterates over
# the requests one by one. Optimize.
@ -1293,7 +1305,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
temperature=dummy_tensors(0.5),
all_greedy=False,
all_random=False,
spec_token_ids=None,
top_p=dummy_tensors(0.9),
top_k=dummy_tensors(logits.size(1) - 1),
min_p=None,