diff --git a/vllm/v1/structured_output/utils.py b/vllm/v1/structured_output/utils.py index 953185a8fc31d..127c8876525b5 100644 --- a/vllm/v1/structured_output/utils.py +++ b/vllm/v1/structured_output/utils.py @@ -8,7 +8,9 @@ import importlib.metadata import os from typing import TYPE_CHECKING +import numpy as np import regex as re +import torch from cachetools import LRUCache from diskcache import Cache @@ -20,9 +22,13 @@ if TYPE_CHECKING: import outlines_core as oc import transformers.file_utils as file_utils import transformers.models.gpt2.tokenization_gpt2 as tokenization_gpt2 + import xgrammar as xgr from vllm.transformers_utils.tokenizer import AnyTokenizer + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.worker.gpu_input_batch import InputBatch else: + xgr = LazyLoader("xgr", globals(), "xgrammar") oc = LazyLoader("oc", globals(), "outlines_core") file_utils = LazyLoader("file_utils", globals(), "transformers.file_utils") tokenization_gpt2 = LazyLoader( @@ -36,6 +42,80 @@ logger = init_logger(__name__) CACHE = None +def apply_grammar_bitmask( + scheduler_output: SchedulerOutput, + input_batch: InputBatch, + logits: torch.Tensor, + device: torch.device, +) -> None: + """ + Apply grammar bitmask to output logits of the model with xgrammar function. + + Args: + scheduler_output (SchedulerOutput): The result of engine scheduling. + input_batch (InputBatch): The input of model runner. + logits (torch.Tensor): The output logits of model forward. + device (torch.device): The device that model runner running on. + """ + grammar_bitmask = scheduler_output.grammar_bitmask + if grammar_bitmask is None: + return + + # We receive the structured output bitmask from the scheduler, + # compacted to contain bitmasks only for structured output requests. + # The order of the requests in the bitmask is not guaranteed to be the + # same as the order of the requests in the gpu runner's batch. We need + # to sort the bitmask to match the order of the requests used here. + + # Get the batch indices of the structured output requests. + # Keep track of the number of speculative tokens scheduled for every + # request in the batch, as the logit indices are offset by this amount. + struct_out_req_batch_indices: dict[str, int] = {} + cumulative_offset = 0 + seq = sorted(input_batch.req_id_to_index.items(), key=lambda x: x[1]) + for req_id, batch_index in seq: + logit_index = batch_index + cumulative_offset + cumulative_offset += len( + scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) + if req_id in scheduler_output.structured_output_request_ids: + struct_out_req_batch_indices[req_id] = logit_index + + out_indices = [] + + # Reorder the bitmask to match the order of the requests in the batch. + sorted_bitmask = np.full(shape=(logits.shape[0], grammar_bitmask.shape[1]), + fill_value=-1, + dtype=grammar_bitmask.dtype) + cumulative_index = 0 + seq = sorted(scheduler_output.structured_output_request_ids.items(), + key=lambda x: x[1]) + for req_id, _ in seq: + logit_index = struct_out_req_batch_indices[req_id] + num_spec_tokens = len( + scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) + for i in range(1 + num_spec_tokens): + sorted_bitmask[logit_index + i] = \ + grammar_bitmask[cumulative_index + i] + out_indices.append(logit_index + i) + cumulative_index += 1 + num_spec_tokens + grammar_bitmask = sorted_bitmask + + # If the length of out indices and the logits have the same shape + # we don't need to pass indices to the kernel, + # since the bitmask is already aligned with the logits. + skip_out_indices = len(out_indices) == logits.shape[0] + + # Serialization of np.ndarray is much more efficient than a tensor, + # so we receive it in that format. + grammar_bitmask = torch.from_numpy(grammar_bitmask).contiguous() + + xgr.apply_token_bitmask_inplace( + logits, + grammar_bitmask.to(device, non_blocking=True), + indices=out_indices if not skip_out_indices else None, + ) + + class OutlinesVocabulary: """ Wrapper class for `outlines_core.Vocabulary`, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 2e67984cb4327..4873b586724ec 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -54,7 +54,7 @@ from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.tasks import GenerationTask, PoolingTask, SupportedTask from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size, + GiB_bytes, check_use_alibi, get_dtype_size, is_pin_memory_available, round_up, supports_dynamo) from vllm.v1.attention.backends.flash_attn import AttentionMetadata from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder @@ -85,6 +85,7 @@ from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer +from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper @@ -101,12 +102,8 @@ from .utils import (AttentionGroup, MultiModalBudget, scatter_mm_placeholders) if TYPE_CHECKING: - import xgrammar as xgr - from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.v1.core.sched.output import SchedulerOutput -else: - xgr = LazyLoader("xgr", globals(), "xgrammar") logger = init_logger(__name__) @@ -1617,71 +1614,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return tuple(tasks) - def apply_grammar_bitmask( - self, - scheduler_output: "SchedulerOutput", - logits: torch.Tensor, - ): - grammar_bitmask = scheduler_output.grammar_bitmask - if grammar_bitmask is None: - return - - # We receive the structured output bitmask from the scheduler, - # compacted to contain bitmasks only for structured output requests. - # The order of the requests in the bitmask is not guaranteed to be the - # same as the order of the requests in the gpu runner's batch. We need - # to sort the bitmask to match the order of the requests used here. - - # Get the batch indices of the structured output requests. - # Keep track of the number of speculative tokens scheduled for every - # request in the batch, as the logit indices are offset by this amount. - struct_out_req_batch_indices: dict[str, int] = {} - cumulative_offset = 0 - seq = sorted(self.input_batch.req_id_to_index.items(), - key=lambda x: x[1]) - for req_id, batch_index in seq: - logit_index = batch_index + cumulative_offset - cumulative_offset += len( - scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) - if req_id in scheduler_output.structured_output_request_ids: - struct_out_req_batch_indices[req_id] = logit_index - - out_indices = [] - - # Reorder the bitmask to match the order of the requests in the batch. - sorted_bitmask = np.full(shape=(logits.shape[0], - grammar_bitmask.shape[1]), - fill_value=-1, - dtype=grammar_bitmask.dtype) - cumulative_index = 0 - seq = sorted(scheduler_output.structured_output_request_ids.items(), - key=lambda x: x[1]) - for req_id, _ in seq: - logit_index = struct_out_req_batch_indices[req_id] - num_spec_tokens = len( - scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) - for i in range(1 + num_spec_tokens): - sorted_bitmask[logit_index + i] = \ - grammar_bitmask[cumulative_index + i] - out_indices.append(logit_index + i) - cumulative_index += 1 + num_spec_tokens - grammar_bitmask = sorted_bitmask - - # If the length of out indices and the logits have the same shape - # we don't need to pass indices to the kernel, - # since the bitmask is already aligned with the logits. - skip_out_indices = len(out_indices) == logits.shape[0] - - # Serialization of np.ndarray is much more efficient than a tensor, - # so we receive it in that format. - grammar_bitmask = torch.from_numpy(grammar_bitmask).contiguous() - - xgr.apply_token_bitmask_inplace( - logits, - grammar_bitmask.to(self.device, non_blocking=True), - indices=out_indices if not skip_out_indices else None, - ) - def sync_and_slice_intermediate_tensors( self, num_tokens: int, intermediate_tensors: IntermediateTensors, sync_self: bool) -> IntermediateTensors: @@ -2232,7 +2164,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Apply structured output bitmasks if present if scheduler_output.grammar_bitmask is not None: - self.apply_grammar_bitmask(scheduler_output, logits) + apply_grammar_bitmask(scheduler_output, self.input_batch, + logits, self.device) with record_function_or_nullcontext("Sample"): sampler_output = self._sample(logits, spec_decode_metadata)