save and retrieve compile settings when re-compiling

This commit is contained in:
kijai 2024-11-05 16:10:27 +02:00
parent 365e0699b1
commit cdf8ca8298

View File

@ -1,11 +1,10 @@
import torch
import numpy as np
from PIL import Image
import json, re, os, io, time
from typing import Union
import json, re, os, io, time, platform
import re
import importlib
from contextlib import contextmanager
import model_management
import folder_paths
@ -2300,6 +2299,8 @@ def patched_load_lora_for_models(model, clip, lora, strength_model, strength_cli
print("NOT LOADED {}".format(x))
if patch_keys:
compile_settings = getattr(model.model, "compile_settings")
print("compile_settings: ", compile_settings)
for k in patch_keys:
if "diffusion_model." in k:
# Remove the prefix to get the attribute path
@ -2314,12 +2315,36 @@ def patched_load_lora_for_models(model, clip, lora, strength_model, strength_cli
else:
block = getattr(block, attr)
# Compile the block
compiled_block = torch.compile(block, mode="default", fullgraph=False, backend="inductor")
compiled_block = torch.compile(block, mode=compile_settings["mode"], dynamic=compile_settings["dynamic"], fullgraph=compile_settings["fullgraph"], backend=compile_settings["backend"])
# Add the compiled block back as an object patch
model.add_object_patch(k, compiled_block)
return (new_modelpatcher, new_clip)
def patched_write_atomic(
path_: str,
content: Union[str, bytes],
make_dirs: bool = False,
encode_utf_8: bool = False,
) -> None:
# Write into temporary file first to avoid conflicts between threads
# Avoid using a named temporary file, as those have restricted permissions
from pathlib import Path
import os
import shutil
import threading
assert isinstance(
content, (str, bytes)
), "Only strings and byte arrays can be saved in the cache"
path = Path(path_)
if make_dirs:
path.parent.mkdir(parents=True, exist_ok=True)
tmp_path = path.parent / f".{os.getpid()}.{threading.get_ident()}.tmp"
write_mode = "w" if isinstance(content, str) else "wb"
with tmp_path.open(write_mode, encoding="utf-8" if encode_utf_8 else None) as f:
f.write(content)
shutil.copy2(src=tmp_path, dst=path) #changed to allow overwriting cache files
os.remove(tmp_path)
class PatchModelPatcherOrder:
@classmethod
def INPUT_TYPES(s):
@ -2357,6 +2382,7 @@ class TorchCompileModelFluxAdvanced:
"mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
"double_blocks": ("STRING", {"default": "0-18", "multiline": True}),
"single_blocks": ("STRING", {"default": "0-37", "multiline": True}),
"dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
@ -2375,7 +2401,13 @@ class TorchCompileModelFluxAdvanced:
blocks.append(int(part))
return blocks
def patch(self, model, backend, mode, fullgraph, single_blocks, double_blocks):
def patch(self, model, backend, mode, fullgraph, single_blocks, double_blocks, dynamic):
if platform.system() == 'Windows':
try:
import torch._inductor.codecache
torch._inductor.codecache.write_atomic = patched_write_atomic #temporary workaround for the cache write bug in Windows
except:
pass
single_block_list = self.parse_blocks(single_blocks)
double_block_list = self.parse_blocks(double_blocks)
m = model.clone()
@ -2386,12 +2418,19 @@ class TorchCompileModelFluxAdvanced:
for i, block in enumerate(diffusion_model.double_blocks):
if i in double_block_list:
#print("Compiling double_block", i)
m.add_object_patch(f"diffusion_model.double_blocks.{i}", torch.compile(block, mode=mode, fullgraph=fullgraph, backend=backend))
m.add_object_patch(f"diffusion_model.double_blocks.{i}", torch.compile(block, mode=mode, dynamic=dynamic, fullgraph=fullgraph, backend=backend))
for i, block in enumerate(diffusion_model.single_blocks):
if i in single_block_list:
#print("Compiling single block", i)
m.add_object_patch(f"diffusion_model.single_blocks.{i}", torch.compile(block, mode=mode, fullgraph=fullgraph, backend=backend))
m.add_object_patch(f"diffusion_model.single_blocks.{i}", torch.compile(block, mode=mode, dynamic=dynamic, fullgraph=fullgraph, backend=backend))
self._compiled = True
compile_settings = {
"backend": backend,
"mode": mode,
"fullgraph": fullgraph,
"dynamic": dynamic,
}
setattr(m.model, "compile_settings", compile_settings)
except:
raise RuntimeError("Failed to compile model")