diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index efd51c79c37c..3405aaebf6a8 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -260,6 +260,7 @@ async def async_request_openai_completions( if request_func_input.model_name else request_func_input.model, "prompt": request_func_input.prompt, "temperature": 0.0, + "repetition_penalty": 1.0, "max_tokens": request_func_input.output_len, "logprobs": request_func_input.logprobs, "stream": True, diff --git a/benchmarks/benchmark_serving_structured_output.py b/benchmarks/benchmark_serving_structured_output.py index 74ee00ec8930..7c40e39ac810 100644 --- a/benchmarks/benchmark_serving_structured_output.py +++ b/benchmarks/benchmark_serving_structured_output.py @@ -123,6 +123,8 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, copy.deepcopy(schema) for _ in range(args.num_prompts) ] for i in range(len(json_schemas)): + if "properties" not in json_schemas[i]: + json_schemas[i]["properties"] = {} json_schemas[i]["properties"][ f"__optional_field_{uuid.uuid4()}"] = { "type": @@ -134,7 +136,7 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, json_schemas = [schema] * args.num_prompts def gen_prompt(index: int): - return f"Generate an example of a user profile given the following schema: {json.dumps(get_schema(index))}" # noqa: E501 + return f"Generate an example of a brief user profile given the following schema: {json.dumps(get_schema(index))}" # noqa: E501 def get_schema(index: int): return json_schemas[index % len(json_schemas)] @@ -231,7 +233,8 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, idx -= len_dataset schema = dataset["schema"][idx] prompt = tokenizer.apply_chat_template(dataset["prompt"][idx], - tokenize=False) + tokenize=False, + add_generation_prompt=True) input_len = len(tokenizer(prompt).input_ids) completion = dataset["completion"][idx] @@ -849,7 +852,7 @@ if __name__ == "__main__": 'json', 'json-unique', 'grammar', 'regex', 'choice', 'xgrammar_bench' ]) - parser.add_argument("--json_schema_path", + parser.add_argument("--json-schema-path", type=str, default=None, help="Path to json schema.") diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index 29ec6088ee8b..d25699591145 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -16,13 +16,31 @@ from vllm.outputs import RequestOutput from vllm.platforms import current_platform from vllm.sampling_params import GuidedDecodingParams, SamplingParams +NGRAM_SPEC_CONFIG = { + "model": "[ngram]", + "num_speculative_tokens": 5, + "prompt_lookup_max": 5, + "prompt_lookup_min": 1, +} + +EAGLE_SPEC_CONFIG = { + "method": "eagle", + "model": "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", + "num_speculative_tokens": 5, +} + PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [ - ("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto"), - ("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto"), - ("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral"), - ("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto"), + ("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto", None), + ("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None), + ("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None), + ("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None), #FIXME: This test is flaky on CI thus disabled #("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"), + ("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", + NGRAM_SPEC_CONFIG), + ("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", NGRAM_SPEC_CONFIG), + ("meta-llama/Meta-Llama-3.1-8B-Instruct", "xgrammar", "auto", + EAGLE_SPEC_CONFIG) ] PARAMS_MODELS_TOKENIZER_MODE = [ @@ -45,8 +63,9 @@ class CarDescription(BaseModel): @pytest.mark.skip_global_cleanup -@pytest.mark.parametrize("model_name, guided_decoding_backend, tokenizer_mode", - PARAMS_MODELS_BACKENDS_TOKENIZER_MODE) +@pytest.mark.parametrize( + "model_name, guided_decoding_backend, tokenizer_mode, speculative_config", + PARAMS_MODELS_BACKENDS_TOKENIZER_MODE) def test_structured_output( monkeypatch: pytest.MonkeyPatch, sample_json_schema: dict[str, Any], @@ -58,6 +77,7 @@ def test_structured_output( guided_decoding_backend: str, tokenizer_mode: str, model_name: str, + speculative_config: dict[str, Any], ): monkeypatch.setenv("VLLM_USE_V1", "1") @@ -71,7 +91,8 @@ def test_structured_output( max_model_len=1024, guided_decoding_backend=guided_decoding_backend, guided_decoding_disable_any_whitespace=True, - tokenizer_mode=tokenizer_mode) + tokenizer_mode=tokenizer_mode, + speculative_config=speculative_config) # # Test 1: Generate JSON output based on a provided schema diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 21711c9292f9..7ebbb4954f51 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -441,7 +441,7 @@ class Scheduler(SchedulerInterface): grammar_bitmask = self.structured_output_manager.grammar_bitmask( self.requests, structured_output_request_ids, - len(self.running), + scheduled_spec_decode_tokens, ) # Construct the scheduler output. new_reqs_data = [ @@ -682,10 +682,6 @@ class Scheduler(SchedulerInterface): self.encoder_cache_manager.free_encoder_input( request, input_id) - # Add newly generated spec token ids to the request. - if spec_token_ids is not None: - request.spec_token_ids = spec_token_ids[req_index] - stopped = False new_logprobs = None new_token_ids = generated_token_ids @@ -717,6 +713,17 @@ class Scheduler(SchedulerInterface): request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] req_id, new_token_ids) + # Add newly generated spec token ids to the request. + if spec_token_ids is not None: + if request.use_structured_output: + metadata = request.structured_output_request + assert metadata is not None and metadata.grammar is not None + # Needs to happen after new_token_ids are accepted. + request.spec_token_ids = metadata.grammar.validate_tokens( + spec_token_ids[req_index]) + else: + request.spec_token_ids = spec_token_ids[req_index] + # Get prompt logprobs for this request. prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) if new_token_ids: diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 47ae4c4f03ee..3183edb7c94e 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -27,6 +27,7 @@ class StructuredOutputManager: def __init__(self, vllm_config: VllmConfig): self.backend: Optional[StructuredOutputBackend] = None self.vllm_config = vllm_config + self._grammar_bitmask: Optional[torch.Tensor] = None # The default max_workers if not specified is the number of CPUs * 5, @@ -80,7 +81,7 @@ class StructuredOutputManager: self, requests: dict[str, Request], structured_output_request_ids: dict[str, int], - batch_len: int, + scheduled_spec_decode_tokens: dict[str, list[int]], ) -> Optional[npt.NDArray[np.int32]]: # Prepare the structured output bitmask for this batch. if not structured_output_request_ids: @@ -88,20 +89,52 @@ class StructuredOutputManager: if self._grammar_bitmask is None: assert self.backend is not None - self._grammar_bitmask = self.backend.allocate_token_bitmask( - self.vllm_config.scheduler_config.max_num_seqs) + max_batch_size = self.vllm_config.scheduler_config.max_num_seqs + if self.vllm_config.speculative_config is not None: + max_num_spec_tokens = self.vllm_config.\ + speculative_config.num_speculative_tokens + else: + max_num_spec_tokens = 0 - # Fill the bitmask using the index of each request equal to its - # position in the batch. Resize the bitmask down to the size of - # the batch. - bitmask_tensor = self._grammar_bitmask - for req_id, batch_index in structured_output_request_ids.items(): + # Allocate a bitmask for each token needing to be checked: + # one for each speculative position, and one more for the + # bonus token / non-speculative token. + self._grammar_bitmask = \ + self.backend.allocate_token_bitmask( + max_batch_size * (1 + max_num_spec_tokens)) + + # Generate a batched bitmask for all structured output requests. + # When speculative decoding is enabled, we need to include multiple + # masks for each request, one for each possible bonus token position. + # These are stored inline in the tensor and unpacked by the gpu runner. + cumulative_index = 0 + ordered_seq = sorted(structured_output_request_ids.items(), + key=lambda x: x[1]) + # NOTE: This outer loop can likely be parallelized to improve + # performance of bitmask generation for large batches. + for req_id, _ in ordered_seq: request = requests[req_id].structured_output_request assert request is not None and request.grammar is not None - if not request.grammar.is_terminated(): - request.grammar.fill_bitmask(bitmask_tensor, batch_index) - if batch_len < self._grammar_bitmask.shape[0]: - bitmask_tensor = self._grammar_bitmask[:batch_len] + state_advancements = 0 + req_tokens = scheduled_spec_decode_tokens.get(req_id, []) + [None] + for i, token in enumerate(req_tokens): + if not request.grammar.is_terminated(): + request.grammar.fill_bitmask(self._grammar_bitmask, + cumulative_index) + if token is not None: + # In order to generate the correct bitmask for each + # position in the speculative sequence, we advance + # the FSM state for each speculative token and rollback + # to restore the previous state when we are finished. + assert request.grammar.accept_tokens(req_id, [token]) + state_advancements += 1 + cumulative_index += 1 + if state_advancements > 0: + request.grammar.rollback(state_advancements) + + bitmask_tensor = self._grammar_bitmask + if cumulative_index < self._grammar_bitmask.shape[0]: + bitmask_tensor = self._grammar_bitmask[:cumulative_index] # After finishing with the xgrammar operations, we convert to # np.ndarray, because that is much more efficient for serialization diff --git a/vllm/v1/structured_output/backend_guidance.py b/vllm/v1/structured_output/backend_guidance.py index 8fb3e56bcb95..0ab175e781e7 100644 --- a/vllm/v1/structured_output/backend_guidance.py +++ b/vllm/v1/structured_output/backend_guidance.py @@ -144,6 +144,27 @@ class GuidanceGrammar(StructuredOutputGrammar): return r + def validate_tokens(self, tokens: list[int]) -> list[int]: + """Checks if the list of tokens are accepted by the parser in sequence. + Will not advance the parser. + + Returns the prefix list of tokens that are accepted by the parser. + """ + if len(tokens) == 0: + return [] + if self.ll_matcher.is_stopped(): + return [] + + num_tokens = self.ll_matcher.validate_tokens(tokens) + + self.check_error() + + return tokens[:num_tokens] + + def rollback(self, num_tokens: int) -> None: + self.ll_matcher.rollback(num_tokens) + self.check_error() + def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None: # this will automatically return [EOS] mask if the matcher is stopped # or otherwise in an error state diff --git a/vllm/v1/structured_output/backend_types.py b/vllm/v1/structured_output/backend_types.py index 6330bcbf20c3..33ca9f8cf484 100644 --- a/vllm/v1/structured_output/backend_types.py +++ b/vllm/v1/structured_output/backend_types.py @@ -35,6 +35,30 @@ class StructuredOutputGrammar(ABC): bool: True if the tokens are accepted, False otherwise. """ + @abstractmethod + def validate_tokens(self, tokens: list[int]) -> list[int]: + """ + Validates the provided tokens against the grammar. + Will not advance the FSM. + + Args: + tokens (list[int]): A list of token IDs to validate. + + Returns: + list[int]: A list of accepted token IDs. Will be a prefix + of the input tokens, and empty if none are accepted. + """ + + @abstractmethod + def rollback(self, num_tokens: int) -> None: + """ + Rolls back the state of the grammar by a specified number of tokens. + Will also revert counters for the number of processed tokens. + + Args: + num_tokens (int): The number of tokens to roll back. + """ + @abstractmethod def fill_bitmask(self, bitmask: torch.Tensor, batch_index: int) -> None: """ diff --git a/vllm/v1/structured_output/backend_xgrammar.py b/vllm/v1/structured_output/backend_xgrammar.py index 50a7d1683acd..c82a3cab2fa3 100644 --- a/vllm/v1/structured_output/backend_xgrammar.py +++ b/vllm/v1/structured_output/backend_xgrammar.py @@ -40,6 +40,11 @@ class XgrammarBackend(StructuredOutputBackend): self.disable_any_whitespace = \ vllm_config.decoding_config.disable_any_whitespace + self.num_speculative_tokens = 0 + if self.vllm_config.speculative_config is not None: + self.num_speculative_tokens = \ + self.vllm_config.speculative_config.num_speculative_tokens + tokenizer = tokenizer_group.get_lora_tokenizer(None) self.vocab_size = vllm_config.model_config.get_vocab_size() if isinstance(tokenizer, MistralTokenizer): @@ -118,7 +123,10 @@ class XgrammarBackend(StructuredOutputBackend): f"grammar is not of valid supported types. ({request_type!s})") return XgrammarGrammar( - matcher=xgr.GrammarMatcher(ctx), + matcher=xgr.GrammarMatcher( + ctx, + max_rollback_tokens=self.num_speculative_tokens, + ), vocab_size=self.vocab_size, ctx=ctx, ) @@ -136,7 +144,6 @@ class XgrammarGrammar(StructuredOutputGrammar): # supporting different backends, in the future. # For now, just xgrammar. # - # TODO: support max_rollback_tokens # https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string # for jump-forward decoding @@ -163,6 +170,27 @@ class XgrammarGrammar(StructuredOutputGrammar): self.num_processed_tokens += 1 return True + def validate_tokens(self, tokens: list[int]) -> list[int]: + """Checks if the list of tokens are accepted by the FSM in sequence. + Will not advance the FSM. + + Returns the prefix list of tokens that are accepted by the FSM. + """ + accepted_tokens = [] + for token in tokens: + if self.matcher.accept_token(token): + accepted_tokens.append(token) + else: + break + if len(accepted_tokens) > 0: + # Rollback the FSM to the initial state + self.matcher.rollback(len(accepted_tokens)) + return accepted_tokens + + def rollback(self, num_tokens: int) -> None: + self.matcher.rollback(num_tokens) + self.num_processed_tokens -= num_tokens + def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None: self.matcher.fill_next_token_bitmask(bitmask, idx) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 41de305a016e..97d8c91b4659 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -957,46 +957,58 @@ class GPUModelRunner(LoRAModelRunnerMixin): scheduler_output: "SchedulerOutput", logits: torch.Tensor, ): - # Serialization of np.ndarray is much more efficient than a tensor, - # so we receive it in that format. grammar_bitmask = scheduler_output.grammar_bitmask if grammar_bitmask is None: return - # We receive the structured output bitmask from the scheduler, but the - # indices of the requests in the batch may not match the indices of - # the bitmask since the scheduler doesn't know how the gpu runner is - # ordering the requests in the batch. We need to sort the bitmask to - # match the order of the requests used here. + # We receive the structured output bitmask from the scheduler, + # compacted to contain bitmasks only for structured output requests. + # The order of the requests in the bitmask is not guaranteed to be the + # same as the order of the requests in the gpu runner's batch. We need + # to sort the bitmask to match the order of the requests used here. + + # Get the batch indices of the structured output requests. + # Keep track of the number of speculative tokens scheduled for every + # request in the batch, as the logit indices are offset by this amount. struct_out_req_batch_indices: dict[str, int] = {} - indices_match = True - for req_id in self.input_batch.req_ids: - mask_index = scheduler_output.structured_output_request_ids.get( - req_id) - if mask_index is None: - # not a structured output request - continue - batch_index = self.input_batch.req_id_to_index[req_id] - if batch_index != mask_index: - indices_match = False - struct_out_req_batch_indices[req_id] = batch_index + cumulative_offset = 0 + seq = sorted(self.input_batch.req_id_to_index.items(), + key=lambda x: x[1]) + for req_id, batch_index in seq: + logit_index = batch_index + cumulative_offset + cumulative_offset += len( + scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) + if req_id in scheduler_output.structured_output_request_ids: + struct_out_req_batch_indices[req_id] = logit_index - if not indices_match: - # Sort the bitmask to match the order of the requests - sorted_bitmask = np.zeros_like(grammar_bitmask) - for req_id, batch_index in struct_out_req_batch_indices.items(): - orig_index = scheduler_output.structured_output_request_ids[ - req_id] - sorted_bitmask[batch_index] = grammar_bitmask[orig_index] - grammar_bitmask = sorted_bitmask + out_indices = [] + # Reorder the bitmask to match the order of the requests in the batch. + sorted_bitmask = np.zeros_like(grammar_bitmask, + shape=(logits.shape[0], + grammar_bitmask.shape[1])) + cumulative_index = 0 + seq = sorted(scheduler_output.structured_output_request_ids.items(), + key=lambda x: x[1]) + for req_id, _ in seq: + logit_index = struct_out_req_batch_indices[req_id] + num_spec_tokens = len( + scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) + for i in range(1 + num_spec_tokens): + sorted_bitmask[logit_index + i] = \ + grammar_bitmask[cumulative_index + i] + out_indices.append(logit_index + i) + cumulative_index += 1 + num_spec_tokens + grammar_bitmask = sorted_bitmask + + # Serialization of np.ndarray is much more efficient than a tensor, + # so we receive it in that format. grammar_bitmask = torch.from_numpy(grammar_bitmask) - # TODO: compatibility with spec decode xgr.apply_token_bitmask_inplace( logits, grammar_bitmask.to(self.device, non_blocking=True), - indices=list(struct_out_req_batch_indices.values()), + indices=out_indices, ) @torch.inference_mode()