[Perf] Parallelize fill_bitmask to accelerate high-throughput guided decoding (#21862)

Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
This commit is contained in:
Benjamin Chislett 2025-08-05 22:57:49 -04:00 committed by GitHub
parent 8e6c7e873f
commit 7e6544c797
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 102 additions and 39 deletions

View File

@ -3,7 +3,7 @@
from __future__ import annotations
import multiprocessing
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import Future, ThreadPoolExecutor
from typing import TYPE_CHECKING, Optional
from vllm.config import VllmConfig
@ -40,6 +40,17 @@ class StructuredOutputManager:
self._grammar_bitmask: Optional[torch.Tensor] = None
self._full_mask = torch.tensor(-1, dtype=torch.int32)
max_batch_size = self.vllm_config.scheduler_config.max_num_seqs
self.fill_bitmask_parallel_threshold = 128
if self.fill_bitmask_parallel_threshold < max_batch_size:
self.fill_bitmask_parallel_batch_size = 16
# Use:
# - at least 1 CPU
# - at most half the number of CPUs or 8, whichever is less
max_workers = max(1, min(multiprocessing.cpu_count() // 2, 8))
self.executor_for_fillmask = ThreadPoolExecutor(
max_workers=max_workers)
if not self.vllm_config.model_config.skip_tokenizer_init:
# The default max_workers if not specified is the number of
# CPUs * 5, which is way too high since these tasks are CPU-bound,
@ -120,6 +131,26 @@ class StructuredOutputManager:
assert self.backend is not None
return self.backend.compile_grammar(request_type, grammar_spec)
def _fill_bitmasks(
self,
batch: list[tuple[StructuredOutputGrammar, int, bool]],
) -> None:
assert self._grammar_bitmask is not None
for grammar, index, apply_bitmask in batch:
if apply_bitmask and not grammar.is_terminated():
grammar.fill_bitmask(self._grammar_bitmask, index)
else:
# Note that for thinking support, we will need to
# reset the relevant part of the bitmask for consequent
# requests here.
self._grammar_bitmask[index].fill_(self._full_mask)
def _async_submit_fill_bitmask(
self,
batch: list[tuple[StructuredOutputGrammar, int, bool]],
) -> Future:
return self.executor_for_fillmask.submit(self._fill_bitmasks, batch)
def grammar_bitmask(
self,
requests: dict[str, Request],
@ -146,7 +177,6 @@ class StructuredOutputManager:
self.backend.allocate_token_bitmask(
max_batch_size * (1 + max_num_spec_tokens))
bitmask_tensor = self._grammar_bitmask
# Generate a batched bitmask for all structured output requests.
# When speculative decoding is enabled, we need to include multiple
# masks for each request, one for each possible bonus token position.
@ -155,47 +185,61 @@ class StructuredOutputManager:
ordered_seq = sorted(structured_output_request_ids.items(),
key=lambda x: x[1])
# Note that for thinking support, we will need to
# reset the relevant part of the bitmask for consequent
# request here.
bitmask_tensor[:(len(ordered_seq) * (1 + max_num_spec_tokens))].fill_(
self._full_mask)
# Optimized parallel filling of bitmasks for
# non-spec, large-batch-size cases
if len(ordered_seq) > self.fill_bitmask_parallel_threshold and \
max_num_spec_tokens == 0:
promises = []
batch = []
for req_id, _ in ordered_seq:
request = requests[req_id]
structured_output_request = request.structured_output_request
if TYPE_CHECKING:
assert structured_output_request is not None
assert structured_output_request.grammar is not None
# NOTE: This outer loop can likely be parallelized to improve
# performance of bitmask generation for large batches.
for req_id, _ in ordered_seq:
request = requests[req_id]
structured_output_request = request.structured_output_request
apply_bitmask = self.should_fill_bitmask(request)
batch.append((structured_output_request.grammar,
cumulative_index, apply_bitmask))
if len(batch) == self.fill_bitmask_parallel_batch_size:
promises.append(self._async_submit_fill_bitmask(batch))
batch = []
if TYPE_CHECKING:
assert structured_output_request is not None
assert structured_output_request.grammar is not None
apply_bitmask: bool = True
if self.reasoner is not None:
if structured_output_request.reasoning_ended is None:
structured_output_request.reasoning_ended = \
self.reasoner.is_reasoning_end(request.prompt_token_ids)
apply_bitmask = structured_output_request.reasoning_ended
cumulative_index += 1
if batch:
promises.append(self._async_submit_fill_bitmask(batch))
state_advancements = 0
req_tokens = scheduled_spec_decode_tokens.get(req_id, []) + [None]
for i, token in enumerate(req_tokens):
if apply_bitmask and not \
structured_output_request.grammar.is_terminated():
structured_output_request.grammar.fill_bitmask(
bitmask_tensor, cumulative_index)
if token is not None:
# In order to generate the correct bitmask for each
# position in the speculative sequence, we advance
# the FSM state for each speculative token and rollback
# to restore the previous state when we are finished.
# Wait for all bitmask filling tasks to complete.
for promise in promises:
promise.result()
else:
# Fallback to serial filling of bitmasks for small-batch-size cases
for req_id, _ in ordered_seq:
request = requests[req_id]
structured_output_request = request.structured_output_request
if TYPE_CHECKING:
assert structured_output_request is not None
assert structured_output_request.grammar is not None
apply_bitmask = self.should_fill_bitmask(request)
state_advancements = 0
req_tokens = scheduled_spec_decode_tokens.get(req_id, [])
for i, token in enumerate(req_tokens + [None]):
self._fill_bitmasks([(structured_output_request.grammar,
cumulative_index, apply_bitmask)])
if apply_bitmask and token is not None and \
not structured_output_request.grammar.is_terminated():
assert structured_output_request.grammar.accept_tokens(
req_id, [token])
state_advancements += 1
cumulative_index += 1
if state_advancements > 0:
structured_output_request.grammar.rollback(state_advancements)
cumulative_index += 1
if state_advancements > 0:
structured_output_request.grammar.rollback(
state_advancements)
bitmask_tensor = self._grammar_bitmask
if cumulative_index < bitmask_tensor.shape[0]:
bitmask_tensor = bitmask_tensor[:cumulative_index]
@ -204,6 +248,15 @@ class StructuredOutputManager:
# and deserialization when sending this to the GPU workers.
return bitmask_tensor.numpy()
def should_fill_bitmask(self, request: Request) -> bool:
if self.reasoner is not None:
assert request.structured_output_request is not None
if request.structured_output_request.reasoning_ended is None:
request.structured_output_request.reasoning_ended = \
self.reasoner.is_reasoning_end(request.prompt_token_ids)
return request.structured_output_request.reasoning_ended
return True
def should_advance(self, request: Request) -> bool:
if not request.use_structured_output:
return False

View File

@ -148,6 +148,7 @@ class XgrammarGrammar(StructuredOutputGrammar):
repr=False,
hash=False,
init=False)
_is_terminated: bool = field(default=False, repr=False, hash=False)
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
"""Accepts a list of tokens and advances the FSM.
@ -155,6 +156,8 @@ class XgrammarGrammar(StructuredOutputGrammar):
Returns True if the FSM was advanced successfully.
Returns False if the FSM failed to advance.
"""
if self._is_terminated:
return False
for token in tokens:
if not self.matcher.accept_token(token):
logger.error(
@ -162,6 +165,7 @@ class XgrammarGrammar(StructuredOutputGrammar):
"for tokens %s. Please file an issue.", request_id, token)
return False
self.num_processed_tokens += 1
self._is_terminated = self.matcher.is_terminated()
return True
def validate_tokens(self, tokens: list[int]) -> list[int]:
@ -184,12 +188,13 @@ class XgrammarGrammar(StructuredOutputGrammar):
def rollback(self, num_tokens: int) -> None:
self.matcher.rollback(num_tokens)
self.num_processed_tokens -= num_tokens
self._is_terminated = self.matcher.is_terminated()
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
self.matcher.fill_next_token_bitmask(bitmask, idx)
def is_terminated(self) -> bool:
return self.matcher.is_terminated()
return self._is_terminated
def reset(self):
self.num_processed_tokens = 0

View File

@ -1324,9 +1324,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
cumulative_index += 1 + num_spec_tokens
grammar_bitmask = sorted_bitmask
# If the grammar bitmask and the logits have the same shape
# we don't need to pass indices to the kernel,
# since the bitmask is already aligned with the logits.
skip_out_indices = grammar_bitmask.shape[0] == logits.shape[0]
# Serialization of np.ndarray is much more efficient than a tensor,
# so we receive it in that format.
grammar_bitmask = torch.from_numpy(grammar_bitmask)
grammar_bitmask = torch.from_numpy(grammar_bitmask).contiguous()
# Force use of the torch.compile implementation from xgrammar to work
# around issues with the Triton kernel in concurrent structured output
@ -1334,7 +1339,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
xgr_torch_compile.apply_token_bitmask_inplace_torch_compile(
logits,
grammar_bitmask.to(self.device, non_blocking=True),
indices=out_indices,
indices=out_indices if not skip_out_indices else None,
)
def sync_and_slice_intermediate_tensors(