diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index c701ab1d35a5..07b422814e13 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -149,31 +149,37 @@ class StructuredOutputManager: # NOTE: This outer loop can likely be parallelized to improve # performance of bitmask generation for large batches. for req_id, _ in ordered_seq: - request = requests[req_id].structured_output_request - if TYPE_CHECKING: - assert request is not None - assert request.grammar is not None + request = requests[req_id] + structured_output_request = request.structured_output_request - apply_bitmask = ( - request.reasoning_ended if self.reasoner is not None else True - ) # noqa: E501 + if TYPE_CHECKING: + assert structured_output_request is not None + assert structured_output_request.grammar is not None + apply_bitmask: bool = True + if self.reasoner is not None: + if structured_output_request.reasoning_ended is None: + structured_output_request.reasoning_ended = \ + self.reasoner.is_reasoning_end(request.prompt_token_ids) + apply_bitmask = structured_output_request.reasoning_ended state_advancements = 0 req_tokens = scheduled_spec_decode_tokens.get(req_id, []) + [None] for i, token in enumerate(req_tokens): - if apply_bitmask and not request.grammar.is_terminated(): - request.grammar.fill_bitmask(bitmask_tensor, - cumulative_index) + if apply_bitmask and not \ + structured_output_request.grammar.is_terminated(): + structured_output_request.grammar.fill_bitmask( + bitmask_tensor, cumulative_index) if token is not None: # In order to generate the correct bitmask for each # position in the speculative sequence, we advance # the FSM state for each speculative token and rollback # to restore the previous state when we are finished. - assert request.grammar.accept_tokens(req_id, [token]) + assert structured_output_request.grammar.accept_tokens( + req_id, [token]) state_advancements += 1 cumulative_index += 1 if state_advancements > 0: - request.grammar.rollback(state_advancements) + structured_output_request.grammar.rollback(state_advancements) if cumulative_index < bitmask_tensor.shape[0]: bitmask_tensor = bitmask_tensor[:cumulative_index] diff --git a/vllm/v1/structured_output/request.py b/vllm/v1/structured_output/request.py index c16320b9e74c..9a7e30d41aaa 100644 --- a/vllm/v1/structured_output/request.py +++ b/vllm/v1/structured_output/request.py @@ -20,7 +20,7 @@ class StructuredOutputRequest: sampling_params: SamplingParams _grammar: Optional[Union[Future[StructuredOutputGrammar], StructuredOutputGrammar]] = None - reasoning_ended: bool = False + reasoning_ended: Optional[bool] = None def _check_grammar_completion(self) -> bool: # NOTE: We have to lazy import to gate circular imports