Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-09-15 11:17:54 -07:00
parent 9a6fcca030
commit 8b3c13c485
13 changed files with 886 additions and 2000 deletions

View File

@ -30,7 +30,6 @@ class NewRequestData:
mm_features: list[MultiModalFeatureSpec]
sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams]
block_ids: tuple[list[int], ...]
num_computed_tokens: int
lora_request: Optional[LoRARequest]
@ -46,7 +45,6 @@ class NewRequestData:
mm_features=request.mm_features,
sampling_params=request.sampling_params,
pooling_params=request.pooling_params,
block_ids=block_ids,
num_computed_tokens=request.num_computed_tokens,
lora_request=request.lora_request,
)
@ -57,7 +55,6 @@ class NewRequestData:
f"prompt_token_ids={self.prompt_token_ids},"
f"mm_features={self.mm_features},"
f"sampling_params={self.sampling_params},"
f"block_ids={self.block_ids},"
f"num_computed_tokens={self.num_computed_tokens},"
f"lora_request={self.lora_request}"
")")
@ -69,7 +66,6 @@ class NewRequestData:
f"prompt_token_ids_len={len(self.prompt_token_ids)},"
f"mm_features={self.mm_features},"
f"sampling_params={self.sampling_params},"
f"block_ids={self.block_ids},"
f"num_computed_tokens={self.num_computed_tokens},"
f"lora_request={self.lora_request}"
")")
@ -77,52 +73,17 @@ class NewRequestData:
@bc_linter_include
@dataclass
class CachedRequestData:
class SchedulerOutput:
req_ids: list[str]
# If resumed_from_preemption is False, new_block_ids will be appended to
# the request's block IDs. If True, new_block_ids will be used as the
# request's block IDs instead of appending to the existing block IDs.
resumed_from_preemption: list[bool]
# NOTE(woosuk): new_token_ids is only used for pipeline parallelism.
# When PP is not used, new_token_ids will be empty.
new_token_ids: list[list[int]]
new_block_ids: list[Optional[tuple[list[int], ...]]]
num_computed_tokens: list[int]
@property
def num_reqs(self) -> int:
return len(self.req_ids)
@classmethod
def make_empty(cls) -> CachedRequestData:
return cls(
req_ids=[],
resumed_from_preemption=[],
new_token_ids=[],
new_block_ids=[],
num_computed_tokens=[],
)
@bc_linter_include
@dataclass
class SchedulerOutput:
cu_new_block_ids: tuple[np.ndarray, ...]
# list of the requests that are scheduled for the first time.
# We cache the request's data in each worker process, so that we don't
# need to re-send it every scheduling step.
scheduled_new_reqs: list[NewRequestData]
# list of the requests that have been scheduled before.
# Since the request's data is already cached in the worker processes,
# we only send the diff to minimize the communication cost.
scheduled_cached_reqs: CachedRequestData
# req_id -> num_scheduled_tokens
# Number of tokens scheduled for each request.
num_scheduled_tokens: dict[str, int]
# Total number of tokens scheduled for all requests.
# Equal to sum(num_scheduled_tokens.values())
total_num_scheduled_tokens: int
# req_id -> spec_token_ids
# If a request does not have any spec decode tokens, it will not be
@ -136,13 +97,11 @@ class SchedulerOutput:
# This can be used for cascade attention.
num_common_prefix_blocks: list[int]
preempted_req_ids: set[str]
# Request IDs that are finished in between the previous and the current
# steps. This is used to notify the workers about the finished requests
# so that they can free the cached states for those requests.
finished_req_ids: set[str]
# list of mm_hash strings associated with the encoder outputs to be
# freed from the encoder cache.
free_encoder_mm_hashes: list[str]
# Dict of request ids to their index within the batch
# for filling the next token bitmask

View File

@ -3,6 +3,7 @@
import copy
from dataclasses import dataclass, fields
from functools import cached_property
from math import prod
from typing import Optional
@ -273,3 +274,10 @@ class KVCacheConfig:
see `_get_kv_cache_config_uniform_page_size` for more details.
"""
kv_cache_groups: list[KVCacheGroupSpec]
@cached_property
def block_sizes(self) -> list[int]:
return [
kv_cache_group.kv_cache_spec.block_size
for kv_cache_group in self.kv_cache_groups
]

View File

@ -6,39 +6,14 @@ from typing import Optional
import torch
from vllm.v1.sample.logits_processor import LogitsProcessors
@dataclass
class SamplingMetadata:
temperature: torch.Tensor
all_greedy: bool
all_random: bool
top_p: Optional[torch.Tensor]
top_k: Optional[torch.Tensor]
generators: dict[int, torch.Generator]
# None means no logprobs, 0 means sampled token logprobs only
max_num_logprobs: Optional[int]
no_penalties: bool
frequency_penalties: torch.Tensor
presence_penalties: torch.Tensor
repetition_penalties: torch.Tensor
token_ids: Optional[torch.Tensor]
num_tokens: Optional[torch.Tensor]
num_prompt_tokens: Optional[torch.Tensor]
# `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size,
# vocab size).
allowed_token_ids_mask: Optional[torch.Tensor]
# req_index -> bad_words_token_ids
bad_words_token_ids: dict[int, list[list[int]]]
# Loaded logits processors
logitsprocs: LogitsProcessors

View File

View File

@ -0,0 +1,80 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec,
SlidingWindowSpec)
def get_kv_cache_spec(
vllm_config: VllmConfig,
kv_cache_dtype: torch.dtype,
) -> dict[str, KVCacheSpec]:
block_size = vllm_config.cache_config.block_size
use_mla = vllm_config.model_config.use_mla
kv_cache_spec: dict[str, KVCacheSpec] = {}
attn_layers = get_layers_from_vllm_config(vllm_config, Attention)
for layer_name, attn_module in attn_layers.items():
assert attn_module.attn_type == AttentionType.DECODER
if attn_module.sliding_window is not None:
kv_cache_spec[layer_name] = SlidingWindowSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=kv_cache_dtype,
sliding_window=attn_module.sliding_window,
use_mla=use_mla,
)
else:
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=kv_cache_dtype,
use_mla=use_mla,
)
return kv_cache_spec
def init_attn_backend(vllm_config: VllmConfig):
attn_layers = get_layers_from_vllm_config(vllm_config, Attention)
def _allocate_kv_cache(
kv_cache_config: KVCacheConfig,
device: torch.device,
):
kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
tensor = torch.zeros(kv_cache_tensor.size,
dtype=torch.int8,
device=device)
for layer_name in kv_cache_tensor.shared_by:
kv_cache_raw_tensors[layer_name] = tensor
layer_names = set()
for group in kv_cache_config.kv_cache_groups:
for layer_name in group.layer_names:
layer_names.add(layer_name)
assert layer_names == set(kv_cache_raw_tensors.keys()
), "Some layers are not correctly initialized"
return kv_cache_raw_tensors
def _reshape_kv_cache(
kv_cache_config: KVCacheConfig,
kv_cache_raw_tensors: dict[str, torch.Tensor],
):
pass
def init_kv_cache(
kv_cache_config: KVCacheConfig,
device: torch.device,
):
kv_cache_raw_tensors = _allocate_kv_cache(kv_cache_config, device)
kv_caches = _reshape_kv_cache(kv_cache_config, kv_cache_raw_tensors)

View File

@ -18,7 +18,6 @@ class BlockTables:
self,
block_sizes: list[int],
max_num_reqs: int,
max_num_cached_reqs: int,
max_num_batched_tokens: int,
max_model_len: int,
device: torch.device,
@ -26,21 +25,17 @@ class BlockTables:
):
self.block_sizes = block_sizes
self.max_num_reqs = max_num_reqs
self.max_num_cached_reqs = max_num_cached_reqs
self.max_num_batched_tokens = max_num_batched_tokens
self.max_model_len = max_model_len
self.device = device
self.pin_memory = pin_memory
self.num_kv_cache_groups = len(self.block_sizes)
# [num_kv_cache_groups, max_num_reqs, max_num_blocks]
# num_kv_cache_groups x [max_num_reqs, max_num_blocks]
self.block_tables: list[torch.Tensor] = []
# [num_kv_cache_groups, max_num_cached_reqs, max_num_blocks]
self.block_table_buffers: list[torch.Tensor] = []
for i in range(self.num_kv_cache_groups):
block_size = self.block_sizes[i]
max_num_blocks = cdiv(self.max_model_len, block_size)
block_table = torch.zeros(
self.max_num_reqs,
max_num_blocks,
@ -48,17 +43,16 @@ class BlockTables:
device=self.device,
)
self.block_tables.append(block_table)
block_table_buffer = torch.zeros(
self.max_num_cached_reqs,
max_num_blocks,
dtype=torch.int32,
device=self.device,
)
self.block_table_buffers.append(block_table_buffer)
self.block_table_ptrs = self._make_ptr_tensor(self.block_tables)
self.buffer_ptrs = self._make_ptr_tensor(self.block_table_buffers)
# Block tables used for model's forward pass.
# num_kv_cache_groups x [max_num_reqs, max_num_blocks]
self.input_block_tables: list[torch.Tensor] = [
torch.zeros_like(block_table) for block_table in self.block_tables
]
self.input_block_table_ptrs = self._make_ptr_tensor(
self.input_block_tables)
self.block_table_strides = torch.tensor(
[b.stride(0) for b in self.block_tables],
dtype=torch.int64,
@ -67,7 +61,7 @@ class BlockTables:
dtype=torch.int32,
device=self.device)
self.num_blocks = torch.zeros(self.num_kv_cache_groups,
self.max_num_cached_reqs,
self.max_num_reqs,
dtype=torch.int32,
device=self.device)
self.slot_mappings = torch.zeros(self.num_kv_cache_groups,
@ -107,7 +101,6 @@ class BlockTables:
# [num_reqs]
overwrite: list[bool],
) -> None:
# TODO(woosuk): Optimize & simplify this.
num_reqs = len(req_indices)
self.req_indices.np[:num_reqs] = req_indices
self.overwrite.np[:num_reqs] = overwrite
@ -115,19 +108,21 @@ class BlockTables:
self.cu_num_new_blocks.np[i, :num_reqs + 1] = cu_num_new_blocks[i]
# NOTE(woosuk): Here, we cannot use a fixed-size buffer because there's
# no clear upper bound on the number of new blocks.
new_block_ids_cpu = torch.empty(
# no clear upper bound to the number of new blocks in a single step.
# NOTE(woosuk): The buffer has to be cached, because otherwise we cannot
# guarantee that the buffer is not freed before the copy is completed.
self.new_block_ids_cpu = torch.empty(
self.num_kv_cache_groups,
max(len(x) for x in new_block_ids),
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory,
)
new_block_ids_np = new_block_ids_cpu.numpy()
new_block_ids_np = self.new_block_ids_cpu.numpy()
for i in range(self.num_kv_cache_groups):
new_block_ids_np[i, :len(new_block_ids[i])] = new_block_ids[i]
new_block_ids_gpu = new_block_ids_cpu.to(self.device,
non_blocking=True)
new_block_ids_gpu = self.new_block_ids_cpu.to(self.device,
non_blocking=True)
_append_block_ids_kernel[(self.num_kv_cache_groups, num_reqs)](
self.req_indices.copy_to_gpu(num_reqs),
@ -137,33 +132,34 @@ class BlockTables:
new_block_ids_gpu.stride(0),
self.overwrite.copy_to_gpu(num_reqs),
self.block_table_strides,
self.buffer_ptrs,
self.block_table_ptrs,
self.num_blocks,
self.num_blocks.stride(0),
BLOCK_SIZE=1024,
)
def compute_block_tables(
def gather_block_tables(
self,
idx_mapping: torch.Tensor,
) -> tuple[torch.Tensor, ...]:
batch_size = idx_mapping.shape[0]
_compute_block_tables_kernel[(self.num_kv_cache_groups, batch_size)](
num_reqs = idx_mapping.shape[0]
_gather_block_tables_kernel[(self.num_kv_cache_groups, num_reqs)](
idx_mapping,
self.buffer_ptrs,
self.block_table_ptrs,
self.input_block_table_ptrs,
self.block_table_strides,
self.num_blocks,
self.num_blocks.stride(0),
BLOCK_SIZE=1024,
)
return tuple(b[:batch_size] for b in self.block_tables)
return tuple(block_table[:num_reqs]
for block_table in self.input_block_tables)
def compute_slot_mappings(
self,
query_start_loc: torch.Tensor,
positions: torch.Tensor,
) -> tuple[torch.Tensor, ...]:
) -> torch.Tensor:
num_reqs = query_start_loc.shape[0] - 1
num_tokens = positions.shape[0]
num_groups = self.num_kv_cache_groups
@ -172,7 +168,7 @@ class BlockTables:
self.max_num_batched_tokens,
query_start_loc,
positions,
self.block_table_ptrs,
self.input_block_table_ptrs,
self.block_table_strides,
self.block_sizes_tensor,
self.slot_mappings,
@ -180,7 +176,7 @@ class BlockTables:
PAD_ID=PAD_SLOT_ID,
BLOCK_SIZE=1024,
)
return tuple(x[:num_tokens] for x in self.slot_mappings)
return self.slot_mappings[:, :num_tokens]
@triton.jit
@ -194,8 +190,8 @@ def _append_block_ids_kernel(
overwrite, # [num_reqs]
block_table_strides, # [num_kv_cache_groups]
# Outputs
block_table_buffer_ptrs, # [num_kv_cache_groups]
num_blocks_ptr, # [num_kv_cache_groups, max_num_cached_reqs]
block_table_ptrs, # [num_kv_cache_groups]
num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs]
num_blocks_stride,
# Constants
BLOCK_SIZE: tl.constexpr,
@ -220,10 +216,9 @@ def _append_block_ids_kernel(
tl.store(group_num_blocks_ptr + req_idx, dst_end_idx)
# Destination
block_table_buffer_ptr = _load_ptr(block_table_buffer_ptrs + group_id,
tl.int32)
block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
block_table_stride = tl.load(block_table_strides + group_id)
buffer_row_ptr = block_table_buffer_ptr + req_idx * block_table_stride
row_ptr = block_table_ptr + req_idx * block_table_stride
group_new_block_ids_ptr = (new_block_ids_ptr +
group_id * new_block_ids_stride)
@ -231,18 +226,18 @@ def _append_block_ids_kernel(
offset = i + tl.arange(0, BLOCK_SIZE)
block_ids = tl.load(group_new_block_ids_ptr + start_idx + offset,
mask=offset < num_new_blocks)
tl.store(buffer_row_ptr + dst_start_idx + offset,
tl.store(row_ptr + dst_start_idx + offset,
block_ids,
mask=offset < num_new_blocks)
@triton.jit
def _compute_block_tables_kernel(
def _gather_block_tables_kernel(
batch_idx_to_req_idx, # [batch_size]
src_block_table_ptrs, # [num_kv_cache_groups]
dst_block_table_ptrs, # [num_kv_cache_groups]
block_table_strides, # [num_kv_cache_groups]
num_blocks_ptr, # [num_kv_cache_groups, max_num_cached_reqs]
num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs]
num_blocks_stride,
BLOCK_SIZE: tl.constexpr,
):

View File

@ -0,0 +1,26 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model_loader
from vllm.utils import DeviceMemoryProfiler, GiB_bytes
logger = init_logger(__name__)
def load_model(vllm_config: VllmConfig):
time_before_load = time.perf_counter()
with DeviceMemoryProfiler() as m:
model_loader = get_model_loader(vllm_config.load_config)
logger.info("Loading model from scratch...")
model = model_loader.load_model(vllm_config=vllm_config,
model_config=vllm_config.model_config)
time_after_load = time.perf_counter()
logger.info("Model loading took %.4f GiB and %.6f seconds",
m.consumed_memory / GiB_bytes,
time_after_load - time_before_load)
return model

View File

@ -1,14 +1,42 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Any, Optional
from typing import Any
import numba
import numpy as np
import torch
from numba import types
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.utils import CpuGpuBuffer
class InputBuffers:
def __init__(
self,
max_num_reqs: int,
max_num_tokens: int,
device: torch.device,
pin_memory: bool,
):
self.max_num_reqs = max_num_reqs
self.max_num_tokens = max_num_tokens
self.device = device
self.pin_memory = pin_memory
self.idx_mapping = self._make_buffer(max_num_reqs, dtype=torch.int32)
self.input_ids = self._make_buffer(max_num_tokens, dtype=torch.int32)
self.positions = self._make_buffer(max_num_tokens, dtype=torch.int64)
self.query_start_loc = self._make_buffer(max_num_reqs + 1,
dtype=torch.int32)
self.seq_lens = self._make_buffer(max_num_reqs, dtype=torch.int32)
def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(*args,
dtype=dtype,
pin_memory=self.pin_memory,
device=self.device)
@dataclass
@ -16,9 +44,7 @@ class InputBatch:
# batch_idx -> req_id
req_ids: list[str]
# req_id -> batch_idx
req_id_to_batch_idx: dict[str, int]
num_reqs: int
# batch_idx -> req_state_idx
idx_mapping: torch.Tensor
@ -26,14 +52,18 @@ class InputBatch:
# batch_idx -> num_scheduled_tokens
num_scheduled_tokens: np.ndarray
total_num_tokens: int
max_query_len: int
num_reqs: int
# sum(num_scheduled_tokens)
num_tokens: int
# [max_num_batched_tokens]
input_ids: torch.Tensor
# [max_num_batched_tokens]
positions: torch.Tensor
# layer_name -> Metadata
attn_metadata: dict[str, Any]
spec_decode_common_attn_metadata: Optional[Any]
spec_decode_metadata: Optional[SpecDecodeMetadata]
# [num_reqs]
logits_indices: torch.Tensor
@ -47,9 +77,9 @@ class InputBatch:
types.int32[:], # num_computed_tokens
types.int32[:], # num_scheduled_tokens
types.int32[:], # input_ids
types.int64[:], # positions
types.int32[:], # query_start_loc
types.int32[:], # seq_lens
types.int64[:], # positions
)
],
nopython=True,
@ -61,9 +91,9 @@ def prepare_inputs(
num_computed_tokens: np.ndarray, # [N]
num_scheduled_tokens: np.ndarray, # [B]
input_ids: np.ndarray, # [num_input_tokens]
positions: np.ndarray, # [num_input_tokens]
query_start_loc: np.ndarray, # [B + 1]
seq_lens: np.ndarray, # [B]
positions: np.ndarray, # [num_input_tokens]
) -> None:
num_reqs = num_scheduled_tokens.shape[0]
query_start_loc[0] = 0

View File

@ -0,0 +1,290 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from copy import deepcopy
from typing import Any
import numpy as np
import torch
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, is_pin_memory_available
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.sample.sampler import SamplerOutput
from vllm.v1.worker.gpu.attn_utils import get_kv_cache_spec, init_attn_backend
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.init_utils import load_model
from vllm.v1.worker.gpu.input_batch import (InputBatch, InputBuffers,
prepare_inputs)
from vllm.v1.worker.gpu.sampler import Sampler
from vllm.v1.worker.gpu.states import RequestState
logger = init_logger(__name__)
class GPUModelRunner:
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.compilation_config = vllm_config.compilation_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.speculative_config = vllm_config.speculative_config
self.observability_config = vllm_config.observability_config
self.device = device
self.pin_memory = is_pin_memory_available()
self.dtype = self.model_config.dtype
if self.cache_config.cache_dtype == "auto":
self.kv_cache_dtype = self.dtype
else:
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
self.cache_config.cache_dtype]
self.vocab_size = self.model_config.get_vocab_size()
self.max_model_len = self.model_config.max_model_len
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.max_num_reqs = self.scheduler_config.max_num_seqs
self.req_states = RequestState(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_batched_tokens=self.max_num_tokens,
vocab_size=self.vocab_size,
device=self.device,
pin_memory=self.pin_memory,
)
self.input_buffers = InputBuffers(
max_num_reqs=self.max_num_reqs,
max_num_tokens=self.max_num_tokens,
device=self.device,
pin_memory=self.pin_memory,
)
self.sampler = Sampler()
def load_model(self) -> None:
self.model = load_model(self.vllm_config)
def get_kv_cache_spec(self):
return get_kv_cache_spec(self.vllm_config, self.kv_cache_dtype)
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
kv_cache_config = deepcopy(kv_cache_config)
self.kv_cache_config = kv_cache_config
block_sizes = kv_cache_config.block_sizes
self.block_tables = BlockTables(
block_sizes=block_sizes,
max_num_reqs=self.max_num_reqs,
max_num_batched_tokens=self.max_num_tokens,
max_model_len=self.max_model_len,
device=self.device,
pin_memory=self.pin_memory,
)
self.attn_metadata_builders = init_attn_backend(self.vllm_config)
def update_states(self, scheduler_output: SchedulerOutput) -> None:
for req_id in scheduler_output.preempted_req_ids:
self.req_states.remove_request(req_id)
for req_id in scheduler_output.finished_req_ids:
self.req_states.remove_request(req_id)
# TODO(woosuk): Change SchedulerOutput.
req_indices: list[int] = []
cu_num_new_blocks = tuple(
[0] for _ in range(self.block_tables.num_kv_cache_groups))
new_block_ids = tuple(
[] for _ in range(self.block_tables.num_kv_cache_groups))
overwrite: list[bool] = []
# Add new requests to the cached states.
for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = new_req_data.req_id
self.req_states.add_request(
req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids,
num_computed_tokens=new_req_data.num_computed_tokens,
sampling_params=new_req_data.sampling_params,
)
req_index = self.req_states.req_id_to_index[req_id]
req_indices.append(req_index)
for i, block_ids in enumerate(new_req_data.block_ids):
x = cu_num_new_blocks[i][-1]
cu_num_new_blocks[i].append(x + len(block_ids))
new_block_ids[i].extend(block_ids)
overwrite.append(True)
# Update the states of the running/resumed requests.
cached_reqs = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(cached_reqs.req_ids):
req_index = self.req_states.req_id_to_index[req_id]
req_new_block_ids = cached_reqs.new_block_ids[i]
if req_new_block_ids is not None:
req_indices.append(req_index)
for group_id, block_ids in enumerate(req_new_block_ids):
x = cu_num_new_blocks[group_id][-1]
cu_num_new_blocks[group_id].append(x + len(block_ids))
new_block_ids[group_id].extend(block_ids)
overwrite.append(False)
self.req_states.num_computed_tokens[req_index] = (
cached_reqs.num_computed_tokens[i])
if req_indices:
self.block_tables.append_block_ids(
req_indices=req_indices,
cu_num_new_blocks=cu_num_new_blocks,
new_block_ids=new_block_ids,
overwrite=overwrite,
)
def prepare_inputs(self, scheduler_output: SchedulerOutput) -> InputBatch:
num_tokens = scheduler_output.total_num_scheduled_tokens
assert num_tokens > 0
num_reqs = len(scheduler_output.num_scheduled_tokens)
# Decode first, then prefill.
# batch_idx -> req_id
req_ids = sorted(scheduler_output.num_scheduled_tokens,
key=scheduler_output.num_scheduled_tokens.get)
num_scheduled_tokens = np.array(
[scheduler_output.num_scheduled_tokens[i] for i in req_ids],
dtype=np.int32)
idx_mapping_list = [
self.req_states.req_id_to_index[req_id] for req_id in req_ids
]
self.input_buffers.idx_mapping.np[:num_reqs] = idx_mapping_list
idx_mapping_np = self.input_buffers.idx_mapping.np[:num_reqs]
idx_mapping = self.input_buffers.idx_mapping.copy_to_gpu(num_reqs)
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
block_tables = self.block_tables.gather_block_tables(idx_mapping)
input_ids = self.input_buffers.input_ids
positions = self.input_buffers.positions
query_start_loc = self.input_buffers.query_start_loc
seq_lens = self.input_buffers.seq_lens
prepare_inputs(
idx_mapping_np,
self.req_states.token_ids,
self.req_states.num_computed_tokens,
num_scheduled_tokens,
input_ids.np,
positions.np,
query_start_loc.np,
seq_lens.np,
)
input_ids.copy_to_gpu(num_tokens)
positions.copy_to_gpu(num_tokens)
# NOTE(woosuk): We should copy the whole query_start_loc and seq_lens
# tensors from CPU to GPU, because they may include paddings needed
# for full CUDA graph mode.
query_start_loc.copy_to_gpu()
query_start_loc = query_start_loc.gpu[:num_reqs + 1]
max_query_len = int(num_scheduled_tokens.max())
seq_lens.copy_to_gpu()
seq_lens_np = seq_lens.np[:num_reqs]
max_seq_len = int(seq_lens_np.max())
seq_lens = seq_lens.gpu[:num_reqs]
# Slot mappings: [num_kv_cache_groups, num_tokens]
slot_mappings = self.block_tables.compute_slot_mappings(
query_start_loc, positions.gpu[:num_tokens])
logits_indices = query_start_loc[1:] - 1
attn_metadata: dict[str, Any] = {}
for i, kv_cache_spec in enumerate(
self.kv_cache_config.kv_cache_groups):
block_table = block_tables[i]
slot_mapping = slot_mappings[i]
return InputBatch(
req_ids=req_ids,
num_reqs=num_reqs,
idx_mapping=idx_mapping,
idx_mapping_np=idx_mapping_np,
num_scheduled_tokens=num_scheduled_tokens,
num_tokens=num_tokens,
input_ids=input_ids,
positions=positions,
attn_metadata=attn_metadata,
logits_indices=logits_indices,
)
def sample(
self,
input_batch: InputBatch,
logits: torch.Tensor,
) -> SamplerOutput:
sampling_metadata = self.req_states.make_sampling_metadata(
input_batch.idx_mapping_np)
sampler_output = self.sampler(
logits=logits,
sampling_metadata=sampling_metadata,
)
return sampler_output
def postprocess(
self,
input_batch: InputBatch,
sampler_output: SamplerOutput,
) -> np.ndarray:
# Get the number of sampled tokens.
# Handle requests that are chunked-prefilling.
idx_mapping_np = input_batch.idx_mapping_np
num_computed_tokens = self.req_states.num_computed_tokens[
idx_mapping_np]
post_num_computed_tokens = (num_computed_tokens +
input_batch.num_scheduled_tokens)
num_tokens = self.req_states.num_tokens[idx_mapping_np]
is_chunked_prefilling = post_num_computed_tokens < num_tokens
# 0 if chunked-prefilling, 1 if not.
num_sampled_tokens = (~is_chunked_prefilling).astype(np.int32)
# Increment the number of tokens.
self.req_states.num_tokens[idx_mapping_np] += num_sampled_tokens
return num_sampled_tokens
def execute_model(
self,
scheduler_output: SchedulerOutput,
):
self.update_states(scheduler_output)
if scheduler_output.total_num_scheduled_tokens == 0:
return
input_batch = self.prepare_inputs(scheduler_output)
with set_forward_context(
input_batch.attn_metadata,
self.vllm_config,
):
hidden_states = self.model(
input_ids=input_batch.input_ids,
positions=input_batch.positions,
)
sampling_hidden_states = hidden_states[input_batch.logits_indices]
logits = self.model.compute_logits(sampling_hidden_states, None)
sampler_output = self.sample(input_batch, logits)
num_sampled_tokens = self.postprocess(input_batch, sampler_output)
return output

View File

@ -0,0 +1,238 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
import torch.nn as nn
import triton
import triton.language as tl
from vllm.config import LogprobsMode
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
_SAMPLING_EPS = 1e-5
class Sampler(nn.Module):
def __init__(
self,
logprobs_mode: LogprobsMode = LogprobsMode.PROCESSED_LOGPROBS,
):
super().__init__()
assert logprobs_mode == LogprobsMode.PROCESSED_LOGPROBS
self.logprobs_mode = logprobs_mode
def forward(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
# Divide logits by temperature, in FP32.
logits = apply_temperature(logits, sampling_metadata.temperature)
# Apply top_k and/or top_p.
logits = apply_top_k_top_p(
logits,
sampling_metadata.top_k,
sampling_metadata.top_p,
)
# Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
# Sample the next token (int64).
sampled = gumbel_sample(
probs,
sampling_metadata.temperature,
None, # seeds
None, # pos
)
logprobs_tensors = None
num_logprobs = sampling_metadata.max_num_logprobs
if num_logprobs is not None:
assert num_logprobs >= 0
# Compute the logprobs.
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float32)
# Gather the logprobs of the topk and sampled token.
logprobs_tensors = self.gather_logprobs(
logprobs,
num_logprobs,
sampled,
)
# These are GPU tensors.
sampler_output = SamplerOutput(
# The sampled tokens are expanded to 2D tensor with shape
# [num_requests, 1], where each row represents one generated
# token per request.
sampled_token_ids=sampled.unsqueeze(-1),
logprobs_tensors=logprobs_tensors,
)
return sampler_output
def gather_logprobs(
self,
logprobs: torch.Tensor,
num_logprobs: int,
sampled: torch.Tensor,
) -> LogprobsTensors:
sampled = sampled.unsqueeze(-1)
sampled_logprobs = logprobs.gather(-1, sampled)
sampled_ranks = (logprobs > sampled_logprobs).sum(-1)
if num_logprobs == 0:
# Return the logprobs of the sampled token.
logprobs_tensors = LogprobsTensors(
sampled,
sampled_logprobs,
sampled_ranks,
)
else:
# Return (num_logprobs + 1) logprobs.
topk_logprobs, topk_indices = torch.topk(
logprobs,
num_logprobs,
dim=-1,
)
logprobs_tensors = LogprobsTensors(
torch.cat((sampled, topk_indices), dim=1),
torch.cat((sampled_logprobs, topk_logprobs), dim=1),
sampled_ranks,
)
return logprobs_tensors
@triton.jit
def _apply_temp_kernel(
logits, # bf16[batch_size, vocab_size]
logits_stride,
output, # fp32[batch_size, vocab_size]
output_stride,
temperature,
vocab_size,
BLOCK_SIZE: tl.constexpr,
EPSILON: tl.constexpr,
):
batch_idx = tl.program_id(0)
block_idx = tl.program_id(1)
temp = tl.load(temperature + batch_idx)
if temp < EPSILON:
# Greedy sampling. Don't apply temperature.
# NOTE(woosuk): In this case, we assume that its logprobs are not used.
temp = tl.ones([1], dtype=tl.float32)
offset = tl.arange(0, BLOCK_SIZE)
block = block_idx * BLOCK_SIZE + offset
# Load the logits.
x = tl.load(logits + batch_idx * logits_stride + block,
mask=block < vocab_size)
x = x.to(tl.float32)
x = x / temp
tl.store(output + batch_idx * output_stride + block,
x,
mask=block < vocab_size)
def apply_temperature(
logits: torch.Tensor,
temperature: torch.Tensor,
) -> torch.Tensor:
batch_size, vocab_size = logits.shape
output = torch.empty_like(logits, dtype=torch.float32)
BLOCK_SIZE = 8192
_apply_temp_kernel[(batch_size, triton.cdiv(vocab_size, BLOCK_SIZE))](
logits,
logits.stride(0),
output,
output.stride(0),
temperature,
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
EPSILON=_SAMPLING_EPS,
)
return output
@triton.jit
def _apply_gumbel_kernel(
probs_ptr,
probs_stride,
seeds_ptr,
temp_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
EPSILON: tl.constexpr,
):
req_idx = tl.program_id(0)
seed = tl.load(seeds_ptr + req_idx)
temp = tl.load(temp_ptr + req_idx)
if temp < EPSILON:
# Greedy sampling. Don't apply gumbel noise.
return
block_id = tl.program_id(1)
r_offset = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
q = tl.rand(seed, r_offset)
# NOTE(woosuk): This logic makes sure q is not 0.
RMAX = 0.9999999403953552
RMAX_LOG = -5.960464477539063e-08
q = tl.where(q >= RMAX, RMAX_LOG, tl.math.log(q))
q = -1.0 * q
p = tl.load(probs_ptr + req_idx * probs_stride + r_offset,
mask=r_offset < vocab_size)
p = p / q
tl.store(probs_ptr + req_idx * probs_stride + r_offset,
p,
mask=r_offset < vocab_size)
def gumbel_sample(
# fp32[num_reqs, vocab_size]
probs: torch.Tensor,
# fp32[num_reqs]
temperature: torch.Tensor,
# int64[num_reqs]
seeds: Optional[torch.Tensor],
# int64[num_reqs]
pos: Optional[torch.Tensor],
) -> torch.Tensor:
num_reqs = probs.shape[0]
vocab_size = probs.shape[1]
if seeds is not None:
# Per-request seed.
assert pos is not None
gumbel_seeds = seeds + pos
else:
# Global seed.
assert pos is None
seed_dtype = torch.int64
gumbel_seeds = torch.randint(
torch.iinfo(seed_dtype).min,
torch.iinfo(seed_dtype).max,
(num_reqs, ),
dtype=seed_dtype,
device=probs.device,
)
# Update the probs in-place.
BLOCK_SIZE = 8192
_apply_gumbel_kernel[(num_reqs, triton.cdiv(vocab_size, BLOCK_SIZE))](
probs,
probs.stride(0),
gumbel_seeds,
temperature,
vocab_size,
BLOCK_SIZE,
EPSILON=_SAMPLING_EPS,
)
# Sample the next token.
return probs.argmax(dim=-1).view(-1)

View File

@ -0,0 +1,143 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Union
import numpy as np
import torch
from vllm.sampling_params import SamplingParams
from vllm.v1.sample.metadata import SamplingMetadata
class RequestState:
def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_batched_tokens: int,
vocab_size: int,
device: torch.device,
pin_memory: bool,
):
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_batched_tokens = max_num_batched_tokens
self.vocab_size = vocab_size
self.device = device
self.pin_memory = pin_memory
self.req_id_to_index: dict[str, int] = {}
self.index_to_req_id: dict[int, str] = {}
self.free_indices = list(range(max_num_reqs))
# TODO(woosuk): Because the token_ids tensor can be very big, we only
# initialize it on CPU memory.
self.token_ids = np.zeros(
(self.max_num_reqs, self.max_model_len),
dtype=np.int32,
)
self.num_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
self.num_computed_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
# Last sampled token ids.
self.last_token = torch.zeros(
self.max_num_reqs,
dtype=torch.int32,
device=self.device,
)
# Sampling parameters.
self.temperature = np.zeros(self.max_num_reqs, dtype=np.float32)
self.top_p = np.zeros(self.max_num_reqs, dtype=np.float32)
self.top_k = np.zeros(self.max_num_reqs, dtype=np.int32)
self.num_logprobs = np.empty(self.max_num_reqs, dtype=np.int32)
# -1 means no logprobs are requested.
self.num_logprobs.fill(-1)
@property
def num_reqs(self) -> int:
return len(self.req_id_to_index)
def add_request(
self,
req_id: str,
prompt_token_ids: list[int],
num_computed_tokens: int,
sampling_params: SamplingParams,
) -> None:
assert len(self.free_indices) > 0
req_idx = self.free_indices.pop()
self.req_id_to_index[req_id] = req_idx
self.index_to_req_id[req_idx] = req_id
prompt_len = len(prompt_token_ids)
self.num_tokens[req_idx] = prompt_len
self.token_ids[req_idx, :prompt_len] = prompt_token_ids
self.num_computed_tokens[req_idx] = num_computed_tokens
self.temperature[req_idx] = sampling_params.temperature
self.top_p[req_idx] = sampling_params.top_p
if 0 < sampling_params.top_k < self.vocab_size:
top_k = sampling_params.top_k
else:
top_k = self.vocab_size
self.top_k[req_idx] = top_k
if sampling_params.num_logprobs is not None:
num_logprobs = sampling_params.num_logprobs
else:
num_logprobs = -1
self.num_logprobs[req_idx] = num_logprobs
def append_token_ids(
self,
req_idx: int,
token_ids: Union[list[int], np.ndarray],
) -> None:
start_idx = self.num_tokens[req_idx]
end_idx = start_idx + len(token_ids)
self.token_ids[req_idx, start_idx:end_idx] = token_ids
self.num_tokens[req_idx] = end_idx
def remove_request(self, req_id: str) -> None:
req_idx = self.req_id_to_index.pop(req_id, None)
if req_idx is None:
# Request not found.
return
self.index_to_req_id.pop(req_idx, None)
self.free_indices.append(req_idx)
def make_sampling_metadata(
self,
idx_mapping: np.ndarray,
) -> SamplingMetadata:
temperature = self.temperature[idx_mapping]
temperature = self._copy_np_to_gpu(temperature)
top_p = self.top_p[idx_mapping]
no_top_p = np.all(top_p == 1.0)
top_p = self._copy_np_to_gpu(top_p) if not no_top_p else None
top_k = self.top_k[idx_mapping]
no_top_k = np.all(top_k == self.vocab_size)
top_k = self._copy_np_to_gpu(top_k) if not no_top_k else None
num_logprobs = self.num_logprobs[idx_mapping]
max_num_logprobs = np.max(num_logprobs)
if max_num_logprobs == -1:
max_num_logprobs = None
return SamplingMetadata(
temperature=temperature,
top_p=top_p,
top_k=top_k,
max_num_logprobs=max_num_logprobs,
)
def _copy_np_to_gpu(self, src: np.ndarray) -> torch.Tensor:
cpu_tensor = torch.from_numpy(src)
if self.pin_memory:
cpu_tensor = cpu_tensor.pin_memory()
return cpu_tensor.to(device=self.device, non_blocking=True)

File diff suppressed because it is too large Load Diff

View File

@ -1,291 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Optional, Union
import numpy as np
import torch
import triton
import triton.language as tl
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.v1.sample.logits_processor import LogitsProcessors
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
_MAX_SPEC_LEN = 32
@dataclass
class RequestData:
mm_kwargs: list[MultiModalKwargsItem]
mm_positions: list[PlaceholderRange]
sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams]
mm_hashes: list[str]
# M-RoPE (only for Qwen2/2.5-VL)
mrope_positions: Optional[torch.Tensor] = None
mrope_position_delta: Optional[int] = None
lora_request: Optional[LoRARequest] = None
class RequestState:
def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_batched_tokens: int,
max_num_cached_reqs: int,
device: torch.device,
pin_memory: bool,
vocab_size: int,
logitsprocs: Optional[LogitsProcessors] = None,
is_spec_decode: bool = False,
is_pooling_model: bool = False,
):
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_batched_tokens = max_num_batched_tokens
self.max_num_cached_reqs = max_num_cached_reqs
self.device = device
self.pin_memory = pin_memory
self.vocab_size = vocab_size
self.is_spec_decode = is_spec_decode
self.pooling_params = None
self.num_prompt_logprobs: dict[int, int] = {}
self.req_id_to_index: dict[str, int] = {}
self.index_to_req_id: dict[int, str] = {}
self.free_indices = list(range(max_num_cached_reqs))
# Request states.
self.req_data: dict[int, RequestData] = {}
# TODO(woosuk): Because the token_ids tensor can be very big, we only
# initialize it on CPU memory.
self.token_ids = np.zeros(
(self.max_num_cached_reqs, self.max_model_len),
dtype=np.int32,
)
self.num_prompt_tokens = np.zeros(self.max_num_cached_reqs,
dtype=np.int32)
self.num_tokens = np.zeros(self.max_num_cached_reqs, dtype=np.int32)
self.num_computed_tokens = np.zeros(self.max_num_cached_reqs,
dtype=np.int32)
# Last sampled token ids.
self.last_sampled_token = torch.zeros(
self.max_num_cached_reqs,
dtype=torch.int32,
device=self.device,
)
self.temperature = np.zeros(self.max_num_cached_reqs, dtype=np.float32)
self.top_p = np.zeros(self.max_num_cached_reqs, dtype=np.float32)
self.top_k = np.zeros(self.max_num_cached_reqs, dtype=np.int32)
self.frequency_penalties = np.zeros(self.max_num_cached_reqs,
dtype=np.float32)
self.presence_penalties = np.zeros(self.max_num_cached_reqs,
dtype=np.float32)
self.repetition_penalties = np.zeros(self.max_num_cached_reqs,
dtype=np.float32)
self.generators: dict[int, torch.Generator] = {}
def add_request(
self,
req_id: str,
prompt_token_ids: list[int],
num_computed_tokens: int,
sampling_params: SamplingParams,
) -> None:
assert len(self.free_indices) > 0, "No free space in GPU worker states"
req_idx = self.free_indices.pop()
self.req_id_to_index[req_id] = req_idx
self.index_to_req_id[req_idx] = req_id
prompt_len = len(prompt_token_ids)
self.num_prompt_tokens[req_idx] = prompt_len
self.num_tokens[req_idx] = prompt_len
self.token_ids[req_idx, :prompt_len] = prompt_token_ids
self.num_computed_tokens[req_idx] = num_computed_tokens
self.temperature[req_idx] = sampling_params.temperature
self.top_p[req_idx] = sampling_params.top_p
if 0 < sampling_params.top_k < self.vocab_size:
top_k = sampling_params.top_k
else:
top_k = self.vocab_size
self.top_k[req_idx] = top_k
self.frequency_penalties[req_idx] = sampling_params.frequency_penalty
self.presence_penalties[req_idx] = sampling_params.presence_penalty
self.repetition_penalties[req_idx] = sampling_params.repetition_penalty
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
generator = torch.Generator(device=self.device)
generator.manual_seed(sampling_params.seed)
self.generators[req_idx] = generator
@property
def num_cached_reqs(self) -> int:
return len(self.req_id_to_index)
def append_token_ids(
self,
req_idx: int,
token_ids: Union[list[int], np.ndarray],
) -> None:
start_idx = self.num_tokens.np[req_idx]
end_idx = start_idx + len(token_ids)
self.token_ids.np[req_idx, start_idx:end_idx] = token_ids
self.num_tokens.np[req_idx] = end_idx
def remove_request(self, req_id: str) -> None:
req_idx = self.req_id_to_index.pop(req_id, None)
if req_idx is None:
# Request not found.
return
self.index_to_req_id.pop(req_idx, None)
self.free_indices.append(req_idx)
def make_sampling_metadata(
self,
idx_mapping: np.ndarray,
) -> SamplingMetadata:
temperature = self.temperature[idx_mapping]
all_greedy = np.all(temperature == 0.0)
all_random = np.all(temperature != 0.0)
temperature = self._copy_np_to_gpu(temperature)
top_p = self.top_p[idx_mapping]
no_top_p = np.all(top_p == 1.0)
top_p = self._copy_np_to_gpu(top_p) if not no_top_p else None
top_k = self.top_k[idx_mapping]
no_top_k = np.all(top_k == self.vocab_size)
top_k = self._copy_np_to_gpu(top_k) if not no_top_k else None
frequency_penalties = self.frequency_penalties[idx_mapping]
presence_penalties = self.presence_penalties[idx_mapping]
repetition_penalties = self.repetition_penalties[idx_mapping]
no_penalties = (np.all(frequency_penalties == 0.0)
and np.all(presence_penalties == 0.0)
and np.all(repetition_penalties == 1.0))
if no_penalties:
frequency_penalties = None
presence_penalties = None
repetition_penalties = None
else:
frequency_penalties = self._copy_np_to_gpu(frequency_penalties)
presence_penalties = self._copy_np_to_gpu(presence_penalties)
repetition_penalties = self._copy_np_to_gpu(repetition_penalties)
if self.generators:
generators = {
req_idx: self.generators[req_idx]
for req_idx in idx_mapping if req_idx in self.generators
}
else:
generators = {}
return SamplingMetadata(
temperature=temperature,
all_greedy=all_greedy,
all_random=all_random,
top_p=top_p,
top_k=top_k,
frequency_penalties=frequency_penalties,
presence_penalties=presence_penalties,
repetition_penalties=repetition_penalties,
no_penalties=no_penalties,
generators=generators,
token_ids=None,
num_tokens=None,
num_prompt_tokens=None,
max_num_logprobs=None,
allowed_token_ids_mask=None,
bad_words_token_ids={},
logitsprocs=None,
)
def _copy_np_to_gpu(self, src: np.ndarray) -> torch.Tensor:
cpu_tensor = torch.from_numpy(src)
return cpu_tensor.to(device=self.device, non_blocking=True)
def make_spec_decode_metadata(
self,
query_start_loc: torch.Tensor,
cu_num_draft_tokens: torch.Tensor,
cu_num_draft_tokens_np: np.ndarray,
input_ids: torch.Tensor,
) -> SpecDecodeMetadata:
batch_size = query_start_loc.shape[0] - 1
total_num_draft_tokens = cu_num_draft_tokens_np[batch_size - 1]
logits_indices = torch.empty(total_num_draft_tokens + batch_size,
dtype=torch.int32,
device=self.device)
target_logits_indices = torch.empty(total_num_draft_tokens,
dtype=torch.int32,
device=self.device)
bonus_logits_indices = torch.empty(batch_size,
dtype=torch.int32,
device=self.device)
_prepare_spec_decode_kernel[(batch_size, )](
query_start_loc,
cu_num_draft_tokens,
logits_indices,
target_logits_indices,
bonus_logits_indices,
BLOCK_SIZE=triton.next_power_of_2(_MAX_SPEC_LEN + 1),
)
draft_token_ids = input_ids[logits_indices]
draft_token_ids = draft_token_ids[target_logits_indices + 1]
return SpecDecodeMetadata(
draft_token_ids=draft_token_ids,
num_draft_tokens=cu_num_draft_tokens_np.tolist(),
cu_num_draft_tokens=cu_num_draft_tokens,
target_logits_indices=target_logits_indices,
bonus_logits_indices=bonus_logits_indices,
logits_indices=logits_indices,
)
@triton.jit
def _prepare_spec_decode_kernel(
query_start_loc, # [B + 1]
cu_num_draft_tokens, # [B]
logits_indices, # [N + B]
target_logits_indices, # [N]
bonus_logits_indices, # [B]
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
if batch_idx == 0:
draft_start_idx = 0
else:
draft_start_idx = tl.load(cu_num_draft_tokens + batch_idx - 1)
draft_end_idx = tl.load(cu_num_draft_tokens + batch_idx)
draft_len = draft_end_idx - draft_start_idx
sample_len = draft_len + 1
q_end_idx = tl.load(query_start_loc + batch_idx + 1)
sample_start_idx = draft_start_idx + batch_idx
sample_end_idx = sample_start_idx + sample_len
offset = tl.arange(0, BLOCK_SIZE)
tl.store(logits_indices + sample_start_idx + offset,
q_end_idx - sample_len + offset,
mask=offset < sample_len)
tl.store(target_logits_indices + draft_start_idx + offset,
sample_start_idx + offset,
mask=offset < draft_len)
tl.store(bonus_logits_indices + batch_idx, sample_end_idx - 1)