mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-08 20:34:23 +08:00
possible compile fixes
This commit is contained in:
parent
f21432bea1
commit
4374273138
@ -1,9 +1,41 @@
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import json
|
||||
import folder_paths
|
||||
import comfy.model_management as mm
|
||||
from typing import Union
|
||||
|
||||
def patched_write_atomic(
|
||||
path_: str,
|
||||
content: Union[str, bytes],
|
||||
make_dirs: bool = False,
|
||||
encode_utf_8: bool = False,
|
||||
) -> None:
|
||||
# Write into temporary file first to avoid conflicts between threads
|
||||
# Avoid using a named temporary file, as those have restricted permissions
|
||||
from pathlib import Path
|
||||
import os
|
||||
import shutil
|
||||
import threading
|
||||
assert isinstance(
|
||||
content, (str, bytes)
|
||||
), "Only strings and byte arrays can be saved in the cache"
|
||||
path = Path(path_)
|
||||
if make_dirs:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp_path = path.parent / f".{os.getpid()}.{threading.get_ident()}.tmp"
|
||||
write_mode = "w" if isinstance(content, str) else "wb"
|
||||
with tmp_path.open(write_mode, encoding="utf-8" if encode_utf_8 else None) as f:
|
||||
f.write(content)
|
||||
shutil.copy2(src=tmp_path, dst=path) #changed to allow overwriting cache files
|
||||
os.remove(tmp_path)
|
||||
try:
|
||||
import torch._inductor.codecache
|
||||
torch._inductor.codecache.write_atomic = patched_write_atomic
|
||||
except:
|
||||
pass
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from diffusers.models import AutoencoderKLCogVideoX
|
||||
from diffusers.schedulers import CogVideoXDDIMScheduler
|
||||
|
||||
11
nodes.py
11
nodes.py
@ -5,7 +5,7 @@ import comfy.model_management as mm
|
||||
from einops import rearrange
|
||||
from contextlib import nullcontext
|
||||
|
||||
from .utils import log, check_diffusers_version
|
||||
from .utils import log, check_diffusers_version, print_memory
|
||||
check_diffusers_version()
|
||||
from diffusers.schedulers import (
|
||||
CogVideoXDDIMScheduler,
|
||||
@ -864,6 +864,10 @@ class CogVideoSampler:
|
||||
|
||||
# if sigmas is not None:
|
||||
# sigma_list = sigmas.tolist()
|
||||
try:
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
except:
|
||||
pass
|
||||
|
||||
autocastcondition = not pipeline["onediff"] or not dtype == torch.float32
|
||||
autocast_context = torch.autocast(mm.get_autocast_device(device), dtype=dtype) if autocastcondition else nullcontext()
|
||||
@ -899,7 +903,12 @@ class CogVideoSampler:
|
||||
block.cached_hidden_states = None
|
||||
block.cached_encoder_hidden_states = None
|
||||
|
||||
print_memory(device)
|
||||
mm.soft_empty_cache()
|
||||
try:
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
except:
|
||||
pass
|
||||
|
||||
return (pipeline, {"samples": latents})
|
||||
|
||||
|
||||
10
utils.py
10
utils.py
@ -1,5 +1,5 @@
|
||||
import importlib.metadata
|
||||
|
||||
import torch
|
||||
import logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
log = logging.getLogger(__name__)
|
||||
@ -20,3 +20,11 @@ def remove_specific_blocks(model, block_indices_to_remove):
|
||||
model.transformer_blocks = nn.ModuleList(new_blocks)
|
||||
|
||||
return model
|
||||
|
||||
def print_memory(device):
|
||||
memory = torch.cuda.memory_allocated(device) / 1024**3
|
||||
max_memory = torch.cuda.max_memory_allocated(device) / 1024**3
|
||||
max_reserved = torch.cuda.max_memory_reserved(device) / 1024**3
|
||||
log.info(f"Allocated memory: {memory=:.3f} GB")
|
||||
log.info(f"Max allocated memory: {max_memory=:.3f} GB")
|
||||
log.info(f"Max reserved memory: {max_reserved=:.3f} GB")
|
||||
Loading…
x
Reference in New Issue
Block a user