diff --git a/nodes/nodes.py b/nodes/nodes.py index 55529d5..8a2078f 100644 --- a/nodes/nodes.py +++ b/nodes/nodes.py @@ -2240,6 +2240,9 @@ class CheckpointLoaderKJ: return model, clip, vae import re class TorchCompileModelFluxAdvanced: + def __init__(self): + self._compiled = False + @classmethod def INPUT_TYPES(s): return {"required": { @@ -2267,33 +2270,33 @@ class TorchCompileModelFluxAdvanced: blocks.append(int(part)) return blocks - def compile_diffusion_model(self, diffusion_model, backend, mode, fullgraph, single_block_list, double_block_list): - #print("Diffusion model object before compilation:", diffusion_model) - for i, block in enumerate(diffusion_model.double_blocks): - if i in double_block_list: - print("Compiling double block", i) - diffusion_model.double_blocks[i] = torch.compile(block, mode=mode, fullgraph=fullgraph, backend=backend) - for i, block in enumerate(diffusion_model.single_blocks): - if i in single_block_list: - print("Compiling single block", i) - diffusion_model.single_blocks[i] = torch.compile(block, mode=mode, fullgraph=fullgraph, backend=backend) - + def patch(self, model, backend, mode, fullgraph, single_blocks, double_blocks): + single_block_list = self.parse_blocks(single_blocks) + double_block_list = self.parse_blocks(double_blocks) + m = model.clone() + diffusion_model = m.get_model_object("diffusion_model") + + if not self._compiled: + try: + for i, block in enumerate(diffusion_model.double_blocks): + if i in double_block_list: + print("Compiling double_block", i) + m.add_object_patch(f"diffusion_model.double_blocks.{i}", torch.compile(block, mode=mode, fullgraph=fullgraph, backend=backend)) + for i, block in enumerate(diffusion_model.single_blocks): + if i in single_block_list: + print("Compiling single block", i) + m.add_object_patch(f"diffusion_model.single_blocks.{i}", torch.compile(block, mode=mode, fullgraph=fullgraph, backend=backend)) + self._compiled = True + except: + raise RuntimeError("Failed to compile model") + + return (m, ) + # rest of the layers that are not patched # diffusion_model.final_layer = torch.compile(diffusion_model.final_layer, mode=mode, fullgraph=fullgraph, backend=backend) # diffusion_model.guidance_in = torch.compile(diffusion_model.guidance_in, mode=mode, fullgraph=fullgraph, backend=backend) # diffusion_model.img_in = torch.compile(diffusion_model.img_in, mode=mode, fullgraph=fullgraph, backend=backend) # diffusion_model.time_in = torch.compile(diffusion_model.time_in, mode=mode, fullgraph=fullgraph, backend=backend) # diffusion_model.txt_in = torch.compile(diffusion_model.txt_in, mode=mode, fullgraph=fullgraph, backend=backend) # diffusion_model.vector_in = torch.compile(diffusion_model.vector_in, mode=mode, fullgraph=fullgraph, backend=backend) - - #compiled_model = torch.compile(model=diffusion_model, backend=backend) - #print("Compiled diffusion model object:", compiled_model) - return diffusion_model - - def patch(self, model, backend, mode, fullgraph, single_blocks, double_blocks): - single_block_list = self.parse_blocks(single_blocks) - double_block_list = self.parse_blocks(double_blocks) - m = model.clone() - diffusion_model = m.get_model_object("diffusion_model") - #self.compile_diffusion_model(diffusion_model, backend, mode, fullgraph, single_block_list, double_block_list) - m.add_object_patch("diffusion_model", self.compile_diffusion_model(diffusion_model, backend, mode, fullgraph, single_block_list, double_block_list)) - return (m, ) + +