From 46ce1356f7108f27af9ebb153ddbe02d1a3fa97d Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 24 Feb 2023 11:44:40 +0000 Subject: [PATCH] Add max_num_steps to SamplingParams --- cacheflow/sampling_params.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cacheflow/sampling_params.py b/cacheflow/sampling_params.py index 143f0ab2a4dd1..241d248a0b602 100644 --- a/cacheflow/sampling_params.py +++ b/cacheflow/sampling_params.py @@ -10,6 +10,7 @@ class SamplingParams: top_p: float = 1.0, use_beam_search: bool = False, stop_token_ids: Set[int] = [], + max_num_steps: int = 16, # From OpenAI API. max_context_len: Optional[int] = None, ) -> None: assert n >= 1 @@ -23,6 +24,7 @@ class SamplingParams: # Zero temperature means greedy decoding. assert n == 1 assert top_p == 1.0 + assert max_num_steps >= 1 assert max_context_len is None or max_context_len >= 0 self.n = n @@ -30,4 +32,5 @@ class SamplingParams: self.top_p = top_p self.use_beam_search = use_beam_search self.stop_token_ids = stop_token_ids + self.max_num_steps = max_num_steps self.max_context_len = max_context_len