From 8917782af6a5892a0afb697badbe761dc8558619 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 8 May 2023 23:03:35 -0700 Subject: [PATCH] Add a system logger (#85) --- cacheflow/logger.py | 51 +++++++++++++++++++++++++++++ cacheflow/master/server.py | 19 +++++++++-- cacheflow/master/simple_frontend.py | 8 +++-- cacheflow/models/memory_analyzer.py | 18 ++++++---- simple_server.py | 2 +- 5 files changed, 85 insertions(+), 13 deletions(-) create mode 100644 cacheflow/logger.py diff --git a/cacheflow/logger.py b/cacheflow/logger.py new file mode 100644 index 000000000000..30d35b34ef55 --- /dev/null +++ b/cacheflow/logger.py @@ -0,0 +1,51 @@ +# Adapted from https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py + +import logging +import sys + + +_FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s" +_DATE_FORMAT = "%m-%d %H:%M:%S" + + +class NewLineFormatter(logging.Formatter): + """Adds logging prefix to newlines to align multi-line messages.""" + + def __init__(self, fmt, datefmt=None): + logging.Formatter.__init__(self, fmt, datefmt) + + def format(self, record): + msg = logging.Formatter.format(self, record) + if record.message != "": + parts = msg.split(record.message) + msg = msg.replace("\n", "\r\n" + parts[0]) + return msg + + +_root_logger = logging.getLogger("cacheflow") +_default_handler = None + + +def _setup_logger(): + _root_logger.setLevel(logging.DEBUG) + global _default_handler + if _default_handler is None: + _default_handler = logging.StreamHandler(sys.stdout) + _default_handler.flush = sys.stdout.flush # type: ignore + _default_handler.setLevel(logging.INFO) + _root_logger.addHandler(_default_handler) + fmt = NewLineFormatter(_FORMAT, datefmt=_DATE_FORMAT) + _default_handler.setFormatter(fmt) + # Setting this will avoid the message + # being propagated to the parent logger. + _root_logger.propagate = False + + +# The logger is initialized when the module is imported. +# This is thread-safe as the module is only imported once, +# guaranteed by the Python GIL. +_setup_logger() + + +def init_logger(name: str): + return logging.getLogger(name) diff --git a/cacheflow/master/server.py b/cacheflow/master/server.py index b0296513e9b1..694422e8c5ee 100644 --- a/cacheflow/master/server.py +++ b/cacheflow/master/server.py @@ -8,6 +8,7 @@ try: except ImportError: ray = None +from cacheflow.logger import init_logger from cacheflow.master.scheduler import Scheduler from cacheflow.master.simple_frontend import SimpleFrontend from cacheflow.models import get_memory_analyzer @@ -17,6 +18,9 @@ from cacheflow.sampling_params import SamplingParams from cacheflow.utils import get_gpu_memory, get_cpu_memory +logger = init_logger(__name__) + + class Server: def __init__( self, @@ -42,6 +46,17 @@ class Server: collect_stats: bool = False, do_memory_analysis: bool = False, ): + logger.info( + "Initializing a server with config: " + f"model={model!r}, " + f"dtype={dtype}, " + f"use_dummy_weights={use_dummy_weights}, " + f"cache_dir={cache_dir}, " + f"use_np_cache={use_np_cache}, " + f"tensor_parallel_size={tensor_parallel_size}, " + f"block_size={block_size}, " + f"seed={seed})" + ) self.num_nodes = num_nodes self.num_devices_per_node = num_devices_per_node self.world_size = pipeline_parallel_size * tensor_parallel_size @@ -61,9 +76,7 @@ class Server: self.num_gpu_blocks = self.memory_analyzer.get_max_num_gpu_blocks( max_num_batched_tokens=max_num_batched_tokens) self.num_cpu_blocks = self.memory_analyzer.get_max_num_cpu_blocks( - swap_space=swap_space) - print(f'# GPU blocks: {self.num_gpu_blocks}, ' - f'# CPU blocks: {self.num_cpu_blocks}') + swap_space_gib=swap_space) # Create a controller for each pipeline stage. self.controllers: List[Controller] = [] diff --git a/cacheflow/master/simple_frontend.py b/cacheflow/master/simple_frontend.py index f8396269874f..eca81c9a167e 100644 --- a/cacheflow/master/simple_frontend.py +++ b/cacheflow/master/simple_frontend.py @@ -1,13 +1,17 @@ import time -from typing import List, Optional, Set, Tuple +from typing import List, Optional, Tuple from transformers import AutoTokenizer +from cacheflow.logger import init_logger from cacheflow.sampling_params import SamplingParams from cacheflow.sequence import Sequence, SequenceGroup from cacheflow.utils import Counter +logger = init_logger(__name__) + + class SimpleFrontend: def __init__( @@ -66,4 +70,4 @@ class SimpleFrontend: token_ids = seq.get_token_ids() output = self.tokenizer.decode(token_ids, skip_special_tokens=True) output = output.strip() - print(f'Seq {seq.seq_id}: {output!r}') + logger.info(f"Seq {seq.seq_id}: {output!r}") diff --git a/cacheflow/models/memory_analyzer.py b/cacheflow/models/memory_analyzer.py index 0adc2e79a624..41f76df0b2ff 100644 --- a/cacheflow/models/memory_analyzer.py +++ b/cacheflow/models/memory_analyzer.py @@ -1,8 +1,12 @@ import torch from transformers import AutoConfig +from cacheflow.logger import init_logger from cacheflow.models.utils import get_dtype_size + +logger = init_logger(__name__) + _GiB = 1 << 30 @@ -23,20 +27,20 @@ class CacheFlowMemoryAnalyzer: def get_max_num_cpu_blocks( self, - swap_space: int, + swap_space_gib: int, ) -> int: - swap_space = swap_space * _GiB + swap_space = swap_space_gib * _GiB cpu_memory = self.cpu_memory if swap_space > 0.8 * cpu_memory: - raise ValueError(f'The swap space ({swap_space / _GiB:.2f} GiB) ' + raise ValueError(f'The swap space ({swap_space_gib:.2f} GiB) ' 'takes more than 80% of the available memory ' f'({cpu_memory / _GiB:.2f} GiB).' 'Please check the swap space size.') if swap_space > 0.5 * cpu_memory: - print(f'WARNING: The swap space ({swap_space / _GiB:.2f} GiB) ' - 'takes more than 50% of the available memory ' - f'({cpu_memory / _GiB:.2f} GiB).' - 'This may slow the system performance.') + logger.info(f'WARNING: The swap space ({swap_space_gib:.2f} GiB) ' + 'takes more than 50% of the available memory ' + f'({cpu_memory / _GiB:.2f} GiB).' + 'This may slow the system performance.') max_num_blocks = swap_space // self.get_cache_block_size() return max_num_blocks diff --git a/simple_server.py b/simple_server.py index 4df46dc16226..7b46f938c47b 100644 --- a/simple_server.py +++ b/simple_server.py @@ -1,11 +1,11 @@ import argparse -from typing import List from cacheflow.master.server import ( add_server_arguments, process_server_arguments, init_local_server_and_frontend_with_arguments) from cacheflow.sampling_params import SamplingParams + def main(args: argparse.Namespace): server, frontend = init_local_server_and_frontend_with_arguments(args) # Test the following inputs.