From 5629f26df7c0fe5126d90463e14292f4a0552a65 Mon Sep 17 00:00:00 2001 From: Lily Liu Date: Tue, 25 Feb 2025 18:14:48 -0800 Subject: [PATCH] [V1][Spec Decode] Change Spec Decode Rejection Sampling API (#13729) --- tests/v1/sample/test_rejection_sampler.py | 17 ++- tests/v1/sample/test_sampler.py | 1 - tests/v1/worker/test_gpu_input_batch.py | 1 - vllm/v1/sample/metadata.py | 3 - vllm/v1/sample/rejection_sampler.py | 130 +++++++++++----------- vllm/v1/sample/sampler.py | 19 ++-- vllm/v1/worker/gpu_input_batch.py | 11 -- vllm/v1/worker/gpu_model_runner.py | 29 +++-- 8 files changed, 102 insertions(+), 109 deletions(-) diff --git a/tests/v1/sample/test_rejection_sampler.py b/tests/v1/sample/test_rejection_sampler.py index 956d91c6daf73..f00585b40ba3f 100644 --- a/tests/v1/sample/test_rejection_sampler.py +++ b/tests/v1/sample/test_rejection_sampler.py @@ -29,7 +29,6 @@ def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata: temperature=torch.tensor([]), all_greedy=True, all_random=False, - spec_token_ids=spec_tokens, top_p=None, top_k=None, min_p=torch.empty(batch_size, ), @@ -55,7 +54,7 @@ def test_perfect_match(sampler): metadata = create_sampling_metadata(spec_tokens) logits = create_logits_tensor(output_tokens) - output = sampler(logits, metadata) + output = sampler(spec_tokens, logits, metadata) expected = torch.tensor([[1, 2, 3, 4]], dtype=torch.int, device=logits.device) @@ -70,7 +69,7 @@ def test_early_mismatch(sampler): metadata = create_sampling_metadata(spec_tokens) logits = create_logits_tensor(output_tokens) - output = sampler(logits, metadata) + output = sampler(spec_tokens, logits, metadata) expected = torch.tensor([[1, 5, INVALID_TOKEN_ID, INVALID_TOKEN_ID]], dtype=torch.int, device=logits.device) @@ -85,7 +84,7 @@ def test_multiple_sequences(sampler): metadata = create_sampling_metadata(spec_tokens) logits = create_logits_tensor(output_tokens) - output = sampler(logits, metadata) + output = sampler(spec_tokens, logits, metadata) expected = torch.tensor([[1, 2, 5], [3, 4, INVALID_TOKEN_ID]], dtype=torch.int, device=logits.device) @@ -100,7 +99,7 @@ def test_single_token_sequence(sampler): metadata = create_sampling_metadata(spec_tokens) logits = create_logits_tensor(output_tokens) - output = sampler(logits, metadata) + output = sampler(spec_tokens, logits, metadata) expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device) assert torch.equal(output.sampled_token_ids, expected) @@ -113,7 +112,7 @@ def test_empty_sequence(sampler): metadata = create_sampling_metadata(spec_tokens) logits = create_logits_tensor(output_tokens) - output = sampler(logits, metadata) + output = sampler(spec_tokens, logits, metadata) expected = torch.tensor([[5]], dtype=torch.int, device=logits.device) assert torch.equal(output.sampled_token_ids, expected) @@ -126,7 +125,7 @@ def test_multiple_mismatches(sampler): metadata = create_sampling_metadata(spec_tokens) logits = create_logits_tensor(output_tokens) - output = sampler(logits, metadata) + output = sampler(spec_tokens, logits, metadata) expected = torch.tensor([[1, 2, 7, INVALID_TOKEN_ID], [4, 8, INVALID_TOKEN_ID, INVALID_TOKEN_ID]], dtype=torch.int, @@ -147,7 +146,7 @@ def test_parametrized_cases(sampler, spec_tokens, output_tokens, expected): metadata = create_sampling_metadata(spec_tokens) logits = create_logits_tensor(output_tokens) - output = sampler(logits, metadata) + output = sampler(spec_tokens, logits, metadata) expected_tensor = torch.tensor(expected, dtype=torch.int, device=logits.device) @@ -163,7 +162,7 @@ def test_logits_shape_handling(sampler): metadata = create_sampling_metadata(spec_tokens) logits = create_logits_tensor(output_tokens, vocab_size) - output = sampler(logits, metadata) + output = sampler(spec_tokens, logits, metadata) expected = torch.tensor([[1, 2, 3]], dtype=torch.int, device=logits.device) assert torch.equal(output.sampled_token_ids, expected) assert logits.shape[-1] == vocab_size diff --git a/tests/v1/sample/test_sampler.py b/tests/v1/sample/test_sampler.py index 34fba5a9f6d7f..435c1b7b5fda9 100644 --- a/tests/v1/sample/test_sampler.py +++ b/tests/v1/sample/test_sampler.py @@ -105,7 +105,6 @@ def _create_default_sampling_metadata( prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids, vocab_size, device), output_token_ids=output_token_ids, - spec_token_ids=None, frequency_penalties=_create_penalty_tensor(batch_size, 0.0, device), presence_penalties=_create_penalty_tensor(batch_size, 0.0, device), repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device), diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index 0aee266264acf..327370e71fffc 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -123,7 +123,6 @@ def _construct_expected_sampling_metadata( dtype=torch.float, device=device), output_token_ids=output_token_ids, - spec_token_ids=None, min_tokens=min_tokens, no_penalties=(all(x == 0 for x in presence_penalties) and all(x == 0 for x in frequency_penalties) diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index 9f7770bbd078a..b757a1dc60c74 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -13,9 +13,6 @@ class SamplingMetadata: all_greedy: bool all_random: bool - # None when there are no speculated tokens. - spec_token_ids: Optional[List[List[int]]] - top_p: Optional[torch.Tensor] top_k: Optional[torch.Tensor] min_p: Optional[torch.Tensor] diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 580ad44297aa0..2e3927345eb5f 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -1,4 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +from typing import List + import torch import torch.nn as nn from torch.nn.utils.rnn import pad_sequence @@ -52,62 +54,62 @@ class RejectionSampler(nn.Module): else: self.forward_method = self.forward_native - def forward(self, logits: torch.Tensor, + def forward(self, draft_token_ids: List[List[int]], + target_probs: torch.Tensor, sampling_metadata: SamplingMetadata) -> SamplerOutput: if not sampling_metadata.all_greedy: raise NotImplementedError( "Currently, only greedy sampling is supported by " "rejection sampler.") - return self.forward_method(logits, sampling_metadata) + return self.forward_method(draft_token_ids, target_probs, + sampling_metadata) def flashinfer_sample( self, - logits: torch.Tensor, + draft_token_ids: List[List[int]], + target_probs: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> SamplerOutput: # NOTE: The following input preparationg can be moved # to the model runner with a persistent manner for better # performance. - assert sampling_metadata.spec_token_ids is not None - spec_token_ids = sampling_metadata.spec_token_ids - max_spec_len = max(len(s) for s in spec_token_ids) - batch_size = len(spec_token_ids) - draft_token_ids = torch.full((batch_size, max_spec_len), - INVALID_TOKEN_ID, - device="cpu", - dtype=torch.long) + sample_lens = [len(x) + 1 for x in draft_token_ids] + # Convert draft token IDs to a tensor, split by sample_lens, then pad. + draft_token_ids = [ + torch.tensor(x, dtype=int, device='cpu') for x in draft_token_ids + ] + draft_token_ids_tensor = pad_sequence(draft_token_ids, + batch_first=True, + padding_value=INVALID_TOKEN_ID) - target_token_ids = torch.full((batch_size, max_spec_len + 1), - fill_value=INVALID_TOKEN_ID, - device=logits.device, - dtype=torch.long) + if sampling_metadata.all_greedy: + target_token_ids = target_probs.argmax(dim=-1).view(-1) + target_token_ids = target_token_ids.split(sample_lens) + target_token_ids = pad_sequence(target_token_ids, + batch_first=True, + padding_value=INVALID_TOKEN_ID) - # TODO: Vectorize the following loop for better performance. - start_loc = 0 - for i in range(batch_size): - num_spec_tokens = len(spec_token_ids[i]) - draft_token_ids[i, :num_spec_tokens] = torch.tensor( - spec_token_ids[i], device="cpu", dtype=torch.long) - end_loc = start_loc + num_spec_tokens + 1 - # Assume greedy sampling. - target_token_ids[i, :num_spec_tokens + 1] = torch.argmax( - logits[start_loc:end_loc], dim=-1) - start_loc = end_loc - - vocab_size = logits.size(-1) - # NOTE: CPU <-> GPU synchronization happens here. - draft_token_ids = draft_token_ids.to(logits.device) - draft_probs = _create_greedy_token_probs(draft_token_ids, vocab_size, - logits.device) - target_probs = _create_greedy_token_probs(target_token_ids, vocab_size, - logits.device) - uniform_samples = torch.zeros(batch_size, - max_spec_len + 1, - device=logits.device) + vocab_size = target_probs.size(-1) + # NOTE: CPU <-> GPU synchronization happens here. + draft_token_ids_tensor = draft_token_ids_tensor.to( + target_probs.device) + draft_probs = _create_greedy_token_probs(draft_token_ids_tensor, + vocab_size, + target_probs.device) + target_probs = _create_greedy_token_probs(target_token_ids, + vocab_size, + target_probs.device) + uniform_samples = torch.zeros(draft_token_ids_tensor.size(0), + draft_token_ids_tensor.size(1) + 1, + device=target_probs.device) + else: + raise NotImplementedError( + "Currently, only greedy sampling is supported by " + "rejection sampler.") sampled_token_ids, _, _ = fs.chain_speculative_sampling( draft_probs, - draft_token_ids, + draft_token_ids_tensor, uniform_samples, target_probs, ) @@ -117,35 +119,35 @@ class RejectionSampler(nn.Module): # TODO: The following method can be optimized for better performance. def forward_native( self, - logits: torch.Tensor, + draft_token_ids: List[List[int]], + target_probs: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> SamplerOutput: - assert sampling_metadata.spec_token_ids is not None - spec_lens = [len(x) for x in sampling_metadata.spec_token_ids] - # Add 1 to include the 'bonus' token. - sample_lens = [x + 1 for x in spec_lens] - - output_token_ids = logits.argmax(dim=-1).view(-1) - output_token_ids = output_token_ids.split(sample_lens) - output_token_ids = pad_sequence(output_token_ids, - batch_first=True, - padding_value=INVALID_TOKEN_ID) - - # Convert spec token IDs to a tensor, split by sample_lens, then pad. - spec_token_ids = [ - torch.tensor(x, - dtype=output_token_ids.dtype, - device=output_token_ids.device) - for x in sampling_metadata.spec_token_ids + sample_lens = [len(x) + 1 for x in draft_token_ids] + # Convert draft token IDs to a tensor, split by sample_lens, then pad. + draft_token_ids = [ + torch.tensor(x, dtype=int, device='cpu') for x in draft_token_ids ] - spec_token_ids = pad_sequence(spec_token_ids, - batch_first=True, - padding_value=INVALID_TOKEN_ID) - - # Produce a mask that remains 1 (True) until the first - # mismatch (cumprod turns 0 after a mismatch). - accept_mask = (output_token_ids[:, :-1] == spec_token_ids).cumprod( - dim=1) + draft_token_ids_tensor = pad_sequence(draft_token_ids, + batch_first=True, + padding_value=INVALID_TOKEN_ID) + draft_token_ids_tensor = draft_token_ids_tensor.to(target_probs.device) + # Add 1 to include the 'bonus' token. + if sampling_metadata.all_greedy: + output_token_ids = target_probs.argmax(dim=-1).view(-1) + output_token_ids = output_token_ids.split(sample_lens) + output_token_ids = pad_sequence(output_token_ids, + batch_first=True, + padding_value=INVALID_TOKEN_ID) + # Produce a mask that remains 1 (True) until the first + # mismatch (cumprod turns 0 after a mismatch). + accept_mask = ( + output_token_ids[:, :-1] == draft_token_ids_tensor).cumprod( + dim=1) + else: + raise NotImplementedError( + "Currently, only greedy sampling is supported by " + "rejection sampler.") # Identify valid positions (non-padding). valid_mask = output_token_ids != INVALID_TOKEN_ID # Generate mask with bonus token. diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 47ec26d420249..b0eb533ae2e58 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -9,7 +9,6 @@ from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.ops.penalties import (apply_all_penalties, apply_min_token_penalties) from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler -from vllm.v1.sample.rejection_sampler import RejectionSampler _SAMPLING_EPS = 1e-5 @@ -19,22 +18,12 @@ class Sampler(nn.Module): def __init__(self): super().__init__() self.topk_topp_sampler = TopKTopPSampler() - self.rejection_sampler = RejectionSampler() def forward( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> SamplerOutput: - if sampling_metadata.spec_token_ids: - if sampling_metadata.max_num_logprobs: - raise NotImplementedError( - "Rejection sampling does not support logprobs.") - return self.rejection_sampler( - logits, - sampling_metadata, - ) - # NOTE(woosuk): Use the original logits (before any penalties or # temperature scaling) for the top-k logprobs. # This is different from the V0 sampler, which uses the logits that @@ -127,6 +116,14 @@ class Sampler(nn.Module): ) return sampled + def compute_probs(self, logits: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + if sampling_metadata.all_greedy: + return logits + # Apply temperature. This is an in-place op changing logits. + logits = self.apply_temperature(logits, sampling_metadata.temperature) + return logits.softmax(dim=-1, dtype=torch.float32) + def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor: return logits.log_softmax(dim=-1, dtype=torch.float32) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index d9fc53490c076..e4e6b88245d0d 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -490,23 +490,12 @@ class InputBatch: presence_penalties=self.presence_penalties[:num_reqs], repetition_penalties=self.repetition_penalties[:num_reqs], output_token_ids=cast(List[List[int]], self.req_output_token_ids), - spec_token_ids=None, min_tokens=self.min_tokens, no_penalties=self.no_penalties, logit_bias=self.logit_bias[:num_reqs], allowed_token_ids_mask=allowed_token_ids_mask, ) - def get_sampling_metadata( - self, - req_id_to_spec_token_ids: Dict[str, List[int]], - ) -> SamplingMetadata: - # Set the new spec token ids in the cached sampling metadata. - self.sampling_metadata.spec_token_ids = [ - req_id_to_spec_token_ids.get(req_id, []) for req_id in self.req_ids - ] if req_id_to_spec_token_ids else None - return self.sampling_metadata - def _make_prompt_token_ids_tensor(self) -> torch.Tensor: max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max() prompt_token_ids_cpu_tensor = torch.empty( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 1fbce3098a346..4d0ae9a205a15 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -32,7 +32,7 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID +from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID, RejectionSampler from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch @@ -122,7 +122,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.use_spec_decode = False if self.speculative_config: self.use_spec_decode = True - + self.rejection_sampler = RejectionSampler() # TODO: find a better way to check if we are using ngram. assert self.speculative_config.ngram_prompt_lookup_min, \ "Currently, only ngram spec decode is supported in V1." @@ -951,12 +951,24 @@ class GPUModelRunner(LoRAModelRunnerMixin): logits = self.model.compute_logits(sample_hidden_states, None) # Sample the next token and get logprobs if needed. - sampling_metadata = self.input_batch.get_sampling_metadata( - scheduler_output.scheduled_spec_decode_tokens) - sampler_output = self.model.sample( - logits=logits, - sampling_metadata=sampling_metadata, - ) + sampling_metadata = self.input_batch.sampling_metadata + if not self.use_spec_decode: + sampler_output = self.model.sample( + logits=logits, + sampling_metadata=sampling_metadata, + ) + else: + target_probs = self.model.sampler.compute_probs( + logits, sampling_metadata) + scheduled_request_ids = scheduler_output.num_scheduled_tokens.keys( + ) + draft_token_ids = [ + scheduler_output.scheduled_spec_decode_tokens.get(req_id, []) + for req_id in scheduled_request_ids + ] + sampler_output = self.rejection_sampler(draft_token_ids, + target_probs, + sampling_metadata) # TODO(woosuk): The following loop can be slow since it iterates over # the requests one by one. Optimize. @@ -1293,7 +1305,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): temperature=dummy_tensors(0.5), all_greedy=False, all_random=False, - spec_token_ids=None, top_p=dummy_tensors(0.9), top_k=dummy_tensors(logits.size(1) - 1), min_p=None,