mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-09 16:09:44 +08:00
[Log] Optimize Startup Log (#26740)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
d95d0f4b98
commit
52efc34ebf
@ -245,10 +245,14 @@ class CompilerManager:
|
||||
if graph_index == 0:
|
||||
# adds some info logging for the first graph
|
||||
if runtime_shape is None:
|
||||
logger.info("Cache the graph for dynamic shape for later use")
|
||||
logger.info_once(
|
||||
"Cache the graph for dynamic shape for later use", scope="local"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Cache the graph of shape %s for later use", str(runtime_shape)
|
||||
logger.info_once(
|
||||
"Cache the graph of shape %s for later use",
|
||||
str(runtime_shape),
|
||||
scope="local",
|
||||
)
|
||||
if runtime_shape is None:
|
||||
logger.debug(
|
||||
@ -272,12 +276,17 @@ class CompilerManager:
|
||||
elapsed = now - compilation_start_time
|
||||
compilation_config.compilation_time += elapsed
|
||||
if runtime_shape is None:
|
||||
logger.info("Compiling a graph for dynamic shape takes %.2f s", elapsed)
|
||||
logger.info_once(
|
||||
"Compiling a graph for dynamic shape takes %.2f s",
|
||||
elapsed,
|
||||
scope="local",
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
logger.info_once(
|
||||
"Compiling a graph for shape %s takes %.2f s",
|
||||
runtime_shape,
|
||||
elapsed,
|
||||
scope="local",
|
||||
)
|
||||
|
||||
return compiled_graph
|
||||
@ -604,10 +613,12 @@ class VllmBackend:
|
||||
disable_cache = envs.VLLM_DISABLE_COMPILE_CACHE
|
||||
|
||||
if disable_cache:
|
||||
logger.info("vLLM's torch.compile cache is disabled.")
|
||||
logger.info_once("vLLM's torch.compile cache is disabled.", scope="local")
|
||||
else:
|
||||
logger.info(
|
||||
"Using cache directory: %s for vLLM's torch.compile", local_cache_dir
|
||||
logger.info_once(
|
||||
"Using cache directory: %s for vLLM's torch.compile",
|
||||
local_cache_dir,
|
||||
scope="local",
|
||||
)
|
||||
|
||||
self.compiler_manager.initialize_cache(
|
||||
@ -620,7 +631,9 @@ class VllmBackend:
|
||||
from .monitor import torch_compile_start_time
|
||||
|
||||
dynamo_time = time.time() - torch_compile_start_time
|
||||
logger.info("Dynamo bytecode transform time: %.2f s", dynamo_time)
|
||||
logger.info_once(
|
||||
"Dynamo bytecode transform time: %.2f s", dynamo_time, scope="local"
|
||||
)
|
||||
self.compilation_config.compilation_time += dynamo_time
|
||||
|
||||
# we control the compilation process, each instance can only be
|
||||
@ -672,7 +685,9 @@ class VllmBackend:
|
||||
with open(graph_path, "w") as f:
|
||||
f.write(src)
|
||||
|
||||
logger.debug("Computation graph saved to %s", graph_path)
|
||||
logger.debug_once(
|
||||
"Computation graph saved to %s", graph_path, scope="local"
|
||||
)
|
||||
|
||||
self._called = True
|
||||
|
||||
|
||||
@ -31,8 +31,10 @@ def start_monitoring_torch_compile(vllm_config: VllmConfig):
|
||||
def end_monitoring_torch_compile(vllm_config: VllmConfig):
|
||||
compilation_config: CompilationConfig = vllm_config.compilation_config
|
||||
if compilation_config.mode == CompilationMode.VLLM_COMPILE:
|
||||
logger.info(
|
||||
"torch.compile takes %.2f s in total", compilation_config.compilation_time
|
||||
logger.info_once(
|
||||
"torch.compile takes %.2f s in total",
|
||||
compilation_config.compilation_time,
|
||||
scope="local",
|
||||
)
|
||||
global context_manager
|
||||
if context_manager is not None:
|
||||
|
||||
@ -13,7 +13,6 @@ from vllm.distributed.device_communicators.pynccl import register_nccl_symmetric
|
||||
from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||
is_symmetric_memory_enabled,
|
||||
)
|
||||
from vllm.distributed.parallel_state import is_global_first_rank
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@ -118,11 +117,11 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
else:
|
||||
raise ValueError(f"Unknown all2all backend: {self.all2all_backend}")
|
||||
|
||||
if is_global_first_rank():
|
||||
logger.info(
|
||||
"Using %s all2all manager.",
|
||||
self.all2all_manager.__class__.__name__,
|
||||
)
|
||||
logger.info_once(
|
||||
"Using %s all2all manager.",
|
||||
self.all2all_manager.__class__.__name__,
|
||||
scope="global",
|
||||
)
|
||||
|
||||
def all_reduce(self, input_):
|
||||
# since currently we perform copy input -> symm_input -> out-of-place AR
|
||||
|
||||
@ -34,7 +34,7 @@ def _can_p2p(rank: int, world_size: int) -> bool:
|
||||
if i == rank:
|
||||
continue
|
||||
if envs.VLLM_SKIP_P2P_CHECK:
|
||||
logger.info("Skipping P2P check and trusting the driver's P2P report.")
|
||||
logger.debug("Skipping P2P check and trusting the driver's P2P report.")
|
||||
return torch.cuda.can_device_access_peer(rank, i)
|
||||
if not gpu_p2p_access_check(rank, i):
|
||||
return False
|
||||
|
||||
@ -108,7 +108,9 @@ class PyNcclCommunicator:
|
||||
if self.rank == 0:
|
||||
# get the unique id from NCCL
|
||||
self.unique_id = self.nccl.ncclGetUniqueId()
|
||||
logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion())
|
||||
logger.info_once(
|
||||
"vLLM is using nccl==%s", self.nccl.ncclGetVersion(), scope="local"
|
||||
)
|
||||
else:
|
||||
# construct an empty unique id
|
||||
self.unique_id = ncclUniqueId()
|
||||
|
||||
@ -312,7 +312,7 @@ class MessageQueue:
|
||||
remote_addr_ipv6=remote_addr_ipv6,
|
||||
)
|
||||
|
||||
logger.info("vLLM message queue communication handle: %s", self.handle)
|
||||
logger.debug("vLLM message queue communication handle: %s", self.handle)
|
||||
|
||||
def export_handle(self) -> Handle:
|
||||
return self.handle
|
||||
|
||||
@ -1157,7 +1157,7 @@ def init_distributed_environment(
|
||||
ip = parallel_config.data_parallel_master_ip
|
||||
port = parallel_config.get_next_dp_init_port()
|
||||
distributed_init_method = get_distributed_init_method(ip, port)
|
||||
logger.info(
|
||||
logger.debug(
|
||||
"Adjusting world_size=%d rank=%d distributed_init_method=%s for DP",
|
||||
world_size,
|
||||
rank,
|
||||
@ -1322,7 +1322,7 @@ def initialize_model_parallel(
|
||||
group_ranks, get_world_group().local_rank, backend, group_name="ep"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
logger.info_once(
|
||||
"rank %s in world size %s is assigned as "
|
||||
"DP rank %s, PP rank %s, TP rank %s, EP rank %s",
|
||||
rank,
|
||||
@ -1625,6 +1625,29 @@ def is_global_first_rank() -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def is_local_first_rank() -> bool:
|
||||
"""
|
||||
Check if the current process is the first local rank (rank 0 on its node).
|
||||
"""
|
||||
try:
|
||||
# prefer the initialized world group if available
|
||||
global _WORLD
|
||||
if _WORLD is not None:
|
||||
return _WORLD.local_rank == 0
|
||||
|
||||
if not torch.distributed.is_initialized():
|
||||
return True
|
||||
|
||||
# fallback to environment-provided local rank if available
|
||||
# note: envs.LOCAL_RANK is set when using env:// launchers (e.g., torchrun)
|
||||
try:
|
||||
return int(envs.LOCAL_RANK) == 0 # type: ignore[arg-type]
|
||||
except Exception:
|
||||
return torch.distributed.get_rank() == 0
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
|
||||
def _node_count(pg: ProcessGroup | StatelessProcessGroup) -> int:
|
||||
"""
|
||||
Returns the total number of nodes in the process group.
|
||||
|
||||
@ -13,7 +13,7 @@ from logging import Logger
|
||||
from logging.config import dictConfig
|
||||
from os import path
|
||||
from types import MethodType
|
||||
from typing import Any, cast
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
import vllm.envs as envs
|
||||
|
||||
@ -59,20 +59,37 @@ DEFAULT_LOGGING_CONFIG = {
|
||||
|
||||
@lru_cache
|
||||
def _print_debug_once(logger: Logger, msg: str, *args: Hashable) -> None:
|
||||
# Set the stacklevel to 2 to print the original caller's line info
|
||||
logger.debug(msg, *args, stacklevel=2)
|
||||
# Set the stacklevel to 3 to print the original caller's line info
|
||||
logger.debug(msg, *args, stacklevel=3)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _print_info_once(logger: Logger, msg: str, *args: Hashable) -> None:
|
||||
# Set the stacklevel to 2 to print the original caller's line info
|
||||
logger.info(msg, *args, stacklevel=2)
|
||||
# Set the stacklevel to 3 to print the original caller's line info
|
||||
logger.info(msg, *args, stacklevel=3)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _print_warning_once(logger: Logger, msg: str, *args: Hashable) -> None:
|
||||
# Set the stacklevel to 2 to print the original caller's line info
|
||||
logger.warning(msg, *args, stacklevel=2)
|
||||
# Set the stacklevel to 3 to print the original caller's line info
|
||||
logger.warning(msg, *args, stacklevel=3)
|
||||
|
||||
|
||||
LogScope = Literal["process", "global", "local"]
|
||||
|
||||
|
||||
def _should_log_with_scope(scope: LogScope) -> bool:
|
||||
"""Decide whether to log based on scope"""
|
||||
if scope == "global":
|
||||
from vllm.distributed.parallel_state import is_global_first_rank
|
||||
|
||||
return is_global_first_rank()
|
||||
if scope == "local":
|
||||
from vllm.distributed.parallel_state import is_local_first_rank
|
||||
|
||||
return is_local_first_rank()
|
||||
# default "process" scope: always log
|
||||
return True
|
||||
|
||||
|
||||
class _VllmLogger(Logger):
|
||||
@ -84,33 +101,43 @@ class _VllmLogger(Logger):
|
||||
`intel_extension_for_pytorch.utils._logger`.
|
||||
"""
|
||||
|
||||
def debug_once(self, msg: str, *args: Hashable) -> None:
|
||||
def debug_once(
|
||||
self, msg: str, *args: Hashable, scope: LogScope = "process"
|
||||
) -> None:
|
||||
"""
|
||||
As [`debug`][logging.Logger.debug], but subsequent calls with
|
||||
the same message are silently dropped.
|
||||
"""
|
||||
if not _should_log_with_scope(scope):
|
||||
return
|
||||
_print_debug_once(self, msg, *args)
|
||||
|
||||
def info_once(self, msg: str, *args: Hashable) -> None:
|
||||
def info_once(self, msg: str, *args: Hashable, scope: LogScope = "process") -> None:
|
||||
"""
|
||||
As [`info`][logging.Logger.info], but subsequent calls with
|
||||
the same message are silently dropped.
|
||||
"""
|
||||
if not _should_log_with_scope(scope):
|
||||
return
|
||||
_print_info_once(self, msg, *args)
|
||||
|
||||
def warning_once(self, msg: str, *args: Hashable) -> None:
|
||||
def warning_once(
|
||||
self, msg: str, *args: Hashable, scope: LogScope = "process"
|
||||
) -> None:
|
||||
"""
|
||||
As [`warning`][logging.Logger.warning], but subsequent calls with
|
||||
the same message are silently dropped.
|
||||
"""
|
||||
if not _should_log_with_scope(scope):
|
||||
return
|
||||
_print_warning_once(self, msg, *args)
|
||||
|
||||
|
||||
# Pre-defined methods mapping to avoid repeated dictionary creation
|
||||
_METHODS_TO_PATCH = {
|
||||
"debug_once": _print_debug_once,
|
||||
"info_once": _print_info_once,
|
||||
"warning_once": _print_warning_once,
|
||||
"debug_once": _VllmLogger.debug_once,
|
||||
"info_once": _VllmLogger.info_once,
|
||||
"warning_once": _VllmLogger.warning_once,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -368,11 +368,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
logger.info_once(
|
||||
"FlashInfer CUTLASS MoE is available for EP"
|
||||
" but not enabled, consider setting"
|
||||
" VLLM_USE_FLASHINFER_MOE_FP16=1 to enable it."
|
||||
" VLLM_USE_FLASHINFER_MOE_FP16=1 to enable it.",
|
||||
scope="local",
|
||||
)
|
||||
elif self.moe.moe_parallel_config.dp_size > 1:
|
||||
logger.info_once(
|
||||
"FlashInfer CUTLASS MoE is currently not available for DP."
|
||||
"FlashInfer CUTLASS MoE is currently not available for DP.",
|
||||
scope="local",
|
||||
)
|
||||
self.flashinfer_cutlass_moe = None # type: ignore
|
||||
|
||||
|
||||
@ -311,9 +311,10 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
loaded_weights = load_weights_and_online_quantize(self, model, model_config)
|
||||
|
||||
self.counter_after_loading_weights = time.perf_counter()
|
||||
logger.info(
|
||||
logger.info_once(
|
||||
"Loading weights took %.2f seconds",
|
||||
self.counter_after_loading_weights - self.counter_before_loading_weights,
|
||||
scope="local",
|
||||
)
|
||||
# We only enable strict check for non-quantized models
|
||||
# that have loaded weights tracking currently.
|
||||
|
||||
@ -416,7 +416,7 @@ def download_weights_from_hf(
|
||||
e,
|
||||
)
|
||||
|
||||
logger.info("Using model weights format %s", allow_patterns)
|
||||
logger.debug("Using model weights format %s", allow_patterns)
|
||||
# Use file lock to prevent multiple processes from
|
||||
# downloading the same model weights at the same time.
|
||||
with get_lock(model_name_or_path, cache_dir):
|
||||
|
||||
@ -222,10 +222,12 @@ def resolve_current_platform_cls_qualname() -> str:
|
||||
)
|
||||
elif len(activated_builtin_plugins) == 1:
|
||||
platform_cls_qualname = builtin_platform_plugins[activated_builtin_plugins[0]]()
|
||||
logger.info("Automatically detected platform %s.", activated_builtin_plugins[0])
|
||||
logger.debug(
|
||||
"Automatically detected platform %s.", activated_builtin_plugins[0]
|
||||
)
|
||||
else:
|
||||
platform_cls_qualname = "vllm.platforms.interface.UnspecifiedPlatform"
|
||||
logger.info("No platform detected, vLLM is running on UnspecifiedPlatform")
|
||||
logger.debug("No platform detected, vLLM is running on UnspecifiedPlatform")
|
||||
return platform_cls_qualname
|
||||
|
||||
|
||||
|
||||
@ -298,7 +298,9 @@ class CudaPlatformBase(Platform):
|
||||
)
|
||||
|
||||
if use_cutlassmla:
|
||||
logger.info_once("Using Cutlass MLA backend on V1 engine.")
|
||||
logger.info_once(
|
||||
"Using Cutlass MLA backend on V1 engine.", scope="local"
|
||||
)
|
||||
return "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend"
|
||||
if use_flashinfermla:
|
||||
from vllm.v1.attention.backends.utils import set_kv_cache_layout
|
||||
|
||||
@ -37,7 +37,7 @@ class GCDebugConfig:
|
||||
except Exception:
|
||||
self.enabled = False
|
||||
logger.error("Failed to parse VLLM_GC_DEBUG(%s)", envs.VLLM_GC_DEBUG)
|
||||
logger.info("GC Debug Config. %s", str(self))
|
||||
logger.debug("GC Debug Config. %s", str(self))
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"enabled:{self.enabled},top_objects:{self.top_objects}"
|
||||
|
||||
@ -1226,7 +1226,7 @@ def _report_kv_cache_config(
|
||||
vllm_config.parallel_config.decode_context_parallel_size,
|
||||
)
|
||||
num_tokens_str = f"{num_tokens:,}"
|
||||
logger.info("GPU KV cache size: %s tokens", num_tokens_str)
|
||||
logger.info_once("GPU KV cache size: %s tokens", num_tokens_str, scope="local")
|
||||
max_model_len_str = f"{vllm_config.model_config.max_model_len:,}"
|
||||
max_concurrency = get_max_concurrency_for_kv_cache_config(
|
||||
vllm_config, kv_cache_config
|
||||
|
||||
@ -19,7 +19,6 @@ import zmq
|
||||
|
||||
from vllm.config import ParallelConfig, VllmConfig
|
||||
from vllm.distributed import stateless_destroy_torch_distributed_process_group
|
||||
from vllm.distributed.parallel_state import is_global_first_rank
|
||||
from vllm.envs import enable_envs_cache
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logging_utils.dump_input import dump_engine_exception
|
||||
@ -90,7 +89,7 @@ class EngineCore:
|
||||
load_general_plugins()
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
if is_global_first_rank():
|
||||
if vllm_config.parallel_config.data_parallel_rank == 0:
|
||||
logger.info(
|
||||
"Initializing a V1 LLM engine (v%s) with config: %s",
|
||||
VLLM_VERSION,
|
||||
@ -235,9 +234,10 @@ class EngineCore:
|
||||
self.model_executor.initialize_from_config(kv_cache_configs)
|
||||
|
||||
elapsed = time.time() - start
|
||||
logger.info(
|
||||
logger.info_once(
|
||||
("init engine (profile, create kv cache, warmup model) took %.2f seconds"),
|
||||
elapsed,
|
||||
scope="local",
|
||||
)
|
||||
return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
|
||||
|
||||
@ -713,7 +713,7 @@ class EngineCoreProc(EngineCore):
|
||||
)
|
||||
|
||||
# Receive initialization message.
|
||||
logger.info("Waiting for init message from front-end.")
|
||||
logger.debug("Waiting for init message from front-end.")
|
||||
if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000):
|
||||
raise RuntimeError(
|
||||
"Did not receive response from front-end "
|
||||
|
||||
@ -215,7 +215,7 @@ class LoggingStatLogger(StatLoggerBase):
|
||||
|
||||
def log_engine_initialized(self):
|
||||
if self.vllm_config.cache_config.num_gpu_blocks:
|
||||
logger.info(
|
||||
logger.debug(
|
||||
"Engine %03d: vllm cache_config_info with initialization "
|
||||
"after num_gpu_blocks is: %d",
|
||||
self.engine_index,
|
||||
|
||||
@ -33,7 +33,10 @@ class TopKTopPSampler(nn.Module):
|
||||
):
|
||||
if envs.VLLM_USE_FLASHINFER_SAMPLER:
|
||||
# Users must opt in explicitly via VLLM_USE_FLASHINFER_SAMPLER=1.
|
||||
logger.info_once("Using FlashInfer for top-p & top-k sampling.")
|
||||
logger.info_once(
|
||||
"Using FlashInfer for top-p & top-k sampling.",
|
||||
scope="global",
|
||||
)
|
||||
self.forward = self.forward_cuda
|
||||
else:
|
||||
logger.debug_once(
|
||||
|
||||
@ -6,7 +6,7 @@ import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.distributed.parallel_state import get_dp_group, is_global_first_rank
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.worker.ubatch_utils import (
|
||||
UBatchSlices,
|
||||
@ -132,12 +132,12 @@ def _synchronize_dp_ranks(
|
||||
should_ubatch = _post_process_ubatch(tensor)
|
||||
|
||||
if should_ubatch and not should_dp_pad:
|
||||
if is_global_first_rank():
|
||||
logger.debug(
|
||||
"Microbatching has been triggered and requires DP padding. "
|
||||
"Enabling DP padding even though it has been explicitly "
|
||||
"disabled."
|
||||
)
|
||||
logger.debug_once(
|
||||
"Microbatching has been triggered and requires DP padding. "
|
||||
"Enabling DP padding even though it has been explicitly "
|
||||
"disabled.",
|
||||
scope="global",
|
||||
)
|
||||
should_dp_pad = True
|
||||
|
||||
# Pad all DP ranks up to the maximum token count across ranks if
|
||||
|
||||
@ -2850,7 +2850,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
Args:
|
||||
eep_scale_up: the model loading is for elastic EP scale up.
|
||||
"""
|
||||
logger.info("Starting to load model %s...", self.model_config.model)
|
||||
logger.info_once(
|
||||
"Starting to load model %s...",
|
||||
self.model_config.model,
|
||||
scope="global",
|
||||
)
|
||||
if eep_scale_up:
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
|
||||
@ -2911,10 +2915,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.model.set_aux_hidden_state_layers(aux_layers)
|
||||
time_after_load = time.perf_counter()
|
||||
self.model_memory_usage = m.consumed_memory
|
||||
logger.info(
|
||||
logger.info_once(
|
||||
"Model loading took %.4f GiB and %.6f seconds",
|
||||
self.model_memory_usage / GiB_bytes,
|
||||
time_after_load - time_before_load,
|
||||
scope="local",
|
||||
)
|
||||
prepare_communication_buffer_for_model(self.model)
|
||||
|
||||
@ -3838,10 +3843,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
elapsed_time = end_time - start_time
|
||||
cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
|
||||
# This usually takes 5~20 seconds.
|
||||
logger.info(
|
||||
logger.info_once(
|
||||
"Graph capturing finished in %.0f secs, took %.2f GiB",
|
||||
elapsed_time,
|
||||
cuda_graph_size / (1 << 30),
|
||||
scope="local",
|
||||
)
|
||||
return cuda_graph_size
|
||||
|
||||
|
||||
@ -20,7 +20,10 @@ from vllm.distributed import (
|
||||
set_custom_all_reduce,
|
||||
)
|
||||
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
|
||||
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_pp_group,
|
||||
get_tp_group,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import set_random_seed
|
||||
@ -334,9 +337,10 @@ class Worker(WorkerBase):
|
||||
GiB(free_gpu_memory - unrequested_memory),
|
||||
)
|
||||
logger.debug(profile_result)
|
||||
logger.info(
|
||||
logger.info_once(
|
||||
"Available KV cache memory: %.2f GiB",
|
||||
GiB(self.available_kv_cache_memory_bytes),
|
||||
scope="local",
|
||||
)
|
||||
gc.collect()
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user