mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 08:04:58 +08:00
[V1] Move more control of kv cache initialization from model_executor to EngineCore (#11960)
Signed-off-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
parent
8027a72461
commit
69d765f5a5
62
tests/v1/test_utils.py
Normal file
62
tests/v1/test_utils.py
Normal file
@ -0,0 +1,62 @@
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.v1.utils import bind_kv_cache
|
||||
|
||||
|
||||
def test_bind_kv_cache():
|
||||
from vllm.attention import Attention
|
||||
|
||||
ctx = {
|
||||
'layers.0.self_attn': Attention(32, 128, 0.1),
|
||||
'layers.1.self_attn': Attention(32, 128, 0.1),
|
||||
'layers.2.self_attn': Attention(32, 128, 0.1),
|
||||
'layers.3.self_attn': Attention(32, 128, 0.1),
|
||||
}
|
||||
kv_cache = {
|
||||
'layers.0.self_attn': torch.zeros((1, )),
|
||||
'layers.1.self_attn': torch.zeros((1, )),
|
||||
'layers.2.self_attn': torch.zeros((1, )),
|
||||
'layers.3.self_attn': torch.zeros((1, )),
|
||||
}
|
||||
runner_kv_caches: List[torch.Tensor] = []
|
||||
bind_kv_cache(kv_cache, ctx, runner_kv_caches)
|
||||
assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[
|
||||
'layers.0.self_attn']
|
||||
assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[
|
||||
'layers.1.self_attn']
|
||||
assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[
|
||||
'layers.2.self_attn']
|
||||
assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[
|
||||
'layers.3.self_attn']
|
||||
|
||||
assert runner_kv_caches[0] is kv_cache['layers.0.self_attn']
|
||||
assert runner_kv_caches[1] is kv_cache['layers.1.self_attn']
|
||||
assert runner_kv_caches[2] is kv_cache['layers.2.self_attn']
|
||||
assert runner_kv_caches[3] is kv_cache['layers.3.self_attn']
|
||||
|
||||
|
||||
def test_bind_kv_cache_non_attention():
|
||||
from vllm.attention import Attention
|
||||
|
||||
# example from Jamba PP=2
|
||||
ctx = {
|
||||
'model.layers.20.attn': Attention(32, 128, 0.1),
|
||||
'model.layers.28.attn': Attention(32, 128, 0.1),
|
||||
}
|
||||
kv_cache = {
|
||||
'model.layers.20.attn': torch.zeros((1, )),
|
||||
'model.layers.28.attn': torch.zeros((1, )),
|
||||
}
|
||||
|
||||
runner_kv_caches: List[torch.Tensor] = []
|
||||
bind_kv_cache(kv_cache, ctx, runner_kv_caches)
|
||||
|
||||
assert ctx['model.layers.20.attn'].kv_cache[0] is kv_cache[
|
||||
'model.layers.20.attn']
|
||||
assert ctx['model.layers.28.attn'].kv_cache[0] is kv_cache[
|
||||
'model.layers.28.attn']
|
||||
|
||||
assert runner_kv_caches[0] is kv_cache['model.layers.20.attn']
|
||||
assert runner_kv_caches[1] is kv_cache['model.layers.28.attn']
|
||||
@ -101,7 +101,9 @@ class Attention(nn.Module):
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.sliding_window = sliding_window
|
||||
self.backend = backend_name_to_enum(attn_backend.get_name())
|
||||
self.dtype = dtype
|
||||
|
||||
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
|
||||
# torch.compile works by registering the attention as one giant
|
||||
|
||||
@ -3,7 +3,10 @@ from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, NamedTuple, Optional, Tuple
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheSpec,
|
||||
KVCacheTensor)
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -305,3 +308,124 @@ def hash_request_tokens(block_size: int,
|
||||
ret.append(block_hash)
|
||||
parent_block_hash_value = block_hash.hash_value
|
||||
return ret
|
||||
|
||||
|
||||
def check_enough_kv_cache_memory(vllm_config: VllmConfig,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
available_memory: int):
|
||||
"""
|
||||
Checks whether `available_memory` is enough for the KV cache to hold at
|
||||
least one request with the model's max_model_len.
|
||||
|
||||
Args:
|
||||
vllm_config: The global VllmConfig
|
||||
kv_cache_spec: The kv cache spec of the model
|
||||
available_memory: Memory available for KV cache in bytes.
|
||||
|
||||
Raises:
|
||||
ValueError: If there is not enough memory available for the KV cache.
|
||||
"""
|
||||
|
||||
if available_memory <= 0:
|
||||
raise ValueError("No available memory for the cache blocks. "
|
||||
"Try increasing `gpu_memory_utilization` when "
|
||||
"initializing the engine.")
|
||||
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
needed_memory = 0
|
||||
for layer_spec in kv_cache_spec.values():
|
||||
needed_memory += layer_spec.bytes_for_tokens(max_model_len)
|
||||
|
||||
if needed_memory > available_memory:
|
||||
raise ValueError(
|
||||
f"To serve at least one request with the models's max seq len "
|
||||
f"({max_model_len}), ({needed_memory/1024/1024/1024:.2f} GB KV "
|
||||
f"cache is needed, which is larger than the available KV cache "
|
||||
f"memory ({available_memory/1024/1024/1024:.2f} GB). Try "
|
||||
f"increasing `gpu_memory_utilization` or decreasing "
|
||||
f"`max_model_len` when initializing the engine.")
|
||||
|
||||
|
||||
def is_kv_cache_type_uniform(kv_cache_spec: KVCacheSpec) -> bool:
|
||||
"""
|
||||
Whether all layers in the given KVCacheSpec have the same type of KV cache.
|
||||
|
||||
Args:
|
||||
kv_cache_spec: The KVCacheSpec of the model
|
||||
|
||||
Returns:
|
||||
True if all layers have the same type, False otherwise.
|
||||
"""
|
||||
|
||||
layer_keys = set(layer.type_id for layer in kv_cache_spec.values())
|
||||
return len(layer_keys) == 1
|
||||
|
||||
|
||||
def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
available_memory: int) -> KVCacheConfig:
|
||||
"""
|
||||
Generates the KV cache configuration for a model with one type of KV cache.
|
||||
Divide the available memory equally among all layers.
|
||||
|
||||
Args:
|
||||
vllm_config: The global VllmConfig
|
||||
kv_cache_spec: The kv cache spec of the model
|
||||
available_memory: Memory available for KV cache in bytes.
|
||||
|
||||
Returns:
|
||||
The generated KVCacheConfig
|
||||
"""
|
||||
|
||||
page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()}
|
||||
assert len(page_sizes) == 1
|
||||
page_size = page_sizes.pop()
|
||||
|
||||
num_blocks = int(available_memory // page_size // len(kv_cache_spec))
|
||||
num_blocks = max(num_blocks, 0)
|
||||
|
||||
if vllm_config.cache_config.num_gpu_blocks_override is not None:
|
||||
num_gpu_blocks_override = \
|
||||
vllm_config.cache_config.num_gpu_blocks_override
|
||||
logger.info(
|
||||
"Overriding num_gpu_blocks=%d with "
|
||||
"num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override)
|
||||
num_blocks = num_gpu_blocks_override
|
||||
|
||||
logger.info("# GPU blocks: %d", num_blocks)
|
||||
|
||||
per_layer_size = page_size * num_blocks
|
||||
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=num_blocks,
|
||||
tensors={
|
||||
layer_name: KVCacheTensor(size=per_layer_size)
|
||||
for layer_name in kv_cache_spec
|
||||
},
|
||||
groups=[[layer_name for layer_name in kv_cache_spec]],
|
||||
kv_cache_spec=kv_cache_spec)
|
||||
return kv_cache_config
|
||||
|
||||
|
||||
def get_kv_cache_config(vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec,
|
||||
available_memory: int) -> KVCacheConfig:
|
||||
"""
|
||||
Generates the KV cache configuration for a model
|
||||
TODO: support hybrid models with more than one type of KV cache.
|
||||
|
||||
Args:
|
||||
vllm_config: The global VllmConfig
|
||||
kv_cache_spec: The kv cache spec of the model
|
||||
available_memory: Memory available for KV cache in bytes.
|
||||
|
||||
Returns:
|
||||
The generated KVCacheConfig
|
||||
"""
|
||||
check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory)
|
||||
if is_kv_cache_type_uniform(kv_cache_spec):
|
||||
# KV cache of all layers are the same, which is true for most models.
|
||||
# Allocate the same amount of memory for each layer.
|
||||
return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec,
|
||||
available_memory)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -11,11 +11,12 @@ import zmq
|
||||
import zmq.asyncio
|
||||
from msgspec import msgpack
|
||||
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.config import (
|
||||
maybe_register_config_serialize_by_value)
|
||||
from vllm.utils import get_exception_traceback, zmq_socket_ctx
|
||||
from vllm.v1.core.kv_cache_utils import get_kv_cache_config
|
||||
from vllm.v1.core.scheduler import Scheduler
|
||||
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile,
|
||||
EngineCoreRequest, EngineCoreRequestType,
|
||||
@ -49,7 +50,7 @@ class EngineCore:
|
||||
|
||||
# Setup KV Caches and update CacheConfig after profiling.
|
||||
num_gpu_blocks, num_cpu_blocks = self._initialize_kv_caches(
|
||||
vllm_config.cache_config)
|
||||
vllm_config)
|
||||
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
|
||||
@ -65,21 +66,25 @@ class EngineCore:
|
||||
vllm_config.model_config)
|
||||
|
||||
def _initialize_kv_caches(self,
|
||||
cache_config: CacheConfig) -> Tuple[int, int]:
|
||||
vllm_config: VllmConfig) -> Tuple[int, int]:
|
||||
start = time.time()
|
||||
num_gpu_blocks, _ = self.model_executor.determine_num_available_blocks(
|
||||
)
|
||||
|
||||
if cache_config.num_gpu_blocks_override is not None:
|
||||
num_gpu_blocks_override = cache_config.num_gpu_blocks_override
|
||||
logger.info(
|
||||
"Overriding num_gpu_blocks=%d with "
|
||||
"num_gpu_blocks_override=%d", num_gpu_blocks,
|
||||
num_gpu_blocks_override)
|
||||
num_gpu_blocks = num_gpu_blocks_override
|
||||
# Get all kv cache needed by the model
|
||||
kv_cache_spec = self.model_executor.get_kv_cache_spec()
|
||||
|
||||
# Profiles the peak memory usage of the model to determine how much
|
||||
# memory can be allocated for kv cache.
|
||||
availble_gpu_memory = self.model_executor.determine_available_memory()
|
||||
|
||||
# Get the kv cache tensor size
|
||||
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
|
||||
availble_gpu_memory)
|
||||
num_gpu_blocks = kv_cache_config.num_blocks
|
||||
num_cpu_blocks = 0
|
||||
self.model_executor.initialize(num_gpu_blocks)
|
||||
|
||||
# Initialize kv cache and warmup the execution
|
||||
self.model_executor.initialize(kv_cache_config)
|
||||
|
||||
elapsed = time.time() - start
|
||||
logger.info(("init engine (profile, create kv cache, "
|
||||
"warmup model) took %.2f seconds"), elapsed)
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Tuple, Type
|
||||
from typing import Type
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
|
||||
|
||||
@ -31,11 +32,15 @@ class Executor(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def initialize(self, num_gpu_blocks: int) -> None:
|
||||
def initialize(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
def determine_available_memory(self) -> int: # in bytes
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_kv_cache_spec(self) -> KVCacheSpec:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@ -23,6 +23,7 @@ from vllm.logger import init_logger
|
||||
from vllm.utils import (get_distributed_init_method, get_mp_context,
|
||||
get_open_port, get_open_zmq_ipc_path, zmq_socket_ctx)
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
@ -90,29 +91,33 @@ class MultiprocExecutor(Executor):
|
||||
for w in self.workers:
|
||||
w.worker_response_mq.wait_until_ready()
|
||||
|
||||
def initialize(self, num_gpu_blocks: int) -> None:
|
||||
def initialize(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Initialize the KV caches and begin the model execution loop of the
|
||||
underlying workers.
|
||||
"""
|
||||
logger.info("# GPU blocks: %d", num_gpu_blocks)
|
||||
self.collective_rpc("initialize_cache", args=(num_gpu_blocks, ))
|
||||
self.collective_rpc("initialize_cache", args=(kv_cache_config, ))
|
||||
self.collective_rpc("compile_or_warm_up_model")
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
def determine_available_memory(self) -> int:
|
||||
"""
|
||||
Determine the number of available KV blocks by invoking the
|
||||
Determine the available memory (in bytes) for KV cache by invoking the
|
||||
underlying worker.
|
||||
"""
|
||||
num_blocks = self.collective_rpc("determine_num_available_blocks")
|
||||
memory_sizes = self.collective_rpc("determine_available_memory")
|
||||
|
||||
# Since we use a shared centralized controller, we take the minimum
|
||||
# number of blocks across all workers to make sure all the memory
|
||||
# memory size across all workers to make sure all the memory
|
||||
# operators can be applied to all workers.
|
||||
num_gpu_blocks = min(b[0] for b in num_blocks)
|
||||
num_cpu_blocks = min(b[1] for b in num_blocks)
|
||||
return min(memory_sizes)
|
||||
|
||||
return num_gpu_blocks, num_cpu_blocks
|
||||
def get_kv_cache_spec(self) -> KVCacheSpec:
|
||||
"""
|
||||
Get all kv cache needed by the model by invoking the underlying worker.
|
||||
"""
|
||||
kv_cache_specs = self.collective_rpc("get_kv_cache_spec")
|
||||
assert all(s == kv_cache_specs[0] for s in kv_cache_specs)
|
||||
return kv_cache_specs[0]
|
||||
|
||||
def collective_rpc(self,
|
||||
method: str,
|
||||
|
||||
@ -10,6 +10,7 @@ from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.executor.ray_utils import (RayWorkerWrapper,
|
||||
initialize_ray_cluster, ray)
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
|
||||
if ray is not None:
|
||||
@ -211,39 +212,40 @@ class RayExecutor(Executor):
|
||||
distributed_init_method=distributed_init_method,
|
||||
)
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
def determine_available_memory(self) -> int:
|
||||
"""
|
||||
Determine the number of available KV blocks.
|
||||
Determine the available GPU memory in bytes.
|
||||
|
||||
This invokes `determine_num_available_blocks` on each worker and takes
|
||||
This invokes `determine_available_memory` on each worker and takes
|
||||
the min of the results, guaranteeing that the selected cache sizes are
|
||||
compatible with all workers.
|
||||
|
||||
Returns:
|
||||
- tuple[num_gpu_blocks, num_cpu_blocks]
|
||||
"""
|
||||
# Get the maximum number of blocks that can be allocated on GPU and CPU.
|
||||
num_blocks = self._run_workers("determine_num_available_blocks")
|
||||
|
||||
memory_sizes = self._run_workers("determine_available_memory")
|
||||
|
||||
# Since we use a shared centralized controller, we take the minimum
|
||||
# number of blocks across all workers to make sure all the memory
|
||||
# memory size across all workers to make sure all the memory
|
||||
# operators can be applied to all workers.
|
||||
num_gpu_blocks = min(b[0] for b in num_blocks)
|
||||
num_cpu_blocks = min(b[1] for b in num_blocks)
|
||||
return min(memory_sizes)
|
||||
|
||||
return num_gpu_blocks, num_cpu_blocks
|
||||
|
||||
def initialize(self, num_gpu_blocks: int) -> None:
|
||||
def initialize(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Initialize the KV cache in all workers.
|
||||
"""
|
||||
# NOTE: This is logged in the executor because there can be >1 worker
|
||||
# with other executors. We could log in the engine level, but work
|
||||
# remains to abstract away the device for non-GPU configurations.
|
||||
logger.info("# GPU blocks: %d", num_gpu_blocks)
|
||||
self._run_workers("initialize_cache", num_gpu_blocks)
|
||||
self._run_workers("initialize_cache", kv_cache_config)
|
||||
self._run_workers("compile_or_warm_up_model")
|
||||
|
||||
def get_kv_cache_spec(self) -> KVCacheSpec:
|
||||
"""
|
||||
Get all kv cache needed by the model
|
||||
|
||||
This invokes `get_kv_cache_spec` on each worker and asserts that
|
||||
they are identical. The KVCacheSpec is then returned.
|
||||
"""
|
||||
kv_cache_specs = self._run_workers("get_kv_cache_spec")
|
||||
assert all(s == kv_cache_specs[0] for s in kv_cache_specs)
|
||||
return kv_cache_specs[0]
|
||||
|
||||
def _run_workers(
|
||||
self,
|
||||
method: str,
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
import os
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.worker.gpu_worker import Worker
|
||||
|
||||
@ -49,20 +50,22 @@ class UniprocExecutor(Executor):
|
||||
distributed_init_method=distributed_init_method,
|
||||
)
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Determine the number of available KV blocks by invoking the
|
||||
underlying worker.
|
||||
def determine_available_memory(self) -> int:
|
||||
"""Determine the available memory (in bytes) for KV cache by invoking
|
||||
the underlying worker.
|
||||
"""
|
||||
return self.worker.determine_num_available_blocks()
|
||||
return self.worker.determine_available_memory()
|
||||
|
||||
def initialize(self, num_gpu_blocks: int) -> None:
|
||||
def get_kv_cache_spec(self) -> KVCacheSpec:
|
||||
"""Get all kv cache needed by the model by invoking the underlying
|
||||
worker.
|
||||
"""
|
||||
return self.worker.get_kv_cache_spec()
|
||||
|
||||
def initialize(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""Initialize the KV cache by invoking the underlying worker.
|
||||
"""
|
||||
# NOTE: This is logged in the executor because there can be >1 worker
|
||||
# with other executors. We could log in the engine level, but work
|
||||
# remains to abstract away the device for non-GPU configurations.
|
||||
logger.info("# GPU blocks: %d", num_gpu_blocks)
|
||||
self.worker.initialize_cache(num_gpu_blocks)
|
||||
self.worker.initialize_cache(kv_cache_config)
|
||||
self.worker.compile_or_warm_up_model()
|
||||
|
||||
def execute_model(
|
||||
|
||||
111
vllm/v1/kv_cache_interface.py
Normal file
111
vllm/v1/kv_cache_interface.py
Normal file
@ -0,0 +1,111 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cdiv, get_dtype_size
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class KVCacheSpecBase:
|
||||
"""
|
||||
A base class for specifying the KV cache format of one layer.
|
||||
"""
|
||||
|
||||
# number of tokens in a block
|
||||
block_size: int
|
||||
|
||||
@property
|
||||
def type_id(self) -> str:
|
||||
"""
|
||||
The type identifier of this KV cache.
|
||||
Return different strings for layers with different KV cache type (e.g.,
|
||||
different number of tokens like full attention vs sliding window
|
||||
attention, different KV cache size per token like layers with different
|
||||
number of heads)
|
||||
|
||||
Returns:
|
||||
The type identifier of this KV cache.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
"""
|
||||
The size of a page with `block_size` tokens in bytes.
|
||||
|
||||
Returns:
|
||||
The page size
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def bytes_for_tokens(self, num_tokens: int) -> int:
|
||||
"""
|
||||
The KV cache size for `num_tokens` tokens in bytes. Returns the real
|
||||
memory size after padding `num_tokens` to full blocks.
|
||||
|
||||
Returns:
|
||||
The KV cache size
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class FullAttentionSpec(KVCacheSpecBase):
|
||||
num_kv_heads: int
|
||||
head_size: int
|
||||
dtype: torch.dtype
|
||||
|
||||
@property
|
||||
def type_id(self) -> str:
|
||||
return f"full_attention_{self.block_size}_{self.page_size_bytes}"
|
||||
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
return 2 * self.block_size * self.num_kv_heads * self.head_size \
|
||||
* get_dtype_size(self.dtype)
|
||||
|
||||
def bytes_for_tokens(self, num_tokens: int) -> int:
|
||||
return cdiv(num_tokens, self.block_size) * self.page_size_bytes
|
||||
|
||||
|
||||
KVCacheSpec = Dict[str, KVCacheSpecBase]
|
||||
|
||||
|
||||
@dataclass
|
||||
class KVCacheTensor:
|
||||
"""
|
||||
A dataclass for specifying how the workers should initialize the KV cache
|
||||
for a layer. Only contains the size of KV cache for that layer for now. Will
|
||||
be extended to support multiple layers sharing the same memory pool.
|
||||
"""
|
||||
size: int # The size of KV cache Tensor in bytes
|
||||
|
||||
|
||||
@dataclass
|
||||
class KVCacheConfig:
|
||||
"""
|
||||
The KV cache configuration of a model.
|
||||
"""
|
||||
"""The number of KV cache blocks"""
|
||||
num_blocks: int
|
||||
"""layer_name -> how to initialize KV cache for that layer"""
|
||||
tensors: Dict[str, KVCacheTensor]
|
||||
"""
|
||||
A list of kv-cache groups. Each group includes a set of layers with
|
||||
the same kv-cache spec, and the total page_size of layers inside a group
|
||||
is same across all groups (as the KVCacheManager only supports allocating
|
||||
pages of the same size). For example:
|
||||
1. A model only uses full attention: one group with all layers in the model.
|
||||
2. (not implemented yet) A model with the same number of full attention
|
||||
layers and sliding window attention layers: two groups, one for full
|
||||
attention layers and one for sliding window attention layers.
|
||||
3. (not implemented yet) A model with 2 full attention layers and 4 sliding
|
||||
window attention layers: three groups, (full * 2), (sw * 2), (sw * 2).
|
||||
"""
|
||||
groups: List[List[str]]
|
||||
"""the KVCacheSpec of the model"""
|
||||
kv_cache_spec: KVCacheSpec
|
||||
@ -1,13 +1,20 @@
|
||||
import multiprocessing
|
||||
import os
|
||||
import weakref
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from typing import (Any, Callable, Dict, Generic, List, Optional, TypeVar,
|
||||
Union, overload)
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, Generic, List,
|
||||
Optional, TypeVar, Union, overload)
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
from vllm.utils import get_mp_context, kill_process_tree
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.layer import Attention
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
@ -134,3 +141,48 @@ def shutdown(proc: multiprocessing.Process, input_path: str, output_path: str):
|
||||
socket_file = ipc_socket.replace("ipc://", "")
|
||||
if os and os.path.exists(socket_file):
|
||||
os.remove(socket_file)
|
||||
|
||||
|
||||
def bind_kv_cache(
|
||||
kv_caches: Dict[str, torch.Tensor],
|
||||
forward_context: Dict[str, "Attention"],
|
||||
runner_kv_caches: List[torch.Tensor],
|
||||
) -> None:
|
||||
"""
|
||||
Bind the allocated KV cache to both ModelRunner and forward context so
|
||||
that the KV cache can be used in the forward pass.
|
||||
|
||||
This function:
|
||||
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
|
||||
kv_caches.
|
||||
2) Associates each attention layer in the `forward_context` with its
|
||||
corresponding KV cache in kv_caches.
|
||||
|
||||
Args:
|
||||
kv_caches: The allocated kv_caches with layer names as keys.
|
||||
forward_context: The global forward context containing all Attention
|
||||
layers with layer names as keys.
|
||||
runner_kv_caches: The kv_cache declared by ModelRunner.
|
||||
"""
|
||||
# Bind kv_caches to ModelRunner
|
||||
assert len(runner_kv_caches) == 0
|
||||
|
||||
# Convert kv_caches dict to a list of tensors in the order of layer_index.
|
||||
index2name = defaultdict(list)
|
||||
for layer_name in kv_caches:
|
||||
index2name[extract_layer_index(layer_name)].append(layer_name)
|
||||
|
||||
for layer_index in sorted(index2name.keys()):
|
||||
layer_names = index2name[layer_index]
|
||||
if len(layer_names) > 1:
|
||||
# One typical case is encoder-decoder model, e.g., bart.
|
||||
# The cross attention and self attention in the same decoder layer
|
||||
# has different layer_name but the same layer_index.
|
||||
raise NotImplementedError
|
||||
layer_name = layer_names[0]
|
||||
runner_kv_caches.append(kv_caches[layer_name])
|
||||
|
||||
# Bind kv_caches to forward context
|
||||
for layer_name, kv_cache in kv_caches.items():
|
||||
# NOTE: Use list because of v0 PP virtual engine.
|
||||
forward_context[layer_name].kv_cache = [kv_cache]
|
||||
|
||||
@ -7,6 +7,8 @@ import torch
|
||||
import torch.distributed
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import CompilationLevel, VllmConfig
|
||||
from vllm.distributed.parallel_state import graph_capture
|
||||
from vllm.forward_context import set_forward_context
|
||||
@ -16,14 +18,16 @@ from vllm.model_executor.model_loader import get_model
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
LayerBlockType, bind_kv_cache, cdiv,
|
||||
is_pin_memory_available)
|
||||
LayerBlockType, cdiv, is_pin_memory_available)
|
||||
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
|
||||
FlashAttentionMetadata)
|
||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheSpec)
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.utils import bind_kv_cache
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -856,15 +860,71 @@ class GPUModelRunner:
|
||||
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
|
||||
elapsed_time, cuda_graph_size / (1 << 30))
|
||||
|
||||
def initialize_kv_cache(self, num_blocks: int) -> None:
|
||||
assert len(self.kv_caches) == 0
|
||||
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
|
||||
num_blocks, self.block_size, self.num_kv_heads, self.head_size)
|
||||
for _ in range(self.num_attn_layers):
|
||||
self.kv_caches.append(
|
||||
torch.zeros(kv_cache_shape,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device))
|
||||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Initialize KV cache based on `kv_cache_config`.
|
||||
Args:
|
||||
kv_cache_config: Configuration for the KV cache, including the KV
|
||||
cache size of each layer
|
||||
"""
|
||||
if len(kv_cache_config.groups) > 1:
|
||||
raise NotImplementedError(
|
||||
"Hybrid models with more than one KV cache type are not "
|
||||
"supported yet.")
|
||||
|
||||
kv_caches: Dict[str, torch.Tensor] = {}
|
||||
|
||||
for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items():
|
||||
tensor_config = kv_cache_config.tensors[layer_name]
|
||||
assert tensor_config.size % layer_spec.page_size_bytes == 0
|
||||
num_blocks = tensor_config.size // layer_spec.page_size_bytes
|
||||
if isinstance(layer_spec, FullAttentionSpec):
|
||||
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
|
||||
num_blocks, layer_spec.block_size, layer_spec.num_kv_heads,
|
||||
layer_spec.head_size)
|
||||
dtype = layer_spec.dtype
|
||||
kv_caches[layer_name] = torch.zeros(kv_cache_shape,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
bind_kv_cache(
|
||||
kv_caches,
|
||||
self.vllm_config.compilation_config.static_forward_context,
|
||||
[self.kv_caches])
|
||||
self.kv_caches)
|
||||
|
||||
def get_kv_cache_spec(self) -> KVCacheSpec:
|
||||
"""
|
||||
Generates the KVCacheSpec by parsing the kv cache format from each
|
||||
Attention module in the static forward context.
|
||||
Returns:
|
||||
KVCacheSpec: A dictionary mapping layer names to their KV cache
|
||||
format. Layers that do not need KV cache are not included.
|
||||
"""
|
||||
|
||||
forward_ctx = self.vllm_config.compilation_config.static_forward_context
|
||||
block_size = self.vllm_config.cache_config.block_size
|
||||
kv_cache_spec: KVCacheSpec = {}
|
||||
for layer_name, attn_module in forward_ctx.items():
|
||||
# TODO: Support other attention modules, e.g., sliding window,
|
||||
# cross-attention, MLA.
|
||||
assert isinstance(attn_module, Attention)
|
||||
if attn_module.attn_type == AttentionType.DECODER:
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=attn_module.dtype,
|
||||
)
|
||||
elif attn_module.attn_type in (AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY):
|
||||
# encoder-only attention does not need KV cache.
|
||||
continue
|
||||
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown attention type: {attn_module.attn_type}")
|
||||
|
||||
return kv_cache_spec
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
"""A GPU worker class."""
|
||||
import gc
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
@ -16,6 +16,7 @@ from vllm.model_executor import set_random_seed
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, get_dtype_size
|
||||
from vllm.v1.core.scheduler import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
@ -112,20 +113,18 @@ class Worker:
|
||||
self.model_runner.load_model()
|
||||
|
||||
@torch.inference_mode()
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Profiles the peak memory usage of the model to determine how many
|
||||
KV blocks may be allocated without OOMs.
|
||||
def determine_available_memory(self) -> int:
|
||||
"""Profiles the peak memory usage of the model to determine how much
|
||||
memory can be used for KV cache without OOMs.
|
||||
|
||||
The engine will first conduct a profiling of the existing memory usage.
|
||||
Then, it calculate the maximum possible number of GPU and CPU blocks
|
||||
that can be allocated with the remaining free memory.
|
||||
Then, it calculate the free memory that can be used for KV cache in
|
||||
bytes.
|
||||
|
||||
.. tip::
|
||||
You may limit the usage of GPU memory
|
||||
by adjusting the `gpu_memory_utilization` parameter.
|
||||
"""
|
||||
# Profile the memory usage of the model and get the maximum number of
|
||||
# cache blocks that can be allocated with the remaining free memory.
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
@ -161,33 +160,14 @@ class Worker:
|
||||
total_gpu_memory * self.cache_config.gpu_memory_utilization -
|
||||
peak_memory)
|
||||
|
||||
# Calculate the number of blocks that can be allocated with the
|
||||
# profiled peak memory.
|
||||
cache_block_size = _get_cache_block_size(self.cache_config,
|
||||
self.model_config,
|
||||
self.parallel_config)
|
||||
num_gpu_blocks = int(available_kv_cache_memory // cache_block_size)
|
||||
num_gpu_blocks = max(num_gpu_blocks, 0)
|
||||
return num_gpu_blocks, 0
|
||||
return int(available_kv_cache_memory)
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int) -> None:
|
||||
"""Allocate GPU and CPU KV cache with the specified number of blocks."""
|
||||
if num_gpu_blocks <= 0:
|
||||
raise ValueError("No available memory for the cache blocks. "
|
||||
"Try increasing `gpu_memory_utilization` when "
|
||||
"initializing the engine.")
|
||||
def get_kv_cache_spec(self) -> KVCacheSpec:
|
||||
return self.model_runner.get_kv_cache_spec()
|
||||
|
||||
max_seq_len = self.cache_config.block_size * num_gpu_blocks
|
||||
max_model_len = self.model_config.max_model_len
|
||||
if max_model_len > max_seq_len:
|
||||
raise ValueError(
|
||||
f"The model's max seq len ({max_model_len}) "
|
||||
"is larger than the maximum number of tokens that can be "
|
||||
f"stored in KV cache ({max_seq_len}). Try increasing "
|
||||
"`gpu_memory_utilization` or decreasing `max_model_len` when "
|
||||
"initializing the engine.")
|
||||
|
||||
self.model_runner.initialize_kv_cache(num_gpu_blocks)
|
||||
def initialize_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""Allocate GPU KV cache with the specified kv_cache_config."""
|
||||
self.model_runner.initialize_kv_cache(kv_cache_config)
|
||||
|
||||
def compile_or_warm_up_model(self) -> None:
|
||||
if not self.model_config.enforce_eager:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user