From 3209b4903376fd723858b256dcfabb8420a0cc64 Mon Sep 17 00:00:00 2001 From: Nikola Borisov Date: Tue, 23 Jan 2024 22:38:55 -0800 Subject: [PATCH] [Bugfix] fix crash if max_tokens=None (#2570) --- tests/test_regression.py | 13 +++++++++++++ tests/test_sampling_params.py | 13 +++++++++++++ vllm/sampling_params.py | 4 ++-- 3 files changed, 28 insertions(+), 2 deletions(-) create mode 100644 tests/test_sampling_params.py diff --git a/tests/test_regression.py b/tests/test_regression.py index 3bfb2b43f2644..c48e474bd889f 100644 --- a/tests/test_regression.py +++ b/tests/test_regression.py @@ -22,6 +22,19 @@ def test_duplicated_ignored_sequence_group(): assert len(prompts) == len(outputs) +def test_max_tokens_none(): + sampling_params = SamplingParams(temperature=0.01, + top_p=0.1, + max_tokens=None) + llm = LLM(model="facebook/opt-125m", + max_num_batched_tokens=4096, + tensor_parallel_size=1) + prompts = ["Just say hello!"] + outputs = llm.generate(prompts, sampling_params=sampling_params) + + assert len(prompts) == len(outputs) + + if __name__ == "__main__": import pytest pytest.main([__file__]) diff --git a/tests/test_sampling_params.py b/tests/test_sampling_params.py new file mode 100644 index 0000000000000..01cbe0c997f29 --- /dev/null +++ b/tests/test_sampling_params.py @@ -0,0 +1,13 @@ +"""Tests for the SamplingParams class. +""" +from vllm import SamplingParams + + +def test_max_tokens_none(): + """max_tokens=None should be allowed""" + SamplingParams(temperature=0.01, top_p=0.1, max_tokens=None) + + +if __name__ == "__main__": + import pytest + pytest.main([__file__]) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index b5710eef4ad50..bb7d0002c910c 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -108,7 +108,7 @@ class SamplingParams: stop_token_ids: Optional[List[int]] = None, include_stop_str_in_output: bool = False, ignore_eos: bool = False, - max_tokens: int = 16, + max_tokens: Optional[int] = 16, logprobs: Optional[int] = None, prompt_logprobs: Optional[int] = None, skip_special_tokens: bool = True, @@ -183,7 +183,7 @@ class SamplingParams: if not 0.0 <= self.min_p <= 1.0: raise ValueError("min_p must be in [0, 1], got " f"{self.min_p}.") - if self.max_tokens < 1: + if self.max_tokens is not None and self.max_tokens < 1: raise ValueError( f"max_tokens must be at least 1, got {self.max_tokens}.") if self.logprobs is not None and self.logprobs < 0: