From 43742731381348c2b84fee029b69422597180784 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sat, 16 Nov 2024 22:18:12 +0200 Subject: [PATCH] possible compile fixes --- model_loading.py | 36 ++++++++++++++++++++++++++++++++++-- nodes.py | 13 +++++++++++-- utils.py | 12 ++++++++++-- 3 files changed, 55 insertions(+), 6 deletions(-) diff --git a/model_loading.py b/model_loading.py index 387de3f..d1482d3 100644 --- a/model_loading.py +++ b/model_loading.py @@ -1,9 +1,41 @@ import os -import torch -import torch.nn as nn import json import folder_paths import comfy.model_management as mm +from typing import Union + +def patched_write_atomic( + path_: str, + content: Union[str, bytes], + make_dirs: bool = False, + encode_utf_8: bool = False, +) -> None: + # Write into temporary file first to avoid conflicts between threads + # Avoid using a named temporary file, as those have restricted permissions + from pathlib import Path + import os + import shutil + import threading + assert isinstance( + content, (str, bytes) + ), "Only strings and byte arrays can be saved in the cache" + path = Path(path_) + if make_dirs: + path.parent.mkdir(parents=True, exist_ok=True) + tmp_path = path.parent / f".{os.getpid()}.{threading.get_ident()}.tmp" + write_mode = "w" if isinstance(content, str) else "wb" + with tmp_path.open(write_mode, encoding="utf-8" if encode_utf_8 else None) as f: + f.write(content) + shutil.copy2(src=tmp_path, dst=path) #changed to allow overwriting cache files + os.remove(tmp_path) +try: + import torch._inductor.codecache + torch._inductor.codecache.write_atomic = patched_write_atomic +except: + pass + +import torch +import torch.nn as nn from diffusers.models import AutoencoderKLCogVideoX from diffusers.schedulers import CogVideoXDDIMScheduler diff --git a/nodes.py b/nodes.py index 6350e98..dd9589a 100644 --- a/nodes.py +++ b/nodes.py @@ -5,7 +5,7 @@ import comfy.model_management as mm from einops import rearrange from contextlib import nullcontext -from .utils import log, check_diffusers_version +from .utils import log, check_diffusers_version, print_memory check_diffusers_version() from diffusers.schedulers import ( CogVideoXDDIMScheduler, @@ -864,6 +864,10 @@ class CogVideoSampler: # if sigmas is not None: # sigma_list = sigmas.tolist() + try: + torch.cuda.reset_peak_memory_stats(device) + except: + pass autocastcondition = not pipeline["onediff"] or not dtype == torch.float32 autocast_context = torch.autocast(mm.get_autocast_device(device), dtype=dtype) if autocastcondition else nullcontext() @@ -898,8 +902,13 @@ class CogVideoSampler: if (hasattr, block, "cached_hidden_states") and block.cached_hidden_states is not None: block.cached_hidden_states = None block.cached_encoder_hidden_states = None - + + print_memory(device) mm.soft_empty_cache() + try: + torch.cuda.reset_peak_memory_stats(device) + except: + pass return (pipeline, {"samples": latents}) diff --git a/utils.py b/utils.py index e3c6fd4..d667097 100644 --- a/utils.py +++ b/utils.py @@ -1,5 +1,5 @@ import importlib.metadata - +import torch import logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') log = logging.getLogger(__name__) @@ -19,4 +19,12 @@ def remove_specific_blocks(model, block_indices_to_remove): new_blocks = [block for i, block in enumerate(transformer_blocks) if i not in block_indices_to_remove] model.transformer_blocks = nn.ModuleList(new_blocks) - return model \ No newline at end of file + return model + +def print_memory(device): + memory = torch.cuda.memory_allocated(device) / 1024**3 + max_memory = torch.cuda.max_memory_allocated(device) / 1024**3 + max_reserved = torch.cuda.max_memory_reserved(device) / 1024**3 + log.info(f"Allocated memory: {memory=:.3f} GB") + log.info(f"Max allocated memory: {max_memory=:.3f} GB") + log.info(f"Max reserved memory: {max_reserved=:.3f} GB") \ No newline at end of file