mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 22:45:50 +08:00
[V1][Spec Decode] Enable spec decode for top-p & top-k sampling (#15063)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
f533b5837f
commit
ebcebeeb6b
@ -36,6 +36,8 @@ def create_logits_tensor(output_token_ids: list[list[int]],
|
|||||||
def create_sampling_metadata(
|
def create_sampling_metadata(
|
||||||
all_greedy: bool,
|
all_greedy: bool,
|
||||||
temperature: Optional[torch.Tensor] = None,
|
temperature: Optional[torch.Tensor] = None,
|
||||||
|
top_k: Optional[torch.Tensor] = None,
|
||||||
|
top_p: Optional[torch.Tensor] = None,
|
||||||
generators: Optional[dict[int, Any]] = None,
|
generators: Optional[dict[int, Any]] = None,
|
||||||
) -> SamplingMetadata:
|
) -> SamplingMetadata:
|
||||||
"""Create a v1 sampling metadata object with all_greedy set
|
"""Create a v1 sampling metadata object with all_greedy set
|
||||||
@ -52,8 +54,8 @@ def create_sampling_metadata(
|
|||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
all_greedy=all_greedy,
|
all_greedy=all_greedy,
|
||||||
all_random=not all_greedy,
|
all_random=not all_greedy,
|
||||||
top_p=None,
|
top_p=top_p,
|
||||||
top_k=None,
|
top_k=top_k,
|
||||||
min_p=torch.empty(1, ),
|
min_p=torch.empty(1, ),
|
||||||
generators=generators,
|
generators=generators,
|
||||||
max_num_logprobs=0,
|
max_num_logprobs=0,
|
||||||
@ -462,3 +464,147 @@ def estimate_rejection_sampling_pdf(
|
|||||||
density=True)
|
density=True)
|
||||||
|
|
||||||
return hist.hist
|
return hist.hist
|
||||||
|
|
||||||
|
|
||||||
|
def _test_masked_logits(
|
||||||
|
rejection_sampler,
|
||||||
|
batch_size: int,
|
||||||
|
num_draft_tokens: int,
|
||||||
|
vocab_size: int,
|
||||||
|
target_logits: torch.Tensor,
|
||||||
|
unmasked_indices: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
):
|
||||||
|
# Set up test parameters
|
||||||
|
num_tokens = batch_size * num_draft_tokens
|
||||||
|
|
||||||
|
# Create random draft probabilities.
|
||||||
|
draft_probs = torch.rand((num_tokens, vocab_size),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=DEVICE)
|
||||||
|
draft_probs = F.softmax(draft_probs, dim=-1)
|
||||||
|
|
||||||
|
# Randomly sample draft token ids from draft probs
|
||||||
|
draft_token_ids = torch.multinomial(draft_probs, num_samples=1)
|
||||||
|
draft_token_ids = draft_token_ids.reshape(batch_size, num_draft_tokens)
|
||||||
|
draft_token_ids = draft_token_ids.tolist()
|
||||||
|
|
||||||
|
# Bonus tokens not used but required
|
||||||
|
bonus_token_ids = torch.zeros((batch_size, 1),
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=DEVICE)
|
||||||
|
|
||||||
|
# Create spec decode metadata
|
||||||
|
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||||
|
draft_token_ids,
|
||||||
|
device=DEVICE,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run rejection sampling
|
||||||
|
output_token_ids = rejection_sampler(
|
||||||
|
spec_decode_metadata,
|
||||||
|
draft_probs=draft_probs,
|
||||||
|
target_logits=target_logits,
|
||||||
|
bonus_token_ids=bonus_token_ids,
|
||||||
|
sampling_metadata=sampling_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove bonus tokens and reshape
|
||||||
|
output_token_ids = output_token_ids[:, :-1].flatten().tolist()
|
||||||
|
|
||||||
|
# Check that all sampled tokens are within the unmasked indices.
|
||||||
|
for i in range(num_tokens):
|
||||||
|
token_id = output_token_ids[i]
|
||||||
|
if token_id == PLACEHOLDER_TOKEN_ID:
|
||||||
|
continue
|
||||||
|
assert token_id in unmasked_indices[i]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("top_k", [1, 5, 99])
|
||||||
|
def test_top_k(rejection_sampler, top_k):
|
||||||
|
"""Test rejection sampling with top-k sampling"""
|
||||||
|
vocab_size = 100
|
||||||
|
batch_size = 100
|
||||||
|
num_draft_tokens = 3
|
||||||
|
num_tokens = batch_size * num_draft_tokens
|
||||||
|
|
||||||
|
# Randomly create top-k indices.
|
||||||
|
top_k_indices = [
|
||||||
|
torch.randperm(vocab_size, device=DEVICE)[:top_k]
|
||||||
|
for _ in range(num_tokens)
|
||||||
|
]
|
||||||
|
top_k_indices = torch.stack(top_k_indices)
|
||||||
|
|
||||||
|
# Create logits with the uniform distribution.
|
||||||
|
target_logits = torch.zeros((num_tokens, vocab_size), device=DEVICE)
|
||||||
|
|
||||||
|
# Increment the logits for top-k indices, a little bit more than the other
|
||||||
|
# ones. If the masking is effective, the non-topk indices will never be
|
||||||
|
# sampled despite the small difference in logits.
|
||||||
|
for i in range(num_tokens):
|
||||||
|
target_logits[i, top_k_indices[i]] += 0.1
|
||||||
|
|
||||||
|
# Create sampling metadata
|
||||||
|
temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE)
|
||||||
|
sampling_metadata = create_sampling_metadata(
|
||||||
|
all_greedy=False,
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=torch.tensor([top_k] * batch_size,
|
||||||
|
device=DEVICE,
|
||||||
|
dtype=torch.int64),
|
||||||
|
)
|
||||||
|
|
||||||
|
_test_masked_logits(
|
||||||
|
rejection_sampler,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_draft_tokens=num_draft_tokens,
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
target_logits=target_logits,
|
||||||
|
unmasked_indices=top_k_indices,
|
||||||
|
sampling_metadata=sampling_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("top_p", [0.5, 0.9, 0.99])
|
||||||
|
def test_top_p(rejection_sampler, top_p):
|
||||||
|
"""Test rejection sampling with top-p sampling"""
|
||||||
|
vocab_size = 100
|
||||||
|
batch_size = 100
|
||||||
|
num_draft_tokens = 3
|
||||||
|
num_tokens = batch_size * num_draft_tokens
|
||||||
|
|
||||||
|
# Create logits with the uniform distribution.
|
||||||
|
target_logits = torch.randn((num_tokens, vocab_size), device=DEVICE)
|
||||||
|
temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE)
|
||||||
|
rescaled_logits = target_logits / temperature
|
||||||
|
|
||||||
|
logits_sort, logits_idx = rescaled_logits.sort(dim=-1, descending=False)
|
||||||
|
probs_sort = logits_sort.softmax(dim=-1)
|
||||||
|
probs_sum = probs_sort.cumsum(dim=-1)
|
||||||
|
top_p_mask = probs_sum <= 1 - top_p
|
||||||
|
# at least one
|
||||||
|
top_p_mask[:, -1] = False
|
||||||
|
|
||||||
|
# Get the top-p indices.
|
||||||
|
top_p_indices = []
|
||||||
|
for i in range(num_tokens):
|
||||||
|
top_p_indices.append(logits_idx[i][~top_p_mask[i]].tolist())
|
||||||
|
|
||||||
|
# Create sampling metadata
|
||||||
|
sampling_metadata = create_sampling_metadata(
|
||||||
|
all_greedy=False,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=torch.tensor([top_p] * batch_size,
|
||||||
|
device=DEVICE,
|
||||||
|
dtype=torch.float32),
|
||||||
|
)
|
||||||
|
|
||||||
|
_test_masked_logits(
|
||||||
|
rejection_sampler,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_draft_tokens=num_draft_tokens,
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
target_logits=target_logits,
|
||||||
|
unmasked_indices=top_p_indices,
|
||||||
|
sampling_metadata=sampling_metadata,
|
||||||
|
)
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import triton.language as tl
|
|||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
|
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||||
from vllm.v1.sample.ops.utils import compiled_softmax
|
from vllm.v1.sample.ops.utils import compiled_softmax
|
||||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||||
|
|
||||||
@ -245,23 +246,79 @@ def compute_probs(
|
|||||||
return logits
|
return logits
|
||||||
|
|
||||||
num_tokens = logits.shape[0]
|
num_tokens = logits.shape[0]
|
||||||
batch_size = cu_num_draft_tokens.shape[0]
|
temperature = expand_batch_to_tokens(
|
||||||
expanded_temperature = torch.empty(
|
|
||||||
(num_tokens, 1),
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=logits.device,
|
|
||||||
)
|
|
||||||
expand_kernel[(batch_size, )](
|
|
||||||
expanded_temperature,
|
|
||||||
sampling_metadata.temperature,
|
sampling_metadata.temperature,
|
||||||
cu_num_draft_tokens,
|
cu_num_draft_tokens,
|
||||||
GREEDY_TEMPERATURE, # replace_from
|
num_tokens,
|
||||||
1, # replace_to
|
replace_from=GREEDY_TEMPERATURE,
|
||||||
MAX_NUM_TOKENS=MAX_SPEC_LEN,
|
replace_to=1,
|
||||||
|
)
|
||||||
|
# TODO(woosuk): Consider using in-place op to reduce memory usage.
|
||||||
|
logits = logits / temperature.unsqueeze(-1)
|
||||||
|
|
||||||
|
# Get expanded top_k and top_p tensors.
|
||||||
|
top_k = None
|
||||||
|
if sampling_metadata.top_k is not None:
|
||||||
|
top_k = expand_batch_to_tokens(
|
||||||
|
sampling_metadata.top_k,
|
||||||
|
cu_num_draft_tokens,
|
||||||
|
num_tokens,
|
||||||
|
)
|
||||||
|
top_p = None
|
||||||
|
if sampling_metadata.top_p is not None:
|
||||||
|
top_p = expand_batch_to_tokens(
|
||||||
|
sampling_metadata.top_p,
|
||||||
|
cu_num_draft_tokens,
|
||||||
|
num_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask,
|
||||||
|
# which is slow for large vocab sizes. This may cause performance issues.
|
||||||
|
logits = apply_top_k_top_p(logits, top_k, top_p)
|
||||||
|
|
||||||
|
output_prob = compiled_softmax(logits)
|
||||||
|
return output_prob
|
||||||
|
|
||||||
|
|
||||||
|
def expand_batch_to_tokens(
|
||||||
|
x: torch.Tensor, # [batch_size]
|
||||||
|
cu_num_tokens: torch.Tensor, # [batch_size]
|
||||||
|
num_tokens: int,
|
||||||
|
replace_from: int = 0,
|
||||||
|
replace_to: int = 0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Expand [batch_size] tensor to [num_tokens] tensor based on the number of
|
||||||
|
tokens per batch in cu_num_tokens.
|
||||||
|
|
||||||
|
For example, if x = [a, b, c] and cu_num_tokens = [2, 5, 6], then
|
||||||
|
num_tokens = 6, and expanded_x = [a, a, b, b, b, c].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: [batch_size] tensor to expand.
|
||||||
|
cu_num_tokens: [batch_size] tensor containing the cumulative number of
|
||||||
|
tokens per batch. Each element represents the total number of
|
||||||
|
tokens up to and including that batch.
|
||||||
|
num_tokens: Total number of tokens.
|
||||||
|
replace_from: int = 0
|
||||||
|
Value to be replaced if it is found in x.
|
||||||
|
replace_to: int = 0
|
||||||
|
Value to replace with when replace_from is found.
|
||||||
|
Returns:
|
||||||
|
expanded_x: [num_tokens] tensor.
|
||||||
|
"""
|
||||||
|
batch_size = x.shape[0]
|
||||||
|
assert cu_num_tokens.shape[0] == batch_size
|
||||||
|
expanded_x = x.new_empty(num_tokens)
|
||||||
|
expand_kernel[(batch_size, )](
|
||||||
|
expanded_x,
|
||||||
|
x,
|
||||||
|
cu_num_tokens,
|
||||||
|
replace_from,
|
||||||
|
replace_to,
|
||||||
|
MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation.
|
||||||
num_warps=1,
|
num_warps=1,
|
||||||
)
|
)
|
||||||
output_prob = compiled_softmax(logits, expanded_temperature)
|
return expanded_x
|
||||||
return output_prob
|
|
||||||
|
|
||||||
|
|
||||||
def generate_uniform_probs(
|
def generate_uniform_probs(
|
||||||
|
|||||||
@ -3,10 +3,7 @@ from vllm.v1.worker.gpu_input_batch import InputBatch
|
|||||||
|
|
||||||
|
|
||||||
def is_spec_decode_supported(req_id: str, input_batch: InputBatch) -> bool:
|
def is_spec_decode_supported(req_id: str, input_batch: InputBatch) -> bool:
|
||||||
if req_id in input_batch.top_k_reqs or req_id in input_batch.top_p_reqs:
|
if req_id in input_batch.min_p_reqs:
|
||||||
# Spec decode doesn't support top_p/top_k sampling.
|
|
||||||
return False
|
|
||||||
elif req_id in input_batch.min_p_reqs:
|
|
||||||
# Spec decode doesn't support min_p sampling.
|
# Spec decode doesn't support min_p sampling.
|
||||||
return False
|
return False
|
||||||
elif (req_id in input_batch.frequency_penalties_reqs
|
elif (req_id in input_batch.frequency_penalties_reqs
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user