diff --git a/tests/v1/sample/test_rejection_sampler.py b/tests/v1/sample/test_rejection_sampler.py index 190927745f1fe..84139a40b544a 100644 --- a/tests/v1/sample/test_rejection_sampler.py +++ b/tests/v1/sample/test_rejection_sampler.py @@ -1,37 +1,51 @@ # SPDX-License-Identifier: Apache-2.0 +from typing import Any, Optional import pytest import torch +import torch.nn.functional as F from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID, RejectionSampler +DEVICE = "cpu" + @pytest.fixture def sampler(): return RejectionSampler() -def create_logits_tensor(token_ids: list[int], +def create_logits_tensor(token_ids: list[list[int]], vocab_size: int = 100) -> torch.Tensor: """Helper function to create logits tensor that will produce desired token ids on argmax""" - logits = torch.full((len(token_ids), vocab_size), -100.0).cuda() - for i, token_id in enumerate(token_ids): - logits[i, token_id] = 100.0 + num_total_tokens = sum(len(tokens) for tokens in token_ids) + logits = torch.full((num_total_tokens, vocab_size), -100.0, device=DEVICE) + start_loc = 0 + for tokens in token_ids: + for j, token_id in enumerate(tokens): + logits[start_loc + j, token_id] = 100.0 + start_loc += len(tokens) return logits -def create_sampling_metadata(spec_tokens: list[list[int]]) -> SamplingMetadata: - batch_size = len(spec_tokens) +def create_sampling_metadata( + all_greedy: bool, + generators: Optional[dict[int, Any]] = None) -> SamplingMetadata: + """Create a v1 sampling metadata object with all_greedy set + to the given value. Either all greedy or all random sampling + is used. + """ + generators = generators or {} return SamplingMetadata( temperature=torch.tensor([]), - all_greedy=True, - all_random=False, + all_greedy=all_greedy, + all_random=not all_greedy, top_p=None, top_k=None, - min_p=torch.empty(batch_size, ), - generators={}, + min_p=torch.empty(1, ), + generators=generators, max_num_logprobs=0, no_penalties=False, prompt_token_ids=None, @@ -40,129 +54,310 @@ def create_sampling_metadata(spec_tokens: list[list[int]]) -> SamplingMetadata: repetition_penalties=torch.tensor([]), output_token_ids=[], min_tokens={}, - logit_bias=[None] * batch_size, + logit_bias=[None], allowed_token_ids_mask=None, bad_words_token_ids={}, ) +########################### Tests for Greedy Sampling ################### def test_perfect_match(sampler): """Test when output tokens perfectly match speculated tokens""" spec_tokens = [[1, 2, 3]] - output_tokens = [1, 2, 3, 4] # 4 is the bonus token + output_tokens = [[1, 2, 3, 4]] # 4 is the bonus token - metadata = create_sampling_metadata(spec_tokens) + metadata = create_sampling_metadata(all_greedy=True) logits = create_logits_tensor(output_tokens) + bonus_token_tensor = torch.tensor([output_tokens[0][-1]], + device=logits.device) - output = sampler(spec_tokens, logits, metadata) + output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata) expected = torch.tensor([[1, 2, 3, 4]], dtype=torch.int, device=logits.device) - assert torch.equal(output.sampled_token_ids, expected) + assert torch.equal(output, expected) def test_early_mismatch(sampler): """Test when there's an early mismatch in tokens""" spec_tokens = [[1, 2, 3]] - output_tokens = [1, 5, 3, 4] # Mismatch at position 1 + output_tokens = [[1, 5, 3, 4]] # Mismatch at position 1 - metadata = create_sampling_metadata(spec_tokens) + metadata = create_sampling_metadata(all_greedy=True) logits = create_logits_tensor(output_tokens) + bonus_token_tensor = torch.tensor([output_tokens[0][-1]], + device=logits.device) - output = sampler(spec_tokens, logits, metadata) + output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata) expected = torch.tensor([[1, 5, INVALID_TOKEN_ID, INVALID_TOKEN_ID]], dtype=torch.int, device=logits.device) - assert torch.equal(output.sampled_token_ids, expected) + assert torch.equal(output, expected) def test_multiple_sequences(sampler): """Test handling multiple sequences of speculated tokens""" spec_tokens = [[1, 2], [3]] - output_tokens = [1, 2, 5, 3, 4] # Two sequences with bonus tokens 5 and 4 + output_tokens = [[1, 2, 5], [3, + 4]] # Two sequences with bonus tokens 5 and 4 - metadata = create_sampling_metadata(spec_tokens) + metadata = create_sampling_metadata(all_greedy=True) logits = create_logits_tensor(output_tokens) + bonus_token_tensor = torch.tensor( + [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device) - output = sampler(spec_tokens, logits, metadata) + output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata) expected = torch.tensor([[1, 2, 5], [3, 4, INVALID_TOKEN_ID]], dtype=torch.int, device=logits.device) - assert torch.equal(output.sampled_token_ids, expected) + assert torch.equal(output, expected) def test_single_token_sequence(sampler): """Test handling sequences with single token""" spec_tokens = [[1]] - output_tokens = [1, 2] # Single token with bonus token 2 + output_tokens = [[1, 2]] # Single token with bonus token 2 - metadata = create_sampling_metadata(spec_tokens) + metadata = create_sampling_metadata(all_greedy=True) logits = create_logits_tensor(output_tokens) + bonus_token_tensor = torch.tensor([output_tokens[0][-1]], + device=logits.device) - output = sampler(spec_tokens, logits, metadata) + output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata) expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device) - assert torch.equal(output.sampled_token_ids, expected) + assert torch.equal(output, expected) def test_empty_sequence(sampler): """Test handling empty sequence of speculated tokens""" spec_tokens: list[list[int]] = [[]] - output_tokens = [5] # Just the bonus token + output_tokens = [[5]] # Just the bonus token - metadata = create_sampling_metadata(spec_tokens) + metadata = create_sampling_metadata(all_greedy=True) logits = create_logits_tensor(output_tokens) + bonus_token_tensor = torch.tensor([output_tokens[0][-1]], + device=logits.device) - output = sampler(spec_tokens, logits, metadata) + output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata) expected = torch.tensor([[5]], dtype=torch.int, device=logits.device) - assert torch.equal(output.sampled_token_ids, expected) + assert torch.equal(output, expected) def test_multiple_mismatches(sampler): """Test handling multiple sequences with mismatches""" spec_tokens = [[1, 2, 3], [4, 5, 6]] - output_tokens = [1, 2, 7, 6, 4, 8, 6, 9] # Mismatches in both sequences + output_tokens = [[1, 2, 7, 6], [4, 8, 6, + 9]] # Mismatches in both sequences - metadata = create_sampling_metadata(spec_tokens) + metadata = create_sampling_metadata(all_greedy=True) logits = create_logits_tensor(output_tokens) + bonus_token_tensor = torch.tensor( + [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device) - output = sampler(spec_tokens, logits, metadata) + output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata) expected = torch.tensor([[1, 2, 7, INVALID_TOKEN_ID], [4, 8, INVALID_TOKEN_ID, INVALID_TOKEN_ID]], dtype=torch.int, device=logits.device) - assert torch.equal(output.sampled_token_ids, expected) + assert torch.equal(output, expected) @pytest.mark.parametrize( "spec_tokens,output_tokens,expected", [ - ([[1, 2]], [1, 2, 3], [[1, 2, 3]]), # Perfect match with bonus - ([[1]], [2, 3], [[2, INVALID_TOKEN_ID]]), # First mismatch - ([[1, 2], [3, 4]], [1, 5, 6, 3, 4, 7], [[1, 5, INVALID_TOKEN_ID], - [3, 4, 7]]), # Mixed matches + ([[1, 2]], [[1, 2, 3]], [[1, 2, 3]]), # Perfect match with bonus + ([[1]], [[2, 3]], [[2, INVALID_TOKEN_ID]]), # First mismatch + ([[1, 2], [3, 4]], [[1, 5, 6], [3, 4, 7]], + [[1, 5, INVALID_TOKEN_ID], [3, 4, 7]]), # Mixed matches ]) def test_parametrized_cases(sampler, spec_tokens, output_tokens, expected): """Parametrized test for various matching scenarios""" - metadata = create_sampling_metadata(spec_tokens) + metadata = create_sampling_metadata(all_greedy=True) logits = create_logits_tensor(output_tokens) + bonus_token_tensor = torch.tensor([tokens[-1] for tokens in output_tokens], + device=logits.device) - output = sampler(spec_tokens, logits, metadata) + output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata) expected_tensor = torch.tensor(expected, dtype=torch.int, device=logits.device) - assert torch.equal(output.sampled_token_ids, expected_tensor) + assert torch.equal(output, expected_tensor) -def test_logits_shape_handling(sampler): - """Test handling of different logits tensor shapes""" - spec_tokens = [[1, 2]] - output_tokens = [1, 2, 3] - vocab_size = 1000 +########################### Tests for Random Sampling ################### +@pytest.mark.parametrize("k", [1, 3, 5]) +@pytest.mark.parametrize("vocab_size", [1000]) +@pytest.mark.parametrize("batch_size", [1, 4, 8]) +@pytest.mark.parametrize("frac_seeded", [0.0, 0.5]) +@pytest.mark.parametrize("n_rep", [20]) +def test_deterministic_when_seeded(sampler, k: int, vocab_size: int, + batch_size: int, frac_seeded: float, + n_rep: int): + draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) + target_probs = torch.rand(batch_size * (k + 1), + vocab_size, + dtype=torch.float32) + bonus_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, 1), + dtype=torch.int64) + draft_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, k), + dtype=torch.int64) - metadata = create_sampling_metadata(spec_tokens) - logits = create_logits_tensor(output_tokens, vocab_size) + seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= frac_seeded - output = sampler(spec_tokens, logits, metadata) - expected = torch.tensor([[1, 2, 3]], dtype=torch.int, device=logits.device) - assert torch.equal(output.sampled_token_ids, expected) - assert logits.shape[-1] == vocab_size + results = [] + for _ in range(n_rep): + seeded_seqs = { + i: torch.Generator(device=DEVICE).manual_seed(i) + for i in range(batch_size) if seeded_mask[i] + } + + sampling_metadata = create_sampling_metadata(all_greedy=False, + generators=seeded_seqs) + rep_result = sampler(draft_token_ids.tolist(), draft_probs, + bonus_token_ids, target_probs, sampling_metadata) + + results.append(rep_result) + + for i in range(batch_size): + if seeded_mask[i]: + for j in range(1, n_rep): + assert torch.equal(results[j][i], results[0][i]) + + +def test_rejection_sampling_approximates_target_distribution(): + """Verify rejection sampling approximates target distribution, + despite sampling from a potentially distinct draft distribution. + + This is done by first creating a random target probability + distribution and a random draft probability distribution. We then + sample token ids from the rejection sampler using these draft + and target distributions. The samples are used to estimate + the output probability distribution, which we expect to approximate + the target distribution. + + A basic distance metric is used to determine similarity between + distributions. + + We expect that as we increase the number of samples, + the distance between the observed distribution and the target + distribution decreases. To measure this, we compare the distance + of the observed distribution against both the target distribution + and a uniform random distribution. We expect the distance between + the observed distribution and the target distribution to improve + much more than the distance improvement between the observed + distribution and the random distribution. + """ + torch.set_default_device(DEVICE) + vocab_size = 10 + k = 2 + num_reference_probs = 100 + + # Prepare draft, target, and reference probability distributions + draft_probs, target_probs = (F.softmax( + torch.rand(vocab_size, dtype=torch.float32), + dim=-1, + ) for _ in range(2)) + reference_probs = F.softmax( + torch.rand(num_reference_probs, vocab_size, dtype=torch.float32), + dim=-1, + ) + + sample_sizes = [10, 100, 1_000, 10_000, 100_000] + distance_wrt_reference: list[float] = [] + distance_wrt_target: list[float] = [] + + for num_samples in sample_sizes: + # Sample using rejection sampling. + rej_sample_probs = estimate_rejection_sampling_pdf( + draft_probs, target_probs, k, vocab_size, num_samples) + rej_sample_probs = rej_sample_probs.to(DEVICE) + + # Average distance from reference probs. + reference_vs_rejsample_dist = torch.dist( + reference_probs, + rej_sample_probs).item() / reference_probs.shape[0] + target_vs_rejsample_dist = torch.dist(target_probs, + rej_sample_probs).item() + + distance_wrt_reference.append(reference_vs_rejsample_dist) + distance_wrt_target.append(target_vs_rejsample_dist) + + relative_change_in_distance_wrt_target = get_ratio_first_to_last( + distance_wrt_target) + relative_change_in_distance_wrt_reference = get_ratio_first_to_last( + distance_wrt_reference) + + print(f"{num_samples=} {target_vs_rejsample_dist=:.05f} " + f"{reference_vs_rejsample_dist=:.05f}") + print(f"{num_samples=} {relative_change_in_distance_wrt_target=:.02f} " + f"{relative_change_in_distance_wrt_reference=:.02f}") + + relative_change_in_distance_wrt_target = get_ratio_first_to_last( + distance_wrt_target) + relative_change_in_distance_wrt_reference = get_ratio_first_to_last( + distance_wrt_reference) + + expected_improvement_multiplier = 20 + assert (relative_change_in_distance_wrt_target + > relative_change_in_distance_wrt_reference * + expected_improvement_multiplier) + + +def get_ratio_first_to_last(elements: list[float]) -> float: + return elements[0] / elements[-1] + + +def estimate_rejection_sampling_pdf( + draft_probs: torch.Tensor, + target_probs: torch.Tensor, + k: int, + vocab_size: int, + num_samples: int, +) -> torch.Tensor: + """Estimate the probability distribution of the output tokens + using rejection sampling. + + Args: + draft_probs: Draft probability distribution. + target_probs: Target probability distribution. + num_samples: Number of samples to draw. + + Returns: + Estimated probability distribution of the output tokens. + """ + sampler = RejectionSampler() + # Repeat draft probs num_samples times. + draft_probs = draft_probs.reshape(1, 1, + vocab_size).repeat(num_samples, k, 1) + + # Repeat target probs num_samples * (k + 1) times. + target_probs = target_probs.reshape(1, 1, vocab_size).repeat( + num_samples, k + 1, 1).reshape(num_samples * (k + 1), vocab_size) + + # Randomly sample draft token ids from draft probs. + draft_token_ids = torch.multinomial(draft_probs[:, 0, :], + num_samples=k, + replacement=True).reshape( + num_samples, k) + + # Bonus tokens not used but required. + bonus_token_ids = torch.zeros((1, 1), dtype=torch.int64, + device=DEVICE).repeat(num_samples, 1) + + sampling_metadata = create_sampling_metadata(all_greedy=False) + output_token_ids = sampler(draft_token_ids.tolist(), draft_probs, + bonus_token_ids, target_probs, + sampling_metadata) + output_token_ids = output_token_ids[:, :-1].flatten() + + hist = torch.histogram(output_token_ids.to(dtype=torch.float, + device="cpu"), + bins=vocab_size, + range=(0, vocab_size), + density=True) + + return hist.hist diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index ea7f3353c115f..5601c62e91fc0 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -1,87 +1,89 @@ # SPDX-License-Identifier: Apache-2.0 +from typing import Optional import torch import torch.nn as nn from torch.nn.utils.rnn import pad_sequence -from vllm import envs from vllm.logger import init_logger -from vllm.platforms import current_platform -from vllm.v1.outputs import SamplerOutput from vllm.v1.sample.metadata import SamplingMetadata - -try: - import flashinfer.sampling as fs - is_flashinfer_available = True -except ImportError: - is_flashinfer_available = False +from vllm.v1.spec_decode.utils import random_sample logger = init_logger(__name__) INVALID_TOKEN_ID = -1 class RejectionSampler(nn.Module): + """ + The implementation strictly follows the algorithm described in + https://arxiv.org/abs/2211.17192. + However, we want to clarify the terminology used in the implementation: + accepted tokens: tokens that are accepted based on the relationship + between the "raw" draft and target probabilities. + recovered tokens: tokens that are sampled based on the adjusted probability + distribution, which is derived from both the draft and target + probabilities. + bonus tokens: + If all proposed tokens are accepted, the bonus token is added to the + end of the sequence. The bonus token is only sampled from the target + probabilities. We pass in the bonus tokens instead of sampling them + in the rejection sampler to allow for more flexibility in the + sampling process. For example, we can use top_p, top_k sampling for + bonus tokens, while spec decode does not support these sampling + strategies. + output tokens: + Tokens are finally generated with the rejection sampler. + output tokens = accepted tokens + recovered tokens + bonus tokens + """ def __init__(self): super().__init__() - if current_platform.is_cuda(): - if is_flashinfer_available: - if envs.VLLM_USE_FLASHINFER_SAMPLER is not False: - # FIXME(woosuk): Currently, we have errors when using - # FlashInfer for rejection sampling. As a workaround, we - # disable FlashInfer for rejection sampling by default. - logger.info("Currently, FlashInfer rejection sampler is " - "disabled because of a bug. Falling back to " - "the PyTorch-native implementation of " - "rejection sampling.") - self.forward_method = self.forward_native - # NOTE(woosuk): The V0 sampler doesn't use FlashInfer for - # sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by - # default it is unused). For backward compatibility, we set - # `VLLM_USE_FLASHINFER_SAMPLER` as None by default and - # interpret it differently in V0 and V1 samplers: In V0, - # None means False, while in V1, None means True. This is - # why we use the condition - # `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here. - # logger.info("Using FlashInfer for rejection sampling.") - # self.forward_method = self.flashinfer_sample - else: - logger.warning( - "FlashInfer is available, but it is not enabled. " - "Falling back to the PyTorch-native implementation of " - "rejection sampling. For the best performance, " - "please set VLLM_USE_FLASHINFER_SAMPLER=1.") - self.forward_method = self.forward_native - else: - logger.warning( - "FlashInfer is not available. Falling back to the PyTorch-" - "native implementation of rejection sampling. For the " - "best performance, please install FlashInfer.") - self.forward_method = self.forward_native - else: - self.forward_method = self.forward_native - - def forward(self, draft_token_ids: list[list[int]], - target_probs: torch.Tensor, - sampling_metadata: SamplingMetadata) -> SamplerOutput: - if not sampling_metadata.all_greedy: - raise NotImplementedError( - "Currently, only greedy sampling is supported by " - "rejection sampler.") - return self.forward_method(draft_token_ids, target_probs, - sampling_metadata) - - def flashinfer_sample( + def forward( self, draft_token_ids: list[list[int]], - target_probs: torch.Tensor, + draft_probs: Optional[torch.Tensor], + bonus_token_ids_tensor: torch.Tensor, # [batch_size, 1] + target_probs: torch.Tensor, # [num_total_tokens, vocab_size] sampling_metadata: SamplingMetadata, - ) -> SamplerOutput: + ) -> torch.Tensor: + ''' + Args: + draft_token_ids (List[List[int]]): + A 2D list of token IDs for each request in the batch. + Each request might have different number of draft tokens. + It may also contain empty lists for requests that have + no draft tokens. + draft_probs (Optional[torch.Tensor]): + Probability distribution for the draft tokens. Shape is + [batch_size, max_spec_len, vocab_size]. Can be None if + probabilities are not provided, which is the case for + ngram spec decode. + bonus_token_ids_tensor (torch.Tensor): + A tensor containing bonus tokens. Shape is [batch_size, 1]. + Bonus tokens are added to the end of the sequence if all + proposed tokens are accepted. We generate the bonus tokens + outside of the rejection sampler with the default sampling + strategy. It allows for more flexibility in the sampling + process such as top_p, top_k sampling. + target_probs (torch.Tensor): + Target model probability distribution. + Shape is [num_total_tokens, vocab_size]. num_total_tokens + is the total number of tokens from all requests. Here, + probabilities from different requests are flattened into + a single tensor because this is the shape of the output + logits. + sampling_metadata (SamplingMetadata): + Additional metadata needed for sampling, such as temperature, + top-k/top-p parameters, or other relevant information. + Returns: + output_token_ids (torch.Tensor): + A tensor containing the final output token IDs. + ''' + # NOTE: The following input preparationg can be moved # to the model runner with a persistent manner for better # performance. - sample_lens = [len(x) + 1 for x in draft_token_ids] # Convert draft token IDs to a tensor, split by sample_lens, then pad. draft_token_ids = [ torch.tensor(x, dtype=int, device='cpu') for x in draft_token_ids @@ -90,90 +92,171 @@ class RejectionSampler(nn.Module): batch_first=True, padding_value=INVALID_TOKEN_ID) - if sampling_metadata.all_greedy: - target_token_ids = target_probs.argmax(dim=-1).view(-1) - target_token_ids = target_token_ids.split(sample_lens) - target_token_ids = pad_sequence(target_token_ids, - batch_first=True, - padding_value=INVALID_TOKEN_ID) + # NOTE: CPU <-> GPU synchronization happens here. + draft_token_ids_tensor = draft_token_ids_tensor.to(target_probs.device) + # Create one-hot tensor for draft token ids. + # This is used for ngram where we don't have draft_probs. + if draft_probs is None and not sampling_metadata.all_greedy: vocab_size = target_probs.size(-1) - # NOTE: CPU <-> GPU synchronization happens here. - draft_token_ids_tensor = draft_token_ids_tensor.to( - target_probs.device) draft_probs = _create_greedy_token_probs(draft_token_ids_tensor, vocab_size, target_probs.device) - target_probs = _create_greedy_token_probs(target_token_ids, - vocab_size, - target_probs.device) - uniform_samples = torch.zeros(draft_token_ids_tensor.size(0), - draft_token_ids_tensor.size(1) + 1, - device=target_probs.device) - else: - raise NotImplementedError( - "Currently, only greedy sampling is supported by " - "rejection sampler.") + sample_lens = [len(x) + 1 for x in draft_token_ids] + target_probs = _convert_2d_probs(target_probs, sample_lens) - sampled_token_ids, _, _ = fs.chain_speculative_sampling( - draft_probs, - draft_token_ids_tensor, - uniform_samples, - target_probs, - ) - return SamplerOutput(sampled_token_ids=sampled_token_ids, - logprobs_tensors=None) + return self.forward_native(draft_token_ids_tensor, draft_probs, + bonus_token_ids_tensor, target_probs, + sampling_metadata) # TODO: The following method can be optimized for better performance. def forward_native( self, - draft_token_ids: list[list[int]], + draft_token_ids_tensor: torch.Tensor, + # [batch_size, max_spec_len, vocab_size] + draft_probs: Optional[torch.Tensor], + bonus_token_ids_tensor: torch.Tensor, + # [batch_size, max_spec_len + 1, vocab_size] target_probs: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> SamplerOutput: - sample_lens = [len(x) + 1 for x in draft_token_ids] - # Convert draft token IDs to a tensor, split by sample_lens, then pad. - draft_token_ids = [ - torch.tensor(x, dtype=int, device='cpu') for x in draft_token_ids - ] - draft_token_ids_tensor = pad_sequence(draft_token_ids, - batch_first=True, - padding_value=INVALID_TOKEN_ID) - draft_token_ids_tensor = draft_token_ids_tensor.to(target_probs.device) + ) -> torch.Tensor: # Add 1 to include the 'bonus' token. if sampling_metadata.all_greedy: - output_token_ids = target_probs.argmax(dim=-1).view(-1) - output_token_ids = output_token_ids.split(sample_lens) - output_token_ids = pad_sequence(output_token_ids, - batch_first=True, - padding_value=INVALID_TOKEN_ID) # Produce a mask that remains 1 (True) until the first # mismatch (cumprod turns 0 after a mismatch). - accept_mask = ( - output_token_ids[:, :-1] == draft_token_ids_tensor).cumprod( - dim=1) - else: - raise NotImplementedError( - "Currently, only greedy sampling is supported by " - "rejection sampler.") - # Identify valid positions (non-padding). - valid_mask = output_token_ids != INVALID_TOKEN_ID - # Generate mask with bonus token. - generate_mask = torch.cat([ - accept_mask, - torch.zeros(accept_mask.size(0), 1, device=accept_mask.device) - ], - dim=1).to(torch.bool) & valid_mask - zeros_mask = (generate_mask == 0) - first_zero_idx = zeros_mask.float().argmax(dim=1) - # Figure out which rows actually contain at least one zero. - rows_with_zero = zeros_mask.any(dim=1) - # Use indexing to set the first zero in each of those rows to 1. - generate_mask[rows_with_zero, first_zero_idx[rows_with_zero]] = 1 + target_token_ids_tensor = target_probs.argmax(dim=-1) + accept_mask = (target_token_ids_tensor[:, :-1] == + draft_token_ids_tensor).cumprod(dim=1) - output_token_ids[~generate_mask] = INVALID_TOKEN_ID - return SamplerOutput(sampled_token_ids=output_token_ids, - logprobs_tensors=None) + # Identify valid positions (non-padding). + valid_mask = target_token_ids_tensor != INVALID_TOKEN_ID + # Generate mask with bonus token. + generate_mask = torch.cat([ + accept_mask, + torch.zeros(accept_mask.size(0), 1, device=accept_mask.device) + ], + dim=1).to(torch.bool) & valid_mask + zeros_mask = (generate_mask == 0) + first_zero_idx = zeros_mask.float().argmax(dim=1) + # Figure out which rows actually contain at least one zero. + rows_with_zero = zeros_mask.any(dim=1) + # Use indexing to set the first zero in each of those rows to 1. + generate_mask[rows_with_zero, first_zero_idx[rows_with_zero]] = 1 + + output_token_ids = target_token_ids_tensor + output_token_ids[~generate_mask] = INVALID_TOKEN_ID + else: + # Reference: https://arxiv.org/pdf/2211.17192 + # 1. Extract the probabilities of the draft tokens. + # [batch_size, max_spec_len] + batch_size = draft_token_ids_tensor.size(0) + max_spec_len = draft_token_ids_tensor.size(1) + invalid_idx = draft_token_ids_tensor == INVALID_TOKEN_ID + draft_token_ids_tensor[invalid_idx] = 0 + assert draft_probs is not None + draft_token_probs = draft_probs.gather( + dim=-1, index=draft_token_ids_tensor.unsqueeze(-1)).squeeze(-1) + target_token_probs = target_probs.gather( + dim=-1, index=draft_token_ids_tensor.unsqueeze(-1)).squeeze(-1) + # Force the probabilities of invalid tokens to inf + # so that they are not accepted. + draft_token_probs[invalid_idx] = float('inf') + + # 2. Generate uniform samples. + # [batch_size, max_spec_len + 1] + uniform_samples = _create_uniform_samples( + sampling_metadata.generators, batch_size, max_spec_len, + target_probs.device) + + # 3. Accept or reject the samples. + # [batch_size, max_spec_len] + # If the draft token probabilities are 0, set them to the smallest + # positive normal value representable by float32. + safe_draft_probs = torch.where(draft_token_probs > 0, + draft_token_probs, + torch.finfo(torch.float32).tiny) + accepted = uniform_samples <= target_token_probs / safe_draft_probs + accept_mask = accepted.cumprod(dim=1) + # Set the token ids to the draft token ids if accepted, otherwise + # set them to INVALID_TOKEN_ID. + accepted_token_ids = (draft_token_ids_tensor * accept_mask + + INVALID_TOKEN_ID * (1 - accept_mask)) + + # 4. Adjust the distribution for the recovered tokens. + # Clamp the bonus probabilities to the smallest positive normal + # value representable by float32. + bonus_prob = torch.clamp(target_probs[:, :-1, :] - draft_probs, + min=torch.finfo(torch.float32).tiny) + normalized_bonus_prob = bonus_prob / bonus_prob.sum(dim=-1, + keepdim=True) + + # 5. Sample recovered token ids. + recovered_token_ids = random_sample( + normalized_bonus_prob, + sampling_metadata.generators).reshape(batch_size, max_spec_len) + + # 6. Get the final output token ids. + # output_token_ids = accepted_token_ids + + # recovered_token_ids + + # bonus_token_id + recovered_bonus_token_ids = torch.cat( + [recovered_token_ids, bonus_token_ids_tensor], dim=1) + # Generate mask with bonus tokens. + generate_mask = torch.cat([ + accept_mask, + torch.zeros(batch_size, 1, device=accept_mask.device) + ], + dim=1).to(torch.bool) + zeros_mask = (generate_mask == 0) + first_zero_idx = zeros_mask.float().argmax(dim=1) + output_token_ids = torch.cat([ + accepted_token_ids, + torch.full((batch_size, 1), + fill_value=INVALID_TOKEN_ID, + device=accept_mask.device) + ], + dim=1) + output_token_ids[torch.arange(batch_size), + first_zero_idx] = recovered_bonus_token_ids[ + torch.arange(batch_size), first_zero_idx] + + return output_token_ids + + def compute_probs(self, logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + sample_lens: list[int]) -> torch.Tensor: + """ + Compute probability distribution from logits based on sampling metadata. + + This function applies temperature scaling to the logits and converts + them to probabilities using softmax. Note that division by + temperature is not performed inplace to preserve the original logits + tensor, which will be used by the original sampler to get bonus tokens. + + Args: + logits: Input logits tensor to be converted to probabilities + sampling_metadata: Metadata containing sampling parameters such + as temperature and whether greedy sampling is used + sample_lens: List of sample lengths used for repeating + temperature values + + Returns: + torch.Tensor: Probability distribution (softmax of scaled logits) + if non-greedy sampling is used, otherwise returns the + original logits + """ + if sampling_metadata.all_greedy: + return logits + assert sampling_metadata.temperature is not None + # We should optimize the following code as + # it will cause CPU -> GPU synchronization. + temperature = torch.repeat_interleave( + sampling_metadata.temperature, + torch.tensor(sample_lens, + device=sampling_metadata.temperature.device)) + temperature = temperature.unsqueeze(dim=1) + logits = logits / temperature + return logits.softmax(dim=-1, dtype=torch.float32) def _create_greedy_token_probs( @@ -199,3 +282,66 @@ def _create_greedy_token_probs( src=valid_mask.unsqueeze(-1).float()) return token_probs + + +def _convert_2d_probs( + probs: torch.Tensor, # [num_total_tokens, vocab_size] + sample_lens: list[int]) -> torch.Tensor: + """ + Converts a 2D tensor of probabilities to a 3D tensor with padding. + [num_total_tokens, vocab_size] -> + [batch_size, max_spec_len + 1, vocab_size] + """ + cumulative_lens = torch.cumsum(torch.tensor(sample_lens, + device=probs.device), + dim=0) + split_indices = cumulative_lens[:-1].tolist() # Exclude last index + + # Split into chunks without loops + chunks = torch.tensor_split(probs, split_indices, dim=0) + + # Pad all sequences to maximum length + padded_probs = pad_sequence(chunks, batch_first=True, padding_value=0.0) + return padded_probs + + +def _create_uniform_samples(seeded_seqs: dict[int, torch.Generator], + batch_size: int, k: int, + device: torch.device) -> torch.Tensor: + """ + Generates a batch of uniform random samples, with optional seeding + for specific sequences. + + This method creates a tensor of shape `(batch_size, k)` filled + with uniform random values in the range [0, 1). If `seeded_seqs` + is provided, the sequences corresponding to specific indices + will be generated using the provided `torch.Generator` for + reproducibility. The other sequences will be generated without + a seed. + + Args: + seeded_seqs : Optional[Dict[int, torch.Generator]] + A dictionary mapping indices in the batch to + `torch.Generator` objects. + batch_size : int + The number of sequences to generate. + k : int + The number of random samples per sequence. + device : torch.device + The device on which to allocate the tensor. + + Returns: + uniform_rand : torch.Tensor + A tensor of shape `(batch_size, k)` containing uniform + random values in the range [0, 1). + """ + + uniform_rand = torch.rand(batch_size, + k, + dtype=torch.float32, + device=device) + # Apply seeded generators only where needed + if seeded_seqs: + for idx, generator in seeded_seqs.items(): + uniform_rand[idx].uniform_(0, 1, generator=generator) + return uniform_rand diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 96f6d807b10ce..d91c057083f31 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -119,14 +119,6 @@ class Sampler(nn.Module): ) return sampled - def compute_probs(self, logits: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - if sampling_metadata.all_greedy: - return logits - # Apply temperature. This is an in-place op changing logits. - logits = self.apply_temperature(logits, sampling_metadata.temperature) - return logits.softmax(dim=-1, dtype=torch.float32) - def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor: return logits.log_softmax(dim=-1, dtype=torch.float32) diff --git a/vllm/v1/spec_decode/utils.py b/vllm/v1/spec_decode/utils.py new file mode 100644 index 0000000000000..5841401367788 --- /dev/null +++ b/vllm/v1/spec_decode/utils.py @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: Apache-2.0 +from vllm.v1.sample.ops.topk_topp_sampler import random_sample # noqa +from vllm.v1.worker.gpu_input_batch import InputBatch + + +def is_spec_decode_supported(req_id: str, input_batch: InputBatch) -> bool: + if req_id in input_batch.top_k_reqs or req_id in input_batch.top_p_reqs: + # Spec decode doesn't support top_p/top_k sampling. + return False + elif req_id in input_batch.min_p_reqs: + # Spec decode doesn't support min_p sampling. + return False + elif (req_id in input_batch.frequency_penalties_reqs + or req_id in input_batch.presence_penalties_reqs + or req_id in input_batch.repetition_penalties_reqs): + # Spec decode doesn't support penalties. + return False + elif req_id in input_batch.num_logprobs: + # Spec decode doesn't support logprobs. + return False + + return True diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4059d5b17b71b..2a98bea562dcb 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -37,6 +37,7 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID, RejectionSampler from vllm.v1.spec_decode.ngram_proposer import NgramProposer +from vllm.v1.spec_decode.utils import is_spec_decode_supported from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin @@ -1020,15 +1021,26 @@ class GPUModelRunner(LoRAModelRunnerMixin): sampling_metadata=sampling_metadata, ) else: - target_probs = self.model.sampler.compute_probs( - logits, sampling_metadata) draft_token_ids = [ scheduler_output.scheduled_spec_decode_tokens.get(req_id, []) for req_id in self.input_batch.req_ids ] - sampler_output = self.rejection_sampler(draft_token_ids, - target_probs, - sampling_metadata) + sample_lens = [len(tokens) + 1 for tokens in draft_token_ids] + recover_logits_idx = np.cumsum(sample_lens) - 1 + target_probs = self.rejection_sampler.compute_probs( + logits, sampling_metadata, sample_lens) + sampler_output = self.model.sample( + logits=logits[recover_logits_idx, :], + sampling_metadata=sampling_metadata, + ) + bonus_token_ids = sampler_output.sampled_token_ids + output_token_ids = self.rejection_sampler( + draft_token_ids, + None, # draft_probs + bonus_token_ids, + target_probs, + sampling_metadata) + sampler_output.sampled_token_ids = output_token_ids # TODO(woosuk): The following loop can be slow since it iterates over # the requests one by one. Optimize. @@ -1075,7 +1087,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): spec_token_ids = None else: spec_token_ids = self.generate_draft_token_ids( - valid_sampled_token_ids) + valid_sampled_token_ids, sampling_metadata) return ModelRunnerOutput( req_ids=self.input_batch.req_ids, @@ -1089,6 +1101,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): def generate_draft_token_ids( self, sampled_token_ids: list[list[int]], + sampling_metadata: SamplingMetadata, ) -> list[list[int]]: # TODO(woosuk): Optimize. draft_token_ids: list[list[int]] = [] @@ -1099,6 +1112,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): draft_token_ids.append([]) continue + # Skip requests that require top-p, top-k, etc. + req_id = self.input_batch.req_ids[i] + if not is_spec_decode_supported(req_id, self.input_batch): + draft_token_ids.append([]) + continue + # Add sampled_token_ids to token_ids_cpu. start_idx = self.input_batch.num_tokens_no_spec[i] end_idx = start_idx + num_sampled_ids