[Structured Output][Refactor] Move apply_grammar_bitmask() method from ModelRunner to structured output utils (#21999)

Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
Shanshan Shen 2025-09-18 20:44:31 +08:00 committed by GitHub
parent 21da73343a
commit 470484a4f5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 84 additions and 71 deletions

View File

@ -8,7 +8,9 @@ import importlib.metadata
import os
from typing import TYPE_CHECKING
import numpy as np
import regex as re
import torch
from cachetools import LRUCache
from diskcache import Cache
@ -20,9 +22,13 @@ if TYPE_CHECKING:
import outlines_core as oc
import transformers.file_utils as file_utils
import transformers.models.gpt2.tokenization_gpt2 as tokenization_gpt2
import xgrammar as xgr
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
else:
xgr = LazyLoader("xgr", globals(), "xgrammar")
oc = LazyLoader("oc", globals(), "outlines_core")
file_utils = LazyLoader("file_utils", globals(), "transformers.file_utils")
tokenization_gpt2 = LazyLoader(
@ -36,6 +42,80 @@ logger = init_logger(__name__)
CACHE = None
def apply_grammar_bitmask(
scheduler_output: SchedulerOutput,
input_batch: InputBatch,
logits: torch.Tensor,
device: torch.device,
) -> None:
"""
Apply grammar bitmask to output logits of the model with xgrammar function.
Args:
scheduler_output (SchedulerOutput): The result of engine scheduling.
input_batch (InputBatch): The input of model runner.
logits (torch.Tensor): The output logits of model forward.
device (torch.device): The device that model runner running on.
"""
grammar_bitmask = scheduler_output.grammar_bitmask
if grammar_bitmask is None:
return
# We receive the structured output bitmask from the scheduler,
# compacted to contain bitmasks only for structured output requests.
# The order of the requests in the bitmask is not guaranteed to be the
# same as the order of the requests in the gpu runner's batch. We need
# to sort the bitmask to match the order of the requests used here.
# Get the batch indices of the structured output requests.
# Keep track of the number of speculative tokens scheduled for every
# request in the batch, as the logit indices are offset by this amount.
struct_out_req_batch_indices: dict[str, int] = {}
cumulative_offset = 0
seq = sorted(input_batch.req_id_to_index.items(), key=lambda x: x[1])
for req_id, batch_index in seq:
logit_index = batch_index + cumulative_offset
cumulative_offset += len(
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
if req_id in scheduler_output.structured_output_request_ids:
struct_out_req_batch_indices[req_id] = logit_index
out_indices = []
# Reorder the bitmask to match the order of the requests in the batch.
sorted_bitmask = np.full(shape=(logits.shape[0], grammar_bitmask.shape[1]),
fill_value=-1,
dtype=grammar_bitmask.dtype)
cumulative_index = 0
seq = sorted(scheduler_output.structured_output_request_ids.items(),
key=lambda x: x[1])
for req_id, _ in seq:
logit_index = struct_out_req_batch_indices[req_id]
num_spec_tokens = len(
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
for i in range(1 + num_spec_tokens):
sorted_bitmask[logit_index + i] = \
grammar_bitmask[cumulative_index + i]
out_indices.append(logit_index + i)
cumulative_index += 1 + num_spec_tokens
grammar_bitmask = sorted_bitmask
# If the length of out indices and the logits have the same shape
# we don't need to pass indices to the kernel,
# since the bitmask is already aligned with the logits.
skip_out_indices = len(out_indices) == logits.shape[0]
# Serialization of np.ndarray is much more efficient than a tensor,
# so we receive it in that format.
grammar_bitmask = torch.from_numpy(grammar_bitmask).contiguous()
xgr.apply_token_bitmask_inplace(
logits,
grammar_bitmask.to(device, non_blocking=True),
indices=out_indices if not skip_out_indices else None,
)
class OutlinesVocabulary:
"""
Wrapper class for `outlines_core.Vocabulary`,

View File

@ -54,7 +54,7 @@ from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size,
GiB_bytes, check_use_alibi, get_dtype_size,
is_pin_memory_available, round_up, supports_dynamo)
from vllm.v1.attention.backends.flash_attn import AttentionMetadata
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
@ -85,6 +85,7 @@ from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.medusa import MedusaProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.structured_output.utils import apply_grammar_bitmask
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
@ -101,12 +102,8 @@ from .utils import (AttentionGroup, MultiModalBudget,
scatter_mm_placeholders)
if TYPE_CHECKING:
import xgrammar as xgr
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.v1.core.sched.output import SchedulerOutput
else:
xgr = LazyLoader("xgr", globals(), "xgrammar")
logger = init_logger(__name__)
@ -1617,71 +1614,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return tuple(tasks)
def apply_grammar_bitmask(
self,
scheduler_output: "SchedulerOutput",
logits: torch.Tensor,
):
grammar_bitmask = scheduler_output.grammar_bitmask
if grammar_bitmask is None:
return
# We receive the structured output bitmask from the scheduler,
# compacted to contain bitmasks only for structured output requests.
# The order of the requests in the bitmask is not guaranteed to be the
# same as the order of the requests in the gpu runner's batch. We need
# to sort the bitmask to match the order of the requests used here.
# Get the batch indices of the structured output requests.
# Keep track of the number of speculative tokens scheduled for every
# request in the batch, as the logit indices are offset by this amount.
struct_out_req_batch_indices: dict[str, int] = {}
cumulative_offset = 0
seq = sorted(self.input_batch.req_id_to_index.items(),
key=lambda x: x[1])
for req_id, batch_index in seq:
logit_index = batch_index + cumulative_offset
cumulative_offset += len(
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
if req_id in scheduler_output.structured_output_request_ids:
struct_out_req_batch_indices[req_id] = logit_index
out_indices = []
# Reorder the bitmask to match the order of the requests in the batch.
sorted_bitmask = np.full(shape=(logits.shape[0],
grammar_bitmask.shape[1]),
fill_value=-1,
dtype=grammar_bitmask.dtype)
cumulative_index = 0
seq = sorted(scheduler_output.structured_output_request_ids.items(),
key=lambda x: x[1])
for req_id, _ in seq:
logit_index = struct_out_req_batch_indices[req_id]
num_spec_tokens = len(
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
for i in range(1 + num_spec_tokens):
sorted_bitmask[logit_index + i] = \
grammar_bitmask[cumulative_index + i]
out_indices.append(logit_index + i)
cumulative_index += 1 + num_spec_tokens
grammar_bitmask = sorted_bitmask
# If the length of out indices and the logits have the same shape
# we don't need to pass indices to the kernel,
# since the bitmask is already aligned with the logits.
skip_out_indices = len(out_indices) == logits.shape[0]
# Serialization of np.ndarray is much more efficient than a tensor,
# so we receive it in that format.
grammar_bitmask = torch.from_numpy(grammar_bitmask).contiguous()
xgr.apply_token_bitmask_inplace(
logits,
grammar_bitmask.to(self.device, non_blocking=True),
indices=out_indices if not skip_out_indices else None,
)
def sync_and_slice_intermediate_tensors(
self, num_tokens: int, intermediate_tensors: IntermediateTensors,
sync_self: bool) -> IntermediateTensors:
@ -2232,7 +2164,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Apply structured output bitmasks if present
if scheduler_output.grammar_bitmask is not None:
self.apply_grammar_bitmask(scheduler_output, logits)
apply_grammar_bitmask(scheduler_output, self.input_batch,
logits, self.device)
with record_function_or_nullcontext("Sample"):
sampler_output = self._sample(logits, spec_decode_metadata)