possible compile fixes

This commit is contained in:
kijai 2024-11-16 22:18:12 +02:00
parent f21432bea1
commit 4374273138
3 changed files with 55 additions and 6 deletions

View File

@ -1,9 +1,41 @@
import os
import torch
import torch.nn as nn
import json
import folder_paths
import comfy.model_management as mm
from typing import Union
def patched_write_atomic(
path_: str,
content: Union[str, bytes],
make_dirs: bool = False,
encode_utf_8: bool = False,
) -> None:
# Write into temporary file first to avoid conflicts between threads
# Avoid using a named temporary file, as those have restricted permissions
from pathlib import Path
import os
import shutil
import threading
assert isinstance(
content, (str, bytes)
), "Only strings and byte arrays can be saved in the cache"
path = Path(path_)
if make_dirs:
path.parent.mkdir(parents=True, exist_ok=True)
tmp_path = path.parent / f".{os.getpid()}.{threading.get_ident()}.tmp"
write_mode = "w" if isinstance(content, str) else "wb"
with tmp_path.open(write_mode, encoding="utf-8" if encode_utf_8 else None) as f:
f.write(content)
shutil.copy2(src=tmp_path, dst=path) #changed to allow overwriting cache files
os.remove(tmp_path)
try:
import torch._inductor.codecache
torch._inductor.codecache.write_atomic = patched_write_atomic
except:
pass
import torch
import torch.nn as nn
from diffusers.models import AutoencoderKLCogVideoX
from diffusers.schedulers import CogVideoXDDIMScheduler

View File

@ -5,7 +5,7 @@ import comfy.model_management as mm
from einops import rearrange
from contextlib import nullcontext
from .utils import log, check_diffusers_version
from .utils import log, check_diffusers_version, print_memory
check_diffusers_version()
from diffusers.schedulers import (
CogVideoXDDIMScheduler,
@ -864,6 +864,10 @@ class CogVideoSampler:
# if sigmas is not None:
# sigma_list = sigmas.tolist()
try:
torch.cuda.reset_peak_memory_stats(device)
except:
pass
autocastcondition = not pipeline["onediff"] or not dtype == torch.float32
autocast_context = torch.autocast(mm.get_autocast_device(device), dtype=dtype) if autocastcondition else nullcontext()
@ -899,7 +903,12 @@ class CogVideoSampler:
block.cached_hidden_states = None
block.cached_encoder_hidden_states = None
print_memory(device)
mm.soft_empty_cache()
try:
torch.cuda.reset_peak_memory_stats(device)
except:
pass
return (pipeline, {"samples": latents})

View File

@ -1,5 +1,5 @@
import importlib.metadata
import torch
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
log = logging.getLogger(__name__)
@ -20,3 +20,11 @@ def remove_specific_blocks(model, block_indices_to_remove):
model.transformer_blocks = nn.ModuleList(new_blocks)
return model
def print_memory(device):
memory = torch.cuda.memory_allocated(device) / 1024**3
max_memory = torch.cuda.max_memory_allocated(device) / 1024**3
max_reserved = torch.cuda.max_memory_reserved(device) / 1024**3
log.info(f"Allocated memory: {memory=:.3f} GB")
log.info(f"Max allocated memory: {max_memory=:.3f} GB")
log.info(f"Max reserved memory: {max_reserved=:.3f} GB")