mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-26 23:43:09 +08:00
[V0 Deprecation] Remove V0 core (#25321)
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai> Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
86fdd686be
commit
7cdd90211b
@ -148,7 +148,6 @@ steps:
|
||||
num_gpus: 4
|
||||
source_file_dependencies:
|
||||
- vllm/distributed/
|
||||
- vllm/core/
|
||||
- tests/distributed/test_utils
|
||||
- tests/distributed/test_pynccl
|
||||
- tests/distributed/test_events
|
||||
@ -867,8 +866,6 @@ steps:
|
||||
- tests/distributed/
|
||||
- vllm/compilation
|
||||
- vllm/worker/worker_base.py
|
||||
- vllm/worker/worker.py
|
||||
- vllm/worker/model_runner.py
|
||||
- entrypoints/llm/test_collective_rpc.py
|
||||
- tests/v1/test_async_llm_dp.py
|
||||
- tests/v1/test_external_lb_dp.py
|
||||
|
||||
2
.github/CODEOWNERS
vendored
2
.github/CODEOWNERS
vendored
@ -4,10 +4,8 @@
|
||||
# This lists cover the "core" components of vLLM that require careful review
|
||||
/vllm/attention @LucasWilkinson
|
||||
/vllm/attention/backends/abstract.py @WoosukKwon @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
|
||||
/vllm/core @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
|
||||
/vllm/executor/executor_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill @22quinn
|
||||
/vllm/worker/worker_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill @22quinn
|
||||
/vllm/worker/worker.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
|
||||
/vllm/model_executor/layers/fused_moe @mgoin
|
||||
/vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill @NickLucche
|
||||
/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth @yewentao256
|
||||
|
||||
@ -70,7 +70,6 @@ line-length = 80
|
||||
"vllm/_version.py" = ["ALL"]
|
||||
# Python 3.8 typing - skip V0 code
|
||||
"vllm/attention/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/core/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/engine/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/executor/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/worker/**/*.py" = ["UP006", "UP035"]
|
||||
@ -117,7 +116,6 @@ files = [
|
||||
"vllm/*.py",
|
||||
"vllm/assets",
|
||||
"vllm/entrypoints",
|
||||
"vllm/core",
|
||||
"vllm/inputs",
|
||||
"vllm/logging_utils",
|
||||
"vllm/multimodal",
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from itertools import accumulate
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
@ -34,9 +34,6 @@ from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
|
||||
flash_attn_with_kvcache)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUBuilder
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -329,7 +326,7 @@ class DifferentialFlashAttentionMetadata(AttentionMetadata):
|
||||
class DifferentialFlashAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[DifferentialFlashAttentionMetadata]):
|
||||
|
||||
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
|
||||
def __init__(self, input_builder):
|
||||
self.input_builder = input_builder
|
||||
self.runner = input_builder.runner
|
||||
self.sliding_window = input_builder.sliding_window
|
||||
@ -350,9 +347,8 @@ class DifferentialFlashAttentionMetadataBuilder(
|
||||
self.num_decode_tokens = 0
|
||||
self.has_prefix_cache_hit = False
|
||||
|
||||
def _add_seq_group(
|
||||
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
||||
chunked_prefill_enabled: bool, prefix_cache_hit: bool):
|
||||
def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool,
|
||||
prefix_cache_hit: bool):
|
||||
"""Add a sequence group to the metadata. Specifically update/append
|
||||
1. context length.
|
||||
2. block table.
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
"""
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
@ -22,9 +22,6 @@ from vllm.utils import async_tensor_h2d
|
||||
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
|
||||
flash_attn_with_kvcache, sparse_attn_func)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUBuilder
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -224,9 +221,8 @@ class DualChunkFlashAttentionMetadataBuilder(FlashAttentionMetadataBuilder):
|
||||
super().prepare()
|
||||
self.orig_seq_lens: List[int] = []
|
||||
|
||||
def _add_seq_group(
|
||||
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
||||
chunked_prefill_enabled: bool, prefix_cache_hit: bool):
|
||||
def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool,
|
||||
prefix_cache_hit: bool):
|
||||
super()._add_seq_group(inter_data, chunked_prefill_enabled,
|
||||
prefix_cache_hit)
|
||||
for prompt_len, seq_len in zip(inter_data.prompt_lens,
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from itertools import accumulate
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
@ -31,9 +31,6 @@ from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
|
||||
flash_attn_with_kvcache)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUBuilder
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -312,7 +309,7 @@ class FlashAttentionMetadata(AttentionMetadata):
|
||||
class FlashAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[FlashAttentionMetadata]):
|
||||
|
||||
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
|
||||
def __init__(self, input_builder):
|
||||
self.input_builder = input_builder
|
||||
self.runner = input_builder.runner
|
||||
self.sliding_window = input_builder.sliding_window
|
||||
@ -332,9 +329,8 @@ class FlashAttentionMetadataBuilder(
|
||||
self.num_decode_tokens = 0
|
||||
self.has_prefix_cache_hit = False
|
||||
|
||||
def _add_seq_group(
|
||||
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
||||
chunked_prefill_enabled: bool, prefix_cache_hit: bool):
|
||||
def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool,
|
||||
prefix_cache_hit: bool):
|
||||
"""Add a sequence group to the metadata. Specifically update/append
|
||||
1. context length.
|
||||
2. block table.
|
||||
|
||||
@ -193,8 +193,7 @@ from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from itertools import accumulate
|
||||
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple,
|
||||
Type, TypeVar)
|
||||
from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
@ -233,9 +232,6 @@ except ImportError:
|
||||
except ImportError:
|
||||
flash_attn_varlen_func = None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUBuilder
|
||||
|
||||
is_hip = current_platform.is_rocm()
|
||||
|
||||
|
||||
@ -638,7 +634,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
|
||||
"""
|
||||
BLOCK_TABLE_EXTENDER: list[list[int]] = []
|
||||
|
||||
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
|
||||
def __init__(self, input_builder):
|
||||
self.input_builder = input_builder
|
||||
self.runner = input_builder.runner
|
||||
self.sliding_window = input_builder.sliding_window
|
||||
@ -668,9 +664,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
|
||||
self.num_decode_tokens = 0
|
||||
self.has_prefix_cache_hit = False
|
||||
|
||||
def _add_seq_group(
|
||||
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
||||
chunked_prefill_enabled: bool, prefix_cache_hit: bool):
|
||||
def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool,
|
||||
prefix_cache_hit: bool):
|
||||
"""Add a sequence group to the metadata. Specifically update/append
|
||||
1. context length.
|
||||
2. block table.
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from itertools import accumulate
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
@ -13,9 +13,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadataBuilder)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.multimodal import MultiModalPlaceholderMap
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import (ModelInputForGPUBuilder)
|
||||
from vllm.utils import async_tensor_h2d
|
||||
|
||||
# Placeholder attention backend for models like Mamba and pooling models that
|
||||
@ -204,7 +201,7 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
|
||||
class PlaceholderAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[PlaceholderAttentionMetadata]):
|
||||
|
||||
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
|
||||
def __init__(self, input_builder):
|
||||
|
||||
self.input_builder = input_builder
|
||||
self.runner = input_builder.runner
|
||||
@ -220,9 +217,7 @@ class PlaceholderAttentionMetadataBuilder(
|
||||
self.num_prefill_tokens = 0
|
||||
self.num_decode_tokens = 0
|
||||
|
||||
def _add_seq_group(
|
||||
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
||||
chunked_prefill_enabled: bool):
|
||||
def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool):
|
||||
"""Add a sequence group to the metadata. Specifically update/append
|
||||
1. context length.
|
||||
"""
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional, Type, Union
|
||||
from typing import Optional, Type, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -19,9 +19,6 @@ from vllm.attention.backends.utils import (compute_slot_mapping,
|
||||
from vllm.attention.ops.rocm_aiter_mla import (aiter_mla_decode_fwd,
|
||||
get_aiter_mla_metadata)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUBuilder
|
||||
|
||||
|
||||
def is_aiter_mla_enabled() -> bool:
|
||||
return envs.VLLM_ROCM_USE_AITER \
|
||||
@ -110,7 +107,7 @@ class AiterMLAMetadata(MLACommonMetadata):
|
||||
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
BLOCK_TABLE_EXTENDER: list[list[int]] = [[]]
|
||||
|
||||
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
|
||||
def __init__(self, input_builder):
|
||||
super().__init__(input_builder)
|
||||
assert self.block_size == 1, "AITER MLA requires only block size 1."
|
||||
|
||||
|
||||
@ -35,9 +35,6 @@ PAD_SLOT_ID = -1
|
||||
# if we have at least this many elements. Could be tuned further.
|
||||
_COMPUTE_SLOT_MAPPING_NUMPY_NUMEL = 256
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUBuilder
|
||||
|
||||
|
||||
def is_block_tables_empty(block_tables: Union[None, Dict]):
|
||||
"""
|
||||
@ -129,7 +126,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
|
||||
|
||||
_metadata_cls: Type[TAttentionMetadata]
|
||||
|
||||
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
|
||||
def __init__(self, input_builder):
|
||||
self.input_builder = input_builder
|
||||
self.runner = input_builder.runner
|
||||
|
||||
@ -149,9 +146,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
|
||||
self.num_prefill_tokens = 0
|
||||
self.num_decode_tokens = 0
|
||||
|
||||
def _add_seq_group(
|
||||
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
||||
chunked_prefill_enabled: bool):
|
||||
def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool):
|
||||
is_prompt = inter_data.is_prompt
|
||||
block_tables = inter_data.block_tables
|
||||
|
||||
|
||||
@ -1,399 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import math
|
||||
from typing import List, Optional
|
||||
|
||||
from vllm.core.block.common import BlockList
|
||||
from vllm.core.block.interfaces import Block, DeviceAwareBlockAllocator
|
||||
from vllm.utils import Device, cdiv, chunk_list
|
||||
|
||||
|
||||
class BlockTable:
|
||||
"""A class to manage blocks for a specific sequence.
|
||||
|
||||
The BlockTable maps a sequence of tokens to a list of blocks, where each
|
||||
block represents a contiguous memory allocation for a portion of the
|
||||
sequence. The blocks are managed by a DeviceAwareBlockAllocator, which is
|
||||
responsible for allocating and freeing memory for the blocks.
|
||||
|
||||
Args:
|
||||
block_size (int): The maximum number of tokens that can be stored in a
|
||||
single block.
|
||||
block_allocator (DeviceAwareBlockAllocator): The block allocator used to
|
||||
manage memory for the blocks.
|
||||
_blocks (Optional[List[Block]], optional): An optional list of existing
|
||||
blocks to initialize the BlockTable with. If not provided, an empty
|
||||
BlockTable is created.
|
||||
max_block_sliding_window (Optional[int], optional): The number of
|
||||
blocks to keep around for each sequence. If None, all blocks
|
||||
are kept (eg., when sliding window is not used).
|
||||
It should at least fit the sliding window size of the model.
|
||||
|
||||
Attributes:
|
||||
_block_size (int): The maximum number of tokens that can be stored in a
|
||||
single block.
|
||||
_allocator (DeviceAwareBlockAllocator): The block allocator used to
|
||||
manage memory for the blocks.
|
||||
_blocks (Optional[List[Block]]): The list of blocks managed by this
|
||||
BlockTable.
|
||||
_num_full_slots (int): The number of tokens currently stored in the
|
||||
blocks.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_size: int,
|
||||
block_allocator: DeviceAwareBlockAllocator,
|
||||
_blocks: Optional[List[Block]] = None,
|
||||
max_block_sliding_window: Optional[int] = None,
|
||||
):
|
||||
self._block_size = block_size
|
||||
self._allocator = block_allocator
|
||||
if _blocks is None:
|
||||
_blocks = []
|
||||
self._blocks: BlockList = BlockList(_blocks)
|
||||
|
||||
self._max_block_sliding_window = max_block_sliding_window
|
||||
self._num_full_slots = self._get_num_token_ids()
|
||||
|
||||
@staticmethod
|
||||
def get_num_required_blocks(token_ids: List[int],
|
||||
block_size: int,
|
||||
num_lookahead_slots: int = 0) -> int:
|
||||
"""Calculates the minimum number of blocks required to store a given
|
||||
sequence of token IDs along with any look-ahead slots that may be
|
||||
required (like in multi-step + chunked-prefill).
|
||||
|
||||
This assumes worst-case scenario, where every block requires a new
|
||||
allocation (e.g. ignoring prefix caching).
|
||||
|
||||
Args:
|
||||
token_ids (List[int]): The sequence of token IDs to be stored.
|
||||
block_size (int): The maximum number of tokens that can be stored in
|
||||
a single block.
|
||||
num_lookahead_slots (int): look-ahead slots that the sequence may
|
||||
require.
|
||||
|
||||
Returns:
|
||||
int: The minimum number of blocks required to store the given
|
||||
sequence of token IDs along with any required look-ahead slots.
|
||||
"""
|
||||
return cdiv(len(token_ids) + num_lookahead_slots, block_size)
|
||||
|
||||
def allocate(self,
|
||||
token_ids: List[int],
|
||||
device: Device = Device.GPU,
|
||||
extra_hash: Optional[int] = None) -> None:
|
||||
"""Allocates memory blocks for storing the given sequence of token IDs.
|
||||
|
||||
This method allocates the required number of blocks to store the given
|
||||
sequence of token IDs.
|
||||
|
||||
Args:
|
||||
token_ids (List[int]): The sequence of token IDs to be stored.
|
||||
device (Device, optional): The device on which the blocks should be
|
||||
allocated. Defaults to Device.GPU.
|
||||
extra_hash (Optional[int]): The hash value of additional
|
||||
factors, such as adapters, that influence the block hash
|
||||
in the prefixcaching block.
|
||||
"""
|
||||
assert not self._is_allocated
|
||||
assert token_ids
|
||||
blocks = self._allocate_blocks_for_token_ids(prev_block=None,
|
||||
token_ids=token_ids,
|
||||
device=device,
|
||||
extra_hash=extra_hash)
|
||||
self.update(blocks)
|
||||
self._num_full_slots = len(token_ids)
|
||||
|
||||
def update(self, blocks: List[Block]) -> None:
|
||||
"""Resets the table to the newly provided blocks
|
||||
(with their corresponding block ids)
|
||||
"""
|
||||
self._blocks.update(blocks)
|
||||
|
||||
def append_token_ids(self,
|
||||
token_ids: List[int],
|
||||
num_lookahead_slots: int = 0,
|
||||
num_computed_slots: Optional[int] = None,
|
||||
extra_hash: Optional[int] = None) -> None:
|
||||
"""Appends a sequence of token IDs to the existing blocks in the
|
||||
BlockTable.
|
||||
|
||||
This method appends the given sequence of token IDs to the existing
|
||||
blocks in the BlockTable. If there is not enough space in the existing
|
||||
blocks, new blocks are allocated using the `ensure_num_empty_slots`
|
||||
method to accommodate the additional tokens.
|
||||
|
||||
The token IDs are divided into chunks of size `block_size` (except for
|
||||
the first chunk, which may be smaller), and each chunk is appended to a
|
||||
separate block.
|
||||
|
||||
Args:
|
||||
token_ids (List[int]): The sequence of token IDs to be appended.
|
||||
num_computed_slots (Optional[int]): The number of KV cache slots
|
||||
that are already filled (computed).
|
||||
When sliding window is enabled, this is used to compute how many
|
||||
blocks to drop at the front of the sequence.
|
||||
Without sliding window, None can be passed.
|
||||
Without chunked prefill, it should be the same as
|
||||
_num_full_slots.
|
||||
extra_hash (Optional[int]): The hash value of additional
|
||||
factors such as adapters that influence the block, apart
|
||||
from the token_ids.
|
||||
"""
|
||||
assert self._is_allocated, "no blocks have been allocated"
|
||||
assert len(self._blocks) > 0
|
||||
|
||||
# Drop blocks that are no longer needed due to sliding window
|
||||
if self._max_block_sliding_window is not None:
|
||||
null_block = self._allocator.allocate_or_get_null_block()
|
||||
assert num_computed_slots is not None
|
||||
end_block_idx = (num_computed_slots //
|
||||
self._block_size) - self._max_block_sliding_window
|
||||
for idx in range(0, end_block_idx):
|
||||
b = self._blocks[idx]
|
||||
if b is not null_block:
|
||||
self._allocator.free(b)
|
||||
self._blocks[idx] = null_block
|
||||
|
||||
# Ensure there are enough empty slots for the new tokens plus
|
||||
# lookahead slots
|
||||
self.ensure_num_empty_slots(num_empty_slots=len(token_ids) +
|
||||
num_lookahead_slots,
|
||||
extra_hash=extra_hash)
|
||||
|
||||
# Update the blocks with the new tokens
|
||||
first_block_idx = self._num_full_slots // self._block_size
|
||||
token_blocks = self._chunk_token_blocks_for_append(token_ids)
|
||||
|
||||
for i, token_block in enumerate(token_blocks):
|
||||
self._blocks.append_token_ids(first_block_idx + i, token_block)
|
||||
|
||||
self._num_full_slots += len(token_ids)
|
||||
|
||||
def ensure_num_empty_slots(self,
|
||||
num_empty_slots: int,
|
||||
extra_hash: Optional[int] = None) -> None:
|
||||
"""Ensures that the BlockTable has at least the specified number of
|
||||
empty slots available.
|
||||
|
||||
This method checks if the BlockTable has enough empty slots (i.e.,
|
||||
available space) to accommodate the requested number of tokens. If not,
|
||||
it allocates additional blocks on the GPU to ensure that the required
|
||||
number of empty slots is available.
|
||||
|
||||
Args:
|
||||
num_empty_slots (int): The minimum number of empty slots required.
|
||||
extra_hash (Optional[int]): The hash value of additional
|
||||
factors such as adapters that influence the block, apart
|
||||
from the token_ids.
|
||||
"""
|
||||
# Currently the block table only supports
|
||||
# appending tokens to GPU blocks.
|
||||
device = Device.GPU
|
||||
assert self._is_allocated
|
||||
|
||||
if self._num_empty_slots >= num_empty_slots:
|
||||
return
|
||||
|
||||
slots_to_allocate = num_empty_slots - self._num_empty_slots
|
||||
blocks_to_allocate = cdiv(slots_to_allocate, self._block_size)
|
||||
|
||||
for _ in range(blocks_to_allocate):
|
||||
assert len(self._blocks) > 0
|
||||
self._blocks.append(
|
||||
self._allocator.allocate_mutable_block(
|
||||
prev_block=self._blocks[-1],
|
||||
device=device,
|
||||
extra_hash=extra_hash))
|
||||
|
||||
def fork(self) -> "BlockTable":
|
||||
"""Creates a new BlockTable instance with a copy of the blocks from the
|
||||
current instance.
|
||||
|
||||
This method creates a new BlockTable instance with the same block size,
|
||||
block allocator, and a copy of the blocks from the current instance. The
|
||||
new BlockTable has its own independent set of blocks, but shares the
|
||||
same underlying memory allocation with the original BlockTable.
|
||||
|
||||
Returns:
|
||||
BlockTable: A new BlockTable instance with a copy of the blocks from
|
||||
the current instance.
|
||||
"""
|
||||
assert self._is_allocated
|
||||
assert len(self._blocks) > 0
|
||||
forked_blocks = self._allocator.fork(self._blocks[-1])
|
||||
return BlockTable(
|
||||
block_size=self._block_size,
|
||||
block_allocator=self._allocator,
|
||||
_blocks=forked_blocks,
|
||||
max_block_sliding_window=self._max_block_sliding_window,
|
||||
)
|
||||
|
||||
def free(self) -> None:
|
||||
"""Frees the memory occupied by the blocks in the BlockTable.
|
||||
|
||||
This method iterates over all the blocks in the `_blocks` list and calls
|
||||
the `free` method of the `_allocator` object to release the memory
|
||||
occupied by each block. After freeing all the blocks, the `_blocks` list
|
||||
is set to `None`.
|
||||
"""
|
||||
for block in self.blocks:
|
||||
self._allocator.free(block)
|
||||
self._blocks.reset()
|
||||
|
||||
@property
|
||||
def physical_block_ids(self) -> List[int]:
|
||||
"""Returns a list of physical block indices for the blocks in the
|
||||
BlockTable.
|
||||
|
||||
This property returns a list of integers, where each integer represents
|
||||
the physical block index of a corresponding block in the `_blocks` list.
|
||||
The physical block index is a unique identifier for the memory location
|
||||
occupied by the block.
|
||||
|
||||
Returns:
|
||||
List[int]: A list of physical block indices for the blocks in the
|
||||
BlockTable.
|
||||
"""
|
||||
return self._blocks.ids()
|
||||
|
||||
def get_unseen_token_ids(self, sequence_token_ids: List[int]) -> List[int]:
|
||||
"""Get the number of "unseen" tokens in the sequence.
|
||||
|
||||
Unseen tokens are tokens in the sequence corresponding to this block
|
||||
table, but are not yet appended to this block table.
|
||||
|
||||
Args:
|
||||
sequence_token_ids (List[int]): The list of token ids in the
|
||||
sequence.
|
||||
|
||||
Returns:
|
||||
List[int]: The postfix of sequence_token_ids that has not yet been
|
||||
appended to the block table.
|
||||
"""
|
||||
|
||||
# Since the block table is append-only, the unseen token ids are the
|
||||
# ones after the appended ones.
|
||||
return sequence_token_ids[self.num_full_slots:]
|
||||
|
||||
def _allocate_blocks_for_token_ids(
|
||||
self,
|
||||
prev_block: Optional[Block],
|
||||
token_ids: List[int],
|
||||
device: Device,
|
||||
extra_hash: Optional[int] = None) -> List[Block]:
|
||||
blocks: List[Block] = []
|
||||
|
||||
block_token_ids = []
|
||||
tail_token_ids = []
|
||||
for cur_token_ids in chunk_list(token_ids, self._block_size):
|
||||
if len(cur_token_ids) == self._block_size:
|
||||
block_token_ids.append(cur_token_ids)
|
||||
else:
|
||||
tail_token_ids.append(cur_token_ids)
|
||||
|
||||
if block_token_ids:
|
||||
blocks.extend(
|
||||
self._allocator.allocate_immutable_blocks(
|
||||
prev_block,
|
||||
block_token_ids=block_token_ids,
|
||||
device=device,
|
||||
extra_hash=extra_hash))
|
||||
prev_block = blocks[-1]
|
||||
|
||||
if tail_token_ids:
|
||||
assert len(tail_token_ids) == 1
|
||||
cur_token_ids = tail_token_ids[0]
|
||||
|
||||
block = self._allocator.allocate_mutable_block(
|
||||
prev_block=prev_block, device=device, extra_hash=extra_hash)
|
||||
block.append_token_ids(cur_token_ids)
|
||||
|
||||
blocks.append(block)
|
||||
|
||||
return blocks
|
||||
|
||||
def _get_all_token_ids(self) -> List[int]:
|
||||
# NOTE: This function is O(seq_len); use sparingly.
|
||||
token_ids: List[int] = []
|
||||
|
||||
if not self._is_allocated:
|
||||
return token_ids
|
||||
|
||||
for block in self.blocks:
|
||||
token_ids.extend(block.token_ids)
|
||||
|
||||
return token_ids
|
||||
|
||||
def _get_num_token_ids(self) -> int:
|
||||
res = 0
|
||||
for block in self.blocks:
|
||||
res += len(block.token_ids)
|
||||
|
||||
return res
|
||||
|
||||
@property
|
||||
def _is_allocated(self) -> bool:
|
||||
return len(self._blocks) > 0
|
||||
|
||||
@property
|
||||
def blocks(self) -> List[Block]:
|
||||
return self._blocks.list()
|
||||
|
||||
@property
|
||||
def _num_empty_slots(self) -> int:
|
||||
assert self._is_allocated
|
||||
return len(self._blocks) * self._block_size - self._num_full_slots
|
||||
|
||||
@property
|
||||
def num_full_slots(self) -> int:
|
||||
"""Returns the total number of tokens currently stored in the
|
||||
BlockTable.
|
||||
|
||||
Returns:
|
||||
int: The total number of tokens currently stored in the BlockTable.
|
||||
"""
|
||||
return self._num_full_slots
|
||||
|
||||
def get_num_blocks_touched_by_append_slots(
|
||||
self, token_ids: List[int], num_lookahead_slots: int) -> int:
|
||||
"""Determine how many blocks will be "touched" by appending the token
|
||||
ids.
|
||||
|
||||
This is required for the scheduler to determine whether a sequence can
|
||||
continue generation, or if it must be preempted.
|
||||
"""
|
||||
# Math below is equivalent to:
|
||||
# all_token_ids = token_ids + [-1] * num_lookahead_slots
|
||||
# token_blocks = self._chunk_token_blocks_for_append(all_token_ids)
|
||||
# return len(token_blocks)
|
||||
|
||||
num_token_ids = len(token_ids) + num_lookahead_slots
|
||||
first_chunk_size = self._block_size - (self._num_full_slots %
|
||||
self._block_size)
|
||||
num_token_blocks = (1 + math.ceil(
|
||||
(num_token_ids - first_chunk_size) / self._block_size))
|
||||
return num_token_blocks
|
||||
|
||||
def _chunk_token_blocks_for_append(
|
||||
self, token_ids: List[int]) -> List[List[int]]:
|
||||
"""Split the token ids into block-sized chunks so they can be easily
|
||||
appended to blocks. The first such "token block" may have less token ids
|
||||
than the block size, since the last allocated block may be partially
|
||||
full.
|
||||
|
||||
If no token ids are provided, then no chunks are returned.
|
||||
"""
|
||||
|
||||
if not token_ids:
|
||||
return []
|
||||
|
||||
first_chunk_size = self._block_size - (self._num_full_slots %
|
||||
self._block_size)
|
||||
token_blocks = [token_ids[:first_chunk_size]]
|
||||
token_blocks.extend(
|
||||
chunk_list(token_ids[first_chunk_size:], self._block_size))
|
||||
return token_blocks
|
||||
@ -1,371 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from typing import Deque, Dict, Iterable, List, Optional, Protocol, Tuple
|
||||
|
||||
from vllm.core.block.interfaces import Block, BlockAllocator
|
||||
|
||||
BlockId = int
|
||||
RefCount = int
|
||||
|
||||
|
||||
class RefCounterProtocol(Protocol):
|
||||
|
||||
def incr(self, block_id: BlockId) -> RefCount:
|
||||
raise NotImplementedError
|
||||
|
||||
def decr(self, block_id: BlockId) -> RefCount:
|
||||
raise NotImplementedError
|
||||
|
||||
def get(self, block_id: BlockId) -> RefCount:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RefCounter(RefCounterProtocol):
|
||||
"""A class for managing reference counts for a set of block indices.
|
||||
|
||||
The RefCounter class maintains a dictionary that maps block indices to their
|
||||
corresponding reference counts. It provides methods to increment, decrement,
|
||||
and retrieve the reference count for a given block index.
|
||||
|
||||
Args:
|
||||
all_block_indices (Iterable[BlockId]): An iterable of block indices
|
||||
to initialize the reference counter with.
|
||||
"""
|
||||
|
||||
def __init__(self, all_block_indices: Iterable[BlockId]):
|
||||
deduped = set(all_block_indices)
|
||||
self._refcounts: Dict[BlockId, RefCount] = {
|
||||
index: 0
|
||||
for index in deduped
|
||||
}
|
||||
|
||||
def incr(self, block_id: BlockId) -> RefCount:
|
||||
assert block_id in self._refcounts
|
||||
pre_incr_refcount = self._refcounts[block_id]
|
||||
|
||||
assert pre_incr_refcount >= 0
|
||||
|
||||
post_incr_refcount = pre_incr_refcount + 1
|
||||
self._refcounts[block_id] = post_incr_refcount
|
||||
return post_incr_refcount
|
||||
|
||||
def decr(self, block_id: BlockId) -> RefCount:
|
||||
assert block_id in self._refcounts
|
||||
refcount = self._refcounts[block_id]
|
||||
|
||||
assert refcount > 0
|
||||
refcount -= 1
|
||||
|
||||
self._refcounts[block_id] = refcount
|
||||
|
||||
return refcount
|
||||
|
||||
def get(self, block_id: BlockId) -> RefCount:
|
||||
assert block_id in self._refcounts
|
||||
return self._refcounts[block_id]
|
||||
|
||||
def as_readonly(self) -> "ReadOnlyRefCounter":
|
||||
return ReadOnlyRefCounter(self)
|
||||
|
||||
|
||||
class ReadOnlyRefCounter(RefCounterProtocol):
|
||||
"""A read-only view of the RefCounter class.
|
||||
|
||||
The ReadOnlyRefCounter class provides a read-only interface to access the
|
||||
reference counts maintained by a RefCounter instance. It does not allow
|
||||
modifications to the reference counts.
|
||||
|
||||
Args:
|
||||
refcounter (RefCounter): The RefCounter instance to create a read-only
|
||||
view for.
|
||||
"""
|
||||
|
||||
def __init__(self, refcounter: RefCounter):
|
||||
self._refcounter = refcounter
|
||||
|
||||
def incr(self, block_id: BlockId) -> RefCount:
|
||||
raise ValueError("Incr not allowed")
|
||||
|
||||
def decr(self, block_id: BlockId) -> RefCount:
|
||||
raise ValueError("Decr not allowed")
|
||||
|
||||
def get(self, block_id: BlockId) -> RefCount:
|
||||
return self._refcounter.get(block_id)
|
||||
|
||||
|
||||
class CopyOnWriteTracker:
|
||||
"""A class for tracking and managing copy-on-write operations for blocks.
|
||||
|
||||
The CopyOnWriteTracker class maintains a mapping of source block indices to
|
||||
their corresponding copy-on-write destination block indices. It works in
|
||||
conjunction with a RefCounter.
|
||||
|
||||
Args:
|
||||
refcounter (RefCounter): The reference counter used to track block
|
||||
reference counts.
|
||||
"""
|
||||
|
||||
def __init__(self, refcounter: RefCounterProtocol):
|
||||
self._copy_on_writes: List[Tuple[BlockId, BlockId]] = []
|
||||
self._refcounter = refcounter
|
||||
|
||||
def is_appendable(self, block: Block) -> bool:
|
||||
"""Checks if the block is shared or not. If shared, then it cannot
|
||||
be appended and needs to be duplicated via copy-on-write
|
||||
"""
|
||||
block_id = block.block_id
|
||||
if block_id is None:
|
||||
return True
|
||||
|
||||
refcount = self._refcounter.get(block_id)
|
||||
return refcount <= 1
|
||||
|
||||
def record_cow(self, src_block_id: Optional[BlockId],
|
||||
trg_block_id: Optional[BlockId]) -> None:
|
||||
"""Records a copy-on-write operation from source to target block id
|
||||
Args:
|
||||
src_block_id (BlockId): The source block id from which to copy
|
||||
the data
|
||||
trg_block_id (BlockId): The target block id to which the data
|
||||
is copied
|
||||
"""
|
||||
assert src_block_id is not None
|
||||
assert trg_block_id is not None
|
||||
self._copy_on_writes.append((src_block_id, trg_block_id))
|
||||
|
||||
def clear_cows(self) -> List[Tuple[BlockId, BlockId]]:
|
||||
"""Clears the copy-on-write tracking information and returns the current
|
||||
state.
|
||||
|
||||
This method returns a list mapping source block indices to
|
||||
destination block indices for the current copy-on-write operations.
|
||||
It then clears the internal tracking information.
|
||||
|
||||
Returns:
|
||||
List[Tuple[BlockId, BlockId]]: A list mapping source
|
||||
block indices to destination block indices for the
|
||||
current copy-on-write operations.
|
||||
"""
|
||||
cows = self._copy_on_writes
|
||||
self._copy_on_writes = []
|
||||
return cows
|
||||
|
||||
|
||||
class BlockPool:
|
||||
"""Used to pre-allocate block objects, in order to avoid excessive python
|
||||
object allocations/deallocations.
|
||||
The pool starts from "pool_size" objects and will increase to more objects
|
||||
if necessary
|
||||
|
||||
Note that multiple block objects may point to the same physical block id,
|
||||
which is why this pool is needed, so that it will be easier to support
|
||||
prefix caching and more complicated sharing of physical blocks.
|
||||
"""
|
||||
|
||||
def __init__(self, block_size: int, create_block: Block.Factory,
|
||||
allocator: BlockAllocator, pool_size: int):
|
||||
self._block_size = block_size
|
||||
self._create_block = create_block
|
||||
self._allocator = allocator
|
||||
self._pool_size = pool_size
|
||||
assert self._pool_size >= 0
|
||||
|
||||
self._free_ids: Deque[int] = deque(range(self._pool_size))
|
||||
self._pool = []
|
||||
for i in range(self._pool_size):
|
||||
self._pool.append(
|
||||
self._create_block(prev_block=None,
|
||||
token_ids=[],
|
||||
block_size=self._block_size,
|
||||
allocator=self._allocator,
|
||||
block_id=None,
|
||||
extra_hash=None))
|
||||
|
||||
def increase_pool(self):
|
||||
"""Doubles the internal pool size
|
||||
"""
|
||||
cur_pool_size = self._pool_size
|
||||
new_pool_size = cur_pool_size * 2
|
||||
self._pool_size = new_pool_size
|
||||
|
||||
self._free_ids += deque(range(cur_pool_size, new_pool_size))
|
||||
|
||||
for i in range(cur_pool_size, new_pool_size):
|
||||
self._pool.append(
|
||||
self._create_block(prev_block=None,
|
||||
token_ids=[],
|
||||
block_size=self._block_size,
|
||||
allocator=self._allocator,
|
||||
block_id=None,
|
||||
extra_hash=None))
|
||||
|
||||
def init_block(self,
|
||||
prev_block: Optional[Block],
|
||||
token_ids: List[int],
|
||||
block_size: int,
|
||||
physical_block_id: Optional[int],
|
||||
extra_hash: Optional[int] = None) -> Block:
|
||||
if len(self._free_ids) == 0:
|
||||
self.increase_pool()
|
||||
assert len(self._free_ids) > 0
|
||||
|
||||
pool_id = self._free_ids.popleft()
|
||||
|
||||
block = self._pool[pool_id]
|
||||
block.__init__( # type: ignore[misc]
|
||||
prev_block=prev_block,
|
||||
token_ids=token_ids,
|
||||
block_size=block_size,
|
||||
allocator=block._allocator, # type: ignore[attr-defined]
|
||||
block_id=physical_block_id,
|
||||
extra_hash=extra_hash)
|
||||
block.pool_id = pool_id # type: ignore[attr-defined]
|
||||
return block
|
||||
|
||||
def free_block(self, block: Block) -> None:
|
||||
self._free_ids.appendleft(block.pool_id) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
class BlockList:
|
||||
"""This class is an optimization to allow fast-access to physical
|
||||
block ids. It maintains a block id list that is updated with the
|
||||
block list and this avoids the need to reconstruct the block id
|
||||
list on every iteration of the block manager
|
||||
"""
|
||||
|
||||
def __init__(self, blocks: List[Block]):
|
||||
self._blocks: List[Block] = []
|
||||
self._block_ids: List[int] = []
|
||||
|
||||
self.update(blocks)
|
||||
|
||||
def _add_block_id(self, block_id: Optional[BlockId]) -> None:
|
||||
assert block_id is not None
|
||||
self._block_ids.append(block_id)
|
||||
|
||||
def _update_block_id(self, block_index: int,
|
||||
new_block_id: Optional[BlockId]) -> None:
|
||||
assert new_block_id is not None
|
||||
self._block_ids[block_index] = new_block_id
|
||||
|
||||
def update(self, blocks: List[Block]):
|
||||
self._blocks = blocks
|
||||
|
||||
# Cache block ids for fast query
|
||||
self._block_ids = []
|
||||
for block in self._blocks:
|
||||
self._add_block_id(block.block_id)
|
||||
|
||||
def append_token_ids(self, block_index: int, token_ids: List[int]) -> None:
|
||||
block = self._blocks[block_index]
|
||||
prev_block_id = block.block_id
|
||||
|
||||
block.append_token_ids(token_ids)
|
||||
|
||||
# CoW or promotion may update the internal block_id
|
||||
if prev_block_id != block.block_id:
|
||||
self._update_block_id(block_index, block.block_id)
|
||||
|
||||
def append(self, new_block: Block):
|
||||
self._blocks.append(new_block)
|
||||
self._add_block_id(new_block.block_id)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._blocks)
|
||||
|
||||
def __getitem__(self, block_index: int) -> Block:
|
||||
return self._blocks[block_index]
|
||||
|
||||
def __setitem__(self, block_index: int, new_block: Block) -> None:
|
||||
self._blocks[block_index] = new_block
|
||||
self._update_block_id(block_index, new_block.block_id)
|
||||
|
||||
def reset(self):
|
||||
self._blocks = []
|
||||
self._block_ids = []
|
||||
|
||||
def list(self) -> List[Block]:
|
||||
return self._blocks
|
||||
|
||||
def ids(self) -> List[int]:
|
||||
return self._block_ids
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheMetricData:
|
||||
"""A utility dataclass to maintain cache metric.
|
||||
To avoid overflow, we maintain the hit rate in block granularity, so that
|
||||
we can maintain a single hit rate for n_completed_block x block_size,
|
||||
and calculate the real time hit rate by the following:
|
||||
BS = The number of queries per block.
|
||||
nB = The number of completed blocks.
|
||||
HR = hit rate of (nB x BS) queries.
|
||||
Q = current number of queries (< BS).
|
||||
H = current number of hits (< BS).
|
||||
hit rate = ((HR x nB) + (H / Q) x (Q / BS)) / (nB + Q / BS)
|
||||
"""
|
||||
num_completed_blocks: int = 0
|
||||
completed_block_cache_hit_rate: float = 0.0
|
||||
num_incompleted_block_queries: int = 0
|
||||
num_incompleted_block_hit: int = 0
|
||||
block_size: int = 1000
|
||||
|
||||
def query(self, hit: bool):
|
||||
self.num_incompleted_block_queries += 1
|
||||
self.num_incompleted_block_hit += 1 if hit else 0
|
||||
|
||||
# When a block is completed, update the cache hit rate
|
||||
# and reset the incomplete numbers.
|
||||
if self.num_incompleted_block_queries == self.block_size:
|
||||
hit_rate = (self.num_incompleted_block_hit /
|
||||
self.num_incompleted_block_queries)
|
||||
self.completed_block_cache_hit_rate = (
|
||||
self.completed_block_cache_hit_rate * self.num_completed_blocks
|
||||
+ hit_rate) / (self.num_completed_blocks + 1)
|
||||
self.num_incompleted_block_queries = 0
|
||||
self.num_incompleted_block_hit = 0
|
||||
self.num_completed_blocks += 1
|
||||
|
||||
def get_hit_rate(self):
|
||||
incomplete_ratio = self.num_incompleted_block_queries / self.block_size
|
||||
total_blocks = self.num_completed_blocks + incomplete_ratio
|
||||
if total_blocks == 0:
|
||||
return 0.0
|
||||
|
||||
completed_block_hit, incompleted_block_hit = 0.0, 0.0
|
||||
if self.num_completed_blocks > 0:
|
||||
completed_block_hit = (self.completed_block_cache_hit_rate *
|
||||
self.num_completed_blocks)
|
||||
if self.num_incompleted_block_queries > 0:
|
||||
incompleted_hit_rate = (self.num_incompleted_block_hit /
|
||||
self.num_incompleted_block_queries)
|
||||
incompleted_block_hit = (incompleted_hit_rate * incomplete_ratio)
|
||||
return (completed_block_hit + incompleted_block_hit) / total_blocks
|
||||
|
||||
|
||||
def get_all_blocks_recursively(last_block: Block) -> List[Block]:
|
||||
"""Retrieves all the blocks in a sequence starting from the last block.
|
||||
|
||||
This function recursively traverses the sequence of blocks in reverse order,
|
||||
starting from the given last block, and returns a list of all the blocks in
|
||||
the sequence.
|
||||
|
||||
Args:
|
||||
last_block (Block): The last block in the sequence.
|
||||
|
||||
Returns:
|
||||
List[Block]: A list of all the blocks in the sequence, in the order they
|
||||
appear.
|
||||
"""
|
||||
|
||||
def recurse(block: Block, lst: List[Block]) -> None:
|
||||
if block.prev_block is not None:
|
||||
recurse(block.prev_block, lst)
|
||||
lst.append(block)
|
||||
|
||||
all_blocks: List[Block] = []
|
||||
recurse(last_block, all_blocks)
|
||||
return all_blocks
|
||||
@ -1,439 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Dict, FrozenSet, List, Optional, Tuple
|
||||
|
||||
from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId,
|
||||
DeviceAwareBlockAllocator)
|
||||
from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator
|
||||
from vllm.core.block.prefix_caching_block import PrefixCachingBlockAllocator
|
||||
from vllm.utils import Device
|
||||
|
||||
|
||||
class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
|
||||
"""A block allocator that can allocate blocks on both CPU and GPU memory.
|
||||
|
||||
This class implements the `DeviceAwareBlockAllocator` interface and provides
|
||||
functionality for allocating and managing blocks of memory on both CPU and
|
||||
GPU devices.
|
||||
|
||||
The `CpuGpuBlockAllocator` maintains separate memory pools for CPU and GPU
|
||||
blocks, and allows for allocation, deallocation, forking, and swapping of
|
||||
blocks across these memory pools.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
allocator_type: str,
|
||||
num_gpu_blocks: int,
|
||||
num_cpu_blocks: int,
|
||||
block_size: int,
|
||||
) -> DeviceAwareBlockAllocator:
|
||||
"""Creates a CpuGpuBlockAllocator instance with the specified
|
||||
configuration.
|
||||
|
||||
This static method creates and returns a CpuGpuBlockAllocator instance
|
||||
based on the provided parameters. It initializes the CPU and GPU block
|
||||
allocators with the specified number of blocks, block size, and
|
||||
allocator type.
|
||||
|
||||
Args:
|
||||
allocator_type (str): The type of block allocator to use for CPU
|
||||
and GPU blocks. Currently supported values are "naive" and
|
||||
"prefix_caching".
|
||||
num_gpu_blocks (int): The number of blocks to allocate for GPU
|
||||
memory.
|
||||
num_cpu_blocks (int): The number of blocks to allocate for CPU
|
||||
memory.
|
||||
block_size (int): The size of each block in number of tokens.
|
||||
|
||||
Returns:
|
||||
DeviceAwareBlockAllocator: A CpuGpuBlockAllocator instance with the
|
||||
specified configuration.
|
||||
|
||||
Notes:
|
||||
- The block IDs are assigned contiguously, with GPU block IDs coming
|
||||
before CPU block IDs.
|
||||
"""
|
||||
reserved_blocks = 0
|
||||
block_ids = list(
|
||||
range(reserved_blocks, num_gpu_blocks + num_cpu_blocks))
|
||||
num_gpu_blocks -= reserved_blocks
|
||||
gpu_block_ids = block_ids[:num_gpu_blocks]
|
||||
cpu_block_ids = block_ids[num_gpu_blocks:]
|
||||
|
||||
if allocator_type == "naive":
|
||||
gpu_allocator: BlockAllocator = NaiveBlockAllocator(
|
||||
create_block=NaiveBlock, # type: ignore
|
||||
num_blocks=num_gpu_blocks,
|
||||
block_size=block_size,
|
||||
block_ids=gpu_block_ids,
|
||||
)
|
||||
|
||||
cpu_allocator: BlockAllocator = NaiveBlockAllocator(
|
||||
create_block=NaiveBlock, # type: ignore
|
||||
num_blocks=num_cpu_blocks,
|
||||
block_size=block_size,
|
||||
block_ids=cpu_block_ids,
|
||||
)
|
||||
elif allocator_type == "prefix_caching":
|
||||
gpu_allocator = PrefixCachingBlockAllocator(
|
||||
num_blocks=num_gpu_blocks,
|
||||
block_size=block_size,
|
||||
block_ids=gpu_block_ids,
|
||||
)
|
||||
|
||||
cpu_allocator = PrefixCachingBlockAllocator(
|
||||
num_blocks=num_cpu_blocks,
|
||||
block_size=block_size,
|
||||
block_ids=cpu_block_ids,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown allocator type {allocator_type=}")
|
||||
|
||||
return CpuGpuBlockAllocator(
|
||||
cpu_block_allocator=cpu_allocator,
|
||||
gpu_block_allocator=gpu_allocator,
|
||||
)
|
||||
|
||||
def __init__(self, cpu_block_allocator: BlockAllocator,
|
||||
gpu_block_allocator: BlockAllocator):
|
||||
assert not (
|
||||
cpu_block_allocator.all_block_ids
|
||||
& gpu_block_allocator.all_block_ids
|
||||
), "cpu and gpu block allocators can't have intersection of block ids"
|
||||
|
||||
self._allocators = {
|
||||
Device.CPU: cpu_block_allocator,
|
||||
Device.GPU: gpu_block_allocator,
|
||||
}
|
||||
|
||||
self._swap_mapping: Dict[int, int] = {}
|
||||
self._null_block: Optional[Block] = None
|
||||
|
||||
self._block_ids_to_allocator: Dict[int, BlockAllocator] = {}
|
||||
for _, allocator in self._allocators.items():
|
||||
for block_id in allocator.all_block_ids:
|
||||
self._block_ids_to_allocator[block_id] = allocator
|
||||
|
||||
def allocate_or_get_null_block(self) -> Block:
|
||||
if self._null_block is None:
|
||||
self._null_block = NullBlock(
|
||||
self.allocate_mutable_block(None, Device.GPU))
|
||||
return self._null_block
|
||||
|
||||
def allocate_mutable_block(self,
|
||||
prev_block: Optional[Block],
|
||||
device: Device,
|
||||
extra_hash: Optional[int] = None) -> Block:
|
||||
"""Allocates a new mutable block on the specified device.
|
||||
|
||||
Args:
|
||||
prev_block (Optional[Block]): The previous block to in the sequence.
|
||||
Used for prefix hashing.
|
||||
device (Device): The device on which to allocate the new block.
|
||||
extra_hash (Optional[int]): The hash value of additional
|
||||
factors, such as adapters, that influence the block hash
|
||||
in the prefix caching block.
|
||||
|
||||
Returns:
|
||||
Block: The newly allocated mutable block.
|
||||
"""
|
||||
return self._allocators[device].allocate_mutable_block(
|
||||
prev_block, extra_hash=extra_hash)
|
||||
|
||||
def allocate_immutable_blocks(
|
||||
self,
|
||||
prev_block: Optional[Block],
|
||||
block_token_ids: List[List[int]],
|
||||
device: Device,
|
||||
extra_hash: Optional[int] = None) -> List[Block]:
|
||||
"""Allocates a new group of immutable blocks with the provided block
|
||||
token IDs on the specified device.
|
||||
|
||||
Args:
|
||||
prev_block (Optional[Block]): The previous block in the sequence.
|
||||
Used for prefix hashing.
|
||||
block_token_ids (List[int]): The list of block token IDs to be
|
||||
stored in the new blocks.
|
||||
device (Device): The device on which to allocate the new block.
|
||||
extra_hash (Optional[int]): The hash value of additional
|
||||
factors, such as adapters, that influence the block hash
|
||||
in the prefix caching block.
|
||||
|
||||
Returns:
|
||||
List[Block]: The newly allocated list of immutable blocks
|
||||
containing the provided block token IDs.
|
||||
"""
|
||||
return self._allocators[device].allocate_immutable_blocks(
|
||||
prev_block, block_token_ids, extra_hash=extra_hash)
|
||||
|
||||
def allocate_immutable_block(self,
|
||||
prev_block: Optional[Block],
|
||||
token_ids: List[int],
|
||||
device: Device,
|
||||
extra_hash: Optional[int] = None) -> Block:
|
||||
"""Allocates a new immutable block with the provided token IDs on the
|
||||
specified device.
|
||||
|
||||
Args:
|
||||
prev_block (Optional[Block]): The previous block in the sequence.
|
||||
Used for prefix hashing.
|
||||
token_ids (List[int]): The list of token IDs to be stored in the new
|
||||
block.
|
||||
device (Device): The device on which to allocate the new block.
|
||||
extra_hash (Optional[int]): The hash value of additional
|
||||
factors, such as adapters, that influence the block hash
|
||||
in the prefix caching block.
|
||||
|
||||
Returns:
|
||||
Block: The newly allocated immutable block containing the provided
|
||||
token IDs.
|
||||
"""
|
||||
return self._allocators[device].allocate_immutable_block(
|
||||
prev_block, token_ids, extra_hash=extra_hash)
|
||||
|
||||
def free(self, block: Block) -> None:
|
||||
"""Frees the memory occupied by the given block.
|
||||
|
||||
Args:
|
||||
block (Block): The block to be freed.
|
||||
"""
|
||||
# Null block should never be freed
|
||||
if isinstance(block, NullBlock):
|
||||
return
|
||||
block_id = block.block_id
|
||||
assert block_id is not None
|
||||
allocator = self._block_ids_to_allocator[block_id]
|
||||
allocator.free(block)
|
||||
|
||||
def fork(self, last_block: Block) -> List[Block]:
|
||||
"""Creates a new sequence of blocks that shares the same underlying
|
||||
memory as the original sequence.
|
||||
|
||||
Args:
|
||||
last_block (Block): The last block in the original sequence.
|
||||
|
||||
Returns:
|
||||
List[Block]: A new list of blocks that shares the same memory as the
|
||||
original sequence.
|
||||
"""
|
||||
# do not attempt to fork the null block
|
||||
assert not isinstance(last_block, NullBlock)
|
||||
block_id = last_block.block_id
|
||||
assert block_id is not None
|
||||
allocator = self._block_ids_to_allocator[block_id]
|
||||
return allocator.fork(last_block)
|
||||
|
||||
def get_num_free_blocks(self, device: Device) -> int:
|
||||
"""Returns the number of free blocks available on the specified device.
|
||||
|
||||
Args:
|
||||
device (Device): The device for which to query the number of free
|
||||
blocks. AssertionError is raised if None is passed.
|
||||
|
||||
Returns:
|
||||
int: The number of free blocks available on the specified device.
|
||||
"""
|
||||
return self._allocators[device].get_num_free_blocks()
|
||||
|
||||
def get_num_total_blocks(self, device: Device) -> int:
|
||||
return self._allocators[device].get_num_total_blocks()
|
||||
|
||||
def get_physical_block_id(self, device: Device, absolute_id: int) -> int:
|
||||
"""Returns the zero-offset block id on certain device given the
|
||||
absolute block id.
|
||||
|
||||
Args:
|
||||
device (Device): The device for which to query relative block id.
|
||||
absolute_id (int): The absolute block id for the block in
|
||||
whole allocator.
|
||||
|
||||
Returns:
|
||||
int: The zero-offset block id on certain device.
|
||||
"""
|
||||
return self._allocators[device].get_physical_block_id(absolute_id)
|
||||
|
||||
def swap(self, blocks: List[Block], src_device: Device,
|
||||
dst_device: Device) -> Dict[int, int]:
|
||||
"""Execute the swap for the given blocks from source_device
|
||||
on to dest_device, save the current swap mapping and append
|
||||
them to the accumulated `self._swap_mapping` for each
|
||||
scheduling move.
|
||||
|
||||
Args:
|
||||
blocks: List of blocks to be swapped.
|
||||
src_device (Device): Device to swap the 'blocks' from.
|
||||
dst_device (Device): Device to swap the 'blocks' to.
|
||||
|
||||
Returns:
|
||||
Dict[int, int]: Swap mapping from source_device
|
||||
on to dest_device.
|
||||
"""
|
||||
src_block_ids = [block.block_id for block in blocks]
|
||||
self._allocators[src_device].swap_out(blocks)
|
||||
self._allocators[dst_device].swap_in(blocks)
|
||||
dst_block_ids = [block.block_id for block in blocks]
|
||||
|
||||
current_swap_mapping: Dict[int, int] = {}
|
||||
for src_block_id, dst_block_id in zip(src_block_ids, dst_block_ids):
|
||||
if src_block_id is not None and dst_block_id is not None:
|
||||
self._swap_mapping[src_block_id] = dst_block_id
|
||||
current_swap_mapping[src_block_id] = dst_block_id
|
||||
return current_swap_mapping
|
||||
|
||||
def get_num_full_blocks_touched(self, blocks: List[Block],
|
||||
device: Device) -> int:
|
||||
"""Returns the number of full blocks that will be touched by
|
||||
swapping in/out the given blocks on to the 'device'.
|
||||
|
||||
Args:
|
||||
blocks: List of blocks to be swapped.
|
||||
device (Device): Device to swap the 'blocks' on.
|
||||
|
||||
Returns:
|
||||
int: the number of full blocks that will be touched by
|
||||
swapping in/out the given blocks on to the 'device'.
|
||||
Non full blocks are ignored when deciding the number
|
||||
of blocks to touch.
|
||||
"""
|
||||
return self._allocators[device].get_num_full_blocks_touched(blocks)
|
||||
|
||||
def clear_copy_on_writes(self) -> List[Tuple[int, int]]:
|
||||
"""Clears the copy-on-write (CoW) state and returns the mapping of
|
||||
source to destination block IDs.
|
||||
|
||||
Returns:
|
||||
List[Tuple[int, int]]: A list mapping source block IDs to
|
||||
destination block IDs.
|
||||
"""
|
||||
# CoW only supported on GPU
|
||||
device = Device.GPU
|
||||
return self._allocators[device].clear_copy_on_writes()
|
||||
|
||||
def mark_blocks_as_accessed(self, block_ids: List[int],
|
||||
now: float) -> None:
|
||||
"""Mark blocks as accessed, only use for prefix caching."""
|
||||
# Prefix caching only supported on GPU.
|
||||
device = Device.GPU
|
||||
return self._allocators[device].mark_blocks_as_accessed(block_ids, now)
|
||||
|
||||
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
|
||||
"""Mark blocks as accessed, only use for prefix caching."""
|
||||
# Prefix caching only supported on GPU.
|
||||
device = Device.GPU
|
||||
return self._allocators[device].mark_blocks_as_computed(block_ids)
|
||||
|
||||
def get_common_computed_block_ids(
|
||||
self, computed_seq_block_ids: List[List[int]]) -> List[int]:
|
||||
# Prefix caching only supported on GPU.
|
||||
device = Device.GPU
|
||||
return self._allocators[device].get_common_computed_block_ids(
|
||||
computed_seq_block_ids)
|
||||
|
||||
@property
|
||||
def all_block_ids(self) -> FrozenSet[int]:
|
||||
return frozenset(self._block_ids_to_allocator.keys())
|
||||
|
||||
def get_prefix_cache_hit_rate(self, device: Device) -> float:
|
||||
"""Prefix cache hit rate. -1 means not supported or disabled."""
|
||||
assert device in self._allocators
|
||||
return self._allocators[device].get_prefix_cache_hit_rate()
|
||||
|
||||
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
|
||||
"""Reset prefix cache for specified or all devices."""
|
||||
if device:
|
||||
return self._allocators[device].reset_prefix_cache()
|
||||
success = True
|
||||
for allocator in self._allocators.values():
|
||||
success = success and allocator.reset_prefix_cache()
|
||||
return success
|
||||
|
||||
def get_and_reset_swaps(self) -> List[Tuple[int, int]]:
|
||||
"""Returns and clears the mapping of source to destination block IDs.
|
||||
Will be called after every swapping operations for now, and after every
|
||||
schedule when BlockManagerV2 become default. Currently not useful.
|
||||
|
||||
Returns:
|
||||
List[Tuple[int, int]]: A mapping of source to destination block IDs.
|
||||
"""
|
||||
mapping = self._swap_mapping.copy()
|
||||
self._swap_mapping.clear()
|
||||
return list(mapping.items())
|
||||
|
||||
def find_cached_blocks_prefix(
|
||||
self,
|
||||
block_hashes: List[int],
|
||||
device: Device = Device.GPU,
|
||||
) -> List[int]:
|
||||
return self._allocators[device].find_cached_blocks_prefix(block_hashes)
|
||||
|
||||
|
||||
class NullBlock(Block):
|
||||
"""
|
||||
Null blocks are used as a placeholders for KV cache blocks that have
|
||||
been dropped due to sliding window.
|
||||
This implementation just wraps an ordinary block and prevents it from
|
||||
being modified. It also allows for testing if a block is NullBlock
|
||||
via isinstance().
|
||||
"""
|
||||
|
||||
def __init__(self, proxy: Block):
|
||||
super().__init__()
|
||||
self._proxy = proxy
|
||||
|
||||
def append_token_ids(self, token_ids: List[BlockId]):
|
||||
raise ValueError("null block should not be modified")
|
||||
|
||||
@property
|
||||
def block_id(self):
|
||||
return self._proxy.block_id
|
||||
|
||||
@block_id.setter
|
||||
def block_id(self, value: Optional[BlockId]):
|
||||
raise ValueError("null block should not be modified")
|
||||
|
||||
@property
|
||||
def token_ids(self) -> List[BlockId]:
|
||||
return self._proxy.token_ids
|
||||
|
||||
@property
|
||||
def num_tokens_total(self) -> int:
|
||||
raise NotImplementedError(
|
||||
"num_tokens_total is not used for null block")
|
||||
|
||||
@property
|
||||
def num_empty_slots(self) -> BlockId:
|
||||
return self._proxy.num_empty_slots
|
||||
|
||||
@property
|
||||
def is_full(self):
|
||||
return self._proxy.is_full
|
||||
|
||||
@property
|
||||
def prev_block(self):
|
||||
return self._proxy.prev_block
|
||||
|
||||
@property
|
||||
def extra_hash(self):
|
||||
return None
|
||||
|
||||
@property
|
||||
def computed(self):
|
||||
return self._proxy.computed
|
||||
|
||||
@computed.setter
|
||||
def computed(self, value):
|
||||
self._proxy.computed = value
|
||||
|
||||
@property
|
||||
def last_accessed(self) -> float:
|
||||
return self._proxy.last_accessed
|
||||
|
||||
@last_accessed.setter
|
||||
def last_accessed(self, last_accessed_ts: float):
|
||||
self._proxy.last_accessed = last_accessed_ts
|
||||
|
||||
@property
|
||||
def content_hash(self):
|
||||
return self._proxy.content_hash
|
||||
@ -1,319 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, FrozenSet, List, Optional, Protocol, Tuple
|
||||
|
||||
from vllm.utils import Device
|
||||
|
||||
BlockId = int
|
||||
|
||||
|
||||
class Block(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def append_token_ids(self, token_ids: List[int]) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def block_id(self) -> Optional[int]:
|
||||
pass
|
||||
|
||||
@block_id.setter
|
||||
@abstractmethod
|
||||
def block_id(self, value: Optional[int]) -> None:
|
||||
"""NOTE: Do not use this API outside Block."""
|
||||
self._block_id = value
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def token_ids(self) -> List[int]:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def num_tokens_total(self) -> int:
|
||||
"""The number of tokens till the current block (inclusive)
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def num_empty_slots(self) -> int:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def is_full(self) -> bool:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def prev_block(self) -> Optional["Block"]:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def extra_hash(self) -> Optional[int]:
|
||||
return None
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def computed(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@computed.setter
|
||||
@abstractmethod
|
||||
def computed(self, value) -> bool:
|
||||
"""Should be only used by PrefixCacingAllocator"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def last_accessed(self) -> float:
|
||||
raise NotImplementedError
|
||||
|
||||
@last_accessed.setter
|
||||
@abstractmethod
|
||||
def last_accessed(self, last_accessed_ts: float):
|
||||
raise NotImplementedError
|
||||
|
||||
class Factory(Protocol):
|
||||
|
||||
@abstractmethod
|
||||
def __call__(
|
||||
self,
|
||||
prev_block: Optional["Block"],
|
||||
token_ids: List[int],
|
||||
block_size: int,
|
||||
allocator: "BlockAllocator",
|
||||
block_id: Optional[int] = None,
|
||||
computed: bool = False,
|
||||
extra_hash: Optional[int] = None,
|
||||
) -> "Block":
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def content_hash(self) -> Optional[int]:
|
||||
"""Return the content-based hash of the current block, or None if it is
|
||||
not yet defined or not supported.
|
||||
|
||||
For the content-based hash to be defined, the current block must be
|
||||
full.
|
||||
"""
|
||||
return None
|
||||
|
||||
|
||||
class BlockAllocator(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def allocate_mutable_block(self, prev_block: Optional[Block],
|
||||
extra_hash: Optional[int]) -> Block:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def allocate_immutable_block(self, prev_block: Optional[Block],
|
||||
token_ids: List[int],
|
||||
extra_hash: Optional[int]) -> Block:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def allocate_immutable_blocks(self, prev_block: Optional[Block],
|
||||
block_token_ids: List[List[int]],
|
||||
extra_hash: Optional[int]) -> List[Block]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def free(self, block: Block) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def fork(self, last_block: Block) -> List[Block]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_num_total_blocks(self) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_num_free_blocks(self) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_physical_block_id(self, absolute_id: int) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def swap_out(self, blocks: List[Block]) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def swap_in(self, blocks: List[Block]) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def all_block_ids(self) -> FrozenSet[int]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def clear_copy_on_writes(self) -> List[Tuple[int, int]]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def mark_blocks_as_accessed(self, block_ids: List[int],
|
||||
now: float) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_common_computed_block_ids(
|
||||
self, computed_seq_block_ids: List[List[int]]) -> List[int]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cow_block_if_not_appendable(self, block: Block) -> BlockId:
|
||||
"""NOTE: This should not be used besides Block"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def promote_to_immutable_block(self, block: Block) -> BlockId:
|
||||
"""NOTE: This should not be used besides Block"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_num_full_blocks_touched(self, blocks: List[Block]) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_prefix_cache_hit_rate(self) -> float:
|
||||
"""Prefix cache hit rate. -1 means not supported or disabled."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
"""Reset prefix cache."""
|
||||
pass
|
||||
|
||||
class NoFreeBlocksError(ValueError):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def find_cached_blocks_prefix(
|
||||
self,
|
||||
block_hashes: List[int],
|
||||
) -> List[int]:
|
||||
pass
|
||||
|
||||
|
||||
class DeviceAwareBlockAllocator(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def allocate_mutable_block(self,
|
||||
prev_block: Optional[Block],
|
||||
device: Device,
|
||||
extra_hash: Optional[int] = None) -> Block:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def allocate_immutable_block(self,
|
||||
prev_block: Optional[Block],
|
||||
token_ids: List[int],
|
||||
device: Device,
|
||||
extra_hash: Optional[int] = None) -> Block:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def allocate_immutable_blocks(
|
||||
self,
|
||||
prev_block: Optional[Block],
|
||||
block_token_ids: List[List[int]],
|
||||
device: Device,
|
||||
extra_hash: Optional[int] = None,
|
||||
) -> List[Block]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_num_free_blocks(self, device: Device) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_num_total_blocks(self, device: Device) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def free(self, block: Block) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def fork(self, last_block: Block) -> List[Block]:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def all_block_ids(self) -> FrozenSet[int]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def clear_copy_on_writes(self) -> List[Tuple[int, int]]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def mark_blocks_as_accessed(self, block_ids: List[int],
|
||||
now: float) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_common_computed_block_ids(
|
||||
self, computed_seq_block_ids: List[List[int]]) -> List[int]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_num_full_blocks_touched(self, blocks: List[Block],
|
||||
device: Device) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def swap(self, blocks: List[Block], src_device: Device,
|
||||
dst_device: Device) -> Dict[int, int]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_physical_block_id(self, device: Device, absolute_id: int) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def allocate_or_get_null_block(self) -> Block:
|
||||
"""
|
||||
Null blocks are used as a placeholders for KV cache blocks that have
|
||||
been dropped due to sliding window.
|
||||
There is at most one null block per allocator.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_prefix_cache_hit_rate(self, device: Device) -> float:
|
||||
"""Prefix cache hit rate. -1 means not supported or disabled."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
|
||||
"""Reset prefix cache."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def find_cached_blocks_prefix(
|
||||
self,
|
||||
block_hashes: List[int],
|
||||
device: Device = Device.GPU,
|
||||
) -> List[int]:
|
||||
pass
|
||||
@ -1,466 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections import deque
|
||||
from typing import Deque, FrozenSet, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
from vllm.core.block.common import (BlockPool, CopyOnWriteTracker, RefCounter,
|
||||
get_all_blocks_recursively)
|
||||
from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
|
||||
|
||||
Refcount = int
|
||||
|
||||
|
||||
class NaiveBlockAllocator(BlockAllocator):
|
||||
"""A simple block allocator that manages blocks of memory without prefix
|
||||
caching.
|
||||
|
||||
Args:
|
||||
create_block (Block.Factory): A factory function for creating new
|
||||
blocks. This is used when a NaiveBlockAllocator is composed within
|
||||
a prefix caching allocator -- the naive block allocator must
|
||||
construct prefix caching blocks (but shouldn't know anything else
|
||||
about them).
|
||||
num_blocks (int): The total number of blocks to manage.
|
||||
block_size (int): The size of each block in tokens.
|
||||
block_ids (Optional[Iterable[int]], optional): An optional iterable of
|
||||
block IDs. If not provided, block IDs will be assigned sequentially
|
||||
from 0 to num_blocks - 1.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
create_block: Block.Factory,
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
block_ids: Optional[Iterable[int]] = None,
|
||||
block_pool: Optional[BlockPool] = None,
|
||||
):
|
||||
if block_ids is None:
|
||||
block_ids = range(num_blocks)
|
||||
|
||||
self._free_block_indices: Deque[BlockId] = deque(block_ids)
|
||||
self._all_block_indices = frozenset(block_ids)
|
||||
assert len(self._all_block_indices) == num_blocks
|
||||
|
||||
self._refcounter = RefCounter(
|
||||
all_block_indices=self._free_block_indices)
|
||||
self._block_size = block_size
|
||||
|
||||
self._cow_tracker = CopyOnWriteTracker(
|
||||
refcounter=self._refcounter.as_readonly())
|
||||
|
||||
if block_pool is None:
|
||||
extra_factor = 4
|
||||
# Pre-allocate "num_blocks * extra_factor" block objects.
|
||||
# The "* extra_factor" is a buffer to allow more block objects
|
||||
# than physical blocks
|
||||
self._block_pool = BlockPool(self._block_size, create_block, self,
|
||||
num_blocks * extra_factor)
|
||||
else:
|
||||
# In this case, the block pool is provided by the caller,
|
||||
# which means that there is most likely a need to share
|
||||
# a block pool between allocators
|
||||
self._block_pool = block_pool
|
||||
|
||||
def allocate_immutable_block(self,
|
||||
prev_block: Optional[Block],
|
||||
token_ids: List[int],
|
||||
extra_hash: Optional[int] = None,
|
||||
device: Optional[Device] = None) -> Block:
|
||||
"""Allocates a new immutable block with the given token IDs, linked to
|
||||
the previous block.
|
||||
|
||||
Args:
|
||||
prev_block (Optional[Block]): The previous block in the sequence. If
|
||||
None, then the block to be allocated is the first block in the
|
||||
sequence.
|
||||
token_ids (List[int]): The token IDs to be stored in the new block.
|
||||
|
||||
Returns:
|
||||
Block: The newly allocated immutable block.
|
||||
"""
|
||||
assert device is None
|
||||
block = self.allocate_mutable_block(prev_block=prev_block)
|
||||
block.append_token_ids(token_ids)
|
||||
return block
|
||||
|
||||
def allocate_immutable_blocks(
|
||||
self,
|
||||
prev_block: Optional[Block],
|
||||
block_token_ids: List[List[int]],
|
||||
extra_hash: Optional[int] = None,
|
||||
device: Optional[Device] = None) -> List[Block]:
|
||||
assert device is None
|
||||
num_blocks = len(block_token_ids)
|
||||
|
||||
block_ids = []
|
||||
for i in range(num_blocks):
|
||||
block_ids.append(self._allocate_block_id())
|
||||
|
||||
blocks = []
|
||||
for i in range(num_blocks):
|
||||
prev_block = self._block_pool.init_block(
|
||||
prev_block=prev_block,
|
||||
token_ids=block_token_ids[i],
|
||||
block_size=self._block_size,
|
||||
physical_block_id=block_ids[i])
|
||||
blocks.append(prev_block)
|
||||
|
||||
return blocks
|
||||
|
||||
def allocate_mutable_block(self,
|
||||
prev_block: Optional[Block],
|
||||
extra_hash: Optional[int] = None,
|
||||
device: Optional[Device] = None) -> Block:
|
||||
"""Allocates a new mutable block, linked to the previous block.
|
||||
|
||||
Args:
|
||||
prev_block (Optional[Block]): The previous block in the sequence. If
|
||||
None, then the block to be allocated is the first block in the
|
||||
sequence.
|
||||
|
||||
Returns:
|
||||
Block: The newly allocated mutable block.
|
||||
"""
|
||||
assert device is None
|
||||
block_id = self._allocate_block_id()
|
||||
block = self._block_pool.init_block(prev_block=prev_block,
|
||||
token_ids=[],
|
||||
block_size=self._block_size,
|
||||
physical_block_id=block_id)
|
||||
return block
|
||||
|
||||
def _allocate_block_id(self) -> BlockId:
|
||||
if not self._free_block_indices:
|
||||
raise BlockAllocator.NoFreeBlocksError()
|
||||
|
||||
block_id = self._free_block_indices.popleft()
|
||||
self._refcounter.incr(block_id)
|
||||
return block_id
|
||||
|
||||
def _free_block_id(self, block: Union[Block, BlockId]) -> None:
|
||||
if isinstance(block, Block):
|
||||
block_id = block.block_id
|
||||
block.block_id = None
|
||||
else:
|
||||
block_id = block
|
||||
assert block_id is not None
|
||||
|
||||
refcount = self._refcounter.decr(block_id)
|
||||
if refcount == 0:
|
||||
self._free_block_indices.appendleft(block_id)
|
||||
|
||||
def free(self, block: Block, keep_block_object: bool = False) -> None:
|
||||
# Release the physical block id
|
||||
self._free_block_id(block)
|
||||
|
||||
# Release the block object
|
||||
if not keep_block_object:
|
||||
self._block_pool.free_block(block)
|
||||
|
||||
def free_block_id(self, block_id: BlockId) -> None:
|
||||
self._free_block_id(block_id)
|
||||
|
||||
def fork(self, last_block: Block) -> List[Block]:
|
||||
"""Creates a new sequence of blocks that shares the same underlying
|
||||
memory as the original sequence.
|
||||
|
||||
Args:
|
||||
last_block (Block): The last block in the original sequence.
|
||||
|
||||
Returns:
|
||||
List[Block]: The new sequence of blocks that shares the same memory
|
||||
as the original sequence.
|
||||
"""
|
||||
source_blocks = get_all_blocks_recursively(last_block)
|
||||
|
||||
forked_blocks: List[Block] = []
|
||||
prev_block = None
|
||||
for block in source_blocks:
|
||||
|
||||
# Increment refcount for each block.
|
||||
assert block.block_id is not None
|
||||
refcount = self._refcounter.incr(block.block_id)
|
||||
assert refcount != 1, "can't fork freed block"
|
||||
|
||||
forked_block = self._block_pool.init_block(
|
||||
prev_block=prev_block,
|
||||
token_ids=block.token_ids,
|
||||
block_size=self._block_size,
|
||||
physical_block_id=block.block_id)
|
||||
|
||||
forked_blocks.append(forked_block)
|
||||
prev_block = forked_blocks[-1]
|
||||
|
||||
return forked_blocks
|
||||
|
||||
def get_num_free_blocks(self) -> int:
|
||||
return len(self._free_block_indices)
|
||||
|
||||
def get_num_total_blocks(self) -> int:
|
||||
return len(self._all_block_indices)
|
||||
|
||||
def get_physical_block_id(self, absolute_id: int) -> int:
|
||||
"""Returns the zero-offset block id on certain block allocator
|
||||
given the absolute block id.
|
||||
|
||||
Args:
|
||||
absolute_id (int): The absolute block id for the block
|
||||
in whole allocator.
|
||||
|
||||
Returns:
|
||||
int: The zero-offset block id on certain device.
|
||||
"""
|
||||
return sorted(self._all_block_indices).index(absolute_id)
|
||||
|
||||
@property
|
||||
def refcounter(self):
|
||||
return self._refcounter
|
||||
|
||||
@property
|
||||
def all_block_ids(self) -> FrozenSet[int]:
|
||||
return self._all_block_indices
|
||||
|
||||
def cow_block_if_not_appendable(self, block: Block) -> BlockId:
|
||||
"""Performs a copy-on-write operation on the given block if it is not
|
||||
appendable.
|
||||
|
||||
Args:
|
||||
block (Block): The block to check for copy-on-write.
|
||||
|
||||
Returns:
|
||||
BlockId: The block index of the new block if a copy-on-write
|
||||
operation was performed, or the original block index if
|
||||
no copy-on-write was necessary.
|
||||
"""
|
||||
src_block_id = block.block_id
|
||||
assert src_block_id is not None
|
||||
|
||||
if self._cow_tracker.is_appendable(block):
|
||||
return src_block_id
|
||||
|
||||
self._free_block_id(block)
|
||||
trg_block_id = self._allocate_block_id()
|
||||
|
||||
self._cow_tracker.record_cow(src_block_id, trg_block_id)
|
||||
|
||||
return trg_block_id
|
||||
|
||||
def clear_copy_on_writes(self) -> List[Tuple[BlockId, BlockId]]:
|
||||
"""Returns the copy-on-write source->destination mapping and clears it.
|
||||
|
||||
Returns:
|
||||
List[Tuple[BlockId, BlockId]]: A list mapping source
|
||||
block indices to destination block indices.
|
||||
"""
|
||||
return self._cow_tracker.clear_cows()
|
||||
|
||||
def mark_blocks_as_accessed(self, block_ids: List[int],
|
||||
now: float) -> None:
|
||||
"""Mark blocks as accessed, used in prefix caching.
|
||||
|
||||
Since the naive allocator does not implement prefix caching, we do
|
||||
nothing.
|
||||
"""
|
||||
pass
|
||||
|
||||
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
|
||||
"""Mark blocks as computed, used in prefix caching.
|
||||
|
||||
Since the naive allocator does not implement prefix caching, we do
|
||||
nothing.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_common_computed_block_ids(
|
||||
self, computed_seq_block_ids: List[List[int]]) -> List[int]:
|
||||
"""Determine blocks that can be skipped in prefill.
|
||||
|
||||
Since the naive allocator does not support prefix caching, always return
|
||||
an empty list.
|
||||
"""
|
||||
return []
|
||||
|
||||
def promote_to_immutable_block(self, block: Block) -> BlockId:
|
||||
raise NotImplementedError("There is no promotion for naive blocks")
|
||||
|
||||
def get_num_full_blocks_touched(self, blocks: List[Block]) -> int:
|
||||
"""Returns the number of full blocks that will be touched by
|
||||
swapping in/out.
|
||||
|
||||
Args:
|
||||
blocks: List of blocks to be swapped.
|
||||
Returns:
|
||||
int: the number of full blocks that will be touched by
|
||||
swapping in/out the given blocks. Non full blocks are ignored
|
||||
when deciding the number of blocks to touch.
|
||||
"""
|
||||
# NOTE: for naive block, we use set to eliminate common blocks among
|
||||
# seqs, also we compare the empty slots in the mutable blocks with
|
||||
# lookahead slots to get the number of unique new block that are
|
||||
# needed.
|
||||
old_block_set = set()
|
||||
for block in blocks:
|
||||
if block.is_full:
|
||||
old_block_set.add(block)
|
||||
return len(old_block_set)
|
||||
|
||||
def swap_out(self, blocks: List[Block]) -> None:
|
||||
for block in blocks:
|
||||
self._free_block_id(block)
|
||||
|
||||
def swap_in(self, blocks: List[Block]) -> None:
|
||||
for block in blocks:
|
||||
# Here we allocate either immutable or mutable block and then
|
||||
# extract its block_id. Note that the block object is released
|
||||
# and the block_id is assigned to "block" to allow reusing the
|
||||
# existing "block" object
|
||||
if block.is_full:
|
||||
tmp_block = self.allocate_immutable_block(
|
||||
prev_block=block.prev_block, token_ids=block.token_ids)
|
||||
else:
|
||||
tmp_block = self.allocate_mutable_block(
|
||||
prev_block=block.prev_block)
|
||||
tmp_block.append_token_ids(block.token_ids)
|
||||
|
||||
block_id = tmp_block.block_id
|
||||
tmp_block.block_id = None
|
||||
self._block_pool.free_block(tmp_block)
|
||||
|
||||
block.block_id = block_id # Assign block_id
|
||||
|
||||
def get_prefix_cache_hit_rate(self) -> float:
|
||||
return -1
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
"""No prefix cache for naive block allocator."""
|
||||
return True
|
||||
|
||||
def find_cached_blocks_prefix(self, block_hashes: List[int]) -> List[int]:
|
||||
# Not applicable for naive block allocator.
|
||||
return []
|
||||
|
||||
|
||||
class NaiveBlock(Block):
|
||||
"""An implementation of the Block class that does not support prefix
|
||||
caching.
|
||||
|
||||
The NaiveBlock class represents a block of token IDs with a fixed size. It
|
||||
provides methods for appending token IDs to the block and manages copy-on
|
||||
-write operations when necessary.
|
||||
|
||||
Args:
|
||||
prev_block (Block): The previous block in the sequence.
|
||||
token_ids (List[int]): The initial token IDs to be stored in the block.
|
||||
block_size (int): The maximum number of token IDs that can be stored in
|
||||
the block.
|
||||
allocator (BlockAllocator): The block allocator associated with this
|
||||
block.
|
||||
block_id (Optional[int], optional): The physical block index
|
||||
of this block. Defaults to None, which means no allocation has been
|
||||
made.
|
||||
_cow_target (Optional[Block], optional): The copy-on-write target block.
|
||||
If not provided, it defaults to self.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
prev_block: Optional[Block],
|
||||
token_ids: List[int],
|
||||
block_size: int,
|
||||
allocator: BlockAllocator,
|
||||
block_id: Optional[int] = None,
|
||||
_cow_target: Optional[Block] = None,
|
||||
extra_hash: Optional[int] = None):
|
||||
self._token_ids: List[int] = []
|
||||
self._block_size = block_size
|
||||
self._prev_block = prev_block
|
||||
self._block_id = block_id
|
||||
self._allocator = allocator
|
||||
self._cow_target = _cow_target if _cow_target is not None else self
|
||||
|
||||
self._append_token_ids_no_cow(token_ids)
|
||||
|
||||
def append_token_ids(self, token_ids: List[int]) -> None:
|
||||
"""Appends the given token IDs to the block and performs a
|
||||
copy-on-write if necessary.
|
||||
|
||||
Args:
|
||||
token_ids (Optional[List[int]]): The token IDs to be appended
|
||||
to the block.
|
||||
"""
|
||||
self._append_token_ids_no_cow(token_ids)
|
||||
|
||||
if self._block_id is not None:
|
||||
self._block_id = (self._allocator.cow_block_if_not_appendable(
|
||||
self._cow_target))
|
||||
|
||||
def _append_token_ids_no_cow(self, token_ids: List[int]) -> None:
|
||||
"""Appends the given token IDs to the block
|
||||
|
||||
Args:
|
||||
token_ids (List[int]): The token IDs to be appended to the block.
|
||||
"""
|
||||
if len(token_ids) == 0:
|
||||
return
|
||||
|
||||
assert len(token_ids) <= self.num_empty_slots
|
||||
|
||||
self._token_ids.extend(token_ids)
|
||||
|
||||
@property
|
||||
def computed(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@computed.setter
|
||||
def computed(self, value) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def last_accessed(self) -> float:
|
||||
raise NotImplementedError
|
||||
|
||||
@last_accessed.setter
|
||||
def last_accessed(self, last_accessed_ts: float):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def block_id(self) -> Optional[int]:
|
||||
return self._block_id
|
||||
|
||||
@block_id.setter
|
||||
def block_id(self, value: Optional[int]) -> None:
|
||||
self._block_id = value
|
||||
|
||||
@property
|
||||
def is_full(self) -> bool:
|
||||
return self.num_empty_slots == 0
|
||||
|
||||
@property
|
||||
def num_empty_slots(self) -> int:
|
||||
return self._block_size - len(self.token_ids)
|
||||
|
||||
@property
|
||||
def token_ids(self) -> List[int]:
|
||||
return self._token_ids
|
||||
|
||||
@property
|
||||
def num_tokens_total(self) -> int:
|
||||
raise NotImplementedError(
|
||||
"num_tokens_total is not used for naive block")
|
||||
|
||||
@property
|
||||
def block_size(self) -> int:
|
||||
return self._block_size
|
||||
|
||||
@property
|
||||
def prev_block(self) -> Optional["Block"]:
|
||||
return self._prev_block
|
||||
|
||||
@property
|
||||
def extra_hash(self):
|
||||
return None
|
||||
|
||||
@property
|
||||
def content_hash(self) -> Optional[int]:
|
||||
return None
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,28 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Block manager utils."""
|
||||
from vllm.sequence import SequenceGroup
|
||||
from vllm.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE,
|
||||
STR_NOT_IMPL_ENC_DEC_SWA)
|
||||
|
||||
|
||||
def check_no_caching_or_swa_for_blockmgr_encdec(
|
||||
block_mgr, seq_group: SequenceGroup) -> None:
|
||||
'''
|
||||
Enforce that prefix caching & sliding-window attention (SWA)
|
||||
are currently unsupported *specifically* for encoder/decoder models.
|
||||
|
||||
Raises NotImplementedError if unsupported scenario is detected.
|
||||
|
||||
Arguments:
|
||||
|
||||
* block_mgr: BlockSpaceManager instance
|
||||
* seq_group: SequenceGroup passed to block_mgr
|
||||
'''
|
||||
|
||||
if seq_group.is_encoder_decoder():
|
||||
if block_mgr.max_block_sliding_window is not None:
|
||||
raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_SWA)
|
||||
|
||||
if block_mgr.enable_caching:
|
||||
raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE)
|
||||
@ -1,523 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""A block manager that manages token blocks."""
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Tuple
|
||||
|
||||
from vllm.core.block.block_table import BlockTable
|
||||
from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
|
||||
from vllm.core.block.interfaces import Block
|
||||
from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker,
|
||||
LastAccessBlocksTracker)
|
||||
from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec
|
||||
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
|
||||
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
|
||||
from vllm.utils import Device
|
||||
|
||||
SeqId = int
|
||||
EncoderSeqId = str
|
||||
|
||||
|
||||
class SelfAttnBlockSpaceManager(BlockSpaceManager):
|
||||
"""BlockSpaceManager which manages the allocation of KV cache.
|
||||
|
||||
It owns responsibility for allocation, swapping, allocating memory for
|
||||
autoregressively-generated tokens, and other advanced features such as
|
||||
prefix caching, forking/copy-on-write, and sliding-window memory allocation.
|
||||
|
||||
This class implements the design described in
|
||||
https://github.com/vllm-project/vllm/pull/3492.
|
||||
|
||||
Lookahead slots
|
||||
The block manager has the notion of a "lookahead slot". These are slots
|
||||
in the KV cache that are allocated for a sequence. Unlike the other
|
||||
allocated slots, the content of these slots is undefined -- the worker
|
||||
may use the memory allocations in any way.
|
||||
|
||||
In practice, a worker could use these lookahead slots to run multiple
|
||||
forward passes for a single scheduler invocation. Each successive
|
||||
forward pass would write KV activations to the corresponding lookahead
|
||||
slot. This allows low inter-token latency use-cases, where the overhead
|
||||
of continuous batching scheduling is amortized over >1 generated tokens.
|
||||
|
||||
Speculative decoding uses lookahead slots to store KV activations of
|
||||
proposal tokens.
|
||||
|
||||
See https://github.com/vllm-project/vllm/pull/3250 for more information
|
||||
on lookahead scheduling.
|
||||
|
||||
Args:
|
||||
block_size (int): The size of each memory block.
|
||||
num_gpu_blocks (int): The number of memory blocks allocated on GPU.
|
||||
num_cpu_blocks (int): The number of memory blocks allocated on CPU.
|
||||
watermark (float, optional): The threshold used for memory swapping.
|
||||
Defaults to 0.01.
|
||||
sliding_window (Optional[int], optional): The size of the sliding
|
||||
window. Defaults to None.
|
||||
enable_caching (bool, optional): Flag indicating whether caching is
|
||||
enabled. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_size: int,
|
||||
num_gpu_blocks: int,
|
||||
num_cpu_blocks: int,
|
||||
watermark: float = 0.01,
|
||||
sliding_window: Optional[int] = None,
|
||||
enable_caching: bool = False,
|
||||
) -> None:
|
||||
self.block_size = block_size
|
||||
self.num_total_gpu_blocks = num_gpu_blocks
|
||||
self.num_total_cpu_blocks = num_cpu_blocks
|
||||
|
||||
self.sliding_window = sliding_window
|
||||
# max_block_sliding_window is the max number of blocks that need to be
|
||||
# allocated
|
||||
self.max_block_sliding_window = None
|
||||
if sliding_window is not None:
|
||||
# +1 here because // rounds down
|
||||
num_blocks = sliding_window // block_size + 1
|
||||
# +1 here because the last block may not be full,
|
||||
# and so the sequence stretches one more block at the beginning
|
||||
# For example, if sliding_window is 3 and block_size is 4,
|
||||
# we may need 2 blocks when the second block only holds 1 token.
|
||||
self.max_block_sliding_window = num_blocks + 1
|
||||
|
||||
self.watermark = watermark
|
||||
assert watermark >= 0.0
|
||||
|
||||
self.enable_caching = enable_caching
|
||||
|
||||
self.watermark_blocks = int(watermark * num_gpu_blocks)
|
||||
|
||||
self.block_allocator = CpuGpuBlockAllocator.create(
|
||||
allocator_type="prefix_caching" if enable_caching else "naive",
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=num_cpu_blocks,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
self.block_tables: Dict[SeqId, BlockTable] = {}
|
||||
self.cross_block_tables: Dict[EncoderSeqId, BlockTable] = {}
|
||||
|
||||
self._computed_blocks_tracker = ComputedBlocksTracker(
|
||||
self.block_allocator, self.block_size, self.enable_caching)
|
||||
self._last_access_blocks_tracker = LastAccessBlocksTracker(
|
||||
self.block_allocator)
|
||||
|
||||
def can_allocate(self,
|
||||
seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int = 0) -> AllocStatus:
|
||||
# FIXME(woosuk): Here we assume that all sequences in the group share
|
||||
# the same prompt. This may not be true for preempted sequences.
|
||||
|
||||
check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group)
|
||||
|
||||
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
|
||||
num_required_blocks = BlockTable.get_num_required_blocks(
|
||||
seq.get_token_ids(),
|
||||
block_size=self.block_size,
|
||||
num_lookahead_slots=num_lookahead_slots,
|
||||
)
|
||||
|
||||
if seq_group.is_encoder_decoder():
|
||||
encoder_seq = seq_group.get_encoder_seq()
|
||||
assert encoder_seq is not None
|
||||
num_required_blocks += BlockTable.get_num_required_blocks(
|
||||
encoder_seq.get_token_ids(),
|
||||
block_size=self.block_size,
|
||||
)
|
||||
|
||||
if self.max_block_sliding_window is not None:
|
||||
num_required_blocks = min(num_required_blocks,
|
||||
self.max_block_sliding_window)
|
||||
|
||||
num_free_gpu_blocks = self.block_allocator.get_num_free_blocks(
|
||||
device=Device.GPU)
|
||||
|
||||
# Use watermark to avoid frequent cache eviction.
|
||||
if (self.num_total_gpu_blocks - num_required_blocks
|
||||
< self.watermark_blocks):
|
||||
return AllocStatus.NEVER
|
||||
if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks:
|
||||
return AllocStatus.OK
|
||||
else:
|
||||
return AllocStatus.LATER
|
||||
|
||||
def _allocate_sequence(self, seq: Sequence) -> BlockTable:
|
||||
block_table = BlockTable(
|
||||
block_size=self.block_size,
|
||||
block_allocator=self.block_allocator,
|
||||
max_block_sliding_window=self.max_block_sliding_window,
|
||||
)
|
||||
if seq.get_token_ids():
|
||||
# NOTE: If there are any factors affecting the block besides
|
||||
# token_ids, they should be added as input to extra_hash.
|
||||
extra_hash = seq.extra_hash()
|
||||
|
||||
# Add blocks to the block table only if the sequence is non empty.
|
||||
block_table.allocate(token_ids=seq.get_token_ids(),
|
||||
extra_hash=extra_hash)
|
||||
|
||||
return block_table
|
||||
|
||||
def allocate(self, seq_group: SequenceGroup) -> None:
|
||||
|
||||
# Allocate self-attention block tables for decoder sequences
|
||||
waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
|
||||
assert not (set(seq.seq_id for seq in waiting_seqs)
|
||||
& self.block_tables.keys()), "block table already exists"
|
||||
|
||||
# NOTE: Here we assume that all sequences in the group have the same
|
||||
# prompt.
|
||||
seq = waiting_seqs[0]
|
||||
block_table: BlockTable = self._allocate_sequence(seq)
|
||||
self.block_tables[seq.seq_id] = block_table
|
||||
|
||||
# Track seq
|
||||
self._last_access_blocks_tracker.add_seq(seq.seq_id)
|
||||
|
||||
# Assign the block table for each sequence.
|
||||
for seq in waiting_seqs[1:]:
|
||||
self.block_tables[seq.seq_id] = block_table.fork()
|
||||
|
||||
# Track seq
|
||||
self._last_access_blocks_tracker.add_seq(seq.seq_id)
|
||||
|
||||
# Allocate cross-attention block table for encoder sequence
|
||||
#
|
||||
# NOTE: Here we assume that all sequences in the group have the same
|
||||
# encoder prompt.
|
||||
request_id = seq_group.request_id
|
||||
|
||||
assert (request_id
|
||||
not in self.cross_block_tables), \
|
||||
"block table already exists"
|
||||
|
||||
check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group)
|
||||
|
||||
if seq_group.is_encoder_decoder():
|
||||
encoder_seq = seq_group.get_encoder_seq()
|
||||
assert encoder_seq is not None
|
||||
block_table = self._allocate_sequence(encoder_seq)
|
||||
self.cross_block_tables[request_id] = block_table
|
||||
|
||||
def can_append_slots(self, seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int) -> bool:
|
||||
"""Determine if there is enough space in the GPU KV cache to continue
|
||||
generation of the specified sequence group.
|
||||
|
||||
We use a worst-case heuristic: assume each touched block will require a
|
||||
new allocation (either via CoW or new block). We can append slots if the
|
||||
number of touched blocks is less than the number of free blocks.
|
||||
|
||||
"Lookahead slots" are slots that are allocated in addition to the slots
|
||||
for known tokens. The contents of the lookahead slots are not defined.
|
||||
This is used by speculative decoding when speculating future tokens.
|
||||
"""
|
||||
|
||||
num_touched_blocks = 0
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
|
||||
num_touched_blocks += (
|
||||
block_table.get_num_blocks_touched_by_append_slots(
|
||||
token_ids=block_table.get_unseen_token_ids(
|
||||
seq.get_token_ids()),
|
||||
num_lookahead_slots=num_lookahead_slots,
|
||||
))
|
||||
|
||||
num_free_gpu_blocks = self.block_allocator.get_num_free_blocks(
|
||||
Device.GPU)
|
||||
return num_touched_blocks <= num_free_gpu_blocks
|
||||
|
||||
def append_slots(
|
||||
self,
|
||||
seq: Sequence,
|
||||
num_lookahead_slots: int,
|
||||
) -> List[Tuple[int, int]]:
|
||||
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
|
||||
block_table.append_token_ids(
|
||||
token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()),
|
||||
num_lookahead_slots=num_lookahead_slots,
|
||||
num_computed_slots=seq.data.get_num_computed_tokens(),
|
||||
extra_hash=seq.extra_hash(),
|
||||
)
|
||||
# Return any new copy-on-writes.
|
||||
new_cows = self.block_allocator.clear_copy_on_writes()
|
||||
return new_cows
|
||||
|
||||
def free(self, seq: Sequence) -> None:
|
||||
seq_id = seq.seq_id
|
||||
|
||||
if seq_id not in self.block_tables:
|
||||
# Already freed or haven't been scheduled yet.
|
||||
return
|
||||
|
||||
# Update seq block ids with the latest access time
|
||||
self._last_access_blocks_tracker.update_seq_blocks_last_access(
|
||||
seq_id, self.block_tables[seq.seq_id].physical_block_ids)
|
||||
|
||||
# Untrack seq
|
||||
self._last_access_blocks_tracker.remove_seq(seq_id)
|
||||
self._computed_blocks_tracker.remove_seq(seq_id)
|
||||
|
||||
# Free table/blocks
|
||||
self.block_tables[seq_id].free()
|
||||
del self.block_tables[seq_id]
|
||||
|
||||
def remove_seq_from_computed_blocks_tracker(self, seq: Sequence) -> None:
|
||||
seq_id = seq.seq_id
|
||||
self._computed_blocks_tracker.remove_seq(seq_id)
|
||||
|
||||
def free_cross(self, seq_group: SequenceGroup) -> None:
|
||||
request_id = seq_group.request_id
|
||||
if request_id not in self.cross_block_tables:
|
||||
# Already freed or hasn't been scheduled yet.
|
||||
return
|
||||
self.cross_block_tables[request_id].free()
|
||||
del self.cross_block_tables[request_id]
|
||||
|
||||
def get_block_table(self, seq: Sequence) -> List[int]:
|
||||
block_ids = self.block_tables[seq.seq_id].physical_block_ids
|
||||
return block_ids # type: ignore
|
||||
|
||||
def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]:
|
||||
request_id = seq_group.request_id
|
||||
assert request_id in self.cross_block_tables
|
||||
block_ids = self.cross_block_tables[request_id].physical_block_ids
|
||||
assert all(b is not None for b in block_ids)
|
||||
return block_ids # type: ignore
|
||||
|
||||
def access_all_blocks_in_seq(self, seq: Sequence, now: float):
|
||||
if self.enable_caching:
|
||||
# Record the latest access time for the sequence. The actual update
|
||||
# of the block ids is deferred to the sequence free(..) call, since
|
||||
# only during freeing of block ids, the blocks are actually added to
|
||||
# the evictor (which is when the most updated time is required)
|
||||
# (This avoids expensive calls to mark_blocks_as_accessed(..))
|
||||
self._last_access_blocks_tracker.update_last_access(
|
||||
seq.seq_id, now)
|
||||
|
||||
def mark_blocks_as_computed(self, seq_group: SequenceGroup,
|
||||
token_chunk_size: int):
|
||||
# If prefix caching is enabled, mark immutable blocks as computed
|
||||
# right after they have been scheduled (for prefill). This assumes
|
||||
# the scheduler is synchronous so blocks are actually computed when
|
||||
# scheduling the next batch.
|
||||
self.block_allocator.mark_blocks_as_computed([])
|
||||
|
||||
def get_common_computed_block_ids(
|
||||
self, seqs: List[Sequence]) -> GenericSequence[int]:
|
||||
"""Determine which blocks for which we skip prefill.
|
||||
|
||||
With prefix caching we can skip prefill for previously-generated blocks.
|
||||
Currently, the attention implementation only supports skipping cached
|
||||
blocks if they are a contiguous prefix of cached blocks.
|
||||
|
||||
This method determines which blocks can be safely skipped for all
|
||||
sequences in the sequence group.
|
||||
"""
|
||||
computed_seq_block_ids = []
|
||||
for seq in seqs:
|
||||
all_blocks = self.block_tables[seq.seq_id].physical_block_ids
|
||||
num_cached_tokens = (
|
||||
self._computed_blocks_tracker.get_num_cached_tokens(seq))
|
||||
assert num_cached_tokens % self.block_size == 0
|
||||
num_cached_blocks = num_cached_tokens // self.block_size
|
||||
computed_block_ids = all_blocks[:num_cached_blocks]
|
||||
computed_seq_block_ids.append(computed_block_ids)
|
||||
|
||||
# NOTE(sang): This assumes seq_block_ids doesn't contain any None.
|
||||
return self.block_allocator.get_common_computed_block_ids(
|
||||
computed_seq_block_ids) # type: ignore
|
||||
|
||||
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
|
||||
if parent_seq.seq_id not in self.block_tables:
|
||||
# Parent sequence has either been freed or never existed.
|
||||
return
|
||||
src_block_table = self.block_tables[parent_seq.seq_id]
|
||||
self.block_tables[child_seq.seq_id] = src_block_table.fork()
|
||||
|
||||
# Track child seq
|
||||
self._last_access_blocks_tracker.add_seq(child_seq.seq_id)
|
||||
|
||||
def can_swap_in(self, seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int) -> AllocStatus:
|
||||
"""Returns the AllocStatus for the given sequence_group
|
||||
with num_lookahead_slots.
|
||||
|
||||
Args:
|
||||
seq_group (SequenceGroup): The sequence group to swap in.
|
||||
num_lookahead_slots (int): Number of lookahead slots used in
|
||||
speculative decoding, default to 0.
|
||||
|
||||
Returns:
|
||||
AllocStatus: The AllocStatus for the given sequence group.
|
||||
"""
|
||||
return self._can_swap(seq_group, Device.GPU, SequenceStatus.SWAPPED,
|
||||
num_lookahead_slots)
|
||||
|
||||
def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
|
||||
"""Returns the block id mapping (from CPU to GPU) generated by
|
||||
swapping in the given seq_group with num_lookahead_slots.
|
||||
|
||||
Args:
|
||||
seq_group (SequenceGroup): The sequence group to swap in.
|
||||
|
||||
Returns:
|
||||
List[Tuple[int, int]]: The mapping of swapping block from CPU
|
||||
to GPU.
|
||||
"""
|
||||
physical_block_id_mapping = []
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
|
||||
blocks = self.block_tables[seq.seq_id].blocks
|
||||
if len(blocks) == 0:
|
||||
continue
|
||||
|
||||
seq_swap_mapping = self.block_allocator.swap(blocks=blocks,
|
||||
src_device=Device.CPU,
|
||||
dst_device=Device.GPU)
|
||||
|
||||
# Refresh the block ids of the table (post-swap)
|
||||
self.block_tables[seq.seq_id].update(blocks)
|
||||
|
||||
seq_physical_block_id_mapping = {
|
||||
self.block_allocator.get_physical_block_id(
|
||||
Device.CPU, cpu_block_id):
|
||||
self.block_allocator.get_physical_block_id(
|
||||
Device.GPU, gpu_block_id)
|
||||
for cpu_block_id, gpu_block_id in seq_swap_mapping.items()
|
||||
}
|
||||
|
||||
physical_block_id_mapping.extend(
|
||||
list(seq_physical_block_id_mapping.items()))
|
||||
|
||||
return physical_block_id_mapping
|
||||
|
||||
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
|
||||
"""Returns whether we can swap out the given sequence_group
|
||||
with num_lookahead_slots.
|
||||
|
||||
Args:
|
||||
seq_group (SequenceGroup): The sequence group to swap out.
|
||||
|
||||
Returns:
|
||||
bool: Whether it's possible to swap out current sequence group.
|
||||
"""
|
||||
alloc_status = self._can_swap(seq_group, Device.CPU,
|
||||
SequenceStatus.RUNNING)
|
||||
return alloc_status == AllocStatus.OK
|
||||
|
||||
def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
|
||||
"""Returns the block id mapping (from GPU to CPU) generated by
|
||||
swapping out the given sequence_group with num_lookahead_slots.
|
||||
|
||||
Args:
|
||||
seq_group (SequenceGroup): The sequence group to swap out.
|
||||
|
||||
Returns:
|
||||
List[Tuple[int, int]]: The mapping of swapping block from
|
||||
GPU to CPU.
|
||||
"""
|
||||
physical_block_id_mapping = []
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
blocks = self.block_tables[seq.seq_id].blocks
|
||||
if len(blocks) == 0:
|
||||
continue
|
||||
|
||||
seq_swap_mapping = self.block_allocator.swap(blocks=blocks,
|
||||
src_device=Device.GPU,
|
||||
dst_device=Device.CPU)
|
||||
|
||||
# Refresh the block ids of the table (post-swap)
|
||||
self.block_tables[seq.seq_id].update(blocks)
|
||||
|
||||
seq_physical_block_id_mapping = {
|
||||
self.block_allocator.get_physical_block_id(
|
||||
Device.GPU, gpu_block_id):
|
||||
self.block_allocator.get_physical_block_id(
|
||||
Device.CPU, cpu_block_id)
|
||||
for gpu_block_id, cpu_block_id in seq_swap_mapping.items()
|
||||
}
|
||||
|
||||
physical_block_id_mapping.extend(
|
||||
list(seq_physical_block_id_mapping.items()))
|
||||
|
||||
return physical_block_id_mapping
|
||||
|
||||
def get_num_free_gpu_blocks(self) -> int:
|
||||
return self.block_allocator.get_num_free_blocks(Device.GPU)
|
||||
|
||||
def get_num_free_cpu_blocks(self) -> int:
|
||||
return self.block_allocator.get_num_free_blocks(Device.CPU)
|
||||
|
||||
def get_prefix_cache_hit_rate(self, device: Device) -> float:
|
||||
return self.block_allocator.get_prefix_cache_hit_rate(device)
|
||||
|
||||
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
|
||||
return self.block_allocator.reset_prefix_cache(device)
|
||||
|
||||
def _can_swap(self,
|
||||
seq_group: SequenceGroup,
|
||||
device: Device,
|
||||
status: SequenceStatus,
|
||||
num_lookahead_slots: int = 0) -> AllocStatus:
|
||||
"""Returns the AllocStatus for swapping in/out the given sequence_group
|
||||
on to the 'device'.
|
||||
|
||||
Args:
|
||||
seq_group (SequenceGroup): The sequence group to swap in/out.
|
||||
device (Device): device to swap the 'seq_group' on.
|
||||
status (SequenceStatus): The status of sequence which is needed
|
||||
for action. RUNNING for swap out and SWAPPED for swap in
|
||||
num_lookahead_slots (int): Number of lookahead slots used in
|
||||
speculative decoding, default to 0.
|
||||
|
||||
Returns:
|
||||
AllocStatus: The AllocStatus for swapping in/out the given
|
||||
sequence_group on to the 'device'.
|
||||
"""
|
||||
# First determine the number of blocks that will be touched by this
|
||||
# swap. Then verify if there are available blocks in the device
|
||||
# to perform the swap.
|
||||
num_blocks_touched = 0
|
||||
blocks: List[Block] = []
|
||||
for seq in seq_group.get_seqs(status=status):
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
if block_table.blocks is not None:
|
||||
# Compute the number blocks to touch for the tokens to be
|
||||
# appended. This does NOT include the full blocks that need
|
||||
# to be touched for the swap.
|
||||
num_blocks_touched += \
|
||||
block_table.get_num_blocks_touched_by_append_slots(
|
||||
block_table.get_unseen_token_ids(seq.get_token_ids()),
|
||||
num_lookahead_slots=num_lookahead_slots)
|
||||
blocks.extend(block_table.blocks)
|
||||
# Compute the number of full blocks to touch and add it to the
|
||||
# existing count of blocks to touch.
|
||||
num_blocks_touched += self.block_allocator.get_num_full_blocks_touched(
|
||||
blocks, device=device)
|
||||
|
||||
watermark_blocks = 0
|
||||
if device == Device.GPU:
|
||||
watermark_blocks = self.watermark_blocks
|
||||
|
||||
if self.block_allocator.get_num_total_blocks(
|
||||
device) < num_blocks_touched:
|
||||
return AllocStatus.NEVER
|
||||
elif self.block_allocator.get_num_free_blocks(
|
||||
device) - num_blocks_touched >= watermark_blocks:
|
||||
return AllocStatus.OK
|
||||
else:
|
||||
return AllocStatus.LATER
|
||||
|
||||
def get_num_cached_tokens(self, seq: Sequence) -> int:
|
||||
"""Get the number of tokens in blocks that are already computed and
|
||||
cached in the block manager for the sequence.
|
||||
"""
|
||||
return self._computed_blocks_tracker.get_num_cached_tokens(seq)
|
||||
@ -1,157 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import enum
|
||||
import heapq
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
|
||||
class EvictionPolicy(enum.Enum):
|
||||
"""Enum for eviction policy used by make_evictor to instantiate the correct
|
||||
Evictor subclass.
|
||||
"""
|
||||
LRU = enum.auto()
|
||||
|
||||
|
||||
class Evictor(ABC):
|
||||
"""The Evictor subclasses should be used by the BlockAllocator class to
|
||||
handle eviction of freed Blocks.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __contains__(self, block_id: int) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def evict(self) -> Tuple[int, int]:
|
||||
"""Runs the eviction algorithm and returns the evicted block's
|
||||
content hash along with physical block id along with physical block id
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add(self, block_id: int, content_hash: int, num_hashed_tokens: int,
|
||||
last_accessed: float):
|
||||
"""Adds block to the evictor, making it a candidate for eviction"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(self, block_id: int, last_accessed: float):
|
||||
"""Update corresponding block's access time in metadata"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def remove(self, block_id: int):
|
||||
"""Remove a given block id from the cache."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def num_blocks(self) -> int:
|
||||
pass
|
||||
|
||||
|
||||
class BlockMetaData:
|
||||
"""Data structure for storing key data describe cached block, so that
|
||||
evictor could use to make its decision which one to choose for eviction
|
||||
|
||||
Here we use physical block id as the dict key, as there maybe several
|
||||
blocks with the same content hash, but their physical id is unique.
|
||||
"""
|
||||
|
||||
def __init__(self, content_hash: int, num_hashed_tokens: int,
|
||||
last_accessed: float):
|
||||
self.content_hash = content_hash
|
||||
self.num_hashed_tokens = num_hashed_tokens
|
||||
self.last_accessed = last_accessed
|
||||
|
||||
|
||||
class LRUEvictor(Evictor):
|
||||
"""Evicts in a least-recently-used order using the last_accessed timestamp
|
||||
that's recorded in the Block. If there are multiple blocks with
|
||||
the same last_accessed time, then the one with the largest num_hashed_tokens
|
||||
will be evicted. If two blocks each have the lowest last_accessed time and
|
||||
highest num_hashed_tokens value, then one will be chosen arbitrarily
|
||||
"""
|
||||
|
||||
# CLEANUP_THRESHOLD determines the maximum allowable size of the priority
|
||||
# queue relative to the free table size. When this threshold is exceeded,
|
||||
# a cleanup operation is triggered to reduce memory usage.
|
||||
CLEANUP_THRESHOLD = 50
|
||||
|
||||
def __init__(self):
|
||||
self.free_table: Dict[int, BlockMetaData] = {}
|
||||
self.priority_queue = []
|
||||
|
||||
def __contains__(self, block_id: int) -> bool:
|
||||
return block_id in self.free_table
|
||||
|
||||
def evict(self) -> Tuple[int, int]:
|
||||
if len(self.free_table) == 0:
|
||||
raise ValueError("No usable cache memory left")
|
||||
|
||||
while self.priority_queue:
|
||||
# We do not remove outdated entries from the priority queue at the
|
||||
# time of updating the last_accessed timestamp. Instead, outdated
|
||||
# entries are filtered out here during eviction. Outdated entries
|
||||
# would either not in the free table, or have older last accessed
|
||||
# time.
|
||||
last_accessed, _, block_id, content_hash = heapq.heappop(
|
||||
self.priority_queue)
|
||||
if (block_id in self.free_table and
|
||||
self.free_table[block_id].last_accessed == last_accessed):
|
||||
self.free_table.pop(block_id)
|
||||
return block_id, content_hash
|
||||
|
||||
raise ValueError("No usable cache memory left")
|
||||
|
||||
def add(self, block_id: int, content_hash: int, num_hashed_tokens: int,
|
||||
last_accessed: float):
|
||||
self.free_table[block_id] = BlockMetaData(content_hash,
|
||||
num_hashed_tokens,
|
||||
last_accessed)
|
||||
heapq.heappush(
|
||||
self.priority_queue,
|
||||
(last_accessed, -num_hashed_tokens, block_id, content_hash))
|
||||
self._cleanup_if_necessary()
|
||||
|
||||
def update(self, block_id: int, last_accessed: float):
|
||||
self.free_table[block_id].last_accessed = last_accessed
|
||||
|
||||
def _cleanup_if_necessary(self):
|
||||
if len(self.priority_queue) > LRUEvictor.CLEANUP_THRESHOLD * len(
|
||||
self.free_table):
|
||||
self._cleanup()
|
||||
|
||||
def _cleanup(self):
|
||||
new_priority_queue: List[Tuple[float, int, int, int]] = []
|
||||
|
||||
for block_id, block in self.free_table.items():
|
||||
new_priority_queue.append(
|
||||
(block.last_accessed, -block.num_hashed_tokens, block_id,
|
||||
block.content_hash))
|
||||
heapq.heapify(new_priority_queue)
|
||||
|
||||
self.priority_queue = new_priority_queue
|
||||
|
||||
def remove(self, block_id: int):
|
||||
if block_id not in self.free_table:
|
||||
raise ValueError(
|
||||
"Attempting to remove block that's not in the evictor")
|
||||
self.free_table.pop(block_id)
|
||||
|
||||
@property
|
||||
def num_blocks(self) -> int:
|
||||
return len(self.free_table)
|
||||
|
||||
|
||||
def make_evictor(eviction_policy: EvictionPolicy) -> Evictor:
|
||||
if eviction_policy == EvictionPolicy.LRU:
|
||||
return LRUEvictor()
|
||||
else:
|
||||
raise ValueError(f"Unknown cache eviction policy: {eviction_policy}")
|
||||
@ -1,139 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import enum
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Tuple
|
||||
|
||||
from vllm.sequence import Sequence, SequenceGroup
|
||||
from vllm.utils import Device
|
||||
|
||||
|
||||
class AllocStatus(enum.Enum):
|
||||
"""Result for BlockSpaceManager.can_allocate
|
||||
|
||||
1. Ok: seq_group can be allocated now.
|
||||
2. Later: seq_group cannot be allocated.
|
||||
The capacity of allocator is larger than seq_group required.
|
||||
3. Never: seq_group can never be allocated.
|
||||
The seq_group is too large to allocated in GPU.
|
||||
"""
|
||||
OK = enum.auto()
|
||||
LATER = enum.auto()
|
||||
NEVER = enum.auto()
|
||||
|
||||
|
||||
class BlockSpaceManager(ABC):
|
||||
|
||||
@staticmethod
|
||||
def get_block_space_manager_class(version: str):
|
||||
version = version.lower()
|
||||
|
||||
if version == "selfattn":
|
||||
from vllm.core.block_manager import SelfAttnBlockSpaceManager
|
||||
return SelfAttnBlockSpaceManager
|
||||
|
||||
if version == "placeholder":
|
||||
from vllm.core.placeholder_block_space_manager import (
|
||||
PlaceholderBlockSpaceManager)
|
||||
return PlaceholderBlockSpaceManager
|
||||
|
||||
raise ValueError(f"Unknown version {version=}")
|
||||
|
||||
@abstractmethod
|
||||
def can_allocate(self,
|
||||
seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int = 0) -> AllocStatus:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def allocate(self, seq_group: SequenceGroup) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def can_append_slots(self, seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def append_slots(
|
||||
self,
|
||||
seq: Sequence,
|
||||
num_lookahead_slots: int,
|
||||
) -> List[Tuple[int, int]]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def can_swap_in(self, seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int) -> AllocStatus:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def free(self, seq: Sequence) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_block_table(self, seq: Sequence) -> List[int]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_num_free_gpu_blocks(self) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_num_free_cpu_blocks(self) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def access_all_blocks_in_seq(
|
||||
self,
|
||||
seq: Sequence,
|
||||
access_time: float,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_common_computed_block_ids(
|
||||
self, seqs: List[Sequence]) -> GenericSequence[int]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def mark_blocks_as_computed(self, seq_group: SequenceGroup,
|
||||
token_chunk_size: int):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_prefix_cache_hit_rate(self, device: Device) -> float:
|
||||
"""Prefix cache hit rate. -1 means not supported or disabled."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
|
||||
"""Reset prefix cache for specified or all devices."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_num_cached_tokens(self, seq: Sequence) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def remove_seq_from_computed_blocks_tracker(self, seq: Sequence) -> None:
|
||||
pass
|
||||
@ -1,103 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
|
||||
from vllm.sequence import Sequence, SequenceGroup
|
||||
from vllm.utils import Device
|
||||
|
||||
|
||||
class PlaceholderBlockSpaceManager(BlockSpaceManager):
|
||||
"""A version of BlockSpaceManager for use in environments
|
||||
where block management is not required.
|
||||
For example: pooling models or attention-free models like Mamba.
|
||||
|
||||
This class provides the same interface as BlockSpaceManager, but its
|
||||
methods perform no actions or return simple values like True in specific
|
||||
actions. It's designed to be used in scenarios where the overhead of
|
||||
block management is unnecessary, such as in an embedding environment.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def can_allocate(self,
|
||||
seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int = 0) -> AllocStatus:
|
||||
# Always return OK for dummy purposes
|
||||
return AllocStatus.OK
|
||||
|
||||
def allocate(self, seq_group: SequenceGroup) -> None:
|
||||
# No actual allocation logic needed
|
||||
pass
|
||||
|
||||
def can_append_slots(self, seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int) -> bool:
|
||||
return True
|
||||
|
||||
def append_slots(
|
||||
self,
|
||||
seq: Sequence,
|
||||
num_lookahead_slots: int,
|
||||
) -> List[Tuple[int, int]]:
|
||||
return []
|
||||
|
||||
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
|
||||
pass
|
||||
|
||||
def can_swap_in(self, seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int) -> AllocStatus:
|
||||
return AllocStatus.OK
|
||||
|
||||
def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
|
||||
return None # type: ignore
|
||||
|
||||
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
|
||||
return True
|
||||
|
||||
def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
|
||||
return None # type: ignore
|
||||
|
||||
def free(self, seq: Sequence) -> None:
|
||||
# No operation on free
|
||||
return
|
||||
|
||||
def get_block_table(self, seq: Sequence) -> List[int]:
|
||||
return None # type: ignore
|
||||
|
||||
def get_num_free_gpu_blocks(self) -> int:
|
||||
return 1
|
||||
|
||||
def get_num_free_cpu_blocks(self) -> int:
|
||||
return 1
|
||||
|
||||
def access_all_blocks_in_seq(
|
||||
self,
|
||||
seq: Sequence,
|
||||
access_time: float,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def get_common_computed_block_ids(self,
|
||||
seq_group: List[Sequence]) -> List[int]:
|
||||
return []
|
||||
|
||||
def mark_blocks_as_computed(self, seq_group: SequenceGroup,
|
||||
token_chunk_size: int):
|
||||
pass
|
||||
|
||||
def get_prefix_cache_hit_rate(self, device: Device) -> float:
|
||||
return -1
|
||||
|
||||
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
|
||||
return True
|
||||
|
||||
def get_num_cached_tokens(self, seq: Sequence) -> int:
|
||||
return 0
|
||||
|
||||
def remove_seq_from_computed_blocks_tracker(self, seq: Sequence) -> None:
|
||||
return
|
||||
File diff suppressed because it is too large
Load Diff
@ -7,13 +7,11 @@ from typing import Any, AsyncGenerator, Iterable, Mapping, Optional, Union
|
||||
|
||||
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.core.scheduler import SchedulerOutputs
|
||||
from vllm.inputs.data import PromptType, TokensPrompt
|
||||
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
|
||||
from vllm.plugins.io_processors.interface import IOProcessor
|
||||
from vllm.pooling_params import PoolingParams
|
||||
@ -266,11 +264,7 @@ class EngineClient(ABC):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def do_log_stats(
|
||||
self,
|
||||
scheduler_outputs: Optional[SchedulerOutputs] = None,
|
||||
model_output: Optional[list[SamplerOutput]] = None,
|
||||
) -> None:
|
||||
async def do_log_stats(self) -> None:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@ -601,11 +601,7 @@ class AsyncLLM(EngineClient):
|
||||
async def is_tracing_enabled(self) -> bool:
|
||||
return self.observability_config.otlp_traces_endpoint is not None
|
||||
|
||||
async def do_log_stats(
|
||||
self,
|
||||
scheduler_outputs=None,
|
||||
model_output=None,
|
||||
) -> None:
|
||||
async def do_log_stats(self) -> None:
|
||||
if self.logger_manager:
|
||||
self.logger_manager.log()
|
||||
|
||||
|
||||
@ -1,145 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""CacheEngine class for managing the KV cache."""
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention import get_attn_backend
|
||||
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType,
|
||||
get_dtype_size, is_pin_memory_available)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CacheEngine:
|
||||
"""Manages the KV cache.
|
||||
|
||||
This class is responsible for initializing and managing the GPU and CPU KV
|
||||
caches. It also provides methods for performing KV cache operations, such
|
||||
as swapping and copying.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_config: CacheConfig,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
device_config: DeviceConfig,
|
||||
) -> None:
|
||||
self.cache_config = cache_config
|
||||
self.model_config = model_config
|
||||
self.parallel_config = parallel_config
|
||||
self.device_config = device_config
|
||||
|
||||
self.head_size = model_config.get_head_size()
|
||||
# Models like Jamba, have mixed typed layers, E.g Mamba
|
||||
self.num_attention_layers = model_config.get_num_layers_by_block_type(
|
||||
parallel_config, LayerBlockType.attention)
|
||||
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
|
||||
|
||||
self.block_size = cache_config.block_size
|
||||
self.num_gpu_blocks = cache_config.num_gpu_blocks
|
||||
if self.num_gpu_blocks:
|
||||
self.num_gpu_blocks //= parallel_config.pipeline_parallel_size
|
||||
self.num_cpu_blocks = cache_config.num_cpu_blocks
|
||||
if self.num_cpu_blocks:
|
||||
self.num_cpu_blocks //= parallel_config.pipeline_parallel_size
|
||||
|
||||
if cache_config.cache_dtype == "auto":
|
||||
self.dtype = model_config.dtype
|
||||
else:
|
||||
self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
||||
|
||||
# Get attention backend.
|
||||
self.attn_backend = get_attn_backend(self.head_size,
|
||||
model_config.dtype,
|
||||
cache_config.cache_dtype,
|
||||
self.block_size,
|
||||
model_config.is_attention_free,
|
||||
use_mla=model_config.use_mla)
|
||||
|
||||
# Initialize the cache.
|
||||
self.gpu_cache = self._allocate_kv_cache(
|
||||
self.num_gpu_blocks, self.device_config.device_type)
|
||||
self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu")
|
||||
|
||||
def _allocate_kv_cache(
|
||||
self,
|
||||
num_blocks: int,
|
||||
device: str,
|
||||
) -> List[torch.Tensor]:
|
||||
"""Allocates KV cache on the specified device."""
|
||||
kv_cache_generic_shape = self.attn_backend.get_kv_cache_shape(
|
||||
num_blocks, self.block_size, self.num_kv_heads, self.head_size)
|
||||
pin_memory = is_pin_memory_available() if device == "cpu" else False
|
||||
kv_cache: List[torch.Tensor] = []
|
||||
try:
|
||||
kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order(
|
||||
)
|
||||
except (AttributeError, NotImplementedError):
|
||||
kv_cache_stride_order = tuple(range(len(kv_cache_generic_shape)))
|
||||
|
||||
# The allocation respects the backend-defined stride order to ensure
|
||||
# the semantic remains consistent for each backend. We first obtain the
|
||||
# generic kv cache shape and then permute it according to the stride
|
||||
# order which could result in a non-contiguous tensor.
|
||||
kv_cache_allocation_shape = tuple(kv_cache_generic_shape[i]
|
||||
for i in kv_cache_stride_order)
|
||||
|
||||
for _ in range(self.num_attention_layers):
|
||||
# null block in CpuGpuBlockAllocator requires at least that
|
||||
# block to be zeroed-out.
|
||||
# We zero-out everything for simplicity.
|
||||
layer_kv_cache = torch.zeros(
|
||||
kv_cache_allocation_shape,
|
||||
dtype=self.dtype,
|
||||
pin_memory=pin_memory,
|
||||
device=device).permute(*kv_cache_stride_order)
|
||||
|
||||
# view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases
|
||||
# when entry_shape is higher than 1D
|
||||
kv_cache.append(layer_kv_cache)
|
||||
return kv_cache
|
||||
|
||||
def swap_in(self, src_to_dst: torch.Tensor) -> None:
|
||||
for i in range(self.num_attention_layers):
|
||||
self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i],
|
||||
src_to_dst)
|
||||
|
||||
def swap_out(self, src_to_dst: torch.Tensor) -> None:
|
||||
for i in range(self.num_attention_layers):
|
||||
self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i],
|
||||
src_to_dst)
|
||||
|
||||
def copy(self, src_to_dsts: torch.Tensor) -> None:
|
||||
self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts)
|
||||
|
||||
@staticmethod
|
||||
def get_cache_block_size(
|
||||
cache_config: CacheConfig,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
) -> int:
|
||||
head_size = model_config.get_head_size()
|
||||
num_heads = model_config.get_num_kv_heads(parallel_config)
|
||||
num_attention_layers = model_config.get_num_layers_by_block_type(
|
||||
parallel_config, LayerBlockType.attention)
|
||||
|
||||
if cache_config.cache_dtype == "auto":
|
||||
dtype = model_config.dtype
|
||||
else:
|
||||
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
||||
|
||||
key_cache_entry = num_heads * head_size
|
||||
|
||||
# For MLA there is no value cache, since the latent vector
|
||||
# is joint keys and values.
|
||||
value_cache_entry = key_cache_entry if not model_config.use_mla else 0
|
||||
total = num_attention_layers * cache_config.block_size * \
|
||||
(key_cache_entry + value_cache_entry)
|
||||
|
||||
dtype_size = get_dtype_size(dtype)
|
||||
return dtype_size * total
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,666 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""A GPU worker class."""
|
||||
import gc
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from typing import Dict, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.device_allocator.cumem import CuMemAllocator
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment,
|
||||
set_custom_all_reduce)
|
||||
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
|
||||
SequenceGroupMetadata, SequenceGroupMetadataDelta)
|
||||
from vllm.utils import (GiB_bytes, MemorySnapshot, bind_kv_cache,
|
||||
memory_profiling)
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
|
||||
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
|
||||
WorkerInput)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Worker(LocalOrDistributedWorkerBase):
|
||||
"""A worker class that executes (a partition of) the model on a GPU.
|
||||
|
||||
Each worker is associated with a single GPU. The worker is responsible for
|
||||
maintaining the KV cache and executing the model on the GPU. In case of
|
||||
distributed inference, each worker is assigned a partition of the model.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
is_driver_worker: bool = False,
|
||||
model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None,
|
||||
) -> None:
|
||||
WorkerBase.__init__(self, vllm_config)
|
||||
self.parallel_config.rank = rank
|
||||
self.local_rank = local_rank
|
||||
self.rank = rank
|
||||
self.distributed_init_method = distributed_init_method
|
||||
self.is_driver_worker = is_driver_worker
|
||||
if self.model_config.trust_remote_code:
|
||||
# note: lazy import to avoid importing torch before initializing
|
||||
from vllm.utils import init_cached_hf_modules
|
||||
init_cached_hf_modules()
|
||||
|
||||
# Return hidden states from target model if the draft model is an
|
||||
# mlp_speculator
|
||||
speculative_config = self.speculative_config
|
||||
model_config = self.model_config
|
||||
speculative_args = {} if speculative_config is None \
|
||||
or (speculative_config.draft_model_config.hf_config.model_type ==
|
||||
model_config.hf_config.model_type) \
|
||||
or (speculative_config.draft_model_config.hf_config.model_type
|
||||
not in ("medusa",
|
||||
"mlp_speculator",
|
||||
"eagle",
|
||||
"deepseek_mtp",
|
||||
"glm4_moe_mtp",
|
||||
"mimo_mtp",
|
||||
"ernie_mtp",
|
||||
"qwen3_next_mtp")) \
|
||||
else {"return_hidden_states": True}
|
||||
|
||||
self.model_runner: GPUModelRunnerBase = ModelRunner(
|
||||
vllm_config=self.vllm_config,
|
||||
kv_cache_dtype=self.cache_config.cache_dtype,
|
||||
is_driver_worker=is_driver_worker,
|
||||
**speculative_args,
|
||||
)
|
||||
if model_runner_cls is not None:
|
||||
self.model_runner = model_runner_cls(self.model_runner)
|
||||
|
||||
# Uninitialized cache engine. Will be initialized by
|
||||
# initialize_cache.
|
||||
self.cache_engine: List[CacheEngine]
|
||||
self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
|
||||
self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {}
|
||||
|
||||
# Buffers saved before sleep
|
||||
self._sleep_saved_buffers: Dict[str, torch.Tensor] = {}
|
||||
|
||||
# Torch profiler. Enabled and configured through env vars:
|
||||
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||||
logger.info("Profiling enabled. Traces will be saved to: %s",
|
||||
torch_profiler_trace_dir)
|
||||
self.profiler = torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
with_stack=True,
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||
torch_profiler_trace_dir, use_gzip=True))
|
||||
else:
|
||||
self.profiler = None
|
||||
|
||||
def start_profile(self):
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
self.profiler.start()
|
||||
|
||||
def stop_profile(self):
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
self.profiler.stop()
|
||||
# only print profiler results on rank 0
|
||||
if self.local_rank == 0:
|
||||
print(self.profiler.key_averages().table(
|
||||
sort_by="self_cuda_time_total"))
|
||||
|
||||
def sleep(self, level: int = 1) -> None:
|
||||
free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
|
||||
|
||||
# Save the buffers before level 2 sleep
|
||||
if level == 2:
|
||||
model = self.model_runner.model
|
||||
self._sleep_saved_buffers = {
|
||||
name: buffer.cpu().clone()
|
||||
for name, buffer in model.named_buffers()
|
||||
}
|
||||
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple())
|
||||
free_bytes_after_sleep, total = torch.cuda.mem_get_info()
|
||||
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
|
||||
used_bytes = total - free_bytes_after_sleep
|
||||
assert freed_bytes >= 0, "Memory usage increased after sleeping."
|
||||
logger.info(
|
||||
"Sleep mode freed %.2f GiB memory, "
|
||||
"%.2f GiB memory is still in use.", freed_bytes / GiB_bytes,
|
||||
used_bytes / GiB_bytes)
|
||||
|
||||
def wake_up(self, tags: Optional[list[str]] = None) -> None:
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
allocator.wake_up(tags=tags)
|
||||
|
||||
# Restore the buffers after level 2 sleep
|
||||
if len(self._sleep_saved_buffers):
|
||||
model = self.model_runner.model
|
||||
for name, buffer in model.named_buffers():
|
||||
if name in self._sleep_saved_buffers:
|
||||
buffer.data.copy_(self._sleep_saved_buffers[name].data)
|
||||
self._sleep_saved_buffers = {}
|
||||
|
||||
def init_device(self) -> None:
|
||||
if self.device_config.device.type == "cuda":
|
||||
# torch.distributed.all_reduce does not free the input tensor until
|
||||
# the synchronization point. This causes the memory usage to grow
|
||||
# as the number of all_reduce calls increases. This env var disables
|
||||
# this behavior.
|
||||
# Related issue:
|
||||
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
|
||||
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
||||
|
||||
# This env var set by Ray causes exceptions with graph building.
|
||||
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
|
||||
self.device = torch.device(f"cuda:{self.local_rank}")
|
||||
torch.cuda.set_device(self.device)
|
||||
|
||||
_check_if_gpu_supports_dtype(self.model_config.dtype)
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
self.baseline_snapshot = MemorySnapshot()
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Not support device type: {self.device_config.device}")
|
||||
# Initialize the distributed environment.
|
||||
init_worker_distributed_environment(self.vllm_config, self.rank,
|
||||
self.distributed_init_method,
|
||||
self.local_rank)
|
||||
# Set random seed.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
def load_model(self):
|
||||
if self.vllm_config.model_config.enable_sleep_mode:
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
assert allocator.get_current_usage() == 0, (
|
||||
"Sleep mode can only be "
|
||||
"used for one instance per process.")
|
||||
context = allocator.use_memory_pool(tag="weights")
|
||||
else:
|
||||
context = nullcontext()
|
||||
with context:
|
||||
self.model_runner.load_model()
|
||||
|
||||
def save_sharded_state(
|
||||
self,
|
||||
path: str,
|
||||
pattern: Optional[str] = None,
|
||||
max_size: Optional[int] = None,
|
||||
) -> None:
|
||||
self.model_runner.save_sharded_state(
|
||||
path,
|
||||
pattern=pattern,
|
||||
max_size=max_size,
|
||||
)
|
||||
|
||||
def save_tensorized_model(
|
||||
self,
|
||||
tensorizer_config: TensorizerConfig,
|
||||
) -> None:
|
||||
self.model_runner.save_tensorized_model(
|
||||
tensorizer_config=tensorizer_config, )
|
||||
|
||||
@torch.inference_mode()
|
||||
def determine_available_kv_cache_memory(self,
|
||||
total_gpu_memory: int) -> float:
|
||||
if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
|
||||
# still need a profile run which compiles the model for
|
||||
# max_num_batched_tokens
|
||||
self.model_runner.profile_run()
|
||||
|
||||
GiB = lambda b: b / GiB_bytes
|
||||
msg = (
|
||||
f"Initial free memory "
|
||||
f"{GiB(self.baseline_snapshot.free_memory):.2f} "
|
||||
f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f}GiB memory for "
|
||||
"KV Cache as specified by kv_cache_memory_bytes config and "
|
||||
"skipped memory profiling. This does does not respect the "
|
||||
"gpu_memory_utilization config. Only use kv_cache_memory_bytes "
|
||||
"config when you want manual control of KV cache memory "
|
||||
"size. If OOM'ed, check the difference of initial free "
|
||||
"memory between the current run and the previous run "
|
||||
"where kv_cache_memory_bytes is suggested and update it "
|
||||
"correspondingly.")
|
||||
logger.info(msg)
|
||||
return self.cache_config.kv_cache_memory_bytes
|
||||
|
||||
# Execute a forward pass with dummy inputs to profile the memory usage
|
||||
# of the model.
|
||||
with memory_profiling(
|
||||
self.baseline_snapshot,
|
||||
weights_memory=self.model_runner.model_memory_usage) as result:
|
||||
self.model_runner.profile_run()
|
||||
|
||||
self.non_torch_memory = result.non_torch_increase
|
||||
self.peak_activation_memory = result.torch_peak_increase
|
||||
|
||||
self._assert_memory_footprint_increased_during_profiling()
|
||||
|
||||
self.requested_memory = total_gpu_memory * \
|
||||
self.cache_config.gpu_memory_utilization
|
||||
|
||||
self.available_kv_cache_memory = (self.requested_memory -
|
||||
result.non_kv_cache_memory)
|
||||
|
||||
msg = (f"Memory profiling takes {result.profile_time:.2f} seconds\n"
|
||||
"the current vLLM instance can use "
|
||||
"total_gpu_memory "
|
||||
f"({(total_gpu_memory / GiB_bytes):.2f}GiB)"
|
||||
" x gpu_memory_utilization "
|
||||
f"({self.cache_config.gpu_memory_utilization:.2f})"
|
||||
f" = {(self.requested_memory / GiB_bytes):.2f}GiB\n"
|
||||
"model weights take "
|
||||
f"{(result.weights_memory / GiB_bytes):.2f}GiB;"
|
||||
" non_torch_memory takes "
|
||||
f"{(result.non_torch_increase / GiB_bytes):.2f}GiB;"
|
||||
" PyTorch activation peak memory takes "
|
||||
f"{(result.torch_peak_increase / GiB_bytes):.2f}GiB;"
|
||||
" the rest of the memory reserved for KV Cache is "
|
||||
f"{(self.available_kv_cache_memory / GiB_bytes):.2f}GiB.")
|
||||
|
||||
logger.info(msg)
|
||||
return self.available_kv_cache_memory
|
||||
|
||||
@torch.inference_mode()
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Profiles the peak memory usage of the model to determine how many
|
||||
KV blocks may be allocated without OOMs.
|
||||
|
||||
The engine will first conduct a profiling of the existing memory usage.
|
||||
Then, it calculates the maximum possible number of GPU and CPU blocks
|
||||
that can be allocated with the remaining free memory.
|
||||
|
||||
Tip:
|
||||
You may limit the usage of GPU memory
|
||||
by adjusting the `gpu_memory_utilization` parameter.
|
||||
"""
|
||||
# Profile the memory usage of the model and get the maximum number of
|
||||
# cache blocks that can be allocated with the remaining free memory.
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info()
|
||||
available_kv_cache_memory = self.determine_available_kv_cache_memory(
|
||||
total_gpu_memory)
|
||||
|
||||
# Calculate the number of blocks that can be allocated with the
|
||||
# profiled peak memory.
|
||||
cache_block_size = self.get_cache_block_size_bytes()
|
||||
if cache_block_size == 0:
|
||||
num_gpu_blocks = 0
|
||||
num_cpu_blocks = 0
|
||||
else:
|
||||
num_gpu_blocks = int(available_kv_cache_memory // cache_block_size)
|
||||
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
|
||||
cache_block_size)
|
||||
num_gpu_blocks = max(num_gpu_blocks, 0)
|
||||
num_cpu_blocks = max(num_cpu_blocks, 0)
|
||||
|
||||
# Final cleanup
|
||||
gc.collect()
|
||||
|
||||
return num_gpu_blocks, num_cpu_blocks
|
||||
|
||||
def _assert_memory_footprint_increased_during_profiling(self):
|
||||
# NOTE(woosuk): Here we assume that the other processes using the same
|
||||
# GPU did not change their memory usage during the profiling.
|
||||
free_gpu_memory, total = torch.cuda.mem_get_info()
|
||||
cuda_memory = total - free_gpu_memory
|
||||
assert self.baseline_snapshot.cuda_memory < cuda_memory, (
|
||||
"Error in memory profiling. "
|
||||
f"Initial used memory {self.baseline_snapshot.cuda_memory}, "
|
||||
f"currently used memory {cuda_memory}. "
|
||||
f"This happens when the GPU memory was "
|
||||
"not properly cleaned up before initializing the vLLM instance.")
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int) -> None:
|
||||
"""Allocate GPU and CPU KV cache with the specified number of blocks.
|
||||
|
||||
This also warms up the model, which may record CUDA graphs.
|
||||
"""
|
||||
raise_if_cache_size_invalid(
|
||||
num_gpu_blocks, self.cache_config.block_size,
|
||||
self.cache_config.is_attention_free,
|
||||
self.model_config.max_model_len,
|
||||
self.parallel_config.pipeline_parallel_size)
|
||||
|
||||
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
|
||||
if self.vllm_config.model_config.enable_sleep_mode:
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
context = allocator.use_memory_pool(tag="kv_cache")
|
||||
else:
|
||||
context = nullcontext()
|
||||
with context:
|
||||
self._init_cache_engine()
|
||||
self._warm_up_model()
|
||||
|
||||
def _init_cache_engine(self):
|
||||
assert self.cache_config.num_gpu_blocks is not None
|
||||
self.cache_engine = [
|
||||
CacheEngine(self.cache_config, self.model_config,
|
||||
self.parallel_config, self.device_config)
|
||||
for _ in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
self.gpu_cache = [
|
||||
self.cache_engine[ve].gpu_cache
|
||||
for ve in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
# Layer pairings for cross-layer KV sharing.
|
||||
# If an Attention layer `layer_name` is in the keys of this dict, it
|
||||
# means this layer will perform attention using the keys and values
|
||||
# from the KV cache of `shared_kv_cache_layers[layer_name]`.
|
||||
shared_kv_cache_layers: dict[str, str] = {}
|
||||
|
||||
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
||||
|
||||
for layer_name, attn_module in attn_layers.items():
|
||||
if (kv_tgt_layer :=
|
||||
attn_module.kv_sharing_target_layer_name) is not None:
|
||||
# The layer doesn't need its own KV cache and will use that of
|
||||
# the target layer. We skip creating a KVCacheSpec for it, so
|
||||
# that KV cache management logic will act as this layer does
|
||||
# not exist, and doesn't allocate KV cache for the layer. This
|
||||
# enables the memory saving of cross-layer kv sharing, allowing
|
||||
# a given amount of memory to accommodate longer context lengths
|
||||
# or enable more requests to be processed simultaneously.
|
||||
shared_kv_cache_layers[layer_name] = kv_tgt_layer
|
||||
|
||||
bind_kv_cache(self.compilation_config.static_forward_context,
|
||||
self.gpu_cache, shared_kv_cache_layers)
|
||||
|
||||
def _warm_up_model(self) -> None:
|
||||
# warm up sizes that are not in cudagraph capture sizes,
|
||||
# but users still want to compile for better performance,
|
||||
# e.g. for the max-num-batched token size in chunked prefill.
|
||||
warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
|
||||
if not self.model_config.enforce_eager:
|
||||
warmup_sizes = [
|
||||
x for x in warmup_sizes if x not in
|
||||
self.vllm_config.compilation_config.cudagraph_capture_sizes
|
||||
]
|
||||
for size in sorted(warmup_sizes, reverse=True):
|
||||
logger.info("Compile and warming up model for size %d", size)
|
||||
self.model_runner._dummy_run(size)
|
||||
|
||||
cuda_graph_memory_bytes = 0
|
||||
if not self.model_config.enforce_eager:
|
||||
cuda_graph_memory_bytes = self.model_runner.capture_model(
|
||||
self.gpu_cache)
|
||||
|
||||
if (self.cache_config.kv_cache_memory_bytes is None
|
||||
and hasattr(self, "peak_activation_memory")):
|
||||
# Suggests optimal kv cache memory size if we rely on
|
||||
# memory_profiling to guess the kv cache memory size which
|
||||
# provides peak_activation_memory and a few other memory
|
||||
# consumption. `memory_profiling` does not consider
|
||||
# CUDAGraph memory size and may not utilize all gpu memory.
|
||||
# Users may want fine-grained control to specify kv cache
|
||||
# memory size.
|
||||
GiB = lambda b: round(b / GiB_bytes, 2)
|
||||
non_kv_cache_memory = (self.model_runner.model_memory_usage +
|
||||
self.peak_activation_memory +
|
||||
self.non_torch_memory +
|
||||
cuda_graph_memory_bytes)
|
||||
|
||||
# empirically observed that the memory profiling may
|
||||
# slightly underestimate the memory consumption.
|
||||
# So leave a small buffer (=150MiB) to avoid OOM.
|
||||
redundancy_buffer_memory = 150 * (1 << 20)
|
||||
kv_cache_memory_bytes_to_gpu_limit = (
|
||||
self.baseline_snapshot.free_memory - non_kv_cache_memory -
|
||||
redundancy_buffer_memory)
|
||||
kv_cache_memory_bytes_to_requested_limit = (
|
||||
int(self.requested_memory) - non_kv_cache_memory -
|
||||
redundancy_buffer_memory)
|
||||
|
||||
msg = (
|
||||
f"Free memory on device "
|
||||
f"({GiB(self.baseline_snapshot.free_memory)}/"
|
||||
f"{GiB(self.baseline_snapshot.total_memory)} GiB) on startup. "
|
||||
f"Desired GPU memory utilization is "
|
||||
f"({self.cache_config.gpu_memory_utilization}, "
|
||||
f"{GiB(self.requested_memory)} GiB). "
|
||||
f"Actual usage is {GiB(self.model_runner.model_memory_usage)} "
|
||||
f"GiB for weight, {GiB(self.peak_activation_memory)} GiB "
|
||||
f"for peak activation, {GiB(self.non_torch_memory)} GiB "
|
||||
f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} "
|
||||
f"GiB for CUDAGraph memory. Replace gpu_memory_utilization "
|
||||
f"config with `--kv-cache-memory="
|
||||
f"{kv_cache_memory_bytes_to_requested_limit}` to fit into "
|
||||
f"requested memory, or `--kv-cache-memory="
|
||||
f"{kv_cache_memory_bytes_to_gpu_limit}` to fully "
|
||||
f"utilize gpu memory. Current kv cache memory in use is "
|
||||
f"{int(self.available_kv_cache_memory)} bytes.")
|
||||
logger.info(msg)
|
||||
|
||||
# Reset the seed to ensure that the random state is not affected by
|
||||
# the model initialization and profiling.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
@property
|
||||
def do_metadata_broadcast(self) -> bool:
|
||||
return self.parallel_config.tensor_parallel_size > 1
|
||||
|
||||
@property
|
||||
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
|
||||
return self.gpu_cache
|
||||
|
||||
@torch.inference_mode()
|
||||
def prepare_worker_input(
|
||||
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
|
||||
virtual_engine = execute_model_req.virtual_engine
|
||||
num_steps = execute_model_req.num_steps
|
||||
num_seq_groups = len(execute_model_req.seq_group_metadata_list)
|
||||
# `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
|
||||
# they contain parameters to launch cudamemcpyasync.
|
||||
blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in,
|
||||
device="cpu",
|
||||
dtype=torch.int64).view(-1, 2)
|
||||
blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out,
|
||||
device="cpu",
|
||||
dtype=torch.int64).view(-1, 2)
|
||||
# `blocks_to_copy` is a gpu tensor. The src and tgt of
|
||||
# blocks to copy are in the same device, and `blocks_to_copy`
|
||||
# can be used directly within cuda kernels.
|
||||
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
|
||||
device=self.device,
|
||||
dtype=torch.int64).view(-1, 2)
|
||||
|
||||
return WorkerInput(
|
||||
num_seq_groups=num_seq_groups,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
virtual_engine=virtual_engine,
|
||||
num_steps=num_steps,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_worker(self, worker_input: WorkerInput) -> None:
|
||||
virtual_engine = worker_input.virtual_engine
|
||||
# Issue cache operations.
|
||||
if (worker_input.blocks_to_swap_in is not None
|
||||
and worker_input.blocks_to_swap_in.numel() > 0):
|
||||
self.cache_engine[virtual_engine].swap_in(
|
||||
worker_input.blocks_to_swap_in)
|
||||
if (worker_input.blocks_to_swap_out is not None
|
||||
and worker_input.blocks_to_swap_out.numel() > 0):
|
||||
self.cache_engine[virtual_engine].swap_out(
|
||||
worker_input.blocks_to_swap_out)
|
||||
if (worker_input.blocks_to_copy is not None
|
||||
and worker_input.blocks_to_copy.numel() > 0):
|
||||
self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy)
|
||||
|
||||
def _get_cached_seq_group_metadata(
|
||||
self,
|
||||
seq_group_metadata_list: List[Union[SequenceGroupMetadata,
|
||||
SequenceGroupMetadataDelta]],
|
||||
finished_request_ids: List[str]) -> List[SequenceGroupMetadata]:
|
||||
"""Return a list of cached Sequence Group Metadata after updating its
|
||||
state.
|
||||
|
||||
It is used because scheduler only sends delta to workers to reduce
|
||||
the data payload size. The function also cleans up cache based on
|
||||
a given `finished_request_ids`.
|
||||
"""
|
||||
new_seq_group_metadata_list = []
|
||||
for metadata_or_delta in seq_group_metadata_list:
|
||||
request_id = metadata_or_delta.request_id
|
||||
if request_id not in self._seq_group_metadata_cache:
|
||||
# The first prefill.
|
||||
assert isinstance(metadata_or_delta, SequenceGroupMetadata)
|
||||
self._seq_group_metadata_cache[request_id] = metadata_or_delta
|
||||
else:
|
||||
# The first prefill is already cached.
|
||||
if isinstance(metadata_or_delta, SequenceGroupMetadataDelta):
|
||||
self._seq_group_metadata_cache[request_id].apply_delta(
|
||||
metadata_or_delta)
|
||||
else:
|
||||
# If metadata snapshot is sent again, it is
|
||||
# preempted. Reset the cache because we need to start
|
||||
# from scratch.
|
||||
assert isinstance(metadata_or_delta, SequenceGroupMetadata)
|
||||
self._seq_group_metadata_cache[
|
||||
request_id] = metadata_or_delta
|
||||
|
||||
new_seq_group_metadata_list.append(
|
||||
self._seq_group_metadata_cache[request_id])
|
||||
|
||||
# Clean up finished ids
|
||||
for finished_id in finished_request_ids:
|
||||
del self._seq_group_metadata_cache[finished_id]
|
||||
|
||||
return new_seq_group_metadata_list
|
||||
|
||||
def _execute_model_spmd(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
if execute_model_req is not None:
|
||||
new_seq_group_metadata_list = self._get_cached_seq_group_metadata(
|
||||
execute_model_req.seq_group_metadata_list,
|
||||
execute_model_req.finished_requests_ids)
|
||||
|
||||
execute_model_req.seq_group_metadata_list = (
|
||||
new_seq_group_metadata_list)
|
||||
output = super()._execute_model_spmd(execute_model_req,
|
||||
intermediate_tensors)
|
||||
return output
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
return self.model_runner.add_lora(lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
return self.model_runner.remove_lora(lora_id)
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
return self.model_runner.pin_lora(lora_id)
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
return self.model_runner.list_loras()
|
||||
|
||||
@property
|
||||
def max_model_len(self) -> int:
|
||||
return self.model_config.max_model_len
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return self.model_runner.vocab_size
|
||||
|
||||
def get_cache_block_size_bytes(self) -> int:
|
||||
"""Get the size of the KV cache block size in bytes.
|
||||
"""
|
||||
return CacheEngine.get_cache_block_size(self.cache_config,
|
||||
self.model_config,
|
||||
self.parallel_config)
|
||||
|
||||
|
||||
def init_worker_distributed_environment(
|
||||
vllm_config: VllmConfig,
|
||||
rank: int,
|
||||
distributed_init_method: Optional[str] = None,
|
||||
local_rank: int = -1,
|
||||
) -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
parallel_config = vllm_config.parallel_config
|
||||
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
|
||||
|
||||
init_distributed_environment(parallel_config.world_size, rank,
|
||||
distributed_init_method, local_rank,
|
||||
current_platform.dist_backend)
|
||||
ensure_model_parallel_initialized(
|
||||
parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size,
|
||||
parallel_config.decode_context_parallel_size)
|
||||
|
||||
ensure_kv_transfer_initialized(vllm_config)
|
||||
|
||||
|
||||
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
||||
# Check if the GPU supports the dtype.
|
||||
if torch_dtype == torch.bfloat16: # noqa: SIM102
|
||||
if not current_platform.has_device_capability(80):
|
||||
capability = current_platform.get_device_capability()
|
||||
gpu_name = current_platform.get_device_name()
|
||||
|
||||
if capability is None:
|
||||
compute_str = "does not have a compute capability"
|
||||
else:
|
||||
version_str = capability.as_version_str()
|
||||
compute_str = f"has compute capability {version_str}"
|
||||
|
||||
raise ValueError(
|
||||
"Bfloat16 is only supported on GPUs with compute capability "
|
||||
f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
|
||||
"You can use float16 instead by explicitly setting the "
|
||||
"`dtype` flag in CLI, for example: --dtype=half.")
|
||||
|
||||
|
||||
def raise_if_cache_size_invalid(num_gpu_blocks, block_size, is_attention_free,
|
||||
max_model_len, pipeline_parallel_size) -> None:
|
||||
if is_attention_free and num_gpu_blocks != 0:
|
||||
raise ValueError("No memory should be allocated for the cache blocks "
|
||||
f"for an attention-free model, but {num_gpu_blocks} "
|
||||
"blocks are allocated.")
|
||||
if not is_attention_free and num_gpu_blocks <= 0:
|
||||
raise ValueError("No available memory for the cache blocks. "
|
||||
"Try increasing `gpu_memory_utilization` when "
|
||||
"initializing the engine.")
|
||||
max_seq_len = block_size * (num_gpu_blocks // pipeline_parallel_size)
|
||||
if not is_attention_free and max_model_len > max_seq_len:
|
||||
raise ValueError(
|
||||
f"The model's max seq len ({max_model_len}) "
|
||||
"is larger than the maximum number of tokens that can be "
|
||||
f"stored in KV cache ({max_seq_len}). Try increasing "
|
||||
"`gpu_memory_utilization` or decreasing `max_model_len` when "
|
||||
"initializing the engine.")
|
||||
Loading…
x
Reference in New Issue
Block a user