mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-19 06:54:28 +08:00
[Perf] Parallelize fill_bitmask to accelerate high-throughput guided decoding (#21862)
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
This commit is contained in:
parent
8e6c7e873f
commit
7e6544c797
@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import multiprocessing
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
@ -40,6 +40,17 @@ class StructuredOutputManager:
|
||||
self._grammar_bitmask: Optional[torch.Tensor] = None
|
||||
self._full_mask = torch.tensor(-1, dtype=torch.int32)
|
||||
|
||||
max_batch_size = self.vllm_config.scheduler_config.max_num_seqs
|
||||
self.fill_bitmask_parallel_threshold = 128
|
||||
if self.fill_bitmask_parallel_threshold < max_batch_size:
|
||||
self.fill_bitmask_parallel_batch_size = 16
|
||||
# Use:
|
||||
# - at least 1 CPU
|
||||
# - at most half the number of CPUs or 8, whichever is less
|
||||
max_workers = max(1, min(multiprocessing.cpu_count() // 2, 8))
|
||||
self.executor_for_fillmask = ThreadPoolExecutor(
|
||||
max_workers=max_workers)
|
||||
|
||||
if not self.vllm_config.model_config.skip_tokenizer_init:
|
||||
# The default max_workers if not specified is the number of
|
||||
# CPUs * 5, which is way too high since these tasks are CPU-bound,
|
||||
@ -120,6 +131,26 @@ class StructuredOutputManager:
|
||||
assert self.backend is not None
|
||||
return self.backend.compile_grammar(request_type, grammar_spec)
|
||||
|
||||
def _fill_bitmasks(
|
||||
self,
|
||||
batch: list[tuple[StructuredOutputGrammar, int, bool]],
|
||||
) -> None:
|
||||
assert self._grammar_bitmask is not None
|
||||
for grammar, index, apply_bitmask in batch:
|
||||
if apply_bitmask and not grammar.is_terminated():
|
||||
grammar.fill_bitmask(self._grammar_bitmask, index)
|
||||
else:
|
||||
# Note that for thinking support, we will need to
|
||||
# reset the relevant part of the bitmask for consequent
|
||||
# requests here.
|
||||
self._grammar_bitmask[index].fill_(self._full_mask)
|
||||
|
||||
def _async_submit_fill_bitmask(
|
||||
self,
|
||||
batch: list[tuple[StructuredOutputGrammar, int, bool]],
|
||||
) -> Future:
|
||||
return self.executor_for_fillmask.submit(self._fill_bitmasks, batch)
|
||||
|
||||
def grammar_bitmask(
|
||||
self,
|
||||
requests: dict[str, Request],
|
||||
@ -146,7 +177,6 @@ class StructuredOutputManager:
|
||||
self.backend.allocate_token_bitmask(
|
||||
max_batch_size * (1 + max_num_spec_tokens))
|
||||
|
||||
bitmask_tensor = self._grammar_bitmask
|
||||
# Generate a batched bitmask for all structured output requests.
|
||||
# When speculative decoding is enabled, we need to include multiple
|
||||
# masks for each request, one for each possible bonus token position.
|
||||
@ -155,47 +185,61 @@ class StructuredOutputManager:
|
||||
ordered_seq = sorted(structured_output_request_ids.items(),
|
||||
key=lambda x: x[1])
|
||||
|
||||
# Note that for thinking support, we will need to
|
||||
# reset the relevant part of the bitmask for consequent
|
||||
# request here.
|
||||
bitmask_tensor[:(len(ordered_seq) * (1 + max_num_spec_tokens))].fill_(
|
||||
self._full_mask)
|
||||
# Optimized parallel filling of bitmasks for
|
||||
# non-spec, large-batch-size cases
|
||||
if len(ordered_seq) > self.fill_bitmask_parallel_threshold and \
|
||||
max_num_spec_tokens == 0:
|
||||
promises = []
|
||||
batch = []
|
||||
for req_id, _ in ordered_seq:
|
||||
request = requests[req_id]
|
||||
structured_output_request = request.structured_output_request
|
||||
if TYPE_CHECKING:
|
||||
assert structured_output_request is not None
|
||||
assert structured_output_request.grammar is not None
|
||||
|
||||
# NOTE: This outer loop can likely be parallelized to improve
|
||||
# performance of bitmask generation for large batches.
|
||||
for req_id, _ in ordered_seq:
|
||||
request = requests[req_id]
|
||||
structured_output_request = request.structured_output_request
|
||||
apply_bitmask = self.should_fill_bitmask(request)
|
||||
batch.append((structured_output_request.grammar,
|
||||
cumulative_index, apply_bitmask))
|
||||
if len(batch) == self.fill_bitmask_parallel_batch_size:
|
||||
promises.append(self._async_submit_fill_bitmask(batch))
|
||||
batch = []
|
||||
|
||||
if TYPE_CHECKING:
|
||||
assert structured_output_request is not None
|
||||
assert structured_output_request.grammar is not None
|
||||
apply_bitmask: bool = True
|
||||
if self.reasoner is not None:
|
||||
if structured_output_request.reasoning_ended is None:
|
||||
structured_output_request.reasoning_ended = \
|
||||
self.reasoner.is_reasoning_end(request.prompt_token_ids)
|
||||
apply_bitmask = structured_output_request.reasoning_ended
|
||||
cumulative_index += 1
|
||||
if batch:
|
||||
promises.append(self._async_submit_fill_bitmask(batch))
|
||||
|
||||
state_advancements = 0
|
||||
req_tokens = scheduled_spec_decode_tokens.get(req_id, []) + [None]
|
||||
for i, token in enumerate(req_tokens):
|
||||
if apply_bitmask and not \
|
||||
structured_output_request.grammar.is_terminated():
|
||||
structured_output_request.grammar.fill_bitmask(
|
||||
bitmask_tensor, cumulative_index)
|
||||
if token is not None:
|
||||
# In order to generate the correct bitmask for each
|
||||
# position in the speculative sequence, we advance
|
||||
# the FSM state for each speculative token and rollback
|
||||
# to restore the previous state when we are finished.
|
||||
# Wait for all bitmask filling tasks to complete.
|
||||
for promise in promises:
|
||||
promise.result()
|
||||
else:
|
||||
# Fallback to serial filling of bitmasks for small-batch-size cases
|
||||
for req_id, _ in ordered_seq:
|
||||
request = requests[req_id]
|
||||
structured_output_request = request.structured_output_request
|
||||
|
||||
if TYPE_CHECKING:
|
||||
assert structured_output_request is not None
|
||||
assert structured_output_request.grammar is not None
|
||||
apply_bitmask = self.should_fill_bitmask(request)
|
||||
|
||||
state_advancements = 0
|
||||
req_tokens = scheduled_spec_decode_tokens.get(req_id, [])
|
||||
for i, token in enumerate(req_tokens + [None]):
|
||||
self._fill_bitmasks([(structured_output_request.grammar,
|
||||
cumulative_index, apply_bitmask)])
|
||||
|
||||
if apply_bitmask and token is not None and \
|
||||
not structured_output_request.grammar.is_terminated():
|
||||
assert structured_output_request.grammar.accept_tokens(
|
||||
req_id, [token])
|
||||
state_advancements += 1
|
||||
cumulative_index += 1
|
||||
if state_advancements > 0:
|
||||
structured_output_request.grammar.rollback(state_advancements)
|
||||
cumulative_index += 1
|
||||
if state_advancements > 0:
|
||||
structured_output_request.grammar.rollback(
|
||||
state_advancements)
|
||||
|
||||
bitmask_tensor = self._grammar_bitmask
|
||||
if cumulative_index < bitmask_tensor.shape[0]:
|
||||
bitmask_tensor = bitmask_tensor[:cumulative_index]
|
||||
|
||||
@ -204,6 +248,15 @@ class StructuredOutputManager:
|
||||
# and deserialization when sending this to the GPU workers.
|
||||
return bitmask_tensor.numpy()
|
||||
|
||||
def should_fill_bitmask(self, request: Request) -> bool:
|
||||
if self.reasoner is not None:
|
||||
assert request.structured_output_request is not None
|
||||
if request.structured_output_request.reasoning_ended is None:
|
||||
request.structured_output_request.reasoning_ended = \
|
||||
self.reasoner.is_reasoning_end(request.prompt_token_ids)
|
||||
return request.structured_output_request.reasoning_ended
|
||||
return True
|
||||
|
||||
def should_advance(self, request: Request) -> bool:
|
||||
if not request.use_structured_output:
|
||||
return False
|
||||
|
||||
@ -148,6 +148,7 @@ class XgrammarGrammar(StructuredOutputGrammar):
|
||||
repr=False,
|
||||
hash=False,
|
||||
init=False)
|
||||
_is_terminated: bool = field(default=False, repr=False, hash=False)
|
||||
|
||||
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
|
||||
"""Accepts a list of tokens and advances the FSM.
|
||||
@ -155,6 +156,8 @@ class XgrammarGrammar(StructuredOutputGrammar):
|
||||
Returns True if the FSM was advanced successfully.
|
||||
Returns False if the FSM failed to advance.
|
||||
"""
|
||||
if self._is_terminated:
|
||||
return False
|
||||
for token in tokens:
|
||||
if not self.matcher.accept_token(token):
|
||||
logger.error(
|
||||
@ -162,6 +165,7 @@ class XgrammarGrammar(StructuredOutputGrammar):
|
||||
"for tokens %s. Please file an issue.", request_id, token)
|
||||
return False
|
||||
self.num_processed_tokens += 1
|
||||
self._is_terminated = self.matcher.is_terminated()
|
||||
return True
|
||||
|
||||
def validate_tokens(self, tokens: list[int]) -> list[int]:
|
||||
@ -184,12 +188,13 @@ class XgrammarGrammar(StructuredOutputGrammar):
|
||||
def rollback(self, num_tokens: int) -> None:
|
||||
self.matcher.rollback(num_tokens)
|
||||
self.num_processed_tokens -= num_tokens
|
||||
self._is_terminated = self.matcher.is_terminated()
|
||||
|
||||
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
|
||||
self.matcher.fill_next_token_bitmask(bitmask, idx)
|
||||
|
||||
def is_terminated(self) -> bool:
|
||||
return self.matcher.is_terminated()
|
||||
return self._is_terminated
|
||||
|
||||
def reset(self):
|
||||
self.num_processed_tokens = 0
|
||||
|
||||
@ -1324,9 +1324,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
cumulative_index += 1 + num_spec_tokens
|
||||
grammar_bitmask = sorted_bitmask
|
||||
|
||||
# If the grammar bitmask and the logits have the same shape
|
||||
# we don't need to pass indices to the kernel,
|
||||
# since the bitmask is already aligned with the logits.
|
||||
skip_out_indices = grammar_bitmask.shape[0] == logits.shape[0]
|
||||
|
||||
# Serialization of np.ndarray is much more efficient than a tensor,
|
||||
# so we receive it in that format.
|
||||
grammar_bitmask = torch.from_numpy(grammar_bitmask)
|
||||
grammar_bitmask = torch.from_numpy(grammar_bitmask).contiguous()
|
||||
|
||||
# Force use of the torch.compile implementation from xgrammar to work
|
||||
# around issues with the Triton kernel in concurrent structured output
|
||||
@ -1334,7 +1339,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
xgr_torch_compile.apply_token_bitmask_inplace_torch_compile(
|
||||
logits,
|
||||
grammar_bitmask.to(self.device, non_blocking=True),
|
||||
indices=out_indices,
|
||||
indices=out_indices if not skip_out_indices else None,
|
||||
)
|
||||
|
||||
def sync_and_slice_intermediate_tensors(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user