mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-05-28 23:17:17 +08:00
save and retrieve compile settings when re-compiling
This commit is contained in:
parent
365e0699b1
commit
cdf8ca8298
@ -1,11 +1,10 @@
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from typing import Union
|
||||||
import json, re, os, io, time
|
import json, re, os, io, time, platform
|
||||||
import re
|
import re
|
||||||
import importlib
|
import importlib
|
||||||
from contextlib import contextmanager
|
|
||||||
|
|
||||||
import model_management
|
import model_management
|
||||||
import folder_paths
|
import folder_paths
|
||||||
@ -2300,6 +2299,8 @@ def patched_load_lora_for_models(model, clip, lora, strength_model, strength_cli
|
|||||||
print("NOT LOADED {}".format(x))
|
print("NOT LOADED {}".format(x))
|
||||||
|
|
||||||
if patch_keys:
|
if patch_keys:
|
||||||
|
compile_settings = getattr(model.model, "compile_settings")
|
||||||
|
print("compile_settings: ", compile_settings)
|
||||||
for k in patch_keys:
|
for k in patch_keys:
|
||||||
if "diffusion_model." in k:
|
if "diffusion_model." in k:
|
||||||
# Remove the prefix to get the attribute path
|
# Remove the prefix to get the attribute path
|
||||||
@ -2314,12 +2315,36 @@ def patched_load_lora_for_models(model, clip, lora, strength_model, strength_cli
|
|||||||
else:
|
else:
|
||||||
block = getattr(block, attr)
|
block = getattr(block, attr)
|
||||||
# Compile the block
|
# Compile the block
|
||||||
compiled_block = torch.compile(block, mode="default", fullgraph=False, backend="inductor")
|
compiled_block = torch.compile(block, mode=compile_settings["mode"], dynamic=compile_settings["dynamic"], fullgraph=compile_settings["fullgraph"], backend=compile_settings["backend"])
|
||||||
# Add the compiled block back as an object patch
|
# Add the compiled block back as an object patch
|
||||||
model.add_object_patch(k, compiled_block)
|
model.add_object_patch(k, compiled_block)
|
||||||
return (new_modelpatcher, new_clip)
|
return (new_modelpatcher, new_clip)
|
||||||
|
|
||||||
|
def patched_write_atomic(
|
||||||
|
path_: str,
|
||||||
|
content: Union[str, bytes],
|
||||||
|
make_dirs: bool = False,
|
||||||
|
encode_utf_8: bool = False,
|
||||||
|
) -> None:
|
||||||
|
# Write into temporary file first to avoid conflicts between threads
|
||||||
|
# Avoid using a named temporary file, as those have restricted permissions
|
||||||
|
from pathlib import Path
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import threading
|
||||||
|
assert isinstance(
|
||||||
|
content, (str, bytes)
|
||||||
|
), "Only strings and byte arrays can be saved in the cache"
|
||||||
|
path = Path(path_)
|
||||||
|
if make_dirs:
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
tmp_path = path.parent / f".{os.getpid()}.{threading.get_ident()}.tmp"
|
||||||
|
write_mode = "w" if isinstance(content, str) else "wb"
|
||||||
|
with tmp_path.open(write_mode, encoding="utf-8" if encode_utf_8 else None) as f:
|
||||||
|
f.write(content)
|
||||||
|
shutil.copy2(src=tmp_path, dst=path) #changed to allow overwriting cache files
|
||||||
|
os.remove(tmp_path)
|
||||||
|
|
||||||
class PatchModelPatcherOrder:
|
class PatchModelPatcherOrder:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -2357,6 +2382,7 @@ class TorchCompileModelFluxAdvanced:
|
|||||||
"mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
|
"mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
|
||||||
"double_blocks": ("STRING", {"default": "0-18", "multiline": True}),
|
"double_blocks": ("STRING", {"default": "0-18", "multiline": True}),
|
||||||
"single_blocks": ("STRING", {"default": "0-37", "multiline": True}),
|
"single_blocks": ("STRING", {"default": "0-37", "multiline": True}),
|
||||||
|
"dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}),
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("MODEL",)
|
RETURN_TYPES = ("MODEL",)
|
||||||
FUNCTION = "patch"
|
FUNCTION = "patch"
|
||||||
@ -2375,7 +2401,13 @@ class TorchCompileModelFluxAdvanced:
|
|||||||
blocks.append(int(part))
|
blocks.append(int(part))
|
||||||
return blocks
|
return blocks
|
||||||
|
|
||||||
def patch(self, model, backend, mode, fullgraph, single_blocks, double_blocks):
|
def patch(self, model, backend, mode, fullgraph, single_blocks, double_blocks, dynamic):
|
||||||
|
if platform.system() == 'Windows':
|
||||||
|
try:
|
||||||
|
import torch._inductor.codecache
|
||||||
|
torch._inductor.codecache.write_atomic = patched_write_atomic #temporary workaround for the cache write bug in Windows
|
||||||
|
except:
|
||||||
|
pass
|
||||||
single_block_list = self.parse_blocks(single_blocks)
|
single_block_list = self.parse_blocks(single_blocks)
|
||||||
double_block_list = self.parse_blocks(double_blocks)
|
double_block_list = self.parse_blocks(double_blocks)
|
||||||
m = model.clone()
|
m = model.clone()
|
||||||
@ -2386,12 +2418,19 @@ class TorchCompileModelFluxAdvanced:
|
|||||||
for i, block in enumerate(diffusion_model.double_blocks):
|
for i, block in enumerate(diffusion_model.double_blocks):
|
||||||
if i in double_block_list:
|
if i in double_block_list:
|
||||||
#print("Compiling double_block", i)
|
#print("Compiling double_block", i)
|
||||||
m.add_object_patch(f"diffusion_model.double_blocks.{i}", torch.compile(block, mode=mode, fullgraph=fullgraph, backend=backend))
|
m.add_object_patch(f"diffusion_model.double_blocks.{i}", torch.compile(block, mode=mode, dynamic=dynamic, fullgraph=fullgraph, backend=backend))
|
||||||
for i, block in enumerate(diffusion_model.single_blocks):
|
for i, block in enumerate(diffusion_model.single_blocks):
|
||||||
if i in single_block_list:
|
if i in single_block_list:
|
||||||
#print("Compiling single block", i)
|
#print("Compiling single block", i)
|
||||||
m.add_object_patch(f"diffusion_model.single_blocks.{i}", torch.compile(block, mode=mode, fullgraph=fullgraph, backend=backend))
|
m.add_object_patch(f"diffusion_model.single_blocks.{i}", torch.compile(block, mode=mode, dynamic=dynamic, fullgraph=fullgraph, backend=backend))
|
||||||
self._compiled = True
|
self._compiled = True
|
||||||
|
compile_settings = {
|
||||||
|
"backend": backend,
|
||||||
|
"mode": mode,
|
||||||
|
"fullgraph": fullgraph,
|
||||||
|
"dynamic": dynamic,
|
||||||
|
}
|
||||||
|
setattr(m.model, "compile_settings", compile_settings)
|
||||||
except:
|
except:
|
||||||
raise RuntimeError("Failed to compile model")
|
raise RuntimeError("Failed to compile model")
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user