diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 2fe177ea4e126..02b2aa3ea6778 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -264,9 +264,12 @@ class InputBatch: self.top_p_cpu[req_index] = sampling_params.top_p if sampling_params.top_p < 1: self.top_p_reqs.add(req_id) - self.top_k_cpu[req_index] = sampling_params.top_k - if sampling_params.top_k > 0: + top_k = sampling_params.top_k + if 0 < top_k < self.vocab_size: self.top_k_reqs.add(req_id) + else: + top_k = self.vocab_size + self.top_k_cpu[req_index] = top_k self.min_p_cpu[req_index] = sampling_params.min_p self.frequency_penalties_cpu[ req_index] = sampling_params.frequency_penalty