diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index 192ddefe102d..2486c26c6071 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -124,8 +124,9 @@ def _construct_expected_sampling_metadata( if req.sampling_params.allowed_token_ids: allowed_token_ids_mask[index_in_input_batch][ req.sampling_params.allowed_token_ids] = True - bad_words_token_ids[ - index_in_input_batch] = req.sampling_params.bad_words_token_ids + if req.sampling_params.bad_words_token_ids: + bad_words_token_ids[ + index_in_input_batch] = req.sampling_params.bad_words_token_ids return SamplingMetadata( temperature=torch.tensor(temperature, dtype=torch.float, diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index b0a5777cc8d5..9b474a37b96b 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -235,7 +235,7 @@ class SamplingParams( # Fields used for bad words bad_words: Optional[list[str]] = None - _bad_words_token_ids: list[list[int]] = msgspec.field(default_factory=list) + _bad_words_token_ids: Optional[list[list[int]]] = None @staticmethod def from_optional( @@ -464,8 +464,9 @@ class SamplingParams( self.stop_token_ids = list(eos_ids) def update_from_tokenizer(self, tokenizer: AnyTokenizer) -> None: - if self.bad_words is None: + if not self.bad_words: return + self._bad_words_token_ids = [] for bad_word in self.bad_words: # To prohibit words both at the beginning # and in the middle of text @@ -516,7 +517,7 @@ class SamplingParams( return self._all_stop_token_ids @property - def bad_words_token_ids(self) -> list[list[int]]: + def bad_words_token_ids(self) -> Optional[list[list[int]]]: # For internal use only. Backward compatibility not guaranteed return self._bad_words_token_ids diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 9707cb5774cd..55d5429a8935 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -324,8 +324,9 @@ class InputBatch: self.allowed_token_ids_mask_cpu_tensor[req_index][ sampling_params.allowed_token_ids] = False - self.bad_words_token_ids[ - req_index] = sampling_params.bad_words_token_ids + if sampling_params.bad_words_token_ids: + self.bad_words_token_ids[ + req_index] = sampling_params.bad_words_token_ids # Add request lora ID if request.lora_request: