[V1][Feature] Enable Speculative Decoding with Structured Outputs (#14702)

Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
Signed-off-by: Benjamin Chislett <chislett.ben@gmail.com>
This commit is contained in:
Benjamin Chislett 2025-04-29 17:02:10 -07:00 committed by GitHub
parent 7489ec0bab
commit 34120f5acd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 207 additions and 57 deletions

View File

@ -260,6 +260,7 @@ async def async_request_openai_completions(
if request_func_input.model_name else request_func_input.model,
"prompt": request_func_input.prompt,
"temperature": 0.0,
"repetition_penalty": 1.0,
"max_tokens": request_func_input.output_len,
"logprobs": request_func_input.logprobs,
"stream": True,

View File

@ -123,6 +123,8 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
copy.deepcopy(schema) for _ in range(args.num_prompts)
]
for i in range(len(json_schemas)):
if "properties" not in json_schemas[i]:
json_schemas[i]["properties"] = {}
json_schemas[i]["properties"][
f"__optional_field_{uuid.uuid4()}"] = {
"type":
@ -134,7 +136,7 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
json_schemas = [schema] * args.num_prompts
def gen_prompt(index: int):
return f"Generate an example of a user profile given the following schema: {json.dumps(get_schema(index))}" # noqa: E501
return f"Generate an example of a brief user profile given the following schema: {json.dumps(get_schema(index))}" # noqa: E501
def get_schema(index: int):
return json_schemas[index % len(json_schemas)]
@ -231,7 +233,8 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
idx -= len_dataset
schema = dataset["schema"][idx]
prompt = tokenizer.apply_chat_template(dataset["prompt"][idx],
tokenize=False)
tokenize=False,
add_generation_prompt=True)
input_len = len(tokenizer(prompt).input_ids)
completion = dataset["completion"][idx]
@ -849,7 +852,7 @@ if __name__ == "__main__":
'json', 'json-unique', 'grammar', 'regex',
'choice', 'xgrammar_bench'
])
parser.add_argument("--json_schema_path",
parser.add_argument("--json-schema-path",
type=str,
default=None,
help="Path to json schema.")

View File

@ -16,13 +16,31 @@ from vllm.outputs import RequestOutput
from vllm.platforms import current_platform
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
NGRAM_SPEC_CONFIG = {
"model": "[ngram]",
"num_speculative_tokens": 5,
"prompt_lookup_max": 5,
"prompt_lookup_min": 1,
}
EAGLE_SPEC_CONFIG = {
"method": "eagle",
"model": "yuhuili/EAGLE-LLaMA3.1-Instruct-8B",
"num_speculative_tokens": 5,
}
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto"),
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto"),
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral"),
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto"),
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto", None),
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None),
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None),
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None),
#FIXME: This test is flaky on CI thus disabled
#("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"),
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto",
NGRAM_SPEC_CONFIG),
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", NGRAM_SPEC_CONFIG),
("meta-llama/Meta-Llama-3.1-8B-Instruct", "xgrammar", "auto",
EAGLE_SPEC_CONFIG)
]
PARAMS_MODELS_TOKENIZER_MODE = [
@ -45,8 +63,9 @@ class CarDescription(BaseModel):
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("model_name, guided_decoding_backend, tokenizer_mode",
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE)
@pytest.mark.parametrize(
"model_name, guided_decoding_backend, tokenizer_mode, speculative_config",
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE)
def test_structured_output(
monkeypatch: pytest.MonkeyPatch,
sample_json_schema: dict[str, Any],
@ -58,6 +77,7 @@ def test_structured_output(
guided_decoding_backend: str,
tokenizer_mode: str,
model_name: str,
speculative_config: dict[str, Any],
):
monkeypatch.setenv("VLLM_USE_V1", "1")
@ -71,7 +91,8 @@ def test_structured_output(
max_model_len=1024,
guided_decoding_backend=guided_decoding_backend,
guided_decoding_disable_any_whitespace=True,
tokenizer_mode=tokenizer_mode)
tokenizer_mode=tokenizer_mode,
speculative_config=speculative_config)
#
# Test 1: Generate JSON output based on a provided schema

View File

@ -441,7 +441,7 @@ class Scheduler(SchedulerInterface):
grammar_bitmask = self.structured_output_manager.grammar_bitmask(
self.requests,
structured_output_request_ids,
len(self.running),
scheduled_spec_decode_tokens,
)
# Construct the scheduler output.
new_reqs_data = [
@ -682,10 +682,6 @@ class Scheduler(SchedulerInterface):
self.encoder_cache_manager.free_encoder_input(
request, input_id)
# Add newly generated spec token ids to the request.
if spec_token_ids is not None:
request.spec_token_ids = spec_token_ids[req_index]
stopped = False
new_logprobs = None
new_token_ids = generated_token_ids
@ -717,6 +713,17 @@ class Scheduler(SchedulerInterface):
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
req_id, new_token_ids)
# Add newly generated spec token ids to the request.
if spec_token_ids is not None:
if request.use_structured_output:
metadata = request.structured_output_request
assert metadata is not None and metadata.grammar is not None
# Needs to happen after new_token_ids are accepted.
request.spec_token_ids = metadata.grammar.validate_tokens(
spec_token_ids[req_index])
else:
request.spec_token_ids = spec_token_ids[req_index]
# Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
if new_token_ids:

View File

@ -27,6 +27,7 @@ class StructuredOutputManager:
def __init__(self, vllm_config: VllmConfig):
self.backend: Optional[StructuredOutputBackend] = None
self.vllm_config = vllm_config
self._grammar_bitmask: Optional[torch.Tensor] = None
# The default max_workers if not specified is the number of CPUs * 5,
@ -80,7 +81,7 @@ class StructuredOutputManager:
self,
requests: dict[str, Request],
structured_output_request_ids: dict[str, int],
batch_len: int,
scheduled_spec_decode_tokens: dict[str, list[int]],
) -> Optional[npt.NDArray[np.int32]]:
# Prepare the structured output bitmask for this batch.
if not structured_output_request_ids:
@ -88,20 +89,52 @@ class StructuredOutputManager:
if self._grammar_bitmask is None:
assert self.backend is not None
self._grammar_bitmask = self.backend.allocate_token_bitmask(
self.vllm_config.scheduler_config.max_num_seqs)
max_batch_size = self.vllm_config.scheduler_config.max_num_seqs
if self.vllm_config.speculative_config is not None:
max_num_spec_tokens = self.vllm_config.\
speculative_config.num_speculative_tokens
else:
max_num_spec_tokens = 0
# Fill the bitmask using the index of each request equal to its
# position in the batch. Resize the bitmask down to the size of
# the batch.
bitmask_tensor = self._grammar_bitmask
for req_id, batch_index in structured_output_request_ids.items():
# Allocate a bitmask for each token needing to be checked:
# one for each speculative position, and one more for the
# bonus token / non-speculative token.
self._grammar_bitmask = \
self.backend.allocate_token_bitmask(
max_batch_size * (1 + max_num_spec_tokens))
# Generate a batched bitmask for all structured output requests.
# When speculative decoding is enabled, we need to include multiple
# masks for each request, one for each possible bonus token position.
# These are stored inline in the tensor and unpacked by the gpu runner.
cumulative_index = 0
ordered_seq = sorted(structured_output_request_ids.items(),
key=lambda x: x[1])
# NOTE: This outer loop can likely be parallelized to improve
# performance of bitmask generation for large batches.
for req_id, _ in ordered_seq:
request = requests[req_id].structured_output_request
assert request is not None and request.grammar is not None
if not request.grammar.is_terminated():
request.grammar.fill_bitmask(bitmask_tensor, batch_index)
if batch_len < self._grammar_bitmask.shape[0]:
bitmask_tensor = self._grammar_bitmask[:batch_len]
state_advancements = 0
req_tokens = scheduled_spec_decode_tokens.get(req_id, []) + [None]
for i, token in enumerate(req_tokens):
if not request.grammar.is_terminated():
request.grammar.fill_bitmask(self._grammar_bitmask,
cumulative_index)
if token is not None:
# In order to generate the correct bitmask for each
# position in the speculative sequence, we advance
# the FSM state for each speculative token and rollback
# to restore the previous state when we are finished.
assert request.grammar.accept_tokens(req_id, [token])
state_advancements += 1
cumulative_index += 1
if state_advancements > 0:
request.grammar.rollback(state_advancements)
bitmask_tensor = self._grammar_bitmask
if cumulative_index < self._grammar_bitmask.shape[0]:
bitmask_tensor = self._grammar_bitmask[:cumulative_index]
# After finishing with the xgrammar operations, we convert to
# np.ndarray, because that is much more efficient for serialization

View File

@ -144,6 +144,27 @@ class GuidanceGrammar(StructuredOutputGrammar):
return r
def validate_tokens(self, tokens: list[int]) -> list[int]:
"""Checks if the list of tokens are accepted by the parser in sequence.
Will not advance the parser.
Returns the prefix list of tokens that are accepted by the parser.
"""
if len(tokens) == 0:
return []
if self.ll_matcher.is_stopped():
return []
num_tokens = self.ll_matcher.validate_tokens(tokens)
self.check_error()
return tokens[:num_tokens]
def rollback(self, num_tokens: int) -> None:
self.ll_matcher.rollback(num_tokens)
self.check_error()
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
# this will automatically return [EOS] mask if the matcher is stopped
# or otherwise in an error state

View File

@ -35,6 +35,30 @@ class StructuredOutputGrammar(ABC):
bool: True if the tokens are accepted, False otherwise.
"""
@abstractmethod
def validate_tokens(self, tokens: list[int]) -> list[int]:
"""
Validates the provided tokens against the grammar.
Will not advance the FSM.
Args:
tokens (list[int]): A list of token IDs to validate.
Returns:
list[int]: A list of accepted token IDs. Will be a prefix
of the input tokens, and empty if none are accepted.
"""
@abstractmethod
def rollback(self, num_tokens: int) -> None:
"""
Rolls back the state of the grammar by a specified number of tokens.
Will also revert counters for the number of processed tokens.
Args:
num_tokens (int): The number of tokens to roll back.
"""
@abstractmethod
def fill_bitmask(self, bitmask: torch.Tensor, batch_index: int) -> None:
"""

View File

@ -40,6 +40,11 @@ class XgrammarBackend(StructuredOutputBackend):
self.disable_any_whitespace = \
vllm_config.decoding_config.disable_any_whitespace
self.num_speculative_tokens = 0
if self.vllm_config.speculative_config is not None:
self.num_speculative_tokens = \
self.vllm_config.speculative_config.num_speculative_tokens
tokenizer = tokenizer_group.get_lora_tokenizer(None)
self.vocab_size = vllm_config.model_config.get_vocab_size()
if isinstance(tokenizer, MistralTokenizer):
@ -118,7 +123,10 @@ class XgrammarBackend(StructuredOutputBackend):
f"grammar is not of valid supported types. ({request_type!s})")
return XgrammarGrammar(
matcher=xgr.GrammarMatcher(ctx),
matcher=xgr.GrammarMatcher(
ctx,
max_rollback_tokens=self.num_speculative_tokens,
),
vocab_size=self.vocab_size,
ctx=ctx,
)
@ -136,7 +144,6 @@ class XgrammarGrammar(StructuredOutputGrammar):
# supporting different backends, in the future.
# For now, just xgrammar.
#
# TODO: support max_rollback_tokens
# https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string
# for jump-forward decoding
@ -163,6 +170,27 @@ class XgrammarGrammar(StructuredOutputGrammar):
self.num_processed_tokens += 1
return True
def validate_tokens(self, tokens: list[int]) -> list[int]:
"""Checks if the list of tokens are accepted by the FSM in sequence.
Will not advance the FSM.
Returns the prefix list of tokens that are accepted by the FSM.
"""
accepted_tokens = []
for token in tokens:
if self.matcher.accept_token(token):
accepted_tokens.append(token)
else:
break
if len(accepted_tokens) > 0:
# Rollback the FSM to the initial state
self.matcher.rollback(len(accepted_tokens))
return accepted_tokens
def rollback(self, num_tokens: int) -> None:
self.matcher.rollback(num_tokens)
self.num_processed_tokens -= num_tokens
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
self.matcher.fill_next_token_bitmask(bitmask, idx)

View File

@ -957,46 +957,58 @@ class GPUModelRunner(LoRAModelRunnerMixin):
scheduler_output: "SchedulerOutput",
logits: torch.Tensor,
):
# Serialization of np.ndarray is much more efficient than a tensor,
# so we receive it in that format.
grammar_bitmask = scheduler_output.grammar_bitmask
if grammar_bitmask is None:
return
# We receive the structured output bitmask from the scheduler, but the
# indices of the requests in the batch may not match the indices of
# the bitmask since the scheduler doesn't know how the gpu runner is
# ordering the requests in the batch. We need to sort the bitmask to
# match the order of the requests used here.
# We receive the structured output bitmask from the scheduler,
# compacted to contain bitmasks only for structured output requests.
# The order of the requests in the bitmask is not guaranteed to be the
# same as the order of the requests in the gpu runner's batch. We need
# to sort the bitmask to match the order of the requests used here.
# Get the batch indices of the structured output requests.
# Keep track of the number of speculative tokens scheduled for every
# request in the batch, as the logit indices are offset by this amount.
struct_out_req_batch_indices: dict[str, int] = {}
indices_match = True
for req_id in self.input_batch.req_ids:
mask_index = scheduler_output.structured_output_request_ids.get(
req_id)
if mask_index is None:
# not a structured output request
continue
batch_index = self.input_batch.req_id_to_index[req_id]
if batch_index != mask_index:
indices_match = False
struct_out_req_batch_indices[req_id] = batch_index
cumulative_offset = 0
seq = sorted(self.input_batch.req_id_to_index.items(),
key=lambda x: x[1])
for req_id, batch_index in seq:
logit_index = batch_index + cumulative_offset
cumulative_offset += len(
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
if req_id in scheduler_output.structured_output_request_ids:
struct_out_req_batch_indices[req_id] = logit_index
if not indices_match:
# Sort the bitmask to match the order of the requests
sorted_bitmask = np.zeros_like(grammar_bitmask)
for req_id, batch_index in struct_out_req_batch_indices.items():
orig_index = scheduler_output.structured_output_request_ids[
req_id]
sorted_bitmask[batch_index] = grammar_bitmask[orig_index]
grammar_bitmask = sorted_bitmask
out_indices = []
# Reorder the bitmask to match the order of the requests in the batch.
sorted_bitmask = np.zeros_like(grammar_bitmask,
shape=(logits.shape[0],
grammar_bitmask.shape[1]))
cumulative_index = 0
seq = sorted(scheduler_output.structured_output_request_ids.items(),
key=lambda x: x[1])
for req_id, _ in seq:
logit_index = struct_out_req_batch_indices[req_id]
num_spec_tokens = len(
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
for i in range(1 + num_spec_tokens):
sorted_bitmask[logit_index + i] = \
grammar_bitmask[cumulative_index + i]
out_indices.append(logit_index + i)
cumulative_index += 1 + num_spec_tokens
grammar_bitmask = sorted_bitmask
# Serialization of np.ndarray is much more efficient than a tensor,
# so we receive it in that format.
grammar_bitmask = torch.from_numpy(grammar_bitmask)
# TODO: compatibility with spec decode
xgr.apply_token_bitmask_inplace(
logits,
grammar_bitmask.to(self.device, non_blocking=True),
indices=list(struct_out_req_batch_indices.values()),
indices=out_indices,
)
@torch.inference_mode()