From 4cea74c73b2e0981aadfefb3a00e8186d065c897 Mon Sep 17 00:00:00 2001 From: ljss <31004720+beginlner@users.noreply.github.com> Date: Thu, 23 Nov 2023 04:51:09 +0800 Subject: [PATCH] Set top_p=0 and top_k=-1 in greedy sampling (#1748) --- vllm/sampling_params.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 68ba762f899d..f9eca1a9fc43 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -147,6 +147,8 @@ class SamplingParams: self._verify_non_beam_search() if self.temperature < _SAMPLING_EPS: # Zero temperature means greedy sampling. + self.top_p = 1.0 + self.top_k = -1 self._verify_greedy_sampling() def _verify_args(self) -> None: @@ -214,10 +216,6 @@ class SamplingParams: if self.best_of > 1: raise ValueError("best_of must be 1 when using greedy sampling." f"Got {self.best_of}.") - if self.top_p < 1.0 - _SAMPLING_EPS: - raise ValueError("top_p must be 1 when using greedy sampling.") - if self.top_k != -1: - raise ValueError("top_k must be -1 when using greedy sampling.") @cached_property def sampling_type(self) -> SamplingType: