From 99abb8b650c66664cdc84d815b7f306f33bd9881 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 18 Mar 2025 14:31:54 -0700 Subject: [PATCH] [V1][Spec Decode] Optimize Rejection Sampler with Triton Kernels (#14930) Signed-off-by: Woosuk Kwon --- tests/v1/sample/test_rejection_sampler.py | 217 ++++-- vllm/envs.py | 1 - vllm/v1/outputs.py | 2 +- vllm/v1/sample/ops/utils.py | 30 + vllm/v1/sample/rejection_sampler.py | 770 ++++++++++++++-------- vllm/v1/spec_decode/metadata.py | 61 ++ vllm/v1/spec_decode/utils.py | 1 - vllm/v1/worker/gpu_model_runner.py | 201 ++++-- 8 files changed, 875 insertions(+), 408 deletions(-) create mode 100644 vllm/v1/sample/ops/utils.py create mode 100644 vllm/v1/spec_decode/metadata.py diff --git a/tests/v1/sample/test_rejection_sampler.py b/tests/v1/sample/test_rejection_sampler.py index 84139a40b544a..8c423e367ef56 100644 --- a/tests/v1/sample/test_rejection_sampler.py +++ b/tests/v1/sample/test_rejection_sampler.py @@ -6,20 +6,23 @@ import torch import torch.nn.functional as F from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID, RejectionSampler +from vllm.v1.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID, + RejectionSampler) +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata -DEVICE = "cpu" +DEVICE = "cuda" @pytest.fixture -def sampler(): +def rejection_sampler(): return RejectionSampler() -def create_logits_tensor(token_ids: list[list[int]], +def create_logits_tensor(output_token_ids: list[list[int]], vocab_size: int = 100) -> torch.Tensor: """Helper function to create logits tensor that will produce desired token ids on argmax""" + token_ids = [tokens[:-1] for tokens in output_token_ids] num_total_tokens = sum(len(tokens) for tokens in token_ids) logits = torch.full((num_total_tokens, vocab_size), -100.0, device=DEVICE) start_loc = 0 @@ -31,15 +34,22 @@ def create_logits_tensor(token_ids: list[list[int]], def create_sampling_metadata( - all_greedy: bool, - generators: Optional[dict[int, Any]] = None) -> SamplingMetadata: + all_greedy: bool, + temperature: Optional[torch.Tensor] = None, + generators: Optional[dict[int, Any]] = None, +) -> SamplingMetadata: """Create a v1 sampling metadata object with all_greedy set to the given value. Either all greedy or all random sampling is used. """ generators = generators or {} + if all_greedy: + temperature = None + else: + assert temperature is not None + return SamplingMetadata( - temperature=torch.tensor([]), + temperature=temperature, all_greedy=all_greedy, all_random=not all_greedy, top_p=None, @@ -61,7 +71,7 @@ def create_sampling_metadata( ########################### Tests for Greedy Sampling ################### -def test_perfect_match(sampler): +def test_perfect_match(rejection_sampler): """Test when output tokens perfectly match speculated tokens""" spec_tokens = [[1, 2, 3]] output_tokens = [[1, 2, 3, 4]] # 4 is the bonus token @@ -70,15 +80,23 @@ def test_perfect_match(sampler): logits = create_logits_tensor(output_tokens) bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device) + spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, + device=logits.device) - output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata) + output = rejection_sampler( + spec_decode_metadata, + draft_probs=None, + target_logits=logits, + bonus_token_ids=bonus_token_tensor, + sampling_metadata=metadata, + ) expected = torch.tensor([[1, 2, 3, 4]], dtype=torch.int, device=logits.device) assert torch.equal(output, expected) -def test_early_mismatch(sampler): +def test_early_mismatch(rejection_sampler): """Test when there's an early mismatch in tokens""" spec_tokens = [[1, 2, 3]] output_tokens = [[1, 5, 3, 4]] # Mismatch at position 1 @@ -87,15 +105,25 @@ def test_early_mismatch(sampler): logits = create_logits_tensor(output_tokens) bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device) + spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, + device=logits.device) - output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata) - expected = torch.tensor([[1, 5, INVALID_TOKEN_ID, INVALID_TOKEN_ID]], - dtype=torch.int, - device=logits.device) + output = rejection_sampler( + spec_decode_metadata, + draft_probs=None, + target_logits=logits, + bonus_token_ids=bonus_token_tensor, + sampling_metadata=metadata, + ) + expected = torch.tensor( + [[1, 5, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID]], + dtype=torch.int, + device=logits.device, + ) assert torch.equal(output, expected) -def test_multiple_sequences(sampler): +def test_multiple_sequences(rejection_sampler): """Test handling multiple sequences of speculated tokens""" spec_tokens = [[1, 2], [3]] output_tokens = [[1, 2, 5], [3, @@ -105,15 +133,23 @@ def test_multiple_sequences(sampler): logits = create_logits_tensor(output_tokens) bonus_token_tensor = torch.tensor( [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device) + spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, + device=logits.device) - output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata) - expected = torch.tensor([[1, 2, 5], [3, 4, INVALID_TOKEN_ID]], + output = rejection_sampler( + spec_decode_metadata, + draft_probs=None, + target_logits=logits, + bonus_token_ids=bonus_token_tensor, + sampling_metadata=metadata, + ) + expected = torch.tensor([[1, 2, 5], [3, 4, PLACEHOLDER_TOKEN_ID]], dtype=torch.int, device=logits.device) assert torch.equal(output, expected) -def test_single_token_sequence(sampler): +def test_single_token_sequence(rejection_sampler): """Test handling sequences with single token""" spec_tokens = [[1]] output_tokens = [[1, 2]] # Single token with bonus token 2 @@ -122,13 +158,21 @@ def test_single_token_sequence(sampler): logits = create_logits_tensor(output_tokens) bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device) + spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, + device=logits.device) - output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata) + output = rejection_sampler( + spec_decode_metadata, + draft_probs=None, + target_logits=logits, + bonus_token_ids=bonus_token_tensor, + sampling_metadata=metadata, + ) expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device) assert torch.equal(output, expected) -def test_empty_sequence(sampler): +def test_empty_sequence(rejection_sampler): """Test handling empty sequence of speculated tokens""" spec_tokens: list[list[int]] = [[]] output_tokens = [[5]] # Just the bonus token @@ -137,13 +181,21 @@ def test_empty_sequence(sampler): logits = create_logits_tensor(output_tokens) bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device) + spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, + device=logits.device) - output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata) + output = rejection_sampler( + spec_decode_metadata, + draft_probs=None, + target_logits=logits, + bonus_token_ids=bonus_token_tensor, + sampling_metadata=metadata, + ) expected = torch.tensor([[5]], dtype=torch.int, device=logits.device) assert torch.equal(output, expected) -def test_multiple_mismatches(sampler): +def test_multiple_mismatches(rejection_sampler): """Test handling multiple sequences with mismatches""" spec_tokens = [[1, 2, 3], [4, 5, 6]] output_tokens = [[1, 2, 7, 6], [4, 8, 6, @@ -153,12 +205,22 @@ def test_multiple_mismatches(sampler): logits = create_logits_tensor(output_tokens) bonus_token_tensor = torch.tensor( [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device) + spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, + device=logits.device) - output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata) - expected = torch.tensor([[1, 2, 7, INVALID_TOKEN_ID], - [4, 8, INVALID_TOKEN_ID, INVALID_TOKEN_ID]], - dtype=torch.int, - device=logits.device) + output = rejection_sampler( + spec_decode_metadata, + draft_probs=None, + target_logits=logits, + bonus_token_ids=bonus_token_tensor, + sampling_metadata=metadata, + ) + expected = torch.tensor( + [[1, 2, 7, PLACEHOLDER_TOKEN_ID], + [4, 8, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID]], + dtype=torch.int, + device=logits.device, + ) assert torch.equal(output, expected) @@ -166,18 +228,27 @@ def test_multiple_mismatches(sampler): "spec_tokens,output_tokens,expected", [ ([[1, 2]], [[1, 2, 3]], [[1, 2, 3]]), # Perfect match with bonus - ([[1]], [[2, 3]], [[2, INVALID_TOKEN_ID]]), # First mismatch + ([[1]], [[2, 3]], [[2, PLACEHOLDER_TOKEN_ID]]), # First mismatch ([[1, 2], [3, 4]], [[1, 5, 6], [3, 4, 7]], - [[1, 5, INVALID_TOKEN_ID], [3, 4, 7]]), # Mixed matches + [[1, 5, PLACEHOLDER_TOKEN_ID], [3, 4, 7]]), # Mixed matches ]) -def test_parametrized_cases(sampler, spec_tokens, output_tokens, expected): +def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens, + expected): """Parametrized test for various matching scenarios""" metadata = create_sampling_metadata(all_greedy=True) logits = create_logits_tensor(output_tokens) bonus_token_tensor = torch.tensor([tokens[-1] for tokens in output_tokens], device=logits.device) + spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, + device=logits.device) - output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata) + output = rejection_sampler( + spec_decode_metadata, + draft_probs=None, + target_logits=logits, + bonus_token_ids=bonus_token_tensor, + sampling_metadata=metadata, + ) expected_tensor = torch.tensor(expected, dtype=torch.int, device=logits.device) @@ -190,21 +261,31 @@ def test_parametrized_cases(sampler, spec_tokens, output_tokens, expected): @pytest.mark.parametrize("batch_size", [1, 4, 8]) @pytest.mark.parametrize("frac_seeded", [0.0, 0.5]) @pytest.mark.parametrize("n_rep", [20]) -def test_deterministic_when_seeded(sampler, k: int, vocab_size: int, - batch_size: int, frac_seeded: float, - n_rep: int): - draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) - target_probs = torch.rand(batch_size * (k + 1), - vocab_size, - dtype=torch.float32) +def test_deterministic_when_seeded( + rejection_sampler, + k: int, + vocab_size: int, + batch_size: int, + frac_seeded: float, + n_rep: int, +): + num_tokens = batch_size * k + draft_probs = torch.rand(num_tokens, + vocab_size, + dtype=torch.float32, + device=DEVICE) + draft_probs = F.softmax(draft_probs, dim=-1) + target_logits = torch.rand_like(draft_probs) bonus_token_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, 1), - dtype=torch.int64) + dtype=torch.int64, + device=DEVICE) draft_token_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, k), - dtype=torch.int64) + dtype=torch.int64, + device=DEVICE) seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= frac_seeded @@ -215,10 +296,21 @@ def test_deterministic_when_seeded(sampler, k: int, vocab_size: int, for i in range(batch_size) if seeded_mask[i] } + temperature = torch.ones(batch_size, + dtype=torch.float32, + device=DEVICE) sampling_metadata = create_sampling_metadata(all_greedy=False, + temperature=temperature, generators=seeded_seqs) - rep_result = sampler(draft_token_ids.tolist(), draft_probs, - bonus_token_ids, target_probs, sampling_metadata) + spec_decode_metadata = SpecDecodeMetadata.make_dummy( + draft_token_ids.tolist(), device=DEVICE) + rep_result = rejection_sampler( + spec_decode_metadata, + draft_probs=draft_probs, + target_logits=target_logits, + bonus_token_ids=bonus_token_ids, + sampling_metadata=sampling_metadata, + ) results.append(rep_result) @@ -257,10 +349,10 @@ def test_rejection_sampling_approximates_target_distribution(): num_reference_probs = 100 # Prepare draft, target, and reference probability distributions - draft_probs, target_probs = (F.softmax( - torch.rand(vocab_size, dtype=torch.float32), - dim=-1, - ) for _ in range(2)) + draft_probs = F.softmax(torch.rand(vocab_size, dtype=torch.float32), + dim=-1) + target_logits = torch.rand(vocab_size, dtype=torch.float32) + target_probs = F.softmax(target_logits, dim=-1) reference_probs = F.softmax( torch.rand(num_reference_probs, vocab_size, dtype=torch.float32), dim=-1, @@ -273,7 +365,7 @@ def test_rejection_sampling_approximates_target_distribution(): for num_samples in sample_sizes: # Sample using rejection sampling. rej_sample_probs = estimate_rejection_sampling_pdf( - draft_probs, target_probs, k, vocab_size, num_samples) + draft_probs, target_logits, k, vocab_size, num_samples) rej_sample_probs = rej_sample_probs.to(DEVICE) # Average distance from reference probs. @@ -313,7 +405,7 @@ def get_ratio_first_to_last(elements: list[float]) -> float: def estimate_rejection_sampling_pdf( draft_probs: torch.Tensor, - target_probs: torch.Tensor, + target_logits: torch.Tensor, k: int, vocab_size: int, num_samples: int, @@ -323,35 +415,44 @@ def estimate_rejection_sampling_pdf( Args: draft_probs: Draft probability distribution. - target_probs: Target probability distribution. + target_logits: Target logits. num_samples: Number of samples to draw. Returns: Estimated probability distribution of the output tokens. """ - sampler = RejectionSampler() - # Repeat draft probs num_samples times. + rejection_sampler = RejectionSampler() + num_tokens = num_samples * k + # Repeat draft probs num_samples * k times. draft_probs = draft_probs.reshape(1, 1, vocab_size).repeat(num_samples, k, 1) - # Repeat target probs num_samples * (k + 1) times. - target_probs = target_probs.reshape(1, 1, vocab_size).repeat( - num_samples, k + 1, 1).reshape(num_samples * (k + 1), vocab_size) + # Repeat target probs num_tokens times. + target_logits = target_logits.reshape(1, vocab_size).repeat(num_tokens, 1) # Randomly sample draft token ids from draft probs. draft_token_ids = torch.multinomial(draft_probs[:, 0, :], num_samples=k, replacement=True).reshape( num_samples, k) + draft_probs = draft_probs.view(num_tokens, vocab_size) # Bonus tokens not used but required. bonus_token_ids = torch.zeros((1, 1), dtype=torch.int64, device=DEVICE).repeat(num_samples, 1) - sampling_metadata = create_sampling_metadata(all_greedy=False) - output_token_ids = sampler(draft_token_ids.tolist(), draft_probs, - bonus_token_ids, target_probs, - sampling_metadata) + temperature = torch.ones(num_samples, dtype=torch.float32, device=DEVICE) + sampling_metadata = create_sampling_metadata(all_greedy=False, + temperature=temperature) + spec_decode_metadata = SpecDecodeMetadata.make_dummy( + draft_token_ids.tolist(), device=bonus_token_ids.device) + output_token_ids = rejection_sampler( + spec_decode_metadata, + draft_probs=draft_probs, + target_logits=target_logits, + bonus_token_ids=bonus_token_ids, + sampling_metadata=sampling_metadata, + ) output_token_ids = output_token_ids[:, :-1].flatten() hist = torch.histogram(output_token_ids.to(dtype=torch.float, diff --git a/vllm/envs.py b/vllm/envs.py index bf214f314c458..b2937462ad36a 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -35,7 +35,6 @@ if TYPE_CHECKING: VLLM_TRACE_FUNCTION: int = 0 VLLM_ATTENTION_BACKEND: Optional[str] = None VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None - VLLM_USE_FLASHINFER_REJECTION_SAMPLER: bool = False VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False VLLM_PP_LAYER_PARTITION: Optional[str] = None VLLM_CPU_KVCACHE_SPACE: int = 0 diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index edae654b5d339..6f46417170f6e 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -46,7 +46,7 @@ class SamplerOutput: # [num_reqs, max_num_generated_tokens] # Different requests can have different number of generated tokens. # All requests are padded to max_num_generated_tokens. - # INVALID_TOKEN_ID (-1 by default) is used for padding. + # PLACEHOLDER_TOKEN_ID (-1 by default) is used for padding. sampled_token_ids: torch.Tensor logprobs_tensors: Optional[LogprobsTensors] diff --git a/vllm/v1/sample/ops/utils.py b/vllm/v1/sample/ops/utils.py new file mode 100644 index 0000000000000..a54e20603064f --- /dev/null +++ b/vllm/v1/sample/ops/utils.py @@ -0,0 +1,30 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Union + +import torch + + +def compiled_softmax( + logits: torch.Tensor, + temperature: Union[float, torch.Tensor] = 1.0, +) -> torch.Tensor: + """Faster softmax kernel generated by torch.compile. + + Args: + logits: [n, vocab_size] + temperature: [n] or float + """ + # NOTE(woosuk): Avoid recompilation by marking the first dim as dynamic. + torch._dynamo.mark_dynamic(logits, index=0) + if isinstance(temperature, torch.Tensor): + torch._dynamo.mark_dynamic(temperature, index=0) + return _softmax(logits, temperature) + + +@torch.compile +def _softmax( + logits: torch.Tensor, + temperature: Union[float, torch.Tensor], +) -> torch.Tensor: + logits = logits / temperature + return torch.softmax(logits, dim=-1, dtype=torch.float32) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 5601c62e91fc0..6284ae4b490a6 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -3,25 +3,32 @@ from typing import Optional import torch import torch.nn as nn -from torch.nn.utils.rnn import pad_sequence +import triton +import triton.language as tl from vllm.logger import init_logger from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.spec_decode.utils import random_sample +from vllm.v1.sample.ops.utils import compiled_softmax +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata logger = init_logger(__name__) -INVALID_TOKEN_ID = -1 + +PLACEHOLDER_TOKEN_ID: tl.constexpr = -1 +GREEDY_TEMPERATURE: tl.constexpr = -1 +# Maximum number of speculative draft tokens allowed per request in a single +# step. This value is chosen to be large enough to handle typical use cases. +MAX_SPEC_LEN = 32 class RejectionSampler(nn.Module): """ - The implementation strictly follows the algorithm described in + The implementation strictly follows the algorithm described in https://arxiv.org/abs/2211.17192. However, we want to clarify the terminology used in the implementation: - accepted tokens: tokens that are accepted based on the relationship + accepted tokens: tokens that are accepted based on the relationship between the "raw" draft and target probabilities. recovered tokens: tokens that are sampled based on the adjusted probability - distribution, which is derived from both the draft and target + distribution, which is derived from both the draft and target probabilities. bonus tokens: If all proposed tokens are accepted, the bonus token is added to the @@ -31,48 +38,42 @@ class RejectionSampler(nn.Module): sampling process. For example, we can use top_p, top_k sampling for bonus tokens, while spec decode does not support these sampling strategies. - output tokens: - Tokens are finally generated with the rejection sampler. + output tokens: + Tokens are finally generated with the rejection sampler. output tokens = accepted tokens + recovered tokens + bonus tokens """ - def __init__(self): - super().__init__() - def forward( self, - draft_token_ids: list[list[int]], + metadata: SpecDecodeMetadata, + # [num_tokens, vocab_size] draft_probs: Optional[torch.Tensor], - bonus_token_ids_tensor: torch.Tensor, # [batch_size, 1] - target_probs: torch.Tensor, # [num_total_tokens, vocab_size] + # [num_tokens, vocab_size] + target_logits: torch.Tensor, + # [batch_size, 1] + bonus_token_ids: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: ''' Args: - draft_token_ids (List[List[int]]): - A 2D list of token IDs for each request in the batch. - Each request might have different number of draft tokens. - It may also contain empty lists for requests that have - no draft tokens. + metadata: + Metadata for spec decoding. draft_probs (Optional[torch.Tensor]): Probability distribution for the draft tokens. Shape is - [batch_size, max_spec_len, vocab_size]. Can be None if - probabilities are not provided, which is the case for - ngram spec decode. + [num_tokens, vocab_size]. Can be None if probabilities are + not provided, which is the case for ngram spec decode. + target_logits (torch.Tensor): + Target model's logits probability distribution. + Shape is [num_tokens, vocab_size]. Here, probabilities from + different requests are flattened into a single tensor because + this is the shape of the output logits. bonus_token_ids_tensor (torch.Tensor): - A tensor containing bonus tokens. Shape is [batch_size, 1]. - Bonus tokens are added to the end of the sequence if all - proposed tokens are accepted. We generate the bonus tokens - outside of the rejection sampler with the default sampling - strategy. It allows for more flexibility in the sampling + A tensor containing bonus tokens. Shape is [batch_size, 1]. + Bonus tokens are added to the end of the sequence if all + proposed tokens are accepted. We generate the bonus tokens + outside of the rejection sampler with the default sampling + strategy. It allows for more flexibility in the sampling process such as top_p, top_k sampling. - target_probs (torch.Tensor): - Target model probability distribution. - Shape is [num_total_tokens, vocab_size]. num_total_tokens - is the total number of tokens from all requests. Here, - probabilities from different requests are flattened into - a single tensor because this is the shape of the output - logits. sampling_metadata (SamplingMetadata): Additional metadata needed for sampling, such as temperature, top-k/top-p parameters, or other relevant information. @@ -80,268 +81,481 @@ class RejectionSampler(nn.Module): output_token_ids (torch.Tensor): A tensor containing the final output token IDs. ''' + assert metadata.max_spec_len <= MAX_SPEC_LEN + # [num_tokens, vocab_size] + target_probs = compute_probs( + target_logits, + metadata.cu_num_draft_tokens, + sampling_metadata, + ) - # NOTE: The following input preparationg can be moved - # to the model runner with a persistent manner for better - # performance. - # Convert draft token IDs to a tensor, split by sample_lens, then pad. - draft_token_ids = [ - torch.tensor(x, dtype=int, device='cpu') for x in draft_token_ids - ] - draft_token_ids_tensor = pad_sequence(draft_token_ids, - batch_first=True, - padding_value=INVALID_TOKEN_ID) - - # NOTE: CPU <-> GPU synchronization happens here. - draft_token_ids_tensor = draft_token_ids_tensor.to(target_probs.device) - - # Create one-hot tensor for draft token ids. - # This is used for ngram where we don't have draft_probs. - if draft_probs is None and not sampling_metadata.all_greedy: - vocab_size = target_probs.size(-1) - draft_probs = _create_greedy_token_probs(draft_token_ids_tensor, - vocab_size, - target_probs.device) - sample_lens = [len(x) + 1 for x in draft_token_ids] - target_probs = _convert_2d_probs(target_probs, sample_lens) - - return self.forward_native(draft_token_ids_tensor, draft_probs, - bonus_token_ids_tensor, target_probs, - sampling_metadata) - - # TODO: The following method can be optimized for better performance. - def forward_native( - self, - draft_token_ids_tensor: torch.Tensor, - # [batch_size, max_spec_len, vocab_size] - draft_probs: Optional[torch.Tensor], - bonus_token_ids_tensor: torch.Tensor, - # [batch_size, max_spec_len + 1, vocab_size] - target_probs: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> torch.Tensor: - # Add 1 to include the 'bonus' token. - if sampling_metadata.all_greedy: - # Produce a mask that remains 1 (True) until the first - # mismatch (cumprod turns 0 after a mismatch). - target_token_ids_tensor = target_probs.argmax(dim=-1) - accept_mask = (target_token_ids_tensor[:, :-1] == - draft_token_ids_tensor).cumprod(dim=1) - - # Identify valid positions (non-padding). - valid_mask = target_token_ids_tensor != INVALID_TOKEN_ID - # Generate mask with bonus token. - generate_mask = torch.cat([ - accept_mask, - torch.zeros(accept_mask.size(0), 1, device=accept_mask.device) - ], - dim=1).to(torch.bool) & valid_mask - zeros_mask = (generate_mask == 0) - first_zero_idx = zeros_mask.float().argmax(dim=1) - # Figure out which rows actually contain at least one zero. - rows_with_zero = zeros_mask.any(dim=1) - # Use indexing to set the first zero in each of those rows to 1. - generate_mask[rows_with_zero, first_zero_idx[rows_with_zero]] = 1 - - output_token_ids = target_token_ids_tensor - output_token_ids[~generate_mask] = INVALID_TOKEN_ID - else: - # Reference: https://arxiv.org/pdf/2211.17192 - # 1. Extract the probabilities of the draft tokens. - # [batch_size, max_spec_len] - batch_size = draft_token_ids_tensor.size(0) - max_spec_len = draft_token_ids_tensor.size(1) - invalid_idx = draft_token_ids_tensor == INVALID_TOKEN_ID - draft_token_ids_tensor[invalid_idx] = 0 - assert draft_probs is not None - draft_token_probs = draft_probs.gather( - dim=-1, index=draft_token_ids_tensor.unsqueeze(-1)).squeeze(-1) - target_token_probs = target_probs.gather( - dim=-1, index=draft_token_ids_tensor.unsqueeze(-1)).squeeze(-1) - # Force the probabilities of invalid tokens to inf - # so that they are not accepted. - draft_token_probs[invalid_idx] = float('inf') - - # 2. Generate uniform samples. - # [batch_size, max_spec_len + 1] - uniform_samples = _create_uniform_samples( - sampling_metadata.generators, batch_size, max_spec_len, - target_probs.device) - - # 3. Accept or reject the samples. - # [batch_size, max_spec_len] - # If the draft token probabilities are 0, set them to the smallest - # positive normal value representable by float32. - safe_draft_probs = torch.where(draft_token_probs > 0, - draft_token_probs, - torch.finfo(torch.float32).tiny) - accepted = uniform_samples <= target_token_probs / safe_draft_probs - accept_mask = accepted.cumprod(dim=1) - # Set the token ids to the draft token ids if accepted, otherwise - # set them to INVALID_TOKEN_ID. - accepted_token_ids = (draft_token_ids_tensor * accept_mask + - INVALID_TOKEN_ID * (1 - accept_mask)) - - # 4. Adjust the distribution for the recovered tokens. - # Clamp the bonus probabilities to the smallest positive normal - # value representable by float32. - bonus_prob = torch.clamp(target_probs[:, :-1, :] - draft_probs, - min=torch.finfo(torch.float32).tiny) - normalized_bonus_prob = bonus_prob / bonus_prob.sum(dim=-1, - keepdim=True) - - # 5. Sample recovered token ids. - recovered_token_ids = random_sample( - normalized_bonus_prob, - sampling_metadata.generators).reshape(batch_size, max_spec_len) - - # 6. Get the final output token ids. - # output_token_ids = accepted_token_ids + - # recovered_token_ids + - # bonus_token_id - recovered_bonus_token_ids = torch.cat( - [recovered_token_ids, bonus_token_ids_tensor], dim=1) - # Generate mask with bonus tokens. - generate_mask = torch.cat([ - accept_mask, - torch.zeros(batch_size, 1, device=accept_mask.device) - ], - dim=1).to(torch.bool) - zeros_mask = (generate_mask == 0) - first_zero_idx = zeros_mask.float().argmax(dim=1) - output_token_ids = torch.cat([ - accepted_token_ids, - torch.full((batch_size, 1), - fill_value=INVALID_TOKEN_ID, - device=accept_mask.device) - ], - dim=1) - output_token_ids[torch.arange(batch_size), - first_zero_idx] = recovered_bonus_token_ids[ - torch.arange(batch_size), first_zero_idx] - + output_token_ids = rejection_sample( + metadata.draft_token_ids, + metadata.num_draft_tokens, + metadata.max_spec_len, + metadata.cu_num_draft_tokens, + draft_probs, + target_probs, + bonus_token_ids, + sampling_metadata, + ) return output_token_ids - def compute_probs(self, logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - sample_lens: list[int]) -> torch.Tensor: - """ - Compute probability distribution from logits based on sampling metadata. - - This function applies temperature scaling to the logits and converts - them to probabilities using softmax. Note that division by - temperature is not performed inplace to preserve the original logits - tensor, which will be used by the original sampler to get bonus tokens. - - Args: - logits: Input logits tensor to be converted to probabilities - sampling_metadata: Metadata containing sampling parameters such - as temperature and whether greedy sampling is used - sample_lens: List of sample lengths used for repeating - temperature values - - Returns: - torch.Tensor: Probability distribution (softmax of scaled logits) - if non-greedy sampling is used, otherwise returns the - original logits - """ - if sampling_metadata.all_greedy: - return logits - assert sampling_metadata.temperature is not None - # We should optimize the following code as - # it will cause CPU -> GPU synchronization. - temperature = torch.repeat_interleave( - sampling_metadata.temperature, - torch.tensor(sample_lens, - device=sampling_metadata.temperature.device)) - temperature = temperature.unsqueeze(dim=1) - logits = logits / temperature - return logits.softmax(dim=-1, dtype=torch.float32) + @staticmethod + def parse_output( + output_token_ids: torch.Tensor, + vocab_size: int, + ) -> list[list[int]]: + output_token_ids_np = output_token_ids.cpu().numpy() + # Create mask for valid tokens. + valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) & + (output_token_ids_np < vocab_size)) + outputs = [ + row[valid_mask[i]].tolist() + for i, row in enumerate(output_token_ids_np) + ] + return outputs -def _create_greedy_token_probs( - token_ids: torch.Tensor, - vocab_size: int, - out_device: torch.device, +def rejection_sample( + # [num_tokens] + draft_token_ids: torch.Tensor, + # [batch_size] + num_draft_tokens: list[int], + max_spec_len: int, + # [batch_size] + cu_num_draft_tokens: torch.Tensor, + # [num_tokens, vocab_size] + draft_probs: Optional[torch.Tensor], + # [num_tokens, vocab_size] + target_probs: torch.Tensor, + # [batch_size, 1] + bonus_token_ids: torch.Tensor, + sampling_metadata: SamplingMetadata, ) -> torch.Tensor: - batch_size, num_tokens = token_ids.shape + assert draft_token_ids.ndim == 1 + assert draft_probs is None or draft_probs.ndim == 2 + assert cu_num_draft_tokens.ndim == 1 + assert target_probs.ndim == 2 - token_probs = torch.zeros(batch_size, - num_tokens, - vocab_size, - dtype=torch.float, - device=out_device) + batch_size = len(num_draft_tokens) + num_tokens = draft_token_ids.shape[0] + vocab_size = target_probs.shape[-1] + device = target_probs.device + assert draft_token_ids.is_contiguous() + assert draft_probs is None or draft_probs.is_contiguous() + assert target_probs.is_contiguous() + assert bonus_token_ids.is_contiguous() + assert target_probs.shape == (num_tokens, vocab_size) - # Ignore INVALID_TOKEN_ID. - valid_mask = (token_ids != INVALID_TOKEN_ID) - valid_indices = token_ids.clone() - valid_indices[~valid_mask] = 0 + # Create output buffer. + output_token_ids = torch.empty( + (batch_size, max_spec_len + 1), + dtype=torch.int32, # Consistent with SamplerOutput.sampled_token_ids. + device=device, + ) + output_token_ids.fill_(PLACEHOLDER_TOKEN_ID) - token_probs.scatter_(dim=2, - index=valid_indices.unsqueeze(-1), - src=valid_mask.unsqueeze(-1).float()) + if sampling_metadata.all_greedy: + is_greedy = None + else: + is_greedy = sampling_metadata.temperature == GREEDY_TEMPERATURE + if not sampling_metadata.all_random: + # Rejection sampling for greedy sampling requests. + target_argmax = target_probs.argmax(dim=-1) + rejection_greedy_sample_kernel[(batch_size, )]( + output_token_ids, + cu_num_draft_tokens, + draft_token_ids, + target_argmax, + bonus_token_ids, + is_greedy, + max_spec_len, + num_warps=1, + ) + if sampling_metadata.all_greedy: + return output_token_ids - return token_probs + # Generate uniform probabilities for rejection sampling. + # [num_tokens] + uniform_probs = generate_uniform_probs( + num_tokens, + num_draft_tokens, + sampling_metadata.generators, + device, + ) + + # Sample recovered tokens for each position. + # [num_tokens] + recovered_token_ids = sample_recovered_tokens( + max_spec_len, + num_draft_tokens, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + sampling_metadata, + device, + ) + + # Rejection sampling for random sampling requests. + rejection_random_sample_kernel[(batch_size, )]( + output_token_ids, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + bonus_token_ids, + recovered_token_ids, + uniform_probs, + is_greedy, + max_spec_len, + vocab_size, + IS_NGRAM=draft_probs is None, + num_warps=1, + ) + return output_token_ids -def _convert_2d_probs( - probs: torch.Tensor, # [num_total_tokens, vocab_size] - sample_lens: list[int]) -> torch.Tensor: +def compute_probs( + logits: torch.Tensor, # [num_tokens, vocab_size] + cu_num_draft_tokens: torch.Tensor, # [batch_size] + sampling_metadata: SamplingMetadata, +) -> torch.Tensor: + """Compute probability distribution from logits based on sampling metadata. + + This function applies temperature scaling to the logits and converts + them to probabilities using softmax. For greedy decoding, it returns + the original logits. + + Args: + logits: Input logits tensor to be converted to probabilities. + cu_num_draft_tokens: Cumulative number of draft tokens. + sampling_metadata: Metadata containing sampling parameters such as + temperature and whether greedy sampling is used. + + Returns: + torch.Tensor: Probability distribution (softmax of scaled logits) + if non-greedy sampling is used, otherwise returns the + original logits. """ - Converts a 2D tensor of probabilities to a 3D tensor with padding. - [num_total_tokens, vocab_size] -> - [batch_size, max_spec_len + 1, vocab_size] + assert logits.ndim == 2 + assert cu_num_draft_tokens.ndim == 1 + if sampling_metadata.all_greedy: + return logits + + num_tokens = logits.shape[0] + batch_size = cu_num_draft_tokens.shape[0] + expanded_temperature = torch.empty( + (num_tokens, 1), + dtype=torch.float32, + device=logits.device, + ) + expand_kernel[(batch_size, )]( + expanded_temperature, + sampling_metadata.temperature, + cu_num_draft_tokens, + GREEDY_TEMPERATURE, # replace_from + 1, # replace_to + MAX_NUM_TOKENS=MAX_SPEC_LEN, + num_warps=1, + ) + output_prob = compiled_softmax(logits, expanded_temperature) + return output_prob + + +def generate_uniform_probs( + num_tokens: int, + num_draft_tokens: list[int], + generators: dict[int, torch.Generator], + device: torch.device, +) -> torch.Tensor: """ - cumulative_lens = torch.cumsum(torch.tensor(sample_lens, - device=probs.device), - dim=0) - split_indices = cumulative_lens[:-1].tolist() # Exclude last index + Generates a batch of uniform random samples, with optional seeding + if available. - # Split into chunks without loops - chunks = torch.tensor_split(probs, split_indices, dim=0) + This method creates a tensor of shape `(num_tokens, )` filled + with uniform random values in the range [0, 1). If `generators` is provided, + the requests with their own seeds will use the provided `torch.Generator` + for reproducibility. The samples for the other requests will be generated + without a seed. - # Pad all sequences to maximum length - padded_probs = pad_sequence(chunks, batch_first=True, padding_value=0.0) - return padded_probs - - -def _create_uniform_samples(seeded_seqs: dict[int, torch.Generator], - batch_size: int, k: int, - device: torch.device) -> torch.Tensor: + Args: + num_tokens : int + Total number of tokens. + num_draft_tokens : List[List[int]] + Number of draft tokens per request. + generators : Optional[Dict[int, torch.Generator]] + A dictionary mapping indices in the batch to + `torch.Generator` objects. + device : torch.device + The device on which to allocate the tensor. + Returns: + uniform_rand : torch.Tensor + A tensor of shape `(num_tokens, )` containing uniform + random values in the range [0, 1). """ - Generates a batch of uniform random samples, with optional seeding - for specific sequences. + uniform_probs = torch.rand( + (num_tokens, ), + dtype=torch.float32, + device=device, + ) + start_idx = 0 + for req_idx, n in enumerate(num_draft_tokens): + # Do not generate random numbers for requests with no draft tokens. + # This can be important for reproducibility. + if n == 0: + continue + end_idx = start_idx + n + generator = generators.get(req_idx) + if generator is not None: + uniform_probs[start_idx:end_idx].uniform_(generator=generator) + start_idx = end_idx + return uniform_probs - This method creates a tensor of shape `(batch_size, k)` filled - with uniform random values in the range [0, 1). If `seeded_seqs` - is provided, the sequences corresponding to specific indices - will be generated using the provided `torch.Generator` for - reproducibility. The other sequences will be generated without - a seed. - Args: - seeded_seqs : Optional[Dict[int, torch.Generator]] - A dictionary mapping indices in the batch to - `torch.Generator` objects. - batch_size : int - The number of sequences to generate. - k : int - The number of random samples per sequence. - device : torch.device - The device on which to allocate the tensor. +def sample_recovered_tokens( + max_spec_len: int, + num_draft_tokens: list[int], + # [batch_size] + cu_num_draft_tokens: torch.Tensor, + # [num_tokens] + draft_token_ids: torch.Tensor, + # [num_tokens, vocab_size] + draft_probs: Optional[torch.Tensor], + # [num_tokens, vocab_size] + target_probs: torch.Tensor, + sampling_metadata: SamplingMetadata, + device: torch.device, +) -> torch.Tensor: + # NOTE(woosuk): Create only one distribution for each request. + batch_size = len(num_draft_tokens) + vocab_size = target_probs.shape[-1] + q = torch.empty( + (batch_size, vocab_size), + dtype=torch.float32, + device=device, + ) + q.exponential_() + for i, generator in sampling_metadata.generators.items(): + # Do not generate random numbers for requests with no draft tokens. + # This can be important for reproducibility. + if num_draft_tokens[i] > 0: + q[i].exponential_(generator=generator) - Returns: - uniform_rand : torch.Tensor - A tensor of shape `(batch_size, k)` containing uniform - random values in the range [0, 1). - """ + recovered_token_ids = torch.empty_like(draft_token_ids) + sample_recovered_tokens_kernel[(batch_size, max_spec_len)]( + recovered_token_ids, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + q, + vocab_size, + triton.next_power_of_2(vocab_size), + IS_NGRAM=draft_probs is None, + ) + return recovered_token_ids - uniform_rand = torch.rand(batch_size, - k, - dtype=torch.float32, - device=device) - # Apply seeded generators only where needed - if seeded_seqs: - for idx, generator in seeded_seqs.items(): - uniform_rand[idx].uniform_(0, 1, generator=generator) - return uniform_rand + +# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation. +@triton.jit(do_not_specialize=["max_spec_len"]) +def rejection_greedy_sample_kernel( + output_token_ids_ptr, # [batch_size, max_spec_len + 1] + cu_num_draft_tokens_ptr, # [batch_size] + draft_token_ids_ptr, # [num_tokens] + target_argmax_ptr, # [num_tokens] + bonus_token_ids_ptr, # [batch_size] + is_greedy_ptr, # [batch_size] or None + max_spec_len, +): + req_idx = tl.program_id(0) + # FIXME(woosuk): Because is_greedy_ptr is not None at profiling run, + # re-compilation may happen during runtime when is_greedy_ptr is None. + if is_greedy_ptr is None: + is_greedy = True + else: + is_greedy = tl.load(is_greedy_ptr + req_idx) + if not is_greedy: + # Early exit for non-greedy sampling requests. + return + + if req_idx == 0: + start_idx = 0 + else: + start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1) + end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) + num_draft_tokens = end_idx - start_idx + + rejected = False + for pos in range(num_draft_tokens): + if not rejected: + draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) + target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos) + tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, + target_argmax_id) + if draft_token_id != target_argmax_id: + # Reject. + rejected = True + + if not rejected: + # If all tokens are accepted, append the bonus token. + bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx) + tl.store( + output_token_ids_ptr + req_idx * (max_spec_len + 1) + + num_draft_tokens, bonus_token_id) + + +# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation. +@triton.jit(do_not_specialize=["max_spec_len"]) +def rejection_random_sample_kernel( + output_token_ids_ptr, # [batch_size, max_spec_len + 1] + cu_num_draft_tokens_ptr, # [batch_size] + draft_token_ids_ptr, # [num_tokens] + draft_probs_ptr, # [num_tokens, vocab_size] or None + target_probs_ptr, # [num_tokens, vocab_size] + bonus_token_ids_ptr, # [batch_size] + recovered_token_ids_ptr, # [num_tokens] + uniform_probs_ptr, # [num_tokens] + is_greedy_ptr, # [batch_size] + max_spec_len, + vocab_size, + IS_NGRAM: tl.constexpr, +): + req_idx = tl.program_id(0) + is_greedy = tl.load(is_greedy_ptr + req_idx) + if is_greedy: + # Early exit for greedy sampling requests. + return + + if req_idx == 0: + start_idx = 0 + else: + start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1) + end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) + num_draft_tokens = end_idx - start_idx + + rejected = False + for pos in range(num_draft_tokens): + if not rejected: + draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) + if IS_NGRAM: + draft_prob = 1 + else: + draft_prob = tl.load(draft_probs_ptr + + (start_idx + pos) * vocab_size + + draft_token_id) + target_prob = tl.load(target_probs_ptr + + (start_idx + pos) * vocab_size + + draft_token_id) + uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos) + # NOTE(woosuk): While the draft probability should never be 0, + # we check it to avoid NaNs. If it happens to be 0, we reject. + if draft_prob > 0 and target_prob / draft_prob >= uniform_prob: + # Accept. + token_id = draft_token_id + else: + # Reject. Use recovered token. + rejected = True + token_id = tl.load(recovered_token_ids_ptr + start_idx + pos) + tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, + token_id) + + if not rejected: + # If all tokens are accepted, append the bonus token. + bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx) + tl.store( + output_token_ids_ptr + req_idx * (max_spec_len + 1) + + num_draft_tokens, bonus_token_id) + + +# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation. +@triton.jit(do_not_specialize=["replace_from", "replace_to"]) +def expand_kernel( + output_ptr, # [num_tokens] + input_ptr, # [batch_size] + cu_num_tokens_ptr, # [batch_size] + replace_from, + replace_to, + MAX_NUM_TOKENS: tl.constexpr, +): + req_idx = tl.program_id(0) + if req_idx == 0: # noqa: SIM108 + start_idx = 0 + else: + start_idx = tl.load(cu_num_tokens_ptr + req_idx - 1) + end_idx = tl.load(cu_num_tokens_ptr + req_idx) + num_tokens = end_idx - start_idx + + src_val = tl.load(input_ptr + req_idx) + src_val = tl.where(src_val == replace_from, replace_to, src_val) + offset = tl.arange(0, MAX_NUM_TOKENS) + tl.store(output_ptr + start_idx + offset, + src_val, + mask=offset < num_tokens) + + +@triton.jit +def sample_recovered_tokens_kernel( + output_token_ids_ptr, # [num_tokens] + cu_num_draft_tokens_ptr, # [batch_size] + draft_token_ids_ptr, # [num_tokens] + draft_probs_ptr, # [num_tokens, vocab_size] or None + target_probs_ptr, # [num_tokens, vocab_size] + q_ptr, # [batch_size, vocab_size] + vocab_size, + PADDED_VOCAB_SIZE: tl.constexpr, + IS_NGRAM: tl.constexpr, +): + req_idx = tl.program_id(0) + if req_idx == 0: + start_idx = 0 + else: + start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1) + end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) + num_draft_tokens = end_idx - start_idx + + # Early exit for out-of-range positions. + pos = tl.program_id(1) + if pos >= num_draft_tokens: + return + + vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE) + if IS_NGRAM: + draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) + orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + + draft_token_id) + # Temporarily zero out the probability of the draft token. + # This is essentially the same as target_prob - draft_prob, except that + # n-gram does not have draft_prob. We regard it as 1. + tl.store( + target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, + 0) + prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + + vocab_offset, + mask=vocab_offset < vocab_size, + other=0) + else: + draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size + + vocab_offset, + mask=vocab_offset < vocab_size, + other=0) + target_prob = tl.load(target_probs_ptr + + (start_idx + pos) * vocab_size + vocab_offset, + mask=vocab_offset < vocab_size, + other=0) + prob = tl.maximum(target_prob - draft_prob, 0) + # NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because + # `tl.argmax` will select the maximum value. + + q = tl.load(q_ptr + req_idx * vocab_size + vocab_offset, + mask=vocab_offset < vocab_size, + other=float("-inf")) + recovered_id = tl.argmax(prob / q, axis=-1) + tl.store(output_token_ids_ptr + start_idx + pos, recovered_id) + + if IS_NGRAM: + # Restore the original probability. + tl.store( + target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, + orig_prob) diff --git a/vllm/v1/spec_decode/metadata.py b/vllm/v1/spec_decode/metadata.py new file mode 100644 index 0000000000000..1cf650d5fa569 --- /dev/null +++ b/vllm/v1/spec_decode/metadata.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass + +import numpy as np +import torch + + +@dataclass +class SpecDecodeMetadata: + + # [num_tokens] + draft_token_ids: torch.Tensor + # [batch_size] + num_draft_tokens: list[int] + # [batch_size] + cu_num_draft_tokens: torch.Tensor + # [num_tokens] + target_logits_indices: torch.Tensor + # [batch_size] + bonus_logits_indices: torch.Tensor + # [num_tokens + batch_size] + logits_indices: torch.Tensor + + def __post_init__(self): + self.max_spec_len = max(self.num_draft_tokens) + + @classmethod + def make_dummy( + cls, + draft_token_ids: list[list[int]], + device: torch.device, + ) -> "SpecDecodeMetadata": + batch_size = len(draft_token_ids) + num_draft_tokens = [len(ids) for ids in draft_token_ids] + flattened_draft_token_ids = sum(draft_token_ids, []) + num_tokens = len(flattened_draft_token_ids) + + draft_token_ids_tensor = torch.tensor(flattened_draft_token_ids, + dtype=torch.int32, + device=device) + cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32) + cu_num_draft_tokens_tensor = torch.from_numpy(cu_num_draft_tokens).to( + device) + + target_logits_indices = torch.zeros(num_tokens, + dtype=torch.int32, + device=device) + bonus_logits_indices = torch.zeros(batch_size, + dtype=torch.int32, + device=device) + logits_indices = torch.zeros(num_tokens + batch_size, + dtype=torch.int32, + device=device) + return cls( + draft_token_ids=draft_token_ids_tensor, + num_draft_tokens=num_draft_tokens, + cu_num_draft_tokens=cu_num_draft_tokens_tensor, + target_logits_indices=target_logits_indices, + bonus_logits_indices=bonus_logits_indices, + logits_indices=logits_indices, + ) diff --git a/vllm/v1/spec_decode/utils.py b/vllm/v1/spec_decode/utils.py index 5841401367788..d5329ef7b5abf 100644 --- a/vllm/v1/spec_decode/utils.py +++ b/vllm/v1/spec_decode/utils.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 -from vllm.v1.sample.ops.topk_topp_sampler import random_sample # noqa from vllm.v1.worker.gpu_input_batch import InputBatch diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 66015382bfe85..657333c6d84c8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -34,7 +34,8 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, ModelRunnerOutput) from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID, RejectionSampler +from vllm.v1.sample.rejection_sampler import RejectionSampler +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.spec_decode.utils import is_spec_decode_supported from vllm.v1.utils import bind_kv_cache @@ -149,7 +150,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.use_spec_decode = False if self.speculative_config: self.use_spec_decode = True - self.rejection_sampler = RejectionSampler() # TODO: find a better way to check if we are using ngram. assert self.speculative_config.ngram_prompt_lookup_min, \ "Currently, only ngram spec decode is supported in V1." @@ -162,6 +162,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.speculative_config.ngram_prompt_lookup_min, self.speculative_config.num_speculative_tokens, ) + self.rejection_sampler = RejectionSampler() # Request states. self.requests: dict[str, CachedRequestState] = {} @@ -452,7 +453,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): def _prepare_inputs( self, scheduler_output: "SchedulerOutput", - ) -> tuple[FlashAttentionMetadata, torch.Tensor]: + ) -> tuple[FlashAttentionMetadata, torch.Tensor, + Optional[SpecDecodeMetadata]]: total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs @@ -577,22 +579,33 @@ class GPUModelRunner(LoRAModelRunnerMixin): use_spec_decode = len( scheduler_output.scheduled_spec_decode_tokens) > 0 - if use_spec_decode: - logits_indices = self._calc_spec_decode_metadata( - scheduler_output, cu_num_tokens) - else: + if not use_spec_decode: # NOTE(woosuk): Due to chunked prefills, the batch may contain # partial requests. While we should not sample any token # from these partial requests, we do so for simplicity. # We will ignore the sampled tokens from the partial requests. # TODO: Support prompt logprobs. logits_indices = attn_metadata.query_start_loc[1:] - 1 + spec_decode_metadata = None + else: + # Get the number of draft tokens for each request. + # Iterate over the dictionary rather than all requests since not all + # requests have draft tokens. + num_draft_tokens = np.zeros(num_reqs, dtype=np.int32) + for req_id, draft_token_ids in ( + scheduler_output.scheduled_spec_decode_tokens.items()): + req_idx = self.input_batch.req_id_to_index[req_id] + num_draft_tokens[req_idx] = len(draft_token_ids) + + spec_decode_metadata = self._calc_spec_decode_metadata( + num_draft_tokens, cu_num_tokens) + logits_indices = spec_decode_metadata.logits_indices # Hot-Swap lora model if self.lora_config: self.set_active_loras(self.input_batch, num_scheduled_tokens) - return attn_metadata, logits_indices + return attn_metadata, logits_indices, spec_decode_metadata def _compute_cascade_attn_prefix_len( self, @@ -732,49 +745,78 @@ class GPUModelRunner(LoRAModelRunnerMixin): def _calc_spec_decode_metadata( self, - scheduler_output: "SchedulerOutput", - cu_num_tokens: np.ndarray, - ) -> torch.Tensor: - # Get the number of spec decode tokens for each request. - num_reqs = self.input_batch.num_reqs - num_spec_decode_tokens = np.empty(num_reqs, dtype=np.int32) - for i, req_id in enumerate(self.input_batch.req_ids): - num_spec_decode_tokens[i] = len( - scheduler_output.scheduled_spec_decode_tokens.get(req_id, ())) + num_draft_tokens: np.ndarray, + cu_num_scheduled_tokens: np.ndarray, + ) -> SpecDecodeMetadata: + # Inputs: + # cu_num_scheduled_tokens: [ 4, 104, 107, 207, 209] + # num_draft_tokens: [ 3, 0, 2, 0, 1] + # Outputs: + # cu_num_draft_tokens: [ 3, 3, 5, 5, 6] + # logits_indices: [ 0, 1, 2, 3, 103, 104, 105, 106, + # 206, 207, 208] + # target_logits_indices: [ 0, 1, 2, 5, 6, 9] + # bonus_logits_indices: [ 3, 4, 7, 8, 10] - # Get spec decode logits indices. - # E.g., num_scheduled_tokens: [4, 100, 3, 100, 2] - # cu_num_tokens: [4, 104, 107, 207, 209] - # num_spec_tokens_list: [3, 0, 2, 0, 1] - # num_sampled_tokens: [4, 1, 3, 1, 2] - # spec_decode_logits_indices: - # [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] - num_sampled_tokens = num_spec_decode_tokens + 1 - # logits_start_loc: [0, 103, 104, 206, 207] - logits_start_loc = cu_num_tokens - num_sampled_tokens - # [0, 103, 104, 206, 207] -> - # [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] - logits_start_loc = np.repeat(logits_start_loc, num_sampled_tokens) - # The following three lines: - # [4, 1, 3, 1, 2] -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] - # Step 1. [4, 1, 3, 1, 2] -> [4, 5, 8, 9, 11] - cu_num_sampled_tokens = np.cumsum(num_sampled_tokens) - # Step 2. [4, 5, 8, 9, 11] -> [0, 4, 5, 8, 9] - # -> [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9] - cumsums_sampled_offsets = np.repeat( - cu_num_sampled_tokens - num_sampled_tokens, num_sampled_tokens) - # Step 3. [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] - # - [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9] - # -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] - total_num_sampled_tokens = num_sampled_tokens.sum() - sampled_arange = (self.arange_np[:total_num_sampled_tokens] - - cumsums_sampled_offsets) + # Compute the logits indices. + # [4, 1, 3, 1, 2] + num_sampled_tokens = num_draft_tokens + 1 + # Step 1. [4, 5, 8, 9, 11] + cu_num_sampled_tokens = np.cumsum(num_sampled_tokens, dtype=np.int32) + total_num_sampled_tokens = cu_num_sampled_tokens[-1] + # Step 2. [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9] + cumsums_offsets = np.repeat(cu_num_sampled_tokens - num_sampled_tokens, + num_sampled_tokens) + # Step 3. [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] + arange = self.arange_np[:total_num_sampled_tokens] - cumsums_offsets + # Step 4. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] + logits_indices = np.repeat( + cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens) + # Step 5. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] + logits_indices += arange - # [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] -> - # [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] - spec_decode_logits_indices = logits_start_loc + sampled_arange - return torch.from_numpy(spec_decode_logits_indices).to( + # Compute the bonus logits indices. + bonus_logits_indices = cu_num_sampled_tokens - 1 + + # Compute the draft logits indices. + # [3, 3, 5, 5, 6] + cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32) + total_num_draft_tokens = cu_num_draft_tokens[-1] + # [0, 0, 0, 3, 3, 5] + cumsums_offsets = np.repeat(cu_num_draft_tokens - num_draft_tokens, + num_draft_tokens) + # [0, 1, 2, 0, 1, 0] + arange = self.arange_np[:total_num_draft_tokens] - cumsums_offsets + # [0, 0, 0, 5, 5, 9] + target_logits_indices = np.repeat( + cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens) + # [0, 1, 2, 5, 6, 9] + target_logits_indices += arange + + # TODO: Optimize the CPU -> GPU copy. + cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to( self.device, non_blocking=True) + logits_indices = torch.from_numpy(logits_indices).to(self.device, + non_blocking=True) + target_logits_indices = torch.from_numpy(target_logits_indices).to( + self.device, non_blocking=True) + bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to( + self.device, non_blocking=True) + + # Compute the draft token ids. + # draft_token_indices: [ 1, 2, 3, 105, 106, 208] + draft_token_ids = self.input_ids[logits_indices] + draft_token_ids = draft_token_ids[target_logits_indices + 1] + + metadata = SpecDecodeMetadata( + draft_token_ids=draft_token_ids, + num_draft_tokens=num_draft_tokens.tolist(), + cu_num_draft_tokens=cu_num_draft_tokens, + target_logits_indices=target_logits_indices, + bonus_logits_indices=bonus_logits_indices, + logits_indices=logits_indices, + ) + return metadata def _execute_encoder(self, scheduler_output: "SchedulerOutput"): scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs @@ -931,7 +973,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): encoder_outputs = [] # Prepare the decoder inputs. - attn_metadata, logits_indices = self._prepare_inputs(scheduler_output) + attn_metadata, logits_indices, spec_decode_metadata = ( + self._prepare_inputs(scheduler_output)) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if (self.use_cuda_graph and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): @@ -1006,31 +1049,29 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Sample the next token and get logprobs if needed. sampling_metadata = self.input_batch.sampling_metadata - if not self.use_spec_decode: + if spec_decode_metadata is None: sampler_output = self.model.sample( logits=logits, sampling_metadata=sampling_metadata, ) else: - draft_token_ids = [ - scheduler_output.scheduled_spec_decode_tokens.get(req_id, []) - for req_id in self.input_batch.req_ids - ] - sample_lens = [len(tokens) + 1 for tokens in draft_token_ids] - recover_logits_idx = np.cumsum(sample_lens) - 1 - target_probs = self.rejection_sampler.compute_probs( - logits, sampling_metadata, sample_lens) + # TODO(woosuk): Optimize the memory usage. + bonus_logits = logits[spec_decode_metadata.bonus_logits_indices] sampler_output = self.model.sample( - logits=logits[recover_logits_idx, :], + logits=bonus_logits, sampling_metadata=sampling_metadata, ) bonus_token_ids = sampler_output.sampled_token_ids + + # TODO(woosuk): Optimize the memory usage. + target_logits = logits[spec_decode_metadata.target_logits_indices] output_token_ids = self.rejection_sampler( - draft_token_ids, + spec_decode_metadata, None, # draft_probs + target_logits, bonus_token_ids, - target_probs, - sampling_metadata) + sampling_metadata, + ) sampler_output.sampled_token_ids = output_token_ids # TODO(woosuk): The following loop can be slow since it iterates over @@ -1066,13 +1107,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): valid_sampled_token_ids = sampled_token_ids.tolist() else: # Includes spec decode tokens. - valid_mask = sampled_token_ids != INVALID_TOKEN_ID - gen_lens = valid_mask.sum(dim=1).tolist() - # TODO(woosuk): Optimize this. - valid_sampled_token_ids = [ - seq.tolist() - for seq in sampled_token_ids[valid_mask].split(gen_lens) - ] + valid_sampled_token_ids = self.rejection_sampler.parse_output( + sampled_token_ids, self.input_batch.vocab_size) if not self.use_spec_decode: spec_token_ids = None @@ -1316,6 +1352,33 @@ class GPUModelRunner(LoRAModelRunnerMixin): "initializing the engine.") from e else: raise e + if self.use_spec_decode: + draft_token_ids = [[0] for _ in range(num_reqs)] + dummy_spec_decode_metadata = SpecDecodeMetadata.make_dummy( + draft_token_ids, self.device) + + num_tokens = sum(len(ids) for ids in draft_token_ids) + # draft_probs = torch.randn( + # num_tokens, logits.shape[-1], device=self.device, + # dtype=logits.dtype) + draft_probs = None + target_logits = torch.randn(num_tokens, + logits.shape[-1], + device=self.device, + dtype=logits.dtype) + # NOTE(woosuk): Here, we should use int32 because the sampler uses + # int32 for bonus_token_ids. If the dtype mismatches, re-compilation + # will occur at runtime. + bonus_token_ids = torch.zeros(num_reqs, + device=self.device, + dtype=torch.int32) + self.rejection_sampler( + dummy_spec_decode_metadata, + draft_probs, + target_logits, + bonus_token_ids, + dummy_metadata, + ) return sampler_output def profile_run(self) -> None: