mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-14 13:47:07 +08:00
wip
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
9a6fcca030
commit
8b3c13c485
@ -30,7 +30,6 @@ class NewRequestData:
|
||||
mm_features: list[MultiModalFeatureSpec]
|
||||
sampling_params: Optional[SamplingParams]
|
||||
pooling_params: Optional[PoolingParams]
|
||||
block_ids: tuple[list[int], ...]
|
||||
num_computed_tokens: int
|
||||
lora_request: Optional[LoRARequest]
|
||||
|
||||
@ -46,7 +45,6 @@ class NewRequestData:
|
||||
mm_features=request.mm_features,
|
||||
sampling_params=request.sampling_params,
|
||||
pooling_params=request.pooling_params,
|
||||
block_ids=block_ids,
|
||||
num_computed_tokens=request.num_computed_tokens,
|
||||
lora_request=request.lora_request,
|
||||
)
|
||||
@ -57,7 +55,6 @@ class NewRequestData:
|
||||
f"prompt_token_ids={self.prompt_token_ids},"
|
||||
f"mm_features={self.mm_features},"
|
||||
f"sampling_params={self.sampling_params},"
|
||||
f"block_ids={self.block_ids},"
|
||||
f"num_computed_tokens={self.num_computed_tokens},"
|
||||
f"lora_request={self.lora_request}"
|
||||
")")
|
||||
@ -69,7 +66,6 @@ class NewRequestData:
|
||||
f"prompt_token_ids_len={len(self.prompt_token_ids)},"
|
||||
f"mm_features={self.mm_features},"
|
||||
f"sampling_params={self.sampling_params},"
|
||||
f"block_ids={self.block_ids},"
|
||||
f"num_computed_tokens={self.num_computed_tokens},"
|
||||
f"lora_request={self.lora_request}"
|
||||
")")
|
||||
@ -77,52 +73,17 @@ class NewRequestData:
|
||||
|
||||
@bc_linter_include
|
||||
@dataclass
|
||||
class CachedRequestData:
|
||||
class SchedulerOutput:
|
||||
|
||||
req_ids: list[str]
|
||||
# If resumed_from_preemption is False, new_block_ids will be appended to
|
||||
# the request's block IDs. If True, new_block_ids will be used as the
|
||||
# request's block IDs instead of appending to the existing block IDs.
|
||||
resumed_from_preemption: list[bool]
|
||||
# NOTE(woosuk): new_token_ids is only used for pipeline parallelism.
|
||||
# When PP is not used, new_token_ids will be empty.
|
||||
new_token_ids: list[list[int]]
|
||||
new_block_ids: list[Optional[tuple[list[int], ...]]]
|
||||
num_computed_tokens: list[int]
|
||||
|
||||
@property
|
||||
def num_reqs(self) -> int:
|
||||
return len(self.req_ids)
|
||||
|
||||
@classmethod
|
||||
def make_empty(cls) -> CachedRequestData:
|
||||
return cls(
|
||||
req_ids=[],
|
||||
resumed_from_preemption=[],
|
||||
new_token_ids=[],
|
||||
new_block_ids=[],
|
||||
num_computed_tokens=[],
|
||||
)
|
||||
|
||||
|
||||
@bc_linter_include
|
||||
@dataclass
|
||||
class SchedulerOutput:
|
||||
cu_new_block_ids: tuple[np.ndarray, ...]
|
||||
|
||||
# list of the requests that are scheduled for the first time.
|
||||
# We cache the request's data in each worker process, so that we don't
|
||||
# need to re-send it every scheduling step.
|
||||
scheduled_new_reqs: list[NewRequestData]
|
||||
# list of the requests that have been scheduled before.
|
||||
# Since the request's data is already cached in the worker processes,
|
||||
# we only send the diff to minimize the communication cost.
|
||||
scheduled_cached_reqs: CachedRequestData
|
||||
|
||||
# req_id -> num_scheduled_tokens
|
||||
# Number of tokens scheduled for each request.
|
||||
num_scheduled_tokens: dict[str, int]
|
||||
# Total number of tokens scheduled for all requests.
|
||||
# Equal to sum(num_scheduled_tokens.values())
|
||||
total_num_scheduled_tokens: int
|
||||
# req_id -> spec_token_ids
|
||||
# If a request does not have any spec decode tokens, it will not be
|
||||
@ -136,13 +97,11 @@ class SchedulerOutput:
|
||||
# This can be used for cascade attention.
|
||||
num_common_prefix_blocks: list[int]
|
||||
|
||||
preempted_req_ids: set[str]
|
||||
# Request IDs that are finished in between the previous and the current
|
||||
# steps. This is used to notify the workers about the finished requests
|
||||
# so that they can free the cached states for those requests.
|
||||
finished_req_ids: set[str]
|
||||
# list of mm_hash strings associated with the encoder outputs to be
|
||||
# freed from the encoder cache.
|
||||
free_encoder_mm_hashes: list[str]
|
||||
|
||||
# Dict of request ids to their index within the batch
|
||||
# for filling the next token bitmask
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass, fields
|
||||
from functools import cached_property
|
||||
from math import prod
|
||||
from typing import Optional
|
||||
|
||||
@ -273,3 +274,10 @@ class KVCacheConfig:
|
||||
see `_get_kv_cache_config_uniform_page_size` for more details.
|
||||
"""
|
||||
kv_cache_groups: list[KVCacheGroupSpec]
|
||||
|
||||
@cached_property
|
||||
def block_sizes(self) -> list[int]:
|
||||
return [
|
||||
kv_cache_group.kv_cache_spec.block_size
|
||||
for kv_cache_group in self.kv_cache_groups
|
||||
]
|
||||
|
||||
@ -6,39 +6,14 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessors
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplingMetadata:
|
||||
|
||||
temperature: torch.Tensor
|
||||
all_greedy: bool
|
||||
all_random: bool
|
||||
|
||||
top_p: Optional[torch.Tensor]
|
||||
top_k: Optional[torch.Tensor]
|
||||
|
||||
generators: dict[int, torch.Generator]
|
||||
|
||||
# None means no logprobs, 0 means sampled token logprobs only
|
||||
max_num_logprobs: Optional[int]
|
||||
|
||||
no_penalties: bool
|
||||
frequency_penalties: torch.Tensor
|
||||
presence_penalties: torch.Tensor
|
||||
repetition_penalties: torch.Tensor
|
||||
|
||||
token_ids: Optional[torch.Tensor]
|
||||
num_tokens: Optional[torch.Tensor]
|
||||
num_prompt_tokens: Optional[torch.Tensor]
|
||||
|
||||
# `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size,
|
||||
# vocab size).
|
||||
allowed_token_ids_mask: Optional[torch.Tensor]
|
||||
|
||||
# req_index -> bad_words_token_ids
|
||||
bad_words_token_ids: dict[int, list[list[int]]]
|
||||
|
||||
# Loaded logits processors
|
||||
logitsprocs: LogitsProcessors
|
||||
|
||||
0
vllm/v1/worker/gpu/__init__.py
Normal file
0
vllm/v1/worker/gpu/__init__.py
Normal file
80
vllm/v1/worker/gpu/attn_utils.py
Normal file
80
vllm/v1/worker/gpu/attn_utils.py
Normal file
@ -0,0 +1,80 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec,
|
||||
SlidingWindowSpec)
|
||||
|
||||
|
||||
def get_kv_cache_spec(
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_dtype: torch.dtype,
|
||||
) -> dict[str, KVCacheSpec]:
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
use_mla = vllm_config.model_config.use_mla
|
||||
|
||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||
attn_layers = get_layers_from_vllm_config(vllm_config, Attention)
|
||||
for layer_name, attn_module in attn_layers.items():
|
||||
assert attn_module.attn_type == AttentionType.DECODER
|
||||
if attn_module.sliding_window is not None:
|
||||
kv_cache_spec[layer_name] = SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=kv_cache_dtype,
|
||||
sliding_window=attn_module.sliding_window,
|
||||
use_mla=use_mla,
|
||||
)
|
||||
else:
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=kv_cache_dtype,
|
||||
use_mla=use_mla,
|
||||
)
|
||||
return kv_cache_spec
|
||||
|
||||
|
||||
def init_attn_backend(vllm_config: VllmConfig):
|
||||
attn_layers = get_layers_from_vllm_config(vllm_config, Attention)
|
||||
|
||||
|
||||
def _allocate_kv_cache(
|
||||
kv_cache_config: KVCacheConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
|
||||
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
||||
tensor = torch.zeros(kv_cache_tensor.size,
|
||||
dtype=torch.int8,
|
||||
device=device)
|
||||
for layer_name in kv_cache_tensor.shared_by:
|
||||
kv_cache_raw_tensors[layer_name] = tensor
|
||||
|
||||
layer_names = set()
|
||||
for group in kv_cache_config.kv_cache_groups:
|
||||
for layer_name in group.layer_names:
|
||||
layer_names.add(layer_name)
|
||||
assert layer_names == set(kv_cache_raw_tensors.keys()
|
||||
), "Some layers are not correctly initialized"
|
||||
return kv_cache_raw_tensors
|
||||
|
||||
|
||||
def _reshape_kv_cache(
|
||||
kv_cache_config: KVCacheConfig,
|
||||
kv_cache_raw_tensors: dict[str, torch.Tensor],
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
def init_kv_cache(
|
||||
kv_cache_config: KVCacheConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
kv_cache_raw_tensors = _allocate_kv_cache(kv_cache_config, device)
|
||||
kv_caches = _reshape_kv_cache(kv_cache_config, kv_cache_raw_tensors)
|
||||
@ -18,7 +18,6 @@ class BlockTables:
|
||||
self,
|
||||
block_sizes: list[int],
|
||||
max_num_reqs: int,
|
||||
max_num_cached_reqs: int,
|
||||
max_num_batched_tokens: int,
|
||||
max_model_len: int,
|
||||
device: torch.device,
|
||||
@ -26,21 +25,17 @@ class BlockTables:
|
||||
):
|
||||
self.block_sizes = block_sizes
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_num_cached_reqs = max_num_cached_reqs
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.max_model_len = max_model_len
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
|
||||
self.num_kv_cache_groups = len(self.block_sizes)
|
||||
# [num_kv_cache_groups, max_num_reqs, max_num_blocks]
|
||||
# num_kv_cache_groups x [max_num_reqs, max_num_blocks]
|
||||
self.block_tables: list[torch.Tensor] = []
|
||||
# [num_kv_cache_groups, max_num_cached_reqs, max_num_blocks]
|
||||
self.block_table_buffers: list[torch.Tensor] = []
|
||||
for i in range(self.num_kv_cache_groups):
|
||||
block_size = self.block_sizes[i]
|
||||
max_num_blocks = cdiv(self.max_model_len, block_size)
|
||||
|
||||
block_table = torch.zeros(
|
||||
self.max_num_reqs,
|
||||
max_num_blocks,
|
||||
@ -48,17 +43,16 @@ class BlockTables:
|
||||
device=self.device,
|
||||
)
|
||||
self.block_tables.append(block_table)
|
||||
|
||||
block_table_buffer = torch.zeros(
|
||||
self.max_num_cached_reqs,
|
||||
max_num_blocks,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
self.block_table_buffers.append(block_table_buffer)
|
||||
|
||||
self.block_table_ptrs = self._make_ptr_tensor(self.block_tables)
|
||||
self.buffer_ptrs = self._make_ptr_tensor(self.block_table_buffers)
|
||||
|
||||
# Block tables used for model's forward pass.
|
||||
# num_kv_cache_groups x [max_num_reqs, max_num_blocks]
|
||||
self.input_block_tables: list[torch.Tensor] = [
|
||||
torch.zeros_like(block_table) for block_table in self.block_tables
|
||||
]
|
||||
self.input_block_table_ptrs = self._make_ptr_tensor(
|
||||
self.input_block_tables)
|
||||
|
||||
self.block_table_strides = torch.tensor(
|
||||
[b.stride(0) for b in self.block_tables],
|
||||
dtype=torch.int64,
|
||||
@ -67,7 +61,7 @@ class BlockTables:
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
self.num_blocks = torch.zeros(self.num_kv_cache_groups,
|
||||
self.max_num_cached_reqs,
|
||||
self.max_num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
self.slot_mappings = torch.zeros(self.num_kv_cache_groups,
|
||||
@ -107,7 +101,6 @@ class BlockTables:
|
||||
# [num_reqs]
|
||||
overwrite: list[bool],
|
||||
) -> None:
|
||||
# TODO(woosuk): Optimize & simplify this.
|
||||
num_reqs = len(req_indices)
|
||||
self.req_indices.np[:num_reqs] = req_indices
|
||||
self.overwrite.np[:num_reqs] = overwrite
|
||||
@ -115,19 +108,21 @@ class BlockTables:
|
||||
self.cu_num_new_blocks.np[i, :num_reqs + 1] = cu_num_new_blocks[i]
|
||||
|
||||
# NOTE(woosuk): Here, we cannot use a fixed-size buffer because there's
|
||||
# no clear upper bound on the number of new blocks.
|
||||
new_block_ids_cpu = torch.empty(
|
||||
# no clear upper bound to the number of new blocks in a single step.
|
||||
# NOTE(woosuk): The buffer has to be cached, because otherwise we cannot
|
||||
# guarantee that the buffer is not freed before the copy is completed.
|
||||
self.new_block_ids_cpu = torch.empty(
|
||||
self.num_kv_cache_groups,
|
||||
max(len(x) for x in new_block_ids),
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
new_block_ids_np = new_block_ids_cpu.numpy()
|
||||
new_block_ids_np = self.new_block_ids_cpu.numpy()
|
||||
for i in range(self.num_kv_cache_groups):
|
||||
new_block_ids_np[i, :len(new_block_ids[i])] = new_block_ids[i]
|
||||
new_block_ids_gpu = new_block_ids_cpu.to(self.device,
|
||||
non_blocking=True)
|
||||
new_block_ids_gpu = self.new_block_ids_cpu.to(self.device,
|
||||
non_blocking=True)
|
||||
|
||||
_append_block_ids_kernel[(self.num_kv_cache_groups, num_reqs)](
|
||||
self.req_indices.copy_to_gpu(num_reqs),
|
||||
@ -137,33 +132,34 @@ class BlockTables:
|
||||
new_block_ids_gpu.stride(0),
|
||||
self.overwrite.copy_to_gpu(num_reqs),
|
||||
self.block_table_strides,
|
||||
self.buffer_ptrs,
|
||||
self.block_table_ptrs,
|
||||
self.num_blocks,
|
||||
self.num_blocks.stride(0),
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
|
||||
def compute_block_tables(
|
||||
def gather_block_tables(
|
||||
self,
|
||||
idx_mapping: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
batch_size = idx_mapping.shape[0]
|
||||
_compute_block_tables_kernel[(self.num_kv_cache_groups, batch_size)](
|
||||
num_reqs = idx_mapping.shape[0]
|
||||
_gather_block_tables_kernel[(self.num_kv_cache_groups, num_reqs)](
|
||||
idx_mapping,
|
||||
self.buffer_ptrs,
|
||||
self.block_table_ptrs,
|
||||
self.input_block_table_ptrs,
|
||||
self.block_table_strides,
|
||||
self.num_blocks,
|
||||
self.num_blocks.stride(0),
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
return tuple(b[:batch_size] for b in self.block_tables)
|
||||
return tuple(block_table[:num_reqs]
|
||||
for block_table in self.input_block_tables)
|
||||
|
||||
def compute_slot_mappings(
|
||||
self,
|
||||
query_start_loc: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
) -> torch.Tensor:
|
||||
num_reqs = query_start_loc.shape[0] - 1
|
||||
num_tokens = positions.shape[0]
|
||||
num_groups = self.num_kv_cache_groups
|
||||
@ -172,7 +168,7 @@ class BlockTables:
|
||||
self.max_num_batched_tokens,
|
||||
query_start_loc,
|
||||
positions,
|
||||
self.block_table_ptrs,
|
||||
self.input_block_table_ptrs,
|
||||
self.block_table_strides,
|
||||
self.block_sizes_tensor,
|
||||
self.slot_mappings,
|
||||
@ -180,7 +176,7 @@ class BlockTables:
|
||||
PAD_ID=PAD_SLOT_ID,
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
return tuple(x[:num_tokens] for x in self.slot_mappings)
|
||||
return self.slot_mappings[:, :num_tokens]
|
||||
|
||||
|
||||
@triton.jit
|
||||
@ -194,8 +190,8 @@ def _append_block_ids_kernel(
|
||||
overwrite, # [num_reqs]
|
||||
block_table_strides, # [num_kv_cache_groups]
|
||||
# Outputs
|
||||
block_table_buffer_ptrs, # [num_kv_cache_groups]
|
||||
num_blocks_ptr, # [num_kv_cache_groups, max_num_cached_reqs]
|
||||
block_table_ptrs, # [num_kv_cache_groups]
|
||||
num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs]
|
||||
num_blocks_stride,
|
||||
# Constants
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
@ -220,10 +216,9 @@ def _append_block_ids_kernel(
|
||||
tl.store(group_num_blocks_ptr + req_idx, dst_end_idx)
|
||||
|
||||
# Destination
|
||||
block_table_buffer_ptr = _load_ptr(block_table_buffer_ptrs + group_id,
|
||||
tl.int32)
|
||||
block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
|
||||
block_table_stride = tl.load(block_table_strides + group_id)
|
||||
buffer_row_ptr = block_table_buffer_ptr + req_idx * block_table_stride
|
||||
row_ptr = block_table_ptr + req_idx * block_table_stride
|
||||
|
||||
group_new_block_ids_ptr = (new_block_ids_ptr +
|
||||
group_id * new_block_ids_stride)
|
||||
@ -231,18 +226,18 @@ def _append_block_ids_kernel(
|
||||
offset = i + tl.arange(0, BLOCK_SIZE)
|
||||
block_ids = tl.load(group_new_block_ids_ptr + start_idx + offset,
|
||||
mask=offset < num_new_blocks)
|
||||
tl.store(buffer_row_ptr + dst_start_idx + offset,
|
||||
tl.store(row_ptr + dst_start_idx + offset,
|
||||
block_ids,
|
||||
mask=offset < num_new_blocks)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _compute_block_tables_kernel(
|
||||
def _gather_block_tables_kernel(
|
||||
batch_idx_to_req_idx, # [batch_size]
|
||||
src_block_table_ptrs, # [num_kv_cache_groups]
|
||||
dst_block_table_ptrs, # [num_kv_cache_groups]
|
||||
block_table_strides, # [num_kv_cache_groups]
|
||||
num_blocks_ptr, # [num_kv_cache_groups, max_num_cached_reqs]
|
||||
num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs]
|
||||
num_blocks_stride,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
26
vllm/v1/worker/gpu/init_utils.py
Normal file
26
vllm/v1/worker/gpu/init_utils.py
Normal file
@ -0,0 +1,26 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import time
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.utils import DeviceMemoryProfiler, GiB_bytes
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def load_model(vllm_config: VllmConfig):
|
||||
time_before_load = time.perf_counter()
|
||||
|
||||
with DeviceMemoryProfiler() as m:
|
||||
model_loader = get_model_loader(vllm_config.load_config)
|
||||
logger.info("Loading model from scratch...")
|
||||
model = model_loader.load_model(vllm_config=vllm_config,
|
||||
model_config=vllm_config.model_config)
|
||||
|
||||
time_after_load = time.perf_counter()
|
||||
logger.info("Model loading took %.4f GiB and %.6f seconds",
|
||||
m.consumed_memory / GiB_bytes,
|
||||
time_after_load - time_before_load)
|
||||
return model
|
||||
@ -1,14 +1,42 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import numba
|
||||
import numpy as np
|
||||
import torch
|
||||
from numba import types
|
||||
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
|
||||
|
||||
class InputBuffers:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_num_tokens: int,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
|
||||
self.idx_mapping = self._make_buffer(max_num_reqs, dtype=torch.int32)
|
||||
self.input_ids = self._make_buffer(max_num_tokens, dtype=torch.int32)
|
||||
self.positions = self._make_buffer(max_num_tokens, dtype=torch.int64)
|
||||
self.query_start_loc = self._make_buffer(max_num_reqs + 1,
|
||||
dtype=torch.int32)
|
||||
self.seq_lens = self._make_buffer(max_num_reqs, dtype=torch.int32)
|
||||
|
||||
def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
|
||||
return CpuGpuBuffer(*args,
|
||||
dtype=dtype,
|
||||
pin_memory=self.pin_memory,
|
||||
device=self.device)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -16,9 +44,7 @@ class InputBatch:
|
||||
|
||||
# batch_idx -> req_id
|
||||
req_ids: list[str]
|
||||
|
||||
# req_id -> batch_idx
|
||||
req_id_to_batch_idx: dict[str, int]
|
||||
num_reqs: int
|
||||
|
||||
# batch_idx -> req_state_idx
|
||||
idx_mapping: torch.Tensor
|
||||
@ -26,14 +52,18 @@ class InputBatch:
|
||||
|
||||
# batch_idx -> num_scheduled_tokens
|
||||
num_scheduled_tokens: np.ndarray
|
||||
total_num_tokens: int
|
||||
max_query_len: int
|
||||
num_reqs: int
|
||||
# sum(num_scheduled_tokens)
|
||||
num_tokens: int
|
||||
|
||||
# [max_num_batched_tokens]
|
||||
input_ids: torch.Tensor
|
||||
# [max_num_batched_tokens]
|
||||
positions: torch.Tensor
|
||||
|
||||
# layer_name -> Metadata
|
||||
attn_metadata: dict[str, Any]
|
||||
spec_decode_common_attn_metadata: Optional[Any]
|
||||
spec_decode_metadata: Optional[SpecDecodeMetadata]
|
||||
|
||||
# [num_reqs]
|
||||
logits_indices: torch.Tensor
|
||||
|
||||
|
||||
@ -47,9 +77,9 @@ class InputBatch:
|
||||
types.int32[:], # num_computed_tokens
|
||||
types.int32[:], # num_scheduled_tokens
|
||||
types.int32[:], # input_ids
|
||||
types.int64[:], # positions
|
||||
types.int32[:], # query_start_loc
|
||||
types.int32[:], # seq_lens
|
||||
types.int64[:], # positions
|
||||
)
|
||||
],
|
||||
nopython=True,
|
||||
@ -61,9 +91,9 @@ def prepare_inputs(
|
||||
num_computed_tokens: np.ndarray, # [N]
|
||||
num_scheduled_tokens: np.ndarray, # [B]
|
||||
input_ids: np.ndarray, # [num_input_tokens]
|
||||
positions: np.ndarray, # [num_input_tokens]
|
||||
query_start_loc: np.ndarray, # [B + 1]
|
||||
seq_lens: np.ndarray, # [B]
|
||||
positions: np.ndarray, # [num_input_tokens]
|
||||
) -> None:
|
||||
num_reqs = num_scheduled_tokens.shape[0]
|
||||
query_start_loc[0] = 0
|
||||
290
vllm/v1/worker/gpu/model_runner.py
Normal file
290
vllm/v1/worker/gpu/model_runner.py
Normal file
@ -0,0 +1,290 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, is_pin_memory_available
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.sample.sampler import SamplerOutput
|
||||
from vllm.v1.worker.gpu.attn_utils import get_kv_cache_spec, init_attn_backend
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.init_utils import load_model
|
||||
from vllm.v1.worker.gpu.input_batch import (InputBatch, InputBuffers,
|
||||
prepare_inputs)
|
||||
from vllm.v1.worker.gpu.sampler import Sampler
|
||||
from vllm.v1.worker.gpu.states import RequestState
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class GPUModelRunner:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.load_config = vllm_config.load_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
|
||||
self.device = device
|
||||
self.pin_memory = is_pin_memory_available()
|
||||
self.dtype = self.model_config.dtype
|
||||
if self.cache_config.cache_dtype == "auto":
|
||||
self.kv_cache_dtype = self.dtype
|
||||
else:
|
||||
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
||||
self.cache_config.cache_dtype]
|
||||
|
||||
self.vocab_size = self.model_config.get_vocab_size()
|
||||
self.max_model_len = self.model_config.max_model_len
|
||||
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
self.max_num_reqs = self.scheduler_config.max_num_seqs
|
||||
|
||||
self.req_states = RequestState(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_model_len=self.max_model_len,
|
||||
max_num_batched_tokens=self.max_num_tokens,
|
||||
vocab_size=self.vocab_size,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
self.input_buffers = InputBuffers(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_num_tokens=self.max_num_tokens,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def load_model(self) -> None:
|
||||
self.model = load_model(self.vllm_config)
|
||||
|
||||
def get_kv_cache_spec(self):
|
||||
return get_kv_cache_spec(self.vllm_config, self.kv_cache_dtype)
|
||||
|
||||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
kv_cache_config = deepcopy(kv_cache_config)
|
||||
self.kv_cache_config = kv_cache_config
|
||||
block_sizes = kv_cache_config.block_sizes
|
||||
|
||||
self.block_tables = BlockTables(
|
||||
block_sizes=block_sizes,
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_num_batched_tokens=self.max_num_tokens,
|
||||
max_model_len=self.max_model_len,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
self.attn_metadata_builders = init_attn_backend(self.vllm_config)
|
||||
|
||||
def update_states(self, scheduler_output: SchedulerOutput) -> None:
|
||||
for req_id in scheduler_output.preempted_req_ids:
|
||||
self.req_states.remove_request(req_id)
|
||||
for req_id in scheduler_output.finished_req_ids:
|
||||
self.req_states.remove_request(req_id)
|
||||
|
||||
# TODO(woosuk): Change SchedulerOutput.
|
||||
req_indices: list[int] = []
|
||||
cu_num_new_blocks = tuple(
|
||||
[0] for _ in range(self.block_tables.num_kv_cache_groups))
|
||||
new_block_ids = tuple(
|
||||
[] for _ in range(self.block_tables.num_kv_cache_groups))
|
||||
overwrite: list[bool] = []
|
||||
|
||||
# Add new requests to the cached states.
|
||||
for new_req_data in scheduler_output.scheduled_new_reqs:
|
||||
req_id = new_req_data.req_id
|
||||
self.req_states.add_request(
|
||||
req_id=req_id,
|
||||
prompt_token_ids=new_req_data.prompt_token_ids,
|
||||
num_computed_tokens=new_req_data.num_computed_tokens,
|
||||
sampling_params=new_req_data.sampling_params,
|
||||
)
|
||||
|
||||
req_index = self.req_states.req_id_to_index[req_id]
|
||||
req_indices.append(req_index)
|
||||
for i, block_ids in enumerate(new_req_data.block_ids):
|
||||
x = cu_num_new_blocks[i][-1]
|
||||
cu_num_new_blocks[i].append(x + len(block_ids))
|
||||
new_block_ids[i].extend(block_ids)
|
||||
overwrite.append(True)
|
||||
|
||||
# Update the states of the running/resumed requests.
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
for i, req_id in enumerate(cached_reqs.req_ids):
|
||||
req_index = self.req_states.req_id_to_index[req_id]
|
||||
|
||||
req_new_block_ids = cached_reqs.new_block_ids[i]
|
||||
if req_new_block_ids is not None:
|
||||
req_indices.append(req_index)
|
||||
for group_id, block_ids in enumerate(req_new_block_ids):
|
||||
x = cu_num_new_blocks[group_id][-1]
|
||||
cu_num_new_blocks[group_id].append(x + len(block_ids))
|
||||
new_block_ids[group_id].extend(block_ids)
|
||||
overwrite.append(False)
|
||||
|
||||
self.req_states.num_computed_tokens[req_index] = (
|
||||
cached_reqs.num_computed_tokens[i])
|
||||
|
||||
if req_indices:
|
||||
self.block_tables.append_block_ids(
|
||||
req_indices=req_indices,
|
||||
cu_num_new_blocks=cu_num_new_blocks,
|
||||
new_block_ids=new_block_ids,
|
||||
overwrite=overwrite,
|
||||
)
|
||||
|
||||
def prepare_inputs(self, scheduler_output: SchedulerOutput) -> InputBatch:
|
||||
num_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
assert num_tokens > 0
|
||||
num_reqs = len(scheduler_output.num_scheduled_tokens)
|
||||
|
||||
# Decode first, then prefill.
|
||||
# batch_idx -> req_id
|
||||
req_ids = sorted(scheduler_output.num_scheduled_tokens,
|
||||
key=scheduler_output.num_scheduled_tokens.get)
|
||||
num_scheduled_tokens = np.array(
|
||||
[scheduler_output.num_scheduled_tokens[i] for i in req_ids],
|
||||
dtype=np.int32)
|
||||
|
||||
idx_mapping_list = [
|
||||
self.req_states.req_id_to_index[req_id] for req_id in req_ids
|
||||
]
|
||||
self.input_buffers.idx_mapping.np[:num_reqs] = idx_mapping_list
|
||||
idx_mapping_np = self.input_buffers.idx_mapping.np[:num_reqs]
|
||||
idx_mapping = self.input_buffers.idx_mapping.copy_to_gpu(num_reqs)
|
||||
|
||||
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
|
||||
block_tables = self.block_tables.gather_block_tables(idx_mapping)
|
||||
|
||||
input_ids = self.input_buffers.input_ids
|
||||
positions = self.input_buffers.positions
|
||||
query_start_loc = self.input_buffers.query_start_loc
|
||||
seq_lens = self.input_buffers.seq_lens
|
||||
|
||||
prepare_inputs(
|
||||
idx_mapping_np,
|
||||
self.req_states.token_ids,
|
||||
self.req_states.num_computed_tokens,
|
||||
num_scheduled_tokens,
|
||||
input_ids.np,
|
||||
positions.np,
|
||||
query_start_loc.np,
|
||||
seq_lens.np,
|
||||
)
|
||||
input_ids.copy_to_gpu(num_tokens)
|
||||
positions.copy_to_gpu(num_tokens)
|
||||
|
||||
# NOTE(woosuk): We should copy the whole query_start_loc and seq_lens
|
||||
# tensors from CPU to GPU, because they may include paddings needed
|
||||
# for full CUDA graph mode.
|
||||
query_start_loc.copy_to_gpu()
|
||||
query_start_loc = query_start_loc.gpu[:num_reqs + 1]
|
||||
max_query_len = int(num_scheduled_tokens.max())
|
||||
|
||||
seq_lens.copy_to_gpu()
|
||||
seq_lens_np = seq_lens.np[:num_reqs]
|
||||
max_seq_len = int(seq_lens_np.max())
|
||||
seq_lens = seq_lens.gpu[:num_reqs]
|
||||
|
||||
# Slot mappings: [num_kv_cache_groups, num_tokens]
|
||||
slot_mappings = self.block_tables.compute_slot_mappings(
|
||||
query_start_loc, positions.gpu[:num_tokens])
|
||||
|
||||
logits_indices = query_start_loc[1:] - 1
|
||||
|
||||
attn_metadata: dict[str, Any] = {}
|
||||
for i, kv_cache_spec in enumerate(
|
||||
self.kv_cache_config.kv_cache_groups):
|
||||
block_table = block_tables[i]
|
||||
slot_mapping = slot_mappings[i]
|
||||
|
||||
return InputBatch(
|
||||
req_ids=req_ids,
|
||||
num_reqs=num_reqs,
|
||||
idx_mapping=idx_mapping,
|
||||
idx_mapping_np=idx_mapping_np,
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
num_tokens=num_tokens,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
attn_metadata=attn_metadata,
|
||||
logits_indices=logits_indices,
|
||||
)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
input_batch: InputBatch,
|
||||
logits: torch.Tensor,
|
||||
) -> SamplerOutput:
|
||||
sampling_metadata = self.req_states.make_sampling_metadata(
|
||||
input_batch.idx_mapping_np)
|
||||
sampler_output = self.sampler(
|
||||
logits=logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
|
||||
return sampler_output
|
||||
|
||||
def postprocess(
|
||||
self,
|
||||
input_batch: InputBatch,
|
||||
sampler_output: SamplerOutput,
|
||||
) -> np.ndarray:
|
||||
# Get the number of sampled tokens.
|
||||
# Handle requests that are chunked-prefilling.
|
||||
idx_mapping_np = input_batch.idx_mapping_np
|
||||
num_computed_tokens = self.req_states.num_computed_tokens[
|
||||
idx_mapping_np]
|
||||
post_num_computed_tokens = (num_computed_tokens +
|
||||
input_batch.num_scheduled_tokens)
|
||||
num_tokens = self.req_states.num_tokens[idx_mapping_np]
|
||||
|
||||
is_chunked_prefilling = post_num_computed_tokens < num_tokens
|
||||
# 0 if chunked-prefilling, 1 if not.
|
||||
num_sampled_tokens = (~is_chunked_prefilling).astype(np.int32)
|
||||
# Increment the number of tokens.
|
||||
self.req_states.num_tokens[idx_mapping_np] += num_sampled_tokens
|
||||
return num_sampled_tokens
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
):
|
||||
self.update_states(scheduler_output)
|
||||
if scheduler_output.total_num_scheduled_tokens == 0:
|
||||
return
|
||||
|
||||
input_batch = self.prepare_inputs(scheduler_output)
|
||||
|
||||
with set_forward_context(
|
||||
input_batch.attn_metadata,
|
||||
self.vllm_config,
|
||||
):
|
||||
hidden_states = self.model(
|
||||
input_ids=input_batch.input_ids,
|
||||
positions=input_batch.positions,
|
||||
)
|
||||
|
||||
sampling_hidden_states = hidden_states[input_batch.logits_indices]
|
||||
logits = self.model.compute_logits(sampling_hidden_states, None)
|
||||
sampler_output = self.sample(input_batch, logits)
|
||||
|
||||
num_sampled_tokens = self.postprocess(input_batch, sampler_output)
|
||||
return output
|
||||
238
vllm/v1/worker/gpu/sampler.py
Normal file
238
vllm/v1/worker/gpu/sampler.py
Normal file
@ -0,0 +1,238 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.config import LogprobsMode
|
||||
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
|
||||
class Sampler(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logprobs_mode: LogprobsMode = LogprobsMode.PROCESSED_LOGPROBS,
|
||||
):
|
||||
super().__init__()
|
||||
assert logprobs_mode == LogprobsMode.PROCESSED_LOGPROBS
|
||||
self.logprobs_mode = logprobs_mode
|
||||
|
||||
def forward(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
# Divide logits by temperature, in FP32.
|
||||
logits = apply_temperature(logits, sampling_metadata.temperature)
|
||||
|
||||
# Apply top_k and/or top_p.
|
||||
logits = apply_top_k_top_p(
|
||||
logits,
|
||||
sampling_metadata.top_k,
|
||||
sampling_metadata.top_p,
|
||||
)
|
||||
|
||||
# Compute the probabilities.
|
||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
|
||||
# Sample the next token (int64).
|
||||
sampled = gumbel_sample(
|
||||
probs,
|
||||
sampling_metadata.temperature,
|
||||
None, # seeds
|
||||
None, # pos
|
||||
)
|
||||
|
||||
logprobs_tensors = None
|
||||
num_logprobs = sampling_metadata.max_num_logprobs
|
||||
if num_logprobs is not None:
|
||||
assert num_logprobs >= 0
|
||||
# Compute the logprobs.
|
||||
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float32)
|
||||
# Gather the logprobs of the topk and sampled token.
|
||||
logprobs_tensors = self.gather_logprobs(
|
||||
logprobs,
|
||||
num_logprobs,
|
||||
sampled,
|
||||
)
|
||||
|
||||
# These are GPU tensors.
|
||||
sampler_output = SamplerOutput(
|
||||
# The sampled tokens are expanded to 2D tensor with shape
|
||||
# [num_requests, 1], where each row represents one generated
|
||||
# token per request.
|
||||
sampled_token_ids=sampled.unsqueeze(-1),
|
||||
logprobs_tensors=logprobs_tensors,
|
||||
)
|
||||
return sampler_output
|
||||
|
||||
def gather_logprobs(
|
||||
self,
|
||||
logprobs: torch.Tensor,
|
||||
num_logprobs: int,
|
||||
sampled: torch.Tensor,
|
||||
) -> LogprobsTensors:
|
||||
sampled = sampled.unsqueeze(-1)
|
||||
sampled_logprobs = logprobs.gather(-1, sampled)
|
||||
sampled_ranks = (logprobs > sampled_logprobs).sum(-1)
|
||||
if num_logprobs == 0:
|
||||
# Return the logprobs of the sampled token.
|
||||
logprobs_tensors = LogprobsTensors(
|
||||
sampled,
|
||||
sampled_logprobs,
|
||||
sampled_ranks,
|
||||
)
|
||||
else:
|
||||
# Return (num_logprobs + 1) logprobs.
|
||||
topk_logprobs, topk_indices = torch.topk(
|
||||
logprobs,
|
||||
num_logprobs,
|
||||
dim=-1,
|
||||
)
|
||||
logprobs_tensors = LogprobsTensors(
|
||||
torch.cat((sampled, topk_indices), dim=1),
|
||||
torch.cat((sampled_logprobs, topk_logprobs), dim=1),
|
||||
sampled_ranks,
|
||||
)
|
||||
return logprobs_tensors
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _apply_temp_kernel(
|
||||
logits, # bf16[batch_size, vocab_size]
|
||||
logits_stride,
|
||||
output, # fp32[batch_size, vocab_size]
|
||||
output_stride,
|
||||
temperature,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
EPSILON: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
block_idx = tl.program_id(1)
|
||||
|
||||
temp = tl.load(temperature + batch_idx)
|
||||
if temp < EPSILON:
|
||||
# Greedy sampling. Don't apply temperature.
|
||||
# NOTE(woosuk): In this case, we assume that its logprobs are not used.
|
||||
temp = tl.ones([1], dtype=tl.float32)
|
||||
|
||||
offset = tl.arange(0, BLOCK_SIZE)
|
||||
block = block_idx * BLOCK_SIZE + offset
|
||||
|
||||
# Load the logits.
|
||||
x = tl.load(logits + batch_idx * logits_stride + block,
|
||||
mask=block < vocab_size)
|
||||
x = x.to(tl.float32)
|
||||
x = x / temp
|
||||
tl.store(output + batch_idx * output_stride + block,
|
||||
x,
|
||||
mask=block < vocab_size)
|
||||
|
||||
|
||||
def apply_temperature(
|
||||
logits: torch.Tensor,
|
||||
temperature: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
batch_size, vocab_size = logits.shape
|
||||
output = torch.empty_like(logits, dtype=torch.float32)
|
||||
BLOCK_SIZE = 8192
|
||||
_apply_temp_kernel[(batch_size, triton.cdiv(vocab_size, BLOCK_SIZE))](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
output,
|
||||
output.stride(0),
|
||||
temperature,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
EPSILON=_SAMPLING_EPS,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _apply_gumbel_kernel(
|
||||
probs_ptr,
|
||||
probs_stride,
|
||||
seeds_ptr,
|
||||
temp_ptr,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
EPSILON: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
seed = tl.load(seeds_ptr + req_idx)
|
||||
temp = tl.load(temp_ptr + req_idx)
|
||||
|
||||
if temp < EPSILON:
|
||||
# Greedy sampling. Don't apply gumbel noise.
|
||||
return
|
||||
|
||||
block_id = tl.program_id(1)
|
||||
r_offset = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
q = tl.rand(seed, r_offset)
|
||||
|
||||
# NOTE(woosuk): This logic makes sure q is not 0.
|
||||
RMAX = 0.9999999403953552
|
||||
RMAX_LOG = -5.960464477539063e-08
|
||||
q = tl.where(q >= RMAX, RMAX_LOG, tl.math.log(q))
|
||||
q = -1.0 * q
|
||||
|
||||
p = tl.load(probs_ptr + req_idx * probs_stride + r_offset,
|
||||
mask=r_offset < vocab_size)
|
||||
p = p / q
|
||||
|
||||
tl.store(probs_ptr + req_idx * probs_stride + r_offset,
|
||||
p,
|
||||
mask=r_offset < vocab_size)
|
||||
|
||||
|
||||
def gumbel_sample(
|
||||
# fp32[num_reqs, vocab_size]
|
||||
probs: torch.Tensor,
|
||||
# fp32[num_reqs]
|
||||
temperature: torch.Tensor,
|
||||
# int64[num_reqs]
|
||||
seeds: Optional[torch.Tensor],
|
||||
# int64[num_reqs]
|
||||
pos: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
num_reqs = probs.shape[0]
|
||||
vocab_size = probs.shape[1]
|
||||
if seeds is not None:
|
||||
# Per-request seed.
|
||||
assert pos is not None
|
||||
gumbel_seeds = seeds + pos
|
||||
else:
|
||||
# Global seed.
|
||||
assert pos is None
|
||||
seed_dtype = torch.int64
|
||||
gumbel_seeds = torch.randint(
|
||||
torch.iinfo(seed_dtype).min,
|
||||
torch.iinfo(seed_dtype).max,
|
||||
(num_reqs, ),
|
||||
dtype=seed_dtype,
|
||||
device=probs.device,
|
||||
)
|
||||
|
||||
# Update the probs in-place.
|
||||
BLOCK_SIZE = 8192
|
||||
_apply_gumbel_kernel[(num_reqs, triton.cdiv(vocab_size, BLOCK_SIZE))](
|
||||
probs,
|
||||
probs.stride(0),
|
||||
gumbel_seeds,
|
||||
temperature,
|
||||
vocab_size,
|
||||
BLOCK_SIZE,
|
||||
EPSILON=_SAMPLING_EPS,
|
||||
)
|
||||
# Sample the next token.
|
||||
return probs.argmax(dim=-1).view(-1)
|
||||
143
vllm/v1/worker/gpu/states.py
Normal file
143
vllm/v1/worker/gpu/states.py
Normal file
@ -0,0 +1,143 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
|
||||
|
||||
class RequestState:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_model_len: int,
|
||||
max_num_batched_tokens: int,
|
||||
vocab_size: int,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_model_len = max_model_len
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.vocab_size = vocab_size
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
|
||||
self.req_id_to_index: dict[str, int] = {}
|
||||
self.index_to_req_id: dict[int, str] = {}
|
||||
self.free_indices = list(range(max_num_reqs))
|
||||
|
||||
# TODO(woosuk): Because the token_ids tensor can be very big, we only
|
||||
# initialize it on CPU memory.
|
||||
self.token_ids = np.zeros(
|
||||
(self.max_num_reqs, self.max_model_len),
|
||||
dtype=np.int32,
|
||||
)
|
||||
self.num_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
|
||||
self.num_computed_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
|
||||
|
||||
# Last sampled token ids.
|
||||
self.last_token = torch.zeros(
|
||||
self.max_num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# Sampling parameters.
|
||||
self.temperature = np.zeros(self.max_num_reqs, dtype=np.float32)
|
||||
self.top_p = np.zeros(self.max_num_reqs, dtype=np.float32)
|
||||
self.top_k = np.zeros(self.max_num_reqs, dtype=np.int32)
|
||||
self.num_logprobs = np.empty(self.max_num_reqs, dtype=np.int32)
|
||||
# -1 means no logprobs are requested.
|
||||
self.num_logprobs.fill(-1)
|
||||
|
||||
@property
|
||||
def num_reqs(self) -> int:
|
||||
return len(self.req_id_to_index)
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
req_id: str,
|
||||
prompt_token_ids: list[int],
|
||||
num_computed_tokens: int,
|
||||
sampling_params: SamplingParams,
|
||||
) -> None:
|
||||
assert len(self.free_indices) > 0
|
||||
req_idx = self.free_indices.pop()
|
||||
self.req_id_to_index[req_id] = req_idx
|
||||
self.index_to_req_id[req_idx] = req_id
|
||||
|
||||
prompt_len = len(prompt_token_ids)
|
||||
self.num_tokens[req_idx] = prompt_len
|
||||
self.token_ids[req_idx, :prompt_len] = prompt_token_ids
|
||||
self.num_computed_tokens[req_idx] = num_computed_tokens
|
||||
|
||||
self.temperature[req_idx] = sampling_params.temperature
|
||||
self.top_p[req_idx] = sampling_params.top_p
|
||||
if 0 < sampling_params.top_k < self.vocab_size:
|
||||
top_k = sampling_params.top_k
|
||||
else:
|
||||
top_k = self.vocab_size
|
||||
self.top_k[req_idx] = top_k
|
||||
|
||||
if sampling_params.num_logprobs is not None:
|
||||
num_logprobs = sampling_params.num_logprobs
|
||||
else:
|
||||
num_logprobs = -1
|
||||
self.num_logprobs[req_idx] = num_logprobs
|
||||
|
||||
def append_token_ids(
|
||||
self,
|
||||
req_idx: int,
|
||||
token_ids: Union[list[int], np.ndarray],
|
||||
) -> None:
|
||||
start_idx = self.num_tokens[req_idx]
|
||||
end_idx = start_idx + len(token_ids)
|
||||
self.token_ids[req_idx, start_idx:end_idx] = token_ids
|
||||
self.num_tokens[req_idx] = end_idx
|
||||
|
||||
def remove_request(self, req_id: str) -> None:
|
||||
req_idx = self.req_id_to_index.pop(req_id, None)
|
||||
if req_idx is None:
|
||||
# Request not found.
|
||||
return
|
||||
self.index_to_req_id.pop(req_idx, None)
|
||||
self.free_indices.append(req_idx)
|
||||
|
||||
def make_sampling_metadata(
|
||||
self,
|
||||
idx_mapping: np.ndarray,
|
||||
) -> SamplingMetadata:
|
||||
temperature = self.temperature[idx_mapping]
|
||||
temperature = self._copy_np_to_gpu(temperature)
|
||||
|
||||
top_p = self.top_p[idx_mapping]
|
||||
no_top_p = np.all(top_p == 1.0)
|
||||
top_p = self._copy_np_to_gpu(top_p) if not no_top_p else None
|
||||
|
||||
top_k = self.top_k[idx_mapping]
|
||||
no_top_k = np.all(top_k == self.vocab_size)
|
||||
top_k = self._copy_np_to_gpu(top_k) if not no_top_k else None
|
||||
|
||||
num_logprobs = self.num_logprobs[idx_mapping]
|
||||
max_num_logprobs = np.max(num_logprobs)
|
||||
if max_num_logprobs == -1:
|
||||
max_num_logprobs = None
|
||||
|
||||
return SamplingMetadata(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
max_num_logprobs=max_num_logprobs,
|
||||
)
|
||||
|
||||
def _copy_np_to_gpu(self, src: np.ndarray) -> torch.Tensor:
|
||||
cpu_tensor = torch.from_numpy(src)
|
||||
if self.pin_memory:
|
||||
cpu_tensor = cpu_tensor.pin_memory()
|
||||
return cpu_tensor.to(device=self.device, non_blocking=True)
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,291 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessors
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
|
||||
_MAX_SPEC_LEN = 32
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestData:
|
||||
|
||||
mm_kwargs: list[MultiModalKwargsItem]
|
||||
mm_positions: list[PlaceholderRange]
|
||||
sampling_params: Optional[SamplingParams]
|
||||
pooling_params: Optional[PoolingParams]
|
||||
|
||||
mm_hashes: list[str]
|
||||
# M-RoPE (only for Qwen2/2.5-VL)
|
||||
mrope_positions: Optional[torch.Tensor] = None
|
||||
mrope_position_delta: Optional[int] = None
|
||||
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
|
||||
|
||||
class RequestState:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_model_len: int,
|
||||
max_num_batched_tokens: int,
|
||||
max_num_cached_reqs: int,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
vocab_size: int,
|
||||
logitsprocs: Optional[LogitsProcessors] = None,
|
||||
is_spec_decode: bool = False,
|
||||
is_pooling_model: bool = False,
|
||||
):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_model_len = max_model_len
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.max_num_cached_reqs = max_num_cached_reqs
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
self.vocab_size = vocab_size
|
||||
self.is_spec_decode = is_spec_decode
|
||||
self.pooling_params = None
|
||||
self.num_prompt_logprobs: dict[int, int] = {}
|
||||
|
||||
self.req_id_to_index: dict[str, int] = {}
|
||||
self.index_to_req_id: dict[int, str] = {}
|
||||
self.free_indices = list(range(max_num_cached_reqs))
|
||||
|
||||
# Request states.
|
||||
self.req_data: dict[int, RequestData] = {}
|
||||
# TODO(woosuk): Because the token_ids tensor can be very big, we only
|
||||
# initialize it on CPU memory.
|
||||
self.token_ids = np.zeros(
|
||||
(self.max_num_cached_reqs, self.max_model_len),
|
||||
dtype=np.int32,
|
||||
)
|
||||
self.num_prompt_tokens = np.zeros(self.max_num_cached_reqs,
|
||||
dtype=np.int32)
|
||||
self.num_tokens = np.zeros(self.max_num_cached_reqs, dtype=np.int32)
|
||||
self.num_computed_tokens = np.zeros(self.max_num_cached_reqs,
|
||||
dtype=np.int32)
|
||||
|
||||
# Last sampled token ids.
|
||||
self.last_sampled_token = torch.zeros(
|
||||
self.max_num_cached_reqs,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.temperature = np.zeros(self.max_num_cached_reqs, dtype=np.float32)
|
||||
self.top_p = np.zeros(self.max_num_cached_reqs, dtype=np.float32)
|
||||
self.top_k = np.zeros(self.max_num_cached_reqs, dtype=np.int32)
|
||||
|
||||
self.frequency_penalties = np.zeros(self.max_num_cached_reqs,
|
||||
dtype=np.float32)
|
||||
self.presence_penalties = np.zeros(self.max_num_cached_reqs,
|
||||
dtype=np.float32)
|
||||
self.repetition_penalties = np.zeros(self.max_num_cached_reqs,
|
||||
dtype=np.float32)
|
||||
|
||||
self.generators: dict[int, torch.Generator] = {}
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
req_id: str,
|
||||
prompt_token_ids: list[int],
|
||||
num_computed_tokens: int,
|
||||
sampling_params: SamplingParams,
|
||||
) -> None:
|
||||
assert len(self.free_indices) > 0, "No free space in GPU worker states"
|
||||
req_idx = self.free_indices.pop()
|
||||
self.req_id_to_index[req_id] = req_idx
|
||||
self.index_to_req_id[req_idx] = req_id
|
||||
|
||||
prompt_len = len(prompt_token_ids)
|
||||
self.num_prompt_tokens[req_idx] = prompt_len
|
||||
self.num_tokens[req_idx] = prompt_len
|
||||
self.token_ids[req_idx, :prompt_len] = prompt_token_ids
|
||||
self.num_computed_tokens[req_idx] = num_computed_tokens
|
||||
|
||||
self.temperature[req_idx] = sampling_params.temperature
|
||||
self.top_p[req_idx] = sampling_params.top_p
|
||||
if 0 < sampling_params.top_k < self.vocab_size:
|
||||
top_k = sampling_params.top_k
|
||||
else:
|
||||
top_k = self.vocab_size
|
||||
self.top_k[req_idx] = top_k
|
||||
self.frequency_penalties[req_idx] = sampling_params.frequency_penalty
|
||||
self.presence_penalties[req_idx] = sampling_params.presence_penalty
|
||||
self.repetition_penalties[req_idx] = sampling_params.repetition_penalty
|
||||
|
||||
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
|
||||
generator = torch.Generator(device=self.device)
|
||||
generator.manual_seed(sampling_params.seed)
|
||||
self.generators[req_idx] = generator
|
||||
|
||||
@property
|
||||
def num_cached_reqs(self) -> int:
|
||||
return len(self.req_id_to_index)
|
||||
|
||||
def append_token_ids(
|
||||
self,
|
||||
req_idx: int,
|
||||
token_ids: Union[list[int], np.ndarray],
|
||||
) -> None:
|
||||
start_idx = self.num_tokens.np[req_idx]
|
||||
end_idx = start_idx + len(token_ids)
|
||||
self.token_ids.np[req_idx, start_idx:end_idx] = token_ids
|
||||
self.num_tokens.np[req_idx] = end_idx
|
||||
|
||||
def remove_request(self, req_id: str) -> None:
|
||||
req_idx = self.req_id_to_index.pop(req_id, None)
|
||||
if req_idx is None:
|
||||
# Request not found.
|
||||
return
|
||||
self.index_to_req_id.pop(req_idx, None)
|
||||
self.free_indices.append(req_idx)
|
||||
|
||||
def make_sampling_metadata(
|
||||
self,
|
||||
idx_mapping: np.ndarray,
|
||||
) -> SamplingMetadata:
|
||||
temperature = self.temperature[idx_mapping]
|
||||
all_greedy = np.all(temperature == 0.0)
|
||||
all_random = np.all(temperature != 0.0)
|
||||
temperature = self._copy_np_to_gpu(temperature)
|
||||
|
||||
top_p = self.top_p[idx_mapping]
|
||||
no_top_p = np.all(top_p == 1.0)
|
||||
top_p = self._copy_np_to_gpu(top_p) if not no_top_p else None
|
||||
top_k = self.top_k[idx_mapping]
|
||||
no_top_k = np.all(top_k == self.vocab_size)
|
||||
top_k = self._copy_np_to_gpu(top_k) if not no_top_k else None
|
||||
|
||||
frequency_penalties = self.frequency_penalties[idx_mapping]
|
||||
presence_penalties = self.presence_penalties[idx_mapping]
|
||||
repetition_penalties = self.repetition_penalties[idx_mapping]
|
||||
no_penalties = (np.all(frequency_penalties == 0.0)
|
||||
and np.all(presence_penalties == 0.0)
|
||||
and np.all(repetition_penalties == 1.0))
|
||||
if no_penalties:
|
||||
frequency_penalties = None
|
||||
presence_penalties = None
|
||||
repetition_penalties = None
|
||||
else:
|
||||
frequency_penalties = self._copy_np_to_gpu(frequency_penalties)
|
||||
presence_penalties = self._copy_np_to_gpu(presence_penalties)
|
||||
repetition_penalties = self._copy_np_to_gpu(repetition_penalties)
|
||||
|
||||
if self.generators:
|
||||
generators = {
|
||||
req_idx: self.generators[req_idx]
|
||||
for req_idx in idx_mapping if req_idx in self.generators
|
||||
}
|
||||
else:
|
||||
generators = {}
|
||||
|
||||
return SamplingMetadata(
|
||||
temperature=temperature,
|
||||
all_greedy=all_greedy,
|
||||
all_random=all_random,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
frequency_penalties=frequency_penalties,
|
||||
presence_penalties=presence_penalties,
|
||||
repetition_penalties=repetition_penalties,
|
||||
no_penalties=no_penalties,
|
||||
generators=generators,
|
||||
token_ids=None,
|
||||
num_tokens=None,
|
||||
num_prompt_tokens=None,
|
||||
max_num_logprobs=None,
|
||||
allowed_token_ids_mask=None,
|
||||
bad_words_token_ids={},
|
||||
logitsprocs=None,
|
||||
)
|
||||
|
||||
def _copy_np_to_gpu(self, src: np.ndarray) -> torch.Tensor:
|
||||
cpu_tensor = torch.from_numpy(src)
|
||||
return cpu_tensor.to(device=self.device, non_blocking=True)
|
||||
|
||||
def make_spec_decode_metadata(
|
||||
self,
|
||||
query_start_loc: torch.Tensor,
|
||||
cu_num_draft_tokens: torch.Tensor,
|
||||
cu_num_draft_tokens_np: np.ndarray,
|
||||
input_ids: torch.Tensor,
|
||||
) -> SpecDecodeMetadata:
|
||||
batch_size = query_start_loc.shape[0] - 1
|
||||
total_num_draft_tokens = cu_num_draft_tokens_np[batch_size - 1]
|
||||
logits_indices = torch.empty(total_num_draft_tokens + batch_size,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
target_logits_indices = torch.empty(total_num_draft_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
bonus_logits_indices = torch.empty(batch_size,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
_prepare_spec_decode_kernel[(batch_size, )](
|
||||
query_start_loc,
|
||||
cu_num_draft_tokens,
|
||||
logits_indices,
|
||||
target_logits_indices,
|
||||
bonus_logits_indices,
|
||||
BLOCK_SIZE=triton.next_power_of_2(_MAX_SPEC_LEN + 1),
|
||||
)
|
||||
|
||||
draft_token_ids = input_ids[logits_indices]
|
||||
draft_token_ids = draft_token_ids[target_logits_indices + 1]
|
||||
return SpecDecodeMetadata(
|
||||
draft_token_ids=draft_token_ids,
|
||||
num_draft_tokens=cu_num_draft_tokens_np.tolist(),
|
||||
cu_num_draft_tokens=cu_num_draft_tokens,
|
||||
target_logits_indices=target_logits_indices,
|
||||
bonus_logits_indices=bonus_logits_indices,
|
||||
logits_indices=logits_indices,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _prepare_spec_decode_kernel(
|
||||
query_start_loc, # [B + 1]
|
||||
cu_num_draft_tokens, # [B]
|
||||
logits_indices, # [N + B]
|
||||
target_logits_indices, # [N]
|
||||
bonus_logits_indices, # [B]
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
|
||||
if batch_idx == 0:
|
||||
draft_start_idx = 0
|
||||
else:
|
||||
draft_start_idx = tl.load(cu_num_draft_tokens + batch_idx - 1)
|
||||
draft_end_idx = tl.load(cu_num_draft_tokens + batch_idx)
|
||||
draft_len = draft_end_idx - draft_start_idx
|
||||
sample_len = draft_len + 1
|
||||
|
||||
q_end_idx = tl.load(query_start_loc + batch_idx + 1)
|
||||
|
||||
sample_start_idx = draft_start_idx + batch_idx
|
||||
sample_end_idx = sample_start_idx + sample_len
|
||||
offset = tl.arange(0, BLOCK_SIZE)
|
||||
tl.store(logits_indices + sample_start_idx + offset,
|
||||
q_end_idx - sample_len + offset,
|
||||
mask=offset < sample_len)
|
||||
tl.store(target_logits_indices + draft_start_idx + offset,
|
||||
sample_start_idx + offset,
|
||||
mask=offset < draft_len)
|
||||
tl.store(bonus_logits_indices + batch_idx, sample_end_idx - 1)
|
||||
Loading…
x
Reference in New Issue
Block a user