save and retrieve compile settings when re-compiling

This commit is contained in:
kijai 2024-11-05 16:10:27 +02:00
parent 365e0699b1
commit cdf8ca8298

View File

@ -1,11 +1,10 @@
import torch import torch
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from typing import Union
import json, re, os, io, time import json, re, os, io, time, platform
import re import re
import importlib import importlib
from contextlib import contextmanager
import model_management import model_management
import folder_paths import folder_paths
@ -2300,6 +2299,8 @@ def patched_load_lora_for_models(model, clip, lora, strength_model, strength_cli
print("NOT LOADED {}".format(x)) print("NOT LOADED {}".format(x))
if patch_keys: if patch_keys:
compile_settings = getattr(model.model, "compile_settings")
print("compile_settings: ", compile_settings)
for k in patch_keys: for k in patch_keys:
if "diffusion_model." in k: if "diffusion_model." in k:
# Remove the prefix to get the attribute path # Remove the prefix to get the attribute path
@ -2314,12 +2315,36 @@ def patched_load_lora_for_models(model, clip, lora, strength_model, strength_cli
else: else:
block = getattr(block, attr) block = getattr(block, attr)
# Compile the block # Compile the block
compiled_block = torch.compile(block, mode="default", fullgraph=False, backend="inductor") compiled_block = torch.compile(block, mode=compile_settings["mode"], dynamic=compile_settings["dynamic"], fullgraph=compile_settings["fullgraph"], backend=compile_settings["backend"])
# Add the compiled block back as an object patch # Add the compiled block back as an object patch
model.add_object_patch(k, compiled_block) model.add_object_patch(k, compiled_block)
return (new_modelpatcher, new_clip) return (new_modelpatcher, new_clip)
def patched_write_atomic(
path_: str,
content: Union[str, bytes],
make_dirs: bool = False,
encode_utf_8: bool = False,
) -> None:
# Write into temporary file first to avoid conflicts between threads
# Avoid using a named temporary file, as those have restricted permissions
from pathlib import Path
import os
import shutil
import threading
assert isinstance(
content, (str, bytes)
), "Only strings and byte arrays can be saved in the cache"
path = Path(path_)
if make_dirs:
path.parent.mkdir(parents=True, exist_ok=True)
tmp_path = path.parent / f".{os.getpid()}.{threading.get_ident()}.tmp"
write_mode = "w" if isinstance(content, str) else "wb"
with tmp_path.open(write_mode, encoding="utf-8" if encode_utf_8 else None) as f:
f.write(content)
shutil.copy2(src=tmp_path, dst=path) #changed to allow overwriting cache files
os.remove(tmp_path)
class PatchModelPatcherOrder: class PatchModelPatcherOrder:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -2357,6 +2382,7 @@ class TorchCompileModelFluxAdvanced:
"mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}), "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
"double_blocks": ("STRING", {"default": "0-18", "multiline": True}), "double_blocks": ("STRING", {"default": "0-18", "multiline": True}),
"single_blocks": ("STRING", {"default": "0-37", "multiline": True}), "single_blocks": ("STRING", {"default": "0-37", "multiline": True}),
"dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}),
}} }}
RETURN_TYPES = ("MODEL",) RETURN_TYPES = ("MODEL",)
FUNCTION = "patch" FUNCTION = "patch"
@ -2375,7 +2401,13 @@ class TorchCompileModelFluxAdvanced:
blocks.append(int(part)) blocks.append(int(part))
return blocks return blocks
def patch(self, model, backend, mode, fullgraph, single_blocks, double_blocks): def patch(self, model, backend, mode, fullgraph, single_blocks, double_blocks, dynamic):
if platform.system() == 'Windows':
try:
import torch._inductor.codecache
torch._inductor.codecache.write_atomic = patched_write_atomic #temporary workaround for the cache write bug in Windows
except:
pass
single_block_list = self.parse_blocks(single_blocks) single_block_list = self.parse_blocks(single_blocks)
double_block_list = self.parse_blocks(double_blocks) double_block_list = self.parse_blocks(double_blocks)
m = model.clone() m = model.clone()
@ -2386,12 +2418,19 @@ class TorchCompileModelFluxAdvanced:
for i, block in enumerate(diffusion_model.double_blocks): for i, block in enumerate(diffusion_model.double_blocks):
if i in double_block_list: if i in double_block_list:
#print("Compiling double_block", i) #print("Compiling double_block", i)
m.add_object_patch(f"diffusion_model.double_blocks.{i}", torch.compile(block, mode=mode, fullgraph=fullgraph, backend=backend)) m.add_object_patch(f"diffusion_model.double_blocks.{i}", torch.compile(block, mode=mode, dynamic=dynamic, fullgraph=fullgraph, backend=backend))
for i, block in enumerate(diffusion_model.single_blocks): for i, block in enumerate(diffusion_model.single_blocks):
if i in single_block_list: if i in single_block_list:
#print("Compiling single block", i) #print("Compiling single block", i)
m.add_object_patch(f"diffusion_model.single_blocks.{i}", torch.compile(block, mode=mode, fullgraph=fullgraph, backend=backend)) m.add_object_patch(f"diffusion_model.single_blocks.{i}", torch.compile(block, mode=mode, dynamic=dynamic, fullgraph=fullgraph, backend=backend))
self._compiled = True self._compiled = True
compile_settings = {
"backend": backend,
"mode": mode,
"fullgraph": fullgraph,
"dynamic": dynamic,
}
setattr(m.model, "compile_settings", compile_settings)
except: except:
raise RuntimeError("Failed to compile model") raise RuntimeError("Failed to compile model")