mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-25 21:50:59 +08:00
save and retrieve compile settings when re-compiling
This commit is contained in:
parent
365e0699b1
commit
cdf8ca8298
@ -1,11 +1,10 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
import json, re, os, io, time
|
||||
from typing import Union
|
||||
import json, re, os, io, time, platform
|
||||
import re
|
||||
import importlib
|
||||
from contextlib import contextmanager
|
||||
|
||||
import model_management
|
||||
import folder_paths
|
||||
@ -2300,6 +2299,8 @@ def patched_load_lora_for_models(model, clip, lora, strength_model, strength_cli
|
||||
print("NOT LOADED {}".format(x))
|
||||
|
||||
if patch_keys:
|
||||
compile_settings = getattr(model.model, "compile_settings")
|
||||
print("compile_settings: ", compile_settings)
|
||||
for k in patch_keys:
|
||||
if "diffusion_model." in k:
|
||||
# Remove the prefix to get the attribute path
|
||||
@ -2314,12 +2315,36 @@ def patched_load_lora_for_models(model, clip, lora, strength_model, strength_cli
|
||||
else:
|
||||
block = getattr(block, attr)
|
||||
# Compile the block
|
||||
compiled_block = torch.compile(block, mode="default", fullgraph=False, backend="inductor")
|
||||
compiled_block = torch.compile(block, mode=compile_settings["mode"], dynamic=compile_settings["dynamic"], fullgraph=compile_settings["fullgraph"], backend=compile_settings["backend"])
|
||||
# Add the compiled block back as an object patch
|
||||
model.add_object_patch(k, compiled_block)
|
||||
return (new_modelpatcher, new_clip)
|
||||
|
||||
|
||||
def patched_write_atomic(
|
||||
path_: str,
|
||||
content: Union[str, bytes],
|
||||
make_dirs: bool = False,
|
||||
encode_utf_8: bool = False,
|
||||
) -> None:
|
||||
# Write into temporary file first to avoid conflicts between threads
|
||||
# Avoid using a named temporary file, as those have restricted permissions
|
||||
from pathlib import Path
|
||||
import os
|
||||
import shutil
|
||||
import threading
|
||||
assert isinstance(
|
||||
content, (str, bytes)
|
||||
), "Only strings and byte arrays can be saved in the cache"
|
||||
path = Path(path_)
|
||||
if make_dirs:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp_path = path.parent / f".{os.getpid()}.{threading.get_ident()}.tmp"
|
||||
write_mode = "w" if isinstance(content, str) else "wb"
|
||||
with tmp_path.open(write_mode, encoding="utf-8" if encode_utf_8 else None) as f:
|
||||
f.write(content)
|
||||
shutil.copy2(src=tmp_path, dst=path) #changed to allow overwriting cache files
|
||||
os.remove(tmp_path)
|
||||
|
||||
class PatchModelPatcherOrder:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -2357,6 +2382,7 @@ class TorchCompileModelFluxAdvanced:
|
||||
"mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
|
||||
"double_blocks": ("STRING", {"default": "0-18", "multiline": True}),
|
||||
"single_blocks": ("STRING", {"default": "0-37", "multiline": True}),
|
||||
"dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
@ -2375,7 +2401,13 @@ class TorchCompileModelFluxAdvanced:
|
||||
blocks.append(int(part))
|
||||
return blocks
|
||||
|
||||
def patch(self, model, backend, mode, fullgraph, single_blocks, double_blocks):
|
||||
def patch(self, model, backend, mode, fullgraph, single_blocks, double_blocks, dynamic):
|
||||
if platform.system() == 'Windows':
|
||||
try:
|
||||
import torch._inductor.codecache
|
||||
torch._inductor.codecache.write_atomic = patched_write_atomic #temporary workaround for the cache write bug in Windows
|
||||
except:
|
||||
pass
|
||||
single_block_list = self.parse_blocks(single_blocks)
|
||||
double_block_list = self.parse_blocks(double_blocks)
|
||||
m = model.clone()
|
||||
@ -2386,12 +2418,19 @@ class TorchCompileModelFluxAdvanced:
|
||||
for i, block in enumerate(diffusion_model.double_blocks):
|
||||
if i in double_block_list:
|
||||
#print("Compiling double_block", i)
|
||||
m.add_object_patch(f"diffusion_model.double_blocks.{i}", torch.compile(block, mode=mode, fullgraph=fullgraph, backend=backend))
|
||||
m.add_object_patch(f"diffusion_model.double_blocks.{i}", torch.compile(block, mode=mode, dynamic=dynamic, fullgraph=fullgraph, backend=backend))
|
||||
for i, block in enumerate(diffusion_model.single_blocks):
|
||||
if i in single_block_list:
|
||||
#print("Compiling single block", i)
|
||||
m.add_object_patch(f"diffusion_model.single_blocks.{i}", torch.compile(block, mode=mode, fullgraph=fullgraph, backend=backend))
|
||||
m.add_object_patch(f"diffusion_model.single_blocks.{i}", torch.compile(block, mode=mode, dynamic=dynamic, fullgraph=fullgraph, backend=backend))
|
||||
self._compiled = True
|
||||
compile_settings = {
|
||||
"backend": backend,
|
||||
"mode": mode,
|
||||
"fullgraph": fullgraph,
|
||||
"dynamic": dynamic,
|
||||
}
|
||||
setattr(m.model, "compile_settings", compile_settings)
|
||||
except:
|
||||
raise RuntimeError("Failed to compile model")
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user