mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-04 06:09:10 +08:00
[Structured Output][Refactor] Move apply_grammar_bitmask() method from ModelRunner to structured output utils (#21999)
Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
parent
21da73343a
commit
470484a4f5
@ -8,7 +8,9 @@ import importlib.metadata
|
|||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import regex as re
|
import regex as re
|
||||||
|
import torch
|
||||||
from cachetools import LRUCache
|
from cachetools import LRUCache
|
||||||
from diskcache import Cache
|
from diskcache import Cache
|
||||||
|
|
||||||
@ -20,9 +22,13 @@ if TYPE_CHECKING:
|
|||||||
import outlines_core as oc
|
import outlines_core as oc
|
||||||
import transformers.file_utils as file_utils
|
import transformers.file_utils as file_utils
|
||||||
import transformers.models.gpt2.tokenization_gpt2 as tokenization_gpt2
|
import transformers.models.gpt2.tokenization_gpt2 as tokenization_gpt2
|
||||||
|
import xgrammar as xgr
|
||||||
|
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
|
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||||
else:
|
else:
|
||||||
|
xgr = LazyLoader("xgr", globals(), "xgrammar")
|
||||||
oc = LazyLoader("oc", globals(), "outlines_core")
|
oc = LazyLoader("oc", globals(), "outlines_core")
|
||||||
file_utils = LazyLoader("file_utils", globals(), "transformers.file_utils")
|
file_utils = LazyLoader("file_utils", globals(), "transformers.file_utils")
|
||||||
tokenization_gpt2 = LazyLoader(
|
tokenization_gpt2 = LazyLoader(
|
||||||
@ -36,6 +42,80 @@ logger = init_logger(__name__)
|
|||||||
CACHE = None
|
CACHE = None
|
||||||
|
|
||||||
|
|
||||||
|
def apply_grammar_bitmask(
|
||||||
|
scheduler_output: SchedulerOutput,
|
||||||
|
input_batch: InputBatch,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
device: torch.device,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Apply grammar bitmask to output logits of the model with xgrammar function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scheduler_output (SchedulerOutput): The result of engine scheduling.
|
||||||
|
input_batch (InputBatch): The input of model runner.
|
||||||
|
logits (torch.Tensor): The output logits of model forward.
|
||||||
|
device (torch.device): The device that model runner running on.
|
||||||
|
"""
|
||||||
|
grammar_bitmask = scheduler_output.grammar_bitmask
|
||||||
|
if grammar_bitmask is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# We receive the structured output bitmask from the scheduler,
|
||||||
|
# compacted to contain bitmasks only for structured output requests.
|
||||||
|
# The order of the requests in the bitmask is not guaranteed to be the
|
||||||
|
# same as the order of the requests in the gpu runner's batch. We need
|
||||||
|
# to sort the bitmask to match the order of the requests used here.
|
||||||
|
|
||||||
|
# Get the batch indices of the structured output requests.
|
||||||
|
# Keep track of the number of speculative tokens scheduled for every
|
||||||
|
# request in the batch, as the logit indices are offset by this amount.
|
||||||
|
struct_out_req_batch_indices: dict[str, int] = {}
|
||||||
|
cumulative_offset = 0
|
||||||
|
seq = sorted(input_batch.req_id_to_index.items(), key=lambda x: x[1])
|
||||||
|
for req_id, batch_index in seq:
|
||||||
|
logit_index = batch_index + cumulative_offset
|
||||||
|
cumulative_offset += len(
|
||||||
|
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
|
||||||
|
if req_id in scheduler_output.structured_output_request_ids:
|
||||||
|
struct_out_req_batch_indices[req_id] = logit_index
|
||||||
|
|
||||||
|
out_indices = []
|
||||||
|
|
||||||
|
# Reorder the bitmask to match the order of the requests in the batch.
|
||||||
|
sorted_bitmask = np.full(shape=(logits.shape[0], grammar_bitmask.shape[1]),
|
||||||
|
fill_value=-1,
|
||||||
|
dtype=grammar_bitmask.dtype)
|
||||||
|
cumulative_index = 0
|
||||||
|
seq = sorted(scheduler_output.structured_output_request_ids.items(),
|
||||||
|
key=lambda x: x[1])
|
||||||
|
for req_id, _ in seq:
|
||||||
|
logit_index = struct_out_req_batch_indices[req_id]
|
||||||
|
num_spec_tokens = len(
|
||||||
|
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
|
||||||
|
for i in range(1 + num_spec_tokens):
|
||||||
|
sorted_bitmask[logit_index + i] = \
|
||||||
|
grammar_bitmask[cumulative_index + i]
|
||||||
|
out_indices.append(logit_index + i)
|
||||||
|
cumulative_index += 1 + num_spec_tokens
|
||||||
|
grammar_bitmask = sorted_bitmask
|
||||||
|
|
||||||
|
# If the length of out indices and the logits have the same shape
|
||||||
|
# we don't need to pass indices to the kernel,
|
||||||
|
# since the bitmask is already aligned with the logits.
|
||||||
|
skip_out_indices = len(out_indices) == logits.shape[0]
|
||||||
|
|
||||||
|
# Serialization of np.ndarray is much more efficient than a tensor,
|
||||||
|
# so we receive it in that format.
|
||||||
|
grammar_bitmask = torch.from_numpy(grammar_bitmask).contiguous()
|
||||||
|
|
||||||
|
xgr.apply_token_bitmask_inplace(
|
||||||
|
logits,
|
||||||
|
grammar_bitmask.to(device, non_blocking=True),
|
||||||
|
indices=out_indices if not skip_out_indices else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class OutlinesVocabulary:
|
class OutlinesVocabulary:
|
||||||
"""
|
"""
|
||||||
Wrapper class for `outlines_core.Vocabulary`,
|
Wrapper class for `outlines_core.Vocabulary`,
|
||||||
|
|||||||
@ -54,7 +54,7 @@ from vllm.sampling_params import SamplingType
|
|||||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||||
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
|
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
|
||||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||||
GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size,
|
GiB_bytes, check_use_alibi, get_dtype_size,
|
||||||
is_pin_memory_available, round_up, supports_dynamo)
|
is_pin_memory_available, round_up, supports_dynamo)
|
||||||
from vllm.v1.attention.backends.flash_attn import AttentionMetadata
|
from vllm.v1.attention.backends.flash_attn import AttentionMetadata
|
||||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
||||||
@ -85,6 +85,7 @@ from vllm.v1.spec_decode.eagle import EagleProposer
|
|||||||
from vllm.v1.spec_decode.medusa import MedusaProposer
|
from vllm.v1.spec_decode.medusa import MedusaProposer
|
||||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||||
|
from vllm.v1.structured_output.utils import apply_grammar_bitmask
|
||||||
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
|
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
|
||||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||||
from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
|
from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
|
||||||
@ -101,12 +102,8 @@ from .utils import (AttentionGroup, MultiModalBudget,
|
|||||||
scatter_mm_placeholders)
|
scatter_mm_placeholders)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import xgrammar as xgr
|
|
||||||
|
|
||||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
else:
|
|
||||||
xgr = LazyLoader("xgr", globals(), "xgrammar")
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -1617,71 +1614,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
return tuple(tasks)
|
return tuple(tasks)
|
||||||
|
|
||||||
def apply_grammar_bitmask(
|
|
||||||
self,
|
|
||||||
scheduler_output: "SchedulerOutput",
|
|
||||||
logits: torch.Tensor,
|
|
||||||
):
|
|
||||||
grammar_bitmask = scheduler_output.grammar_bitmask
|
|
||||||
if grammar_bitmask is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
# We receive the structured output bitmask from the scheduler,
|
|
||||||
# compacted to contain bitmasks only for structured output requests.
|
|
||||||
# The order of the requests in the bitmask is not guaranteed to be the
|
|
||||||
# same as the order of the requests in the gpu runner's batch. We need
|
|
||||||
# to sort the bitmask to match the order of the requests used here.
|
|
||||||
|
|
||||||
# Get the batch indices of the structured output requests.
|
|
||||||
# Keep track of the number of speculative tokens scheduled for every
|
|
||||||
# request in the batch, as the logit indices are offset by this amount.
|
|
||||||
struct_out_req_batch_indices: dict[str, int] = {}
|
|
||||||
cumulative_offset = 0
|
|
||||||
seq = sorted(self.input_batch.req_id_to_index.items(),
|
|
||||||
key=lambda x: x[1])
|
|
||||||
for req_id, batch_index in seq:
|
|
||||||
logit_index = batch_index + cumulative_offset
|
|
||||||
cumulative_offset += len(
|
|
||||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
|
|
||||||
if req_id in scheduler_output.structured_output_request_ids:
|
|
||||||
struct_out_req_batch_indices[req_id] = logit_index
|
|
||||||
|
|
||||||
out_indices = []
|
|
||||||
|
|
||||||
# Reorder the bitmask to match the order of the requests in the batch.
|
|
||||||
sorted_bitmask = np.full(shape=(logits.shape[0],
|
|
||||||
grammar_bitmask.shape[1]),
|
|
||||||
fill_value=-1,
|
|
||||||
dtype=grammar_bitmask.dtype)
|
|
||||||
cumulative_index = 0
|
|
||||||
seq = sorted(scheduler_output.structured_output_request_ids.items(),
|
|
||||||
key=lambda x: x[1])
|
|
||||||
for req_id, _ in seq:
|
|
||||||
logit_index = struct_out_req_batch_indices[req_id]
|
|
||||||
num_spec_tokens = len(
|
|
||||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
|
|
||||||
for i in range(1 + num_spec_tokens):
|
|
||||||
sorted_bitmask[logit_index + i] = \
|
|
||||||
grammar_bitmask[cumulative_index + i]
|
|
||||||
out_indices.append(logit_index + i)
|
|
||||||
cumulative_index += 1 + num_spec_tokens
|
|
||||||
grammar_bitmask = sorted_bitmask
|
|
||||||
|
|
||||||
# If the length of out indices and the logits have the same shape
|
|
||||||
# we don't need to pass indices to the kernel,
|
|
||||||
# since the bitmask is already aligned with the logits.
|
|
||||||
skip_out_indices = len(out_indices) == logits.shape[0]
|
|
||||||
|
|
||||||
# Serialization of np.ndarray is much more efficient than a tensor,
|
|
||||||
# so we receive it in that format.
|
|
||||||
grammar_bitmask = torch.from_numpy(grammar_bitmask).contiguous()
|
|
||||||
|
|
||||||
xgr.apply_token_bitmask_inplace(
|
|
||||||
logits,
|
|
||||||
grammar_bitmask.to(self.device, non_blocking=True),
|
|
||||||
indices=out_indices if not skip_out_indices else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
def sync_and_slice_intermediate_tensors(
|
def sync_and_slice_intermediate_tensors(
|
||||||
self, num_tokens: int, intermediate_tensors: IntermediateTensors,
|
self, num_tokens: int, intermediate_tensors: IntermediateTensors,
|
||||||
sync_self: bool) -> IntermediateTensors:
|
sync_self: bool) -> IntermediateTensors:
|
||||||
@ -2232,7 +2164,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
# Apply structured output bitmasks if present
|
# Apply structured output bitmasks if present
|
||||||
if scheduler_output.grammar_bitmask is not None:
|
if scheduler_output.grammar_bitmask is not None:
|
||||||
self.apply_grammar_bitmask(scheduler_output, logits)
|
apply_grammar_bitmask(scheduler_output, self.input_batch,
|
||||||
|
logits, self.device)
|
||||||
|
|
||||||
with record_function_or_nullcontext("Sample"):
|
with record_function_or_nullcontext("Sample"):
|
||||||
sampler_output = self._sample(logits, spec_decode_metadata)
|
sampler_output = self._sample(logits, spec_decode_metadata)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user