From 9f88db35da012544a9cd9450c68b8df2e6509b92 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 10 May 2023 12:51:36 -0700 Subject: [PATCH] Support top-k sampling (#94) --- cacheflow/model_executor/layers/sampler.py | 53 ++++++++++++------ cacheflow/sampling_params.py | 63 +++++++++++++--------- simple_server.py | 5 +- 3 files changed, 78 insertions(+), 43 deletions(-) diff --git a/cacheflow/model_executor/layers/sampler.py b/cacheflow/model_executor/layers/sampler.py index b0bc83421d70..7a7e53fbd533 100644 --- a/cacheflow/model_executor/layers/sampler.py +++ b/cacheflow/model_executor/layers/sampler.py @@ -46,12 +46,13 @@ class Sampler(nn.Module): # Compute the log probabilities (before applying top-p). logprobs = torch.log(probs) - # Apply top-p truncation. - top_ps = _get_top_ps(input_metadata) - assert len(top_ps) == probs.shape[0] - if any(p < 1.0 for p in top_ps): + # Apply top-p and top-k truncation. + top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size) + assert len(top_ps) == len(top_ks) == probs.shape[0] + if any(p < 1.0 for p in top_ps) or any(k != -1 for k in top_ks): p = torch.tensor(top_ps, dtype=probs.dtype, device=probs.device) - probs = _apply_top_p(probs, p) + k = torch.tensor(top_ks, dtype=torch.int, device=probs.device) + probs = _apply_top_p_top_k(probs, p, k) # Sample the next tokens. return _sample(probs, logprobs, input_metadata) @@ -94,31 +95,51 @@ def _get_temperatures( return temperatures -def _get_top_ps( +def _get_top_p_top_k( input_metadata: InputMetadata, -) -> List[float]: + vocab_size: int, +) -> Tuple[List[float], List[int]]: top_ps: List[float] = [] + top_ks: List[int] = [] for i, seq_group in enumerate(input_metadata.seq_groups): seq_ids, sampling_params = seq_group + top_p = sampling_params.top_p + # k should not be greater than the vocab size. + top_k = min(sampling_params.top_k, vocab_size) + # k=-1 means no truncation. + top_k = vocab_size if top_k == -1 else top_k if i < input_metadata.num_prompts: # A prompt input. - top_ps.append(sampling_params.top_p) + top_ps.append(top_p) + top_ks.append(top_k) else: # A generation token. - top_ps += [sampling_params.top_p] * len(seq_ids) - return top_ps + top_ps += [top_p] * len(seq_ids) + top_ks += [top_k] * len(seq_ids) + return top_ps, top_ks -def _apply_top_p( +def _apply_top_p_top_k( probs: torch.Tensor, p: torch.Tensor, + k: torch.Tensor, ) -> torch.Tensor: # TODO(woosuk): Optimize. probs_sort, probs_idx = probs.sort(dim=-1, descending=True) + + # Apply top-p. probs_sum = torch.cumsum(probs_sort, dim=-1) - mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1) - probs_sort[mask] = 0.0 - probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + top_p_mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1) + probs_sort[top_p_mask] = 0.0 + + # Apply top-k. + # Create a mask for the top-k elements. + top_k_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device) + top_k_mask = top_k_mask.expand(probs_idx.shape[0], -1) + top_k_mask = top_k_mask >= k.unsqueeze(dim=1) + probs_sort[top_k_mask] = 0.0 + + # Re-sort the probabilities. probs = torch.gather( probs_sort, dim=-1, index=torch.argsort(probs_idx, dim=-1)) return probs @@ -160,7 +181,7 @@ def _sample_from_prompt( next_token_id = torch.argmax(prob) next_token_ids = [next_token_id.item()] else: - # Neucleus sampling. + # Random sampling. # Sample n tokens for the prompt. n = sampling_params.n next_token_ids = torch.multinomial( @@ -218,7 +239,7 @@ def _sample_from_generation_tokens( next_token_ids = [next_token_id.item()] parent_seq_ids = seq_ids else: - # Neucleus sampling. + # Random sampling. # Sample 1 token for each sequence in the group. next_token_ids = torch.multinomial( probs, num_samples=1, replacement=True) diff --git a/cacheflow/sampling_params.py b/cacheflow/sampling_params.py index 8e64126ae46a..589874ebcf66 100644 --- a/cacheflow/sampling_params.py +++ b/cacheflow/sampling_params.py @@ -8,69 +8,82 @@ class SamplingParams: n: int, temperature: float, top_p: float, + top_k: int, use_beam_search: bool, stop_token_ids: Set[int], max_num_steps: int, num_logprobs: int, ) -> None: if n < 1: - raise ValueError(f'n must be at least 1, got {n}.') + raise ValueError(f"n must be at least 1, got {n}.") if temperature < 0.0: raise ValueError( - f'temperature must be non-negative, got {temperature}.') + f"temperature must be non-negative, got {temperature}.") if not 0.0 < top_p <= 1.0: - raise ValueError(f'top_p must be in (0, 1], got {top_p}.') + raise ValueError(f"top_p must be in (0, 1], got {top_p}.") + if top_k < -1 or top_k == 0: + raise ValueError(f"top_k must be -1 (disable), or at least 1, " + f"got {top_k}.") if max_num_steps < 1: raise ValueError( - f'max_num_steps must be at least 1, got {max_num_steps}.') + f"max_num_steps must be at least 1, got {max_num_steps}.") if num_logprobs < 0: raise ValueError( - f'num_logprobs must be non-negative, got {num_logprobs}.') + f"num_logprobs must be non-negative, got {num_logprobs}.") if use_beam_search: if n == 1: raise ValueError( - 'n must be greater than 1 when using beam search.') + "n must be greater than 1 when using beam search.") if temperature > 0.0: raise ValueError( - 'temperature must be 0 when using beam search.') + "temperature must be 0 when using beam search.") if top_p < 1.0: raise ValueError( - 'top_p must be 1 when using beam search.') + "top_p must be 1 when using beam search.") + if top_k != -1: + raise ValueError( + "top_k must be -1 when using beam search.") elif temperature == 0.0: # Zero temperature means greedy sampling. if n > 1: raise ValueError( - 'n must be 1 when using greedy sampling.') + "n must be 1 when using greedy sampling.") if top_p < 1.0: raise ValueError( - 'top_p must be 1 when using greedy sampling.') + "top_p must be 1 when using greedy sampling.") + if top_k != -1: + raise ValueError( + "top_k must be -1 when using greedy sampling.") self.n = n self.temperature = temperature self.top_p = top_p + self.top_k = top_k self.use_beam_search = use_beam_search self.stop_token_ids = stop_token_ids self.max_num_steps = max_num_steps self.num_logprobs = num_logprobs def __repr__(self) -> str: - return (f'SamplingParams(n={self.n}, ' - f'temperature={self.temperature}, ' - f'top_p={self.top_p}, ' - f'use_beam_search={self.use_beam_search}, ' - f'stop_token_ids={self.stop_token_ids}, ' - f'max_num_steps={self.max_num_steps}, ' - f'num_logprobs={self.num_logprobs}') + return (f"SamplingParams(n={self.n}, " + f"temperature={self.temperature}, " + f"top_p={self.top_p}, " + f"top_k={self.top_k}," + f"use_beam_search={self.use_beam_search}, " + f"stop_token_ids={self.stop_token_ids}, " + f"max_num_steps={self.max_num_steps}, " + f"num_logprobs={self.num_logprobs}") @classmethod - def from_dict(cls, d: Dict) -> 'SamplingParams': + def from_dict(cls, d: Dict) -> "SamplingParams": return cls( - n=d.get('n', 1), - temperature=d.get('temperature', 1.0), - top_p=d.get('top_p', 1.0), - use_beam_search=d.get('use_beam_search', False), - stop_token_ids=set(d.get('stop_token_ids', set())), - max_num_steps=d.get('max_num_steps', 16), - num_logprobs=d.get('num_logprobs', 0), + n=d.get("n", 1), + temperature=d.get("temperature", 1.0), + top_p=d.get("top_p", 1.0), + top_k=d.get("top_k", -1), + use_beam_search=d.get("use_beam_search", False), + stop_token_ids=set(d.get("stop_token_ids", set())), + max_num_steps=d.get("max_num_steps", 16), + num_logprobs=d.get("num_logprobs", 0), ) diff --git a/simple_server.py b/simple_server.py index 0fff63fa6559..9644731cf0c3 100644 --- a/simple_server.py +++ b/simple_server.py @@ -11,8 +11,9 @@ def main(args: argparse.Namespace): # Test the following inputs. test_inputs = [ ("A robot may not injure a human being", {}), # Use default parameters. - ("What is the meaning of life?", {"n": 3, "temperature": 0.8, "top_p": 0.99}), - ("It is only with the heart that one can see rightly", {"n": 4, "use_beam_search": True, "temperature": 0.0}), + ("To be or not to be,", {"temperature": 0.8, "top_k": 5}), + ("What is the meaning of life?", {"n": 2, "temperature": 0.8, "top_p": 0.95}), + ("It is only with the heart that one can see rightly", {"n": 3, "use_beam_search": True, "temperature": 0.0}), ] while True: if test_inputs: