Print warnings/errors for large swap space (#123)

This commit is contained in:
Woosuk Kwon 2023-05-23 18:22:26 -07:00 committed by GitHub
parent a283ec2eec
commit aedba6d5ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 34 additions and 0 deletions

View File

@ -3,6 +3,11 @@ from typing import Optional
import torch
from transformers import AutoConfig, PretrainedConfig
from cacheflow.logger import init_logger
from cacheflow.utils import get_cpu_memory
logger = init_logger(__name__)
_GiB = 1 << 30
@ -73,11 +78,37 @@ class CacheConfig:
self.block_size = block_size
self.gpu_memory_utilization = gpu_memory_utilization
self.swap_space_bytes = swap_space * _GiB
self._verify_args()
# Will be set after profiling.
self.num_gpu_blocks = None
self.num_cpu_blocks = None
def _verify_args(self) -> None:
if self.gpu_memory_utilization > 1.0:
raise ValueError(
"GPU memory utilization must be less than 1.0. Got "
f"{self.gpu_memory_utilization}.")
def verify_with_parallel_config(
self,
parallel_config: "ParallelConfig",
) -> None:
total_cpu_memory = get_cpu_memory()
# FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel
# group are in the same node. However, the GPUs may span multiple nodes.
num_gpus_per_node = parallel_config.tensor_parallel_size
cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node
msg = (
f"{cpu_memory_usage / _GiB:.2f} GiB out of "
f"the {total_cpu_memory / _GiB:.2f} GiB total CPU memory is "
"allocated for the swap space.")
if cpu_memory_usage > 0.7 * total_cpu_memory:
raise ValueError("Too large swap space. " + msg)
elif cpu_memory_usage > 0.4 * total_cpu_memory:
logger.warn("Possibly too large swap space. " + msg)
class ParallelConfig:

View File

@ -84,6 +84,7 @@ class LLMServer:
def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config)
self.cache_config.verify_with_parallel_config(self.parallel_config)
def _init_cache(self) -> None:
# Get the maximum number of blocks that can be allocated on GPU and CPU.

View File

@ -24,8 +24,10 @@ class Counter:
def get_gpu_memory(gpu: int = 0) -> int:
"""Returns the total memory of the GPU in bytes."""
return torch.cuda.get_device_properties(gpu).total_memory
def get_cpu_memory() -> int:
"""Returns the total CPU memory of the node in bytes."""
return psutil.virtual_memory().total