diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index ae7280a147063..802fe566ae2ed 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -721,15 +721,29 @@ class Scheduler(SchedulerInterface): # the outer lists can be of length > 1. new_logprobs = logprobs.slice(req_index, req_index + 1) + jump_tokens = [] if new_token_ids and request.use_structured_output: - # NOTE: structured_output_request - # should not be None if use_structured_output, we have - # check above, so safe to ignore type warning - request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] + assert request.structured_output_request is not None + assert request.structured_output_request.grammar is not None + request.structured_output_request.grammar.accept_tokens( req_id, new_token_ids) + if not stopped: + jump_tokens = request.structured_output_request.grammar.jump_forward( + req_id) + for token_id in jump_tokens: + request.append_output_token_ids(token_id) + new_token_ids.append(token_id) + stopped = check_stop(request, self.max_model_len) + if stopped: + break + if jump_tokens: + print(f"jump_tokens: {jump_tokens}") + # Add newly generated spec token ids to the request. - if spec_token_ids is not None: + if jump_tokens: + request.spec_token_ids.clear() + elif spec_token_ids is not None: if request.use_structured_output: metadata = request.structured_output_request assert metadata is not None and metadata.grammar is not None diff --git a/vllm/v1/structured_output/backend_xgrammar.py b/vllm/v1/structured_output/backend_xgrammar.py index c82a3cab2fa36..45c87bb3c96d3 100644 --- a/vllm/v1/structured_output/backend_xgrammar.py +++ b/vllm/v1/structured_output/backend_xgrammar.py @@ -170,6 +170,26 @@ class XgrammarGrammar(StructuredOutputGrammar): self.num_processed_tokens += 1 return True + def jump_forward( + self, + request_id: str, + ) -> list[int]: + bitmask = xgr.allocate_token_bitmask(1, self.vocab_size) + jump_forward_tokens: list[int] = [] + while not self.is_terminated(): + self.fill_bitmask(bitmask, 0) + is_single, unique_token_id = xgr.testing._is_single_token_bitmask( + bitmask, + vocab_size=self.vocab_size, + index=0, + ) + if not is_single: + break + + self.accept_tokens(request_id, [unique_token_id]) + jump_forward_tokens.append(unique_token_id) + return jump_forward_tokens + def validate_tokens(self, tokens: list[int]) -> list[int]: """Checks if the list of tokens are accepted by the FSM in sequence. Will not advance the FSM.