diff --git a/nodes/nodes.py b/nodes/nodes.py index e951bd9..764d7dc 100644 --- a/nodes/nodes.py +++ b/nodes/nodes.py @@ -1,11 +1,10 @@ import torch import numpy as np from PIL import Image - -import json, re, os, io, time +from typing import Union +import json, re, os, io, time, platform import re import importlib -from contextlib import contextmanager import model_management import folder_paths @@ -2300,6 +2299,8 @@ def patched_load_lora_for_models(model, clip, lora, strength_model, strength_cli print("NOT LOADED {}".format(x)) if patch_keys: + compile_settings = getattr(model.model, "compile_settings") + print("compile_settings: ", compile_settings) for k in patch_keys: if "diffusion_model." in k: # Remove the prefix to get the attribute path @@ -2314,12 +2315,36 @@ def patched_load_lora_for_models(model, clip, lora, strength_model, strength_cli else: block = getattr(block, attr) # Compile the block - compiled_block = torch.compile(block, mode="default", fullgraph=False, backend="inductor") + compiled_block = torch.compile(block, mode=compile_settings["mode"], dynamic=compile_settings["dynamic"], fullgraph=compile_settings["fullgraph"], backend=compile_settings["backend"]) # Add the compiled block back as an object patch model.add_object_patch(k, compiled_block) return (new_modelpatcher, new_clip) - +def patched_write_atomic( + path_: str, + content: Union[str, bytes], + make_dirs: bool = False, + encode_utf_8: bool = False, +) -> None: + # Write into temporary file first to avoid conflicts between threads + # Avoid using a named temporary file, as those have restricted permissions + from pathlib import Path + import os + import shutil + import threading + assert isinstance( + content, (str, bytes) + ), "Only strings and byte arrays can be saved in the cache" + path = Path(path_) + if make_dirs: + path.parent.mkdir(parents=True, exist_ok=True) + tmp_path = path.parent / f".{os.getpid()}.{threading.get_ident()}.tmp" + write_mode = "w" if isinstance(content, str) else "wb" + with tmp_path.open(write_mode, encoding="utf-8" if encode_utf_8 else None) as f: + f.write(content) + shutil.copy2(src=tmp_path, dst=path) #changed to allow overwriting cache files + os.remove(tmp_path) + class PatchModelPatcherOrder: @classmethod def INPUT_TYPES(s): @@ -2357,6 +2382,7 @@ class TorchCompileModelFluxAdvanced: "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}), "double_blocks": ("STRING", {"default": "0-18", "multiline": True}), "single_blocks": ("STRING", {"default": "0-37", "multiline": True}), + "dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}), }} RETURN_TYPES = ("MODEL",) FUNCTION = "patch" @@ -2375,7 +2401,13 @@ class TorchCompileModelFluxAdvanced: blocks.append(int(part)) return blocks - def patch(self, model, backend, mode, fullgraph, single_blocks, double_blocks): + def patch(self, model, backend, mode, fullgraph, single_blocks, double_blocks, dynamic): + if platform.system() == 'Windows': + try: + import torch._inductor.codecache + torch._inductor.codecache.write_atomic = patched_write_atomic #temporary workaround for the cache write bug in Windows + except: + pass single_block_list = self.parse_blocks(single_blocks) double_block_list = self.parse_blocks(double_blocks) m = model.clone() @@ -2386,12 +2418,19 @@ class TorchCompileModelFluxAdvanced: for i, block in enumerate(diffusion_model.double_blocks): if i in double_block_list: #print("Compiling double_block", i) - m.add_object_patch(f"diffusion_model.double_blocks.{i}", torch.compile(block, mode=mode, fullgraph=fullgraph, backend=backend)) + m.add_object_patch(f"diffusion_model.double_blocks.{i}", torch.compile(block, mode=mode, dynamic=dynamic, fullgraph=fullgraph, backend=backend)) for i, block in enumerate(diffusion_model.single_blocks): if i in single_block_list: #print("Compiling single block", i) - m.add_object_patch(f"diffusion_model.single_blocks.{i}", torch.compile(block, mode=mode, fullgraph=fullgraph, backend=backend)) + m.add_object_patch(f"diffusion_model.single_blocks.{i}", torch.compile(block, mode=mode, dynamic=dynamic, fullgraph=fullgraph, backend=backend)) self._compiled = True + compile_settings = { + "backend": backend, + "mode": mode, + "fullgraph": fullgraph, + "dynamic": dynamic, + } + setattr(m.model, "compile_settings", compile_settings) except: raise RuntimeError("Failed to compile model")