mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:05:01 +08:00
77 lines
2.5 KiB
Python
77 lines
2.5 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import numpy as np
|
|
import torch
|
|
|
|
from vllm.triton_utils import tl, triton
|
|
from vllm.v1.worker.gpu.input_batch import InputBuffers
|
|
|
|
|
|
def apply_grammar_bitmask(
|
|
logits: torch.Tensor,
|
|
req_ids: list[str],
|
|
grammar_req_ids: list[str],
|
|
grammar_bitmask: np.ndarray,
|
|
input_buffers: InputBuffers,
|
|
) -> None:
|
|
input_buffers.grammar_bitmask.np[: grammar_bitmask.shape[0]] = grammar_bitmask
|
|
input_buffers.grammar_bitmask.copy_to_gpu(grammar_bitmask.shape[0])
|
|
|
|
batch_size = logits.shape[0]
|
|
grammar_req_id_to_idx = {req_id: i for i, req_id in enumerate(grammar_req_ids)}
|
|
# logits -> bitmask mapping
|
|
mapping = [grammar_req_id_to_idx.get(req_id, -1) for req_id in req_ids]
|
|
input_buffers.bitmask_indices.np[:batch_size] = mapping
|
|
input_buffers.bitmask_indices.copy_to_gpu(batch_size)
|
|
|
|
vocab_size = logits.shape[-1]
|
|
BLOCK_SIZE = 8192
|
|
grid = (batch_size, triton.cdiv(vocab_size, BLOCK_SIZE))
|
|
_apply_grammar_bitmask_kernel[grid](
|
|
logits,
|
|
logits.stride(0),
|
|
input_buffers.grammar_bitmask.gpu,
|
|
input_buffers.grammar_bitmask.gpu.stride(0),
|
|
input_buffers.bitmask_indices.gpu,
|
|
vocab_size,
|
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
)
|
|
|
|
|
|
# Adapted from
|
|
# https://github.com/mlc-ai/xgrammar/blob/main/python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py
|
|
@triton.jit
|
|
def _apply_grammar_bitmask_kernel(
|
|
logits_ptr,
|
|
logits_stride,
|
|
bitmask_ptr,
|
|
bitmask_stride,
|
|
bitmask_indices_ptr,
|
|
vocab_size,
|
|
BLOCK_SIZE: tl.constexpr,
|
|
):
|
|
logits_idx = tl.program_id(0)
|
|
bitmask_idx = tl.load(bitmask_indices_ptr + logits_idx)
|
|
if bitmask_idx == -1:
|
|
# No bitmask to apply.
|
|
return
|
|
|
|
# Load the bitmask.
|
|
block_id = tl.program_id(1)
|
|
bitmask_offset = (block_id * BLOCK_SIZE) // 32 + tl.arange(0, BLOCK_SIZE // 32)
|
|
packed_bitmask = tl.load(
|
|
bitmask_ptr + bitmask_idx * bitmask_stride + bitmask_offset,
|
|
mask=bitmask_offset < bitmask_stride,
|
|
)
|
|
# Unpack the bitmask.
|
|
bitmask = ((packed_bitmask[:, None] >> (tl.arange(0, 32)[None, :])) & 1) == 0
|
|
bitmask = bitmask.reshape(BLOCK_SIZE)
|
|
|
|
# Apply the bitmask to the logits.
|
|
block_offset = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
tl.store(
|
|
logits_ptr + logits_idx * logits_stride + block_offset,
|
|
-float("inf"),
|
|
mask=bitmask & (block_offset < vocab_size),
|
|
)
|