[V1][Feature] Enable Speculative Decoding with Structured Outputs (#14702)

Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
Signed-off-by: Benjamin Chislett <chislett.ben@gmail.com>
This commit is contained in:
Benjamin Chislett 2025-04-29 17:02:10 -07:00 committed by GitHub
parent 7489ec0bab
commit 34120f5acd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 207 additions and 57 deletions

View File

@ -260,6 +260,7 @@ async def async_request_openai_completions(
if request_func_input.model_name else request_func_input.model, if request_func_input.model_name else request_func_input.model,
"prompt": request_func_input.prompt, "prompt": request_func_input.prompt,
"temperature": 0.0, "temperature": 0.0,
"repetition_penalty": 1.0,
"max_tokens": request_func_input.output_len, "max_tokens": request_func_input.output_len,
"logprobs": request_func_input.logprobs, "logprobs": request_func_input.logprobs,
"stream": True, "stream": True,

View File

@ -123,6 +123,8 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
copy.deepcopy(schema) for _ in range(args.num_prompts) copy.deepcopy(schema) for _ in range(args.num_prompts)
] ]
for i in range(len(json_schemas)): for i in range(len(json_schemas)):
if "properties" not in json_schemas[i]:
json_schemas[i]["properties"] = {}
json_schemas[i]["properties"][ json_schemas[i]["properties"][
f"__optional_field_{uuid.uuid4()}"] = { f"__optional_field_{uuid.uuid4()}"] = {
"type": "type":
@ -134,7 +136,7 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
json_schemas = [schema] * args.num_prompts json_schemas = [schema] * args.num_prompts
def gen_prompt(index: int): def gen_prompt(index: int):
return f"Generate an example of a user profile given the following schema: {json.dumps(get_schema(index))}" # noqa: E501 return f"Generate an example of a brief user profile given the following schema: {json.dumps(get_schema(index))}" # noqa: E501
def get_schema(index: int): def get_schema(index: int):
return json_schemas[index % len(json_schemas)] return json_schemas[index % len(json_schemas)]
@ -231,7 +233,8 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
idx -= len_dataset idx -= len_dataset
schema = dataset["schema"][idx] schema = dataset["schema"][idx]
prompt = tokenizer.apply_chat_template(dataset["prompt"][idx], prompt = tokenizer.apply_chat_template(dataset["prompt"][idx],
tokenize=False) tokenize=False,
add_generation_prompt=True)
input_len = len(tokenizer(prompt).input_ids) input_len = len(tokenizer(prompt).input_ids)
completion = dataset["completion"][idx] completion = dataset["completion"][idx]
@ -849,7 +852,7 @@ if __name__ == "__main__":
'json', 'json-unique', 'grammar', 'regex', 'json', 'json-unique', 'grammar', 'regex',
'choice', 'xgrammar_bench' 'choice', 'xgrammar_bench'
]) ])
parser.add_argument("--json_schema_path", parser.add_argument("--json-schema-path",
type=str, type=str,
default=None, default=None,
help="Path to json schema.") help="Path to json schema.")

View File

@ -16,13 +16,31 @@ from vllm.outputs import RequestOutput
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sampling_params import GuidedDecodingParams, SamplingParams from vllm.sampling_params import GuidedDecodingParams, SamplingParams
NGRAM_SPEC_CONFIG = {
"model": "[ngram]",
"num_speculative_tokens": 5,
"prompt_lookup_max": 5,
"prompt_lookup_min": 1,
}
EAGLE_SPEC_CONFIG = {
"method": "eagle",
"model": "yuhuili/EAGLE-LLaMA3.1-Instruct-8B",
"num_speculative_tokens": 5,
}
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [ PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto"), ("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto", None),
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto"), ("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None),
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral"), ("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None),
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto"), ("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None),
#FIXME: This test is flaky on CI thus disabled #FIXME: This test is flaky on CI thus disabled
#("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"), #("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"),
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto",
NGRAM_SPEC_CONFIG),
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", NGRAM_SPEC_CONFIG),
("meta-llama/Meta-Llama-3.1-8B-Instruct", "xgrammar", "auto",
EAGLE_SPEC_CONFIG)
] ]
PARAMS_MODELS_TOKENIZER_MODE = [ PARAMS_MODELS_TOKENIZER_MODE = [
@ -45,8 +63,9 @@ class CarDescription(BaseModel):
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("model_name, guided_decoding_backend, tokenizer_mode", @pytest.mark.parametrize(
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE) "model_name, guided_decoding_backend, tokenizer_mode, speculative_config",
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE)
def test_structured_output( def test_structured_output(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
sample_json_schema: dict[str, Any], sample_json_schema: dict[str, Any],
@ -58,6 +77,7 @@ def test_structured_output(
guided_decoding_backend: str, guided_decoding_backend: str,
tokenizer_mode: str, tokenizer_mode: str,
model_name: str, model_name: str,
speculative_config: dict[str, Any],
): ):
monkeypatch.setenv("VLLM_USE_V1", "1") monkeypatch.setenv("VLLM_USE_V1", "1")
@ -71,7 +91,8 @@ def test_structured_output(
max_model_len=1024, max_model_len=1024,
guided_decoding_backend=guided_decoding_backend, guided_decoding_backend=guided_decoding_backend,
guided_decoding_disable_any_whitespace=True, guided_decoding_disable_any_whitespace=True,
tokenizer_mode=tokenizer_mode) tokenizer_mode=tokenizer_mode,
speculative_config=speculative_config)
# #
# Test 1: Generate JSON output based on a provided schema # Test 1: Generate JSON output based on a provided schema

View File

@ -441,7 +441,7 @@ class Scheduler(SchedulerInterface):
grammar_bitmask = self.structured_output_manager.grammar_bitmask( grammar_bitmask = self.structured_output_manager.grammar_bitmask(
self.requests, self.requests,
structured_output_request_ids, structured_output_request_ids,
len(self.running), scheduled_spec_decode_tokens,
) )
# Construct the scheduler output. # Construct the scheduler output.
new_reqs_data = [ new_reqs_data = [
@ -682,10 +682,6 @@ class Scheduler(SchedulerInterface):
self.encoder_cache_manager.free_encoder_input( self.encoder_cache_manager.free_encoder_input(
request, input_id) request, input_id)
# Add newly generated spec token ids to the request.
if spec_token_ids is not None:
request.spec_token_ids = spec_token_ids[req_index]
stopped = False stopped = False
new_logprobs = None new_logprobs = None
new_token_ids = generated_token_ids new_token_ids = generated_token_ids
@ -717,6 +713,17 @@ class Scheduler(SchedulerInterface):
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
req_id, new_token_ids) req_id, new_token_ids)
# Add newly generated spec token ids to the request.
if spec_token_ids is not None:
if request.use_structured_output:
metadata = request.structured_output_request
assert metadata is not None and metadata.grammar is not None
# Needs to happen after new_token_ids are accepted.
request.spec_token_ids = metadata.grammar.validate_tokens(
spec_token_ids[req_index])
else:
request.spec_token_ids = spec_token_ids[req_index]
# Get prompt logprobs for this request. # Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
if new_token_ids: if new_token_ids:

View File

@ -27,6 +27,7 @@ class StructuredOutputManager:
def __init__(self, vllm_config: VllmConfig): def __init__(self, vllm_config: VllmConfig):
self.backend: Optional[StructuredOutputBackend] = None self.backend: Optional[StructuredOutputBackend] = None
self.vllm_config = vllm_config self.vllm_config = vllm_config
self._grammar_bitmask: Optional[torch.Tensor] = None self._grammar_bitmask: Optional[torch.Tensor] = None
# The default max_workers if not specified is the number of CPUs * 5, # The default max_workers if not specified is the number of CPUs * 5,
@ -80,7 +81,7 @@ class StructuredOutputManager:
self, self,
requests: dict[str, Request], requests: dict[str, Request],
structured_output_request_ids: dict[str, int], structured_output_request_ids: dict[str, int],
batch_len: int, scheduled_spec_decode_tokens: dict[str, list[int]],
) -> Optional[npt.NDArray[np.int32]]: ) -> Optional[npt.NDArray[np.int32]]:
# Prepare the structured output bitmask for this batch. # Prepare the structured output bitmask for this batch.
if not structured_output_request_ids: if not structured_output_request_ids:
@ -88,20 +89,52 @@ class StructuredOutputManager:
if self._grammar_bitmask is None: if self._grammar_bitmask is None:
assert self.backend is not None assert self.backend is not None
self._grammar_bitmask = self.backend.allocate_token_bitmask( max_batch_size = self.vllm_config.scheduler_config.max_num_seqs
self.vllm_config.scheduler_config.max_num_seqs) if self.vllm_config.speculative_config is not None:
max_num_spec_tokens = self.vllm_config.\
speculative_config.num_speculative_tokens
else:
max_num_spec_tokens = 0
# Fill the bitmask using the index of each request equal to its # Allocate a bitmask for each token needing to be checked:
# position in the batch. Resize the bitmask down to the size of # one for each speculative position, and one more for the
# the batch. # bonus token / non-speculative token.
bitmask_tensor = self._grammar_bitmask self._grammar_bitmask = \
for req_id, batch_index in structured_output_request_ids.items(): self.backend.allocate_token_bitmask(
max_batch_size * (1 + max_num_spec_tokens))
# Generate a batched bitmask for all structured output requests.
# When speculative decoding is enabled, we need to include multiple
# masks for each request, one for each possible bonus token position.
# These are stored inline in the tensor and unpacked by the gpu runner.
cumulative_index = 0
ordered_seq = sorted(structured_output_request_ids.items(),
key=lambda x: x[1])
# NOTE: This outer loop can likely be parallelized to improve
# performance of bitmask generation for large batches.
for req_id, _ in ordered_seq:
request = requests[req_id].structured_output_request request = requests[req_id].structured_output_request
assert request is not None and request.grammar is not None assert request is not None and request.grammar is not None
if not request.grammar.is_terminated(): state_advancements = 0
request.grammar.fill_bitmask(bitmask_tensor, batch_index) req_tokens = scheduled_spec_decode_tokens.get(req_id, []) + [None]
if batch_len < self._grammar_bitmask.shape[0]: for i, token in enumerate(req_tokens):
bitmask_tensor = self._grammar_bitmask[:batch_len] if not request.grammar.is_terminated():
request.grammar.fill_bitmask(self._grammar_bitmask,
cumulative_index)
if token is not None:
# In order to generate the correct bitmask for each
# position in the speculative sequence, we advance
# the FSM state for each speculative token and rollback
# to restore the previous state when we are finished.
assert request.grammar.accept_tokens(req_id, [token])
state_advancements += 1
cumulative_index += 1
if state_advancements > 0:
request.grammar.rollback(state_advancements)
bitmask_tensor = self._grammar_bitmask
if cumulative_index < self._grammar_bitmask.shape[0]:
bitmask_tensor = self._grammar_bitmask[:cumulative_index]
# After finishing with the xgrammar operations, we convert to # After finishing with the xgrammar operations, we convert to
# np.ndarray, because that is much more efficient for serialization # np.ndarray, because that is much more efficient for serialization

View File

@ -144,6 +144,27 @@ class GuidanceGrammar(StructuredOutputGrammar):
return r return r
def validate_tokens(self, tokens: list[int]) -> list[int]:
"""Checks if the list of tokens are accepted by the parser in sequence.
Will not advance the parser.
Returns the prefix list of tokens that are accepted by the parser.
"""
if len(tokens) == 0:
return []
if self.ll_matcher.is_stopped():
return []
num_tokens = self.ll_matcher.validate_tokens(tokens)
self.check_error()
return tokens[:num_tokens]
def rollback(self, num_tokens: int) -> None:
self.ll_matcher.rollback(num_tokens)
self.check_error()
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None: def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
# this will automatically return [EOS] mask if the matcher is stopped # this will automatically return [EOS] mask if the matcher is stopped
# or otherwise in an error state # or otherwise in an error state

View File

@ -35,6 +35,30 @@ class StructuredOutputGrammar(ABC):
bool: True if the tokens are accepted, False otherwise. bool: True if the tokens are accepted, False otherwise.
""" """
@abstractmethod
def validate_tokens(self, tokens: list[int]) -> list[int]:
"""
Validates the provided tokens against the grammar.
Will not advance the FSM.
Args:
tokens (list[int]): A list of token IDs to validate.
Returns:
list[int]: A list of accepted token IDs. Will be a prefix
of the input tokens, and empty if none are accepted.
"""
@abstractmethod
def rollback(self, num_tokens: int) -> None:
"""
Rolls back the state of the grammar by a specified number of tokens.
Will also revert counters for the number of processed tokens.
Args:
num_tokens (int): The number of tokens to roll back.
"""
@abstractmethod @abstractmethod
def fill_bitmask(self, bitmask: torch.Tensor, batch_index: int) -> None: def fill_bitmask(self, bitmask: torch.Tensor, batch_index: int) -> None:
""" """

View File

@ -40,6 +40,11 @@ class XgrammarBackend(StructuredOutputBackend):
self.disable_any_whitespace = \ self.disable_any_whitespace = \
vllm_config.decoding_config.disable_any_whitespace vllm_config.decoding_config.disable_any_whitespace
self.num_speculative_tokens = 0
if self.vllm_config.speculative_config is not None:
self.num_speculative_tokens = \
self.vllm_config.speculative_config.num_speculative_tokens
tokenizer = tokenizer_group.get_lora_tokenizer(None) tokenizer = tokenizer_group.get_lora_tokenizer(None)
self.vocab_size = vllm_config.model_config.get_vocab_size() self.vocab_size = vllm_config.model_config.get_vocab_size()
if isinstance(tokenizer, MistralTokenizer): if isinstance(tokenizer, MistralTokenizer):
@ -118,7 +123,10 @@ class XgrammarBackend(StructuredOutputBackend):
f"grammar is not of valid supported types. ({request_type!s})") f"grammar is not of valid supported types. ({request_type!s})")
return XgrammarGrammar( return XgrammarGrammar(
matcher=xgr.GrammarMatcher(ctx), matcher=xgr.GrammarMatcher(
ctx,
max_rollback_tokens=self.num_speculative_tokens,
),
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
ctx=ctx, ctx=ctx,
) )
@ -136,7 +144,6 @@ class XgrammarGrammar(StructuredOutputGrammar):
# supporting different backends, in the future. # supporting different backends, in the future.
# For now, just xgrammar. # For now, just xgrammar.
# #
# TODO: support max_rollback_tokens
# https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string # https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string
# for jump-forward decoding # for jump-forward decoding
@ -163,6 +170,27 @@ class XgrammarGrammar(StructuredOutputGrammar):
self.num_processed_tokens += 1 self.num_processed_tokens += 1
return True return True
def validate_tokens(self, tokens: list[int]) -> list[int]:
"""Checks if the list of tokens are accepted by the FSM in sequence.
Will not advance the FSM.
Returns the prefix list of tokens that are accepted by the FSM.
"""
accepted_tokens = []
for token in tokens:
if self.matcher.accept_token(token):
accepted_tokens.append(token)
else:
break
if len(accepted_tokens) > 0:
# Rollback the FSM to the initial state
self.matcher.rollback(len(accepted_tokens))
return accepted_tokens
def rollback(self, num_tokens: int) -> None:
self.matcher.rollback(num_tokens)
self.num_processed_tokens -= num_tokens
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None: def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
self.matcher.fill_next_token_bitmask(bitmask, idx) self.matcher.fill_next_token_bitmask(bitmask, idx)

View File

@ -957,46 +957,58 @@ class GPUModelRunner(LoRAModelRunnerMixin):
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
logits: torch.Tensor, logits: torch.Tensor,
): ):
# Serialization of np.ndarray is much more efficient than a tensor,
# so we receive it in that format.
grammar_bitmask = scheduler_output.grammar_bitmask grammar_bitmask = scheduler_output.grammar_bitmask
if grammar_bitmask is None: if grammar_bitmask is None:
return return
# We receive the structured output bitmask from the scheduler, but the # We receive the structured output bitmask from the scheduler,
# indices of the requests in the batch may not match the indices of # compacted to contain bitmasks only for structured output requests.
# the bitmask since the scheduler doesn't know how the gpu runner is # The order of the requests in the bitmask is not guaranteed to be the
# ordering the requests in the batch. We need to sort the bitmask to # same as the order of the requests in the gpu runner's batch. We need
# match the order of the requests used here. # to sort the bitmask to match the order of the requests used here.
# Get the batch indices of the structured output requests.
# Keep track of the number of speculative tokens scheduled for every
# request in the batch, as the logit indices are offset by this amount.
struct_out_req_batch_indices: dict[str, int] = {} struct_out_req_batch_indices: dict[str, int] = {}
indices_match = True cumulative_offset = 0
for req_id in self.input_batch.req_ids: seq = sorted(self.input_batch.req_id_to_index.items(),
mask_index = scheduler_output.structured_output_request_ids.get( key=lambda x: x[1])
req_id) for req_id, batch_index in seq:
if mask_index is None: logit_index = batch_index + cumulative_offset
# not a structured output request cumulative_offset += len(
continue scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
batch_index = self.input_batch.req_id_to_index[req_id] if req_id in scheduler_output.structured_output_request_ids:
if batch_index != mask_index: struct_out_req_batch_indices[req_id] = logit_index
indices_match = False
struct_out_req_batch_indices[req_id] = batch_index
if not indices_match: out_indices = []
# Sort the bitmask to match the order of the requests
sorted_bitmask = np.zeros_like(grammar_bitmask)
for req_id, batch_index in struct_out_req_batch_indices.items():
orig_index = scheduler_output.structured_output_request_ids[
req_id]
sorted_bitmask[batch_index] = grammar_bitmask[orig_index]
grammar_bitmask = sorted_bitmask
# Reorder the bitmask to match the order of the requests in the batch.
sorted_bitmask = np.zeros_like(grammar_bitmask,
shape=(logits.shape[0],
grammar_bitmask.shape[1]))
cumulative_index = 0
seq = sorted(scheduler_output.structured_output_request_ids.items(),
key=lambda x: x[1])
for req_id, _ in seq:
logit_index = struct_out_req_batch_indices[req_id]
num_spec_tokens = len(
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
for i in range(1 + num_spec_tokens):
sorted_bitmask[logit_index + i] = \
grammar_bitmask[cumulative_index + i]
out_indices.append(logit_index + i)
cumulative_index += 1 + num_spec_tokens
grammar_bitmask = sorted_bitmask
# Serialization of np.ndarray is much more efficient than a tensor,
# so we receive it in that format.
grammar_bitmask = torch.from_numpy(grammar_bitmask) grammar_bitmask = torch.from_numpy(grammar_bitmask)
# TODO: compatibility with spec decode
xgr.apply_token_bitmask_inplace( xgr.apply_token_bitmask_inplace(
logits, logits,
grammar_bitmask.to(self.device, non_blocking=True), grammar_bitmask.to(self.device, non_blocking=True),
indices=list(struct_out_req_batch_indices.values()), indices=out_indices,
) )
@torch.inference_mode() @torch.inference_mode()