TorchCompileVAE and controlnet

This commit is contained in:
kijai 2024-10-28 03:57:58 +02:00
parent 530c5d7eaf
commit fe5fbb03ff
2 changed files with 72 additions and 1 deletions

View File

@ -153,6 +153,8 @@ NODE_CONFIG = {
"CustomControlNetWeightsFluxFromList": {"class": CustomControlNetWeightsFluxFromList, "name": "Custom ControlNet Weights Flux From List"},
"CheckpointLoaderKJ": {"class": CheckpointLoaderKJ, "name": "CheckpointLoaderKJ"},
"TorchCompileModelFluxAdvanced": {"class": TorchCompileModelFluxAdvanced, "name": "TorchCompileModelFluxAdvanced"},
"TorchCompileVAE": {"class": TorchCompileVAE, "name": "TorchCompileVAE"},
"TorchCompileControlNet": {"class": TorchCompileControlNet, "name": "TorchCompileControlNet"},
#instance diffusion
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},

View File

@ -2256,7 +2256,7 @@ class TorchCompileModelFluxAdvanced:
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
CATEGORY = "KJNodes/experimental"
EXPERIMENTAL = True
def parse_blocks(self, blocks_str):
@ -2299,4 +2299,73 @@ class TorchCompileModelFluxAdvanced:
# diffusion_model.txt_in = torch.compile(diffusion_model.txt_in, mode=mode, fullgraph=fullgraph, backend=backend)
# diffusion_model.vector_in = torch.compile(diffusion_model.vector_in, mode=mode, fullgraph=fullgraph, backend=backend)
class TorchCompileVAE:
def __init__(self):
self._compiled_encoder = False
self._compiled_decoder = False
@classmethod
def INPUT_TYPES(s):
return {"required": {
"vae": ("VAE",),
"backend": (["inductor", "cudagraphs"],),
"fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}),
"mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
"compile_encoder": ("BOOLEAN", {"default": True, "tooltip": "Compile encoder"}),
"compile_decoder": ("BOOLEAN", {"default": True, "tooltip": "Compile decoder"}),
}}
RETURN_TYPES = ("VAE",)
FUNCTION = "compile"
CATEGORY = "KJNodes/experimental"
EXPERIMENTAL = True
def compile(self, vae, backend, mode, fullgraph, compile_encoder, compile_decoder):
if compile_encoder:
if not self._compiled_encoder:
try:
vae.first_stage_model.encoder = torch.compile(vae.first_stage_model.encoder, mode=mode, fullgraph=fullgraph, backend=backend)
self._compiled_encoder = True
except:
raise RuntimeError("Failed to compile model")
if compile_decoder:
if not self._compiled_decoder:
try:
vae.first_stage_model.decoder = torch.compile(vae.first_stage_model.decoder, mode=mode, fullgraph=fullgraph, backend=backend)
self._compiled_decoder = True
except:
raise RuntimeError("Failed to compile model")
return (vae, )
class TorchCompileControlNet:
def __init__(self):
self._compiled= False
@classmethod
def INPUT_TYPES(s):
return {"required": {
"controlnet": ("CONTROL_NET",),
"backend": (["inductor", "cudagraphs"],),
"fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}),
"mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
}}
RETURN_TYPES = ("CONTROL_NET",)
FUNCTION = "compile"
CATEGORY = "KJNodes/experimental"
EXPERIMENTAL = True
def compile(self, controlnet, backend, mode, fullgraph):
print(controlnet.control_model)
if not self._compiled:
try:
# for i, block in enumerate(controlnet.control_model.double_blocks):
# print("Compiling controlnet double_block", i)
# controlnet.control_model.double_blocks[i] = torch.compile(block, mode=mode, fullgraph=fullgraph, backend=backend)
controlnet.control_model = torch.compile(controlnet.control_model, mode=mode, fullgraph=fullgraph, backend=backend)
self._compiled = True
except:
self._compiled = False
raise RuntimeError("Failed to compile model")
return (controlnet, )