[V1] Move more control of kv cache initialization from model_executor to EngineCore (#11960)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
Chen Zhang 2025-01-17 15:39:35 +08:00 committed by GitHub
parent 8027a72461
commit 69d765f5a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 514 additions and 103 deletions

62
tests/v1/test_utils.py Normal file
View File

@ -0,0 +1,62 @@
from typing import List
import torch
from vllm.v1.utils import bind_kv_cache
def test_bind_kv_cache():
from vllm.attention import Attention
ctx = {
'layers.0.self_attn': Attention(32, 128, 0.1),
'layers.1.self_attn': Attention(32, 128, 0.1),
'layers.2.self_attn': Attention(32, 128, 0.1),
'layers.3.self_attn': Attention(32, 128, 0.1),
}
kv_cache = {
'layers.0.self_attn': torch.zeros((1, )),
'layers.1.self_attn': torch.zeros((1, )),
'layers.2.self_attn': torch.zeros((1, )),
'layers.3.self_attn': torch.zeros((1, )),
}
runner_kv_caches: List[torch.Tensor] = []
bind_kv_cache(kv_cache, ctx, runner_kv_caches)
assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[
'layers.0.self_attn']
assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[
'layers.1.self_attn']
assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[
'layers.2.self_attn']
assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[
'layers.3.self_attn']
assert runner_kv_caches[0] is kv_cache['layers.0.self_attn']
assert runner_kv_caches[1] is kv_cache['layers.1.self_attn']
assert runner_kv_caches[2] is kv_cache['layers.2.self_attn']
assert runner_kv_caches[3] is kv_cache['layers.3.self_attn']
def test_bind_kv_cache_non_attention():
from vllm.attention import Attention
# example from Jamba PP=2
ctx = {
'model.layers.20.attn': Attention(32, 128, 0.1),
'model.layers.28.attn': Attention(32, 128, 0.1),
}
kv_cache = {
'model.layers.20.attn': torch.zeros((1, )),
'model.layers.28.attn': torch.zeros((1, )),
}
runner_kv_caches: List[torch.Tensor] = []
bind_kv_cache(kv_cache, ctx, runner_kv_caches)
assert ctx['model.layers.20.attn'].kv_cache[0] is kv_cache[
'model.layers.20.attn']
assert ctx['model.layers.28.attn'].kv_cache[0] is kv_cache[
'model.layers.28.attn']
assert runner_kv_caches[0] is kv_cache['model.layers.20.attn']
assert runner_kv_caches[1] is kv_cache['model.layers.28.attn']

View File

@ -101,7 +101,9 @@ class Attention(nn.Module):
self.num_heads = num_heads
self.head_size = head_size
self.num_kv_heads = num_kv_heads
self.sliding_window = sliding_window
self.backend = backend_name_to_enum(attn_backend.get_name())
self.dtype = dtype
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
# torch.compile works by registering the attention as one giant

View File

@ -3,7 +3,10 @@ from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, List, NamedTuple, Optional, Tuple
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheSpec,
KVCacheTensor)
from vllm.v1.request import Request
logger = init_logger(__name__)
@ -305,3 +308,124 @@ def hash_request_tokens(block_size: int,
ret.append(block_hash)
parent_block_hash_value = block_hash.hash_value
return ret
def check_enough_kv_cache_memory(vllm_config: VllmConfig,
kv_cache_spec: KVCacheSpec,
available_memory: int):
"""
Checks whether `available_memory` is enough for the KV cache to hold at
least one request with the model's max_model_len.
Args:
vllm_config: The global VllmConfig
kv_cache_spec: The kv cache spec of the model
available_memory: Memory available for KV cache in bytes.
Raises:
ValueError: If there is not enough memory available for the KV cache.
"""
if available_memory <= 0:
raise ValueError("No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when "
"initializing the engine.")
max_model_len = vllm_config.model_config.max_model_len
needed_memory = 0
for layer_spec in kv_cache_spec.values():
needed_memory += layer_spec.bytes_for_tokens(max_model_len)
if needed_memory > available_memory:
raise ValueError(
f"To serve at least one request with the models's max seq len "
f"({max_model_len}), ({needed_memory/1024/1024/1024:.2f} GB KV "
f"cache is needed, which is larger than the available KV cache "
f"memory ({available_memory/1024/1024/1024:.2f} GB). Try "
f"increasing `gpu_memory_utilization` or decreasing "
f"`max_model_len` when initializing the engine.")
def is_kv_cache_type_uniform(kv_cache_spec: KVCacheSpec) -> bool:
"""
Whether all layers in the given KVCacheSpec have the same type of KV cache.
Args:
kv_cache_spec: The KVCacheSpec of the model
Returns:
True if all layers have the same type, False otherwise.
"""
layer_keys = set(layer.type_id for layer in kv_cache_spec.values())
return len(layer_keys) == 1
def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
kv_cache_spec: KVCacheSpec,
available_memory: int) -> KVCacheConfig:
"""
Generates the KV cache configuration for a model with one type of KV cache.
Divide the available memory equally among all layers.
Args:
vllm_config: The global VllmConfig
kv_cache_spec: The kv cache spec of the model
available_memory: Memory available for KV cache in bytes.
Returns:
The generated KVCacheConfig
"""
page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()}
assert len(page_sizes) == 1
page_size = page_sizes.pop()
num_blocks = int(available_memory // page_size // len(kv_cache_spec))
num_blocks = max(num_blocks, 0)
if vllm_config.cache_config.num_gpu_blocks_override is not None:
num_gpu_blocks_override = \
vllm_config.cache_config.num_gpu_blocks_override
logger.info(
"Overriding num_gpu_blocks=%d with "
"num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override)
num_blocks = num_gpu_blocks_override
logger.info("# GPU blocks: %d", num_blocks)
per_layer_size = page_size * num_blocks
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks,
tensors={
layer_name: KVCacheTensor(size=per_layer_size)
for layer_name in kv_cache_spec
},
groups=[[layer_name for layer_name in kv_cache_spec]],
kv_cache_spec=kv_cache_spec)
return kv_cache_config
def get_kv_cache_config(vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec,
available_memory: int) -> KVCacheConfig:
"""
Generates the KV cache configuration for a model
TODO: support hybrid models with more than one type of KV cache.
Args:
vllm_config: The global VllmConfig
kv_cache_spec: The kv cache spec of the model
available_memory: Memory available for KV cache in bytes.
Returns:
The generated KVCacheConfig
"""
check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory)
if is_kv_cache_type_uniform(kv_cache_spec):
# KV cache of all layers are the same, which is true for most models.
# Allocate the same amount of memory for each layer.
return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec,
available_memory)
else:
raise NotImplementedError

View File

@ -11,11 +11,12 @@ import zmq
import zmq.asyncio
from msgspec import msgpack
from vllm.config import CacheConfig, VllmConfig
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.utils import get_exception_traceback, zmq_socket_ctx
from vllm.v1.core.kv_cache_utils import get_kv_cache_config
from vllm.v1.core.scheduler import Scheduler
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile,
EngineCoreRequest, EngineCoreRequestType,
@ -49,7 +50,7 @@ class EngineCore:
# Setup KV Caches and update CacheConfig after profiling.
num_gpu_blocks, num_cpu_blocks = self._initialize_kv_caches(
vllm_config.cache_config)
vllm_config)
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
@ -65,21 +66,25 @@ class EngineCore:
vllm_config.model_config)
def _initialize_kv_caches(self,
cache_config: CacheConfig) -> Tuple[int, int]:
vllm_config: VllmConfig) -> Tuple[int, int]:
start = time.time()
num_gpu_blocks, _ = self.model_executor.determine_num_available_blocks(
)
if cache_config.num_gpu_blocks_override is not None:
num_gpu_blocks_override = cache_config.num_gpu_blocks_override
logger.info(
"Overriding num_gpu_blocks=%d with "
"num_gpu_blocks_override=%d", num_gpu_blocks,
num_gpu_blocks_override)
num_gpu_blocks = num_gpu_blocks_override
# Get all kv cache needed by the model
kv_cache_spec = self.model_executor.get_kv_cache_spec()
# Profiles the peak memory usage of the model to determine how much
# memory can be allocated for kv cache.
availble_gpu_memory = self.model_executor.determine_available_memory()
# Get the kv cache tensor size
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
availble_gpu_memory)
num_gpu_blocks = kv_cache_config.num_blocks
num_cpu_blocks = 0
self.model_executor.initialize(num_gpu_blocks)
# Initialize kv cache and warmup the execution
self.model_executor.initialize(kv_cache_config)
elapsed = time.time() - start
logger.info(("init engine (profile, create kv cache, "
"warmup model) took %.2f seconds"), elapsed)

View File

@ -1,7 +1,8 @@
from abc import ABC, abstractmethod
from typing import Tuple, Type
from typing import Type
from vllm.config import VllmConfig
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput
@ -31,11 +32,15 @@ class Executor(ABC):
raise NotImplementedError
@abstractmethod
def initialize(self, num_gpu_blocks: int) -> None:
def initialize(self, kv_cache_config: KVCacheConfig) -> None:
raise NotImplementedError
@abstractmethod
def determine_num_available_blocks(self) -> Tuple[int, int]:
def determine_available_memory(self) -> int: # in bytes
raise NotImplementedError
@abstractmethod
def get_kv_cache_spec(self) -> KVCacheSpec:
raise NotImplementedError
@abstractmethod

View File

@ -23,6 +23,7 @@ from vllm.logger import init_logger
from vllm.utils import (get_distributed_init_method, get_mp_context,
get_open_port, get_open_zmq_ipc_path, zmq_socket_ctx)
from vllm.v1.executor.abstract import Executor
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput
from vllm.worker.worker_base import WorkerWrapperBase
@ -90,29 +91,33 @@ class MultiprocExecutor(Executor):
for w in self.workers:
w.worker_response_mq.wait_until_ready()
def initialize(self, num_gpu_blocks: int) -> None:
def initialize(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize the KV caches and begin the model execution loop of the
underlying workers.
"""
logger.info("# GPU blocks: %d", num_gpu_blocks)
self.collective_rpc("initialize_cache", args=(num_gpu_blocks, ))
self.collective_rpc("initialize_cache", args=(kv_cache_config, ))
self.collective_rpc("compile_or_warm_up_model")
def determine_num_available_blocks(self) -> Tuple[int, int]:
def determine_available_memory(self) -> int:
"""
Determine the number of available KV blocks by invoking the
Determine the available memory (in bytes) for KV cache by invoking the
underlying worker.
"""
num_blocks = self.collective_rpc("determine_num_available_blocks")
memory_sizes = self.collective_rpc("determine_available_memory")
# Since we use a shared centralized controller, we take the minimum
# number of blocks across all workers to make sure all the memory
# memory size across all workers to make sure all the memory
# operators can be applied to all workers.
num_gpu_blocks = min(b[0] for b in num_blocks)
num_cpu_blocks = min(b[1] for b in num_blocks)
return min(memory_sizes)
return num_gpu_blocks, num_cpu_blocks
def get_kv_cache_spec(self) -> KVCacheSpec:
"""
Get all kv cache needed by the model by invoking the underlying worker.
"""
kv_cache_specs = self.collective_rpc("get_kv_cache_spec")
assert all(s == kv_cache_specs[0] for s in kv_cache_specs)
return kv_cache_specs[0]
def collective_rpc(self,
method: str,

View File

@ -10,6 +10,7 @@ from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.v1.executor.abstract import Executor
from vllm.v1.executor.ray_utils import (RayWorkerWrapper,
initialize_ray_cluster, ray)
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput
if ray is not None:
@ -211,39 +212,40 @@ class RayExecutor(Executor):
distributed_init_method=distributed_init_method,
)
def determine_num_available_blocks(self) -> Tuple[int, int]:
def determine_available_memory(self) -> int:
"""
Determine the number of available KV blocks.
Determine the available GPU memory in bytes.
This invokes `determine_num_available_blocks` on each worker and takes
This invokes `determine_available_memory` on each worker and takes
the min of the results, guaranteeing that the selected cache sizes are
compatible with all workers.
Returns:
- tuple[num_gpu_blocks, num_cpu_blocks]
"""
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks = self._run_workers("determine_num_available_blocks")
memory_sizes = self._run_workers("determine_available_memory")
# Since we use a shared centralized controller, we take the minimum
# number of blocks across all workers to make sure all the memory
# memory size across all workers to make sure all the memory
# operators can be applied to all workers.
num_gpu_blocks = min(b[0] for b in num_blocks)
num_cpu_blocks = min(b[1] for b in num_blocks)
return min(memory_sizes)
return num_gpu_blocks, num_cpu_blocks
def initialize(self, num_gpu_blocks: int) -> None:
def initialize(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize the KV cache in all workers.
"""
# NOTE: This is logged in the executor because there can be >1 worker
# with other executors. We could log in the engine level, but work
# remains to abstract away the device for non-GPU configurations.
logger.info("# GPU blocks: %d", num_gpu_blocks)
self._run_workers("initialize_cache", num_gpu_blocks)
self._run_workers("initialize_cache", kv_cache_config)
self._run_workers("compile_or_warm_up_model")
def get_kv_cache_spec(self) -> KVCacheSpec:
"""
Get all kv cache needed by the model
This invokes `get_kv_cache_spec` on each worker and asserts that
they are identical. The KVCacheSpec is then returned.
"""
kv_cache_specs = self._run_workers("get_kv_cache_spec")
assert all(s == kv_cache_specs[0] for s in kv_cache_specs)
return kv_cache_specs[0]
def _run_workers(
self,
method: str,

View File

@ -1,10 +1,11 @@
import os
from typing import Optional, Tuple
from typing import Optional
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.v1.executor.abstract import Executor
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.gpu_worker import Worker
@ -49,20 +50,22 @@ class UniprocExecutor(Executor):
distributed_init_method=distributed_init_method,
)
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
def determine_available_memory(self) -> int:
"""Determine the available memory (in bytes) for KV cache by invoking
the underlying worker.
"""
return self.worker.determine_num_available_blocks()
return self.worker.determine_available_memory()
def initialize(self, num_gpu_blocks: int) -> None:
def get_kv_cache_spec(self) -> KVCacheSpec:
"""Get all kv cache needed by the model by invoking the underlying
worker.
"""
return self.worker.get_kv_cache_spec()
def initialize(self, kv_cache_config: KVCacheConfig) -> None:
"""Initialize the KV cache by invoking the underlying worker.
"""
# NOTE: This is logged in the executor because there can be >1 worker
# with other executors. We could log in the engine level, but work
# remains to abstract away the device for non-GPU configurations.
logger.info("# GPU blocks: %d", num_gpu_blocks)
self.worker.initialize_cache(num_gpu_blocks)
self.worker.initialize_cache(kv_cache_config)
self.worker.compile_or_warm_up_model()
def execute_model(

View File

@ -0,0 +1,111 @@
from dataclasses import dataclass
from typing import Dict, List
import torch
from vllm.logger import init_logger
from vllm.utils import cdiv, get_dtype_size
logger = init_logger(__name__)
@dataclass
class KVCacheSpecBase:
"""
A base class for specifying the KV cache format of one layer.
"""
# number of tokens in a block
block_size: int
@property
def type_id(self) -> str:
"""
The type identifier of this KV cache.
Return different strings for layers with different KV cache type (e.g.,
different number of tokens like full attention vs sliding window
attention, different KV cache size per token like layers with different
number of heads)
Returns:
The type identifier of this KV cache.
"""
raise NotImplementedError
@property
def page_size_bytes(self) -> int:
"""
The size of a page with `block_size` tokens in bytes.
Returns:
The page size
"""
raise NotImplementedError
def bytes_for_tokens(self, num_tokens: int) -> int:
"""
The KV cache size for `num_tokens` tokens in bytes. Returns the real
memory size after padding `num_tokens` to full blocks.
Returns:
The KV cache size
"""
raise NotImplementedError
@dataclass
class FullAttentionSpec(KVCacheSpecBase):
num_kv_heads: int
head_size: int
dtype: torch.dtype
@property
def type_id(self) -> str:
return f"full_attention_{self.block_size}_{self.page_size_bytes}"
@property
def page_size_bytes(self) -> int:
return 2 * self.block_size * self.num_kv_heads * self.head_size \
* get_dtype_size(self.dtype)
def bytes_for_tokens(self, num_tokens: int) -> int:
return cdiv(num_tokens, self.block_size) * self.page_size_bytes
KVCacheSpec = Dict[str, KVCacheSpecBase]
@dataclass
class KVCacheTensor:
"""
A dataclass for specifying how the workers should initialize the KV cache
for a layer. Only contains the size of KV cache for that layer for now. Will
be extended to support multiple layers sharing the same memory pool.
"""
size: int # The size of KV cache Tensor in bytes
@dataclass
class KVCacheConfig:
"""
The KV cache configuration of a model.
"""
"""The number of KV cache blocks"""
num_blocks: int
"""layer_name -> how to initialize KV cache for that layer"""
tensors: Dict[str, KVCacheTensor]
"""
A list of kv-cache groups. Each group includes a set of layers with
the same kv-cache spec, and the total page_size of layers inside a group
is same across all groups (as the KVCacheManager only supports allocating
pages of the same size). For example:
1. A model only uses full attention: one group with all layers in the model.
2. (not implemented yet) A model with the same number of full attention
layers and sliding window attention layers: two groups, one for full
attention layers and one for sliding window attention layers.
3. (not implemented yet) A model with 2 full attention layers and 4 sliding
window attention layers: three groups, (full * 2), (sw * 2), (sw * 2).
"""
groups: List[List[str]]
"""the KVCacheSpec of the model"""
kv_cache_spec: KVCacheSpec

View File

@ -1,13 +1,20 @@
import multiprocessing
import os
import weakref
from collections import defaultdict
from collections.abc import Sequence
from typing import (Any, Callable, Dict, Generic, List, Optional, TypeVar,
Union, overload)
from typing import (TYPE_CHECKING, Any, Callable, Dict, Generic, List,
Optional, TypeVar, Union, overload)
import torch
from vllm.logger import init_logger
from vllm.model_executor.models.utils import extract_layer_index
from vllm.utils import get_mp_context, kill_process_tree
if TYPE_CHECKING:
from vllm.attention.layer import Attention
logger = init_logger(__name__)
T = TypeVar("T")
@ -134,3 +141,48 @@ def shutdown(proc: multiprocessing.Process, input_path: str, output_path: str):
socket_file = ipc_socket.replace("ipc://", "")
if os and os.path.exists(socket_file):
os.remove(socket_file)
def bind_kv_cache(
kv_caches: Dict[str, torch.Tensor],
forward_context: Dict[str, "Attention"],
runner_kv_caches: List[torch.Tensor],
) -> None:
"""
Bind the allocated KV cache to both ModelRunner and forward context so
that the KV cache can be used in the forward pass.
This function:
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
kv_caches.
2) Associates each attention layer in the `forward_context` with its
corresponding KV cache in kv_caches.
Args:
kv_caches: The allocated kv_caches with layer names as keys.
forward_context: The global forward context containing all Attention
layers with layer names as keys.
runner_kv_caches: The kv_cache declared by ModelRunner.
"""
# Bind kv_caches to ModelRunner
assert len(runner_kv_caches) == 0
# Convert kv_caches dict to a list of tensors in the order of layer_index.
index2name = defaultdict(list)
for layer_name in kv_caches:
index2name[extract_layer_index(layer_name)].append(layer_name)
for layer_index in sorted(index2name.keys()):
layer_names = index2name[layer_index]
if len(layer_names) > 1:
# One typical case is encoder-decoder model, e.g., bart.
# The cross attention and self attention in the same decoder layer
# has different layer_name but the same layer_index.
raise NotImplementedError
layer_name = layer_names[0]
runner_kv_caches.append(kv_caches[layer_name])
# Bind kv_caches to forward context
for layer_name, kv_cache in kv_caches.items():
# NOTE: Use list because of v0 PP virtual engine.
forward_context[layer_name].kv_cache = [kv_cache]

View File

@ -7,6 +7,8 @@ import torch
import torch.distributed
import torch.nn as nn
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
from vllm.config import CompilationLevel, VllmConfig
from vllm.distributed.parallel_state import graph_capture
from vllm.forward_context import set_forward_context
@ -16,14 +18,16 @@ from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.sampling_params import SamplingType
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
LayerBlockType, bind_kv_cache, cdiv,
is_pin_memory_available)
LayerBlockType, cdiv, is_pin_memory_available)
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
FlashAttentionMetadata)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
if TYPE_CHECKING:
@ -856,15 +860,71 @@ class GPUModelRunner:
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
elapsed_time, cuda_graph_size / (1 << 30))
def initialize_kv_cache(self, num_blocks: int) -> None:
assert len(self.kv_caches) == 0
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
num_blocks, self.block_size, self.num_kv_heads, self.head_size)
for _ in range(self.num_attn_layers):
self.kv_caches.append(
torch.zeros(kv_cache_shape,
dtype=self.kv_cache_dtype,
device=self.device))
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize KV cache based on `kv_cache_config`.
Args:
kv_cache_config: Configuration for the KV cache, including the KV
cache size of each layer
"""
if len(kv_cache_config.groups) > 1:
raise NotImplementedError(
"Hybrid models with more than one KV cache type are not "
"supported yet.")
kv_caches: Dict[str, torch.Tensor] = {}
for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items():
tensor_config = kv_cache_config.tensors[layer_name]
assert tensor_config.size % layer_spec.page_size_bytes == 0
num_blocks = tensor_config.size // layer_spec.page_size_bytes
if isinstance(layer_spec, FullAttentionSpec):
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
num_blocks, layer_spec.block_size, layer_spec.num_kv_heads,
layer_spec.head_size)
dtype = layer_spec.dtype
kv_caches[layer_name] = torch.zeros(kv_cache_shape,
dtype=dtype,
device=self.device)
else:
raise NotImplementedError
bind_kv_cache(
kv_caches,
self.vllm_config.compilation_config.static_forward_context,
[self.kv_caches])
self.kv_caches)
def get_kv_cache_spec(self) -> KVCacheSpec:
"""
Generates the KVCacheSpec by parsing the kv cache format from each
Attention module in the static forward context.
Returns:
KVCacheSpec: A dictionary mapping layer names to their KV cache
format. Layers that do not need KV cache are not included.
"""
forward_ctx = self.vllm_config.compilation_config.static_forward_context
block_size = self.vllm_config.cache_config.block_size
kv_cache_spec: KVCacheSpec = {}
for layer_name, attn_module in forward_ctx.items():
# TODO: Support other attention modules, e.g., sliding window,
# cross-attention, MLA.
assert isinstance(attn_module, Attention)
if attn_module.attn_type == AttentionType.DECODER:
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=attn_module.dtype,
)
elif attn_module.attn_type in (AttentionType.ENCODER,
AttentionType.ENCODER_ONLY):
# encoder-only attention does not need KV cache.
continue
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
raise NotImplementedError
else:
raise ValueError(
f"Unknown attention type: {attn_module.attn_type}")
return kv_cache_spec

View File

@ -1,7 +1,7 @@
"""A GPU worker class."""
import gc
import os
from typing import TYPE_CHECKING, Optional, Tuple
from typing import TYPE_CHECKING, Optional
import torch
import torch.distributed
@ -16,6 +16,7 @@ from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, get_dtype_size
from vllm.v1.core.scheduler import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
@ -112,20 +113,18 @@ class Worker:
self.model_runner.load_model()
@torch.inference_mode()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Profiles the peak memory usage of the model to determine how many
KV blocks may be allocated without OOMs.
def determine_available_memory(self) -> int:
"""Profiles the peak memory usage of the model to determine how much
memory can be used for KV cache without OOMs.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
Then, it calculate the free memory that can be used for KV cache in
bytes.
.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
@ -161,33 +160,14 @@ class Worker:
total_gpu_memory * self.cache_config.gpu_memory_utilization -
peak_memory)
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
cache_block_size = _get_cache_block_size(self.cache_config,
self.model_config,
self.parallel_config)
num_gpu_blocks = int(available_kv_cache_memory // cache_block_size)
num_gpu_blocks = max(num_gpu_blocks, 0)
return num_gpu_blocks, 0
return int(available_kv_cache_memory)
def initialize_cache(self, num_gpu_blocks: int) -> None:
"""Allocate GPU and CPU KV cache with the specified number of blocks."""
if num_gpu_blocks <= 0:
raise ValueError("No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when "
"initializing the engine.")
def get_kv_cache_spec(self) -> KVCacheSpec:
return self.model_runner.get_kv_cache_spec()
max_seq_len = self.cache_config.block_size * num_gpu_blocks
max_model_len = self.model_config.max_model_len
if max_model_len > max_seq_len:
raise ValueError(
f"The model's max seq len ({max_model_len}) "
"is larger than the maximum number of tokens that can be "
f"stored in KV cache ({max_seq_len}). Try increasing "
"`gpu_memory_utilization` or decreasing `max_model_len` when "
"initializing the engine.")
self.model_runner.initialize_kv_cache(num_gpu_blocks)
def initialize_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""Allocate GPU KV cache with the specified kv_cache_config."""
self.model_runner.initialize_kv_cache(kv_cache_config)
def compile_or_warm_up_model(self) -> None:
if not self.model_config.enforce_eager: