fix up TorchCompileModelFluxAdvanced

This commit is contained in:
kijai 2024-10-28 02:32:56 +02:00
parent f59e410568
commit 530c5d7eaf

View File

@ -2240,6 +2240,9 @@ class CheckpointLoaderKJ:
return model, clip, vae
import re
class TorchCompileModelFluxAdvanced:
def __init__(self):
self._compiled = False
@classmethod
def INPUT_TYPES(s):
return {"required": {
@ -2267,33 +2270,33 @@ class TorchCompileModelFluxAdvanced:
blocks.append(int(part))
return blocks
def compile_diffusion_model(self, diffusion_model, backend, mode, fullgraph, single_block_list, double_block_list):
#print("Diffusion model object before compilation:", diffusion_model)
for i, block in enumerate(diffusion_model.double_blocks):
if i in double_block_list:
print("Compiling double block", i)
diffusion_model.double_blocks[i] = torch.compile(block, mode=mode, fullgraph=fullgraph, backend=backend)
for i, block in enumerate(diffusion_model.single_blocks):
if i in single_block_list:
print("Compiling single block", i)
diffusion_model.single_blocks[i] = torch.compile(block, mode=mode, fullgraph=fullgraph, backend=backend)
def patch(self, model, backend, mode, fullgraph, single_blocks, double_blocks):
single_block_list = self.parse_blocks(single_blocks)
double_block_list = self.parse_blocks(double_blocks)
m = model.clone()
diffusion_model = m.get_model_object("diffusion_model")
if not self._compiled:
try:
for i, block in enumerate(diffusion_model.double_blocks):
if i in double_block_list:
print("Compiling double_block", i)
m.add_object_patch(f"diffusion_model.double_blocks.{i}", torch.compile(block, mode=mode, fullgraph=fullgraph, backend=backend))
for i, block in enumerate(diffusion_model.single_blocks):
if i in single_block_list:
print("Compiling single block", i)
m.add_object_patch(f"diffusion_model.single_blocks.{i}", torch.compile(block, mode=mode, fullgraph=fullgraph, backend=backend))
self._compiled = True
except:
raise RuntimeError("Failed to compile model")
return (m, )
# rest of the layers that are not patched
# diffusion_model.final_layer = torch.compile(diffusion_model.final_layer, mode=mode, fullgraph=fullgraph, backend=backend)
# diffusion_model.guidance_in = torch.compile(diffusion_model.guidance_in, mode=mode, fullgraph=fullgraph, backend=backend)
# diffusion_model.img_in = torch.compile(diffusion_model.img_in, mode=mode, fullgraph=fullgraph, backend=backend)
# diffusion_model.time_in = torch.compile(diffusion_model.time_in, mode=mode, fullgraph=fullgraph, backend=backend)
# diffusion_model.txt_in = torch.compile(diffusion_model.txt_in, mode=mode, fullgraph=fullgraph, backend=backend)
# diffusion_model.vector_in = torch.compile(diffusion_model.vector_in, mode=mode, fullgraph=fullgraph, backend=backend)
#compiled_model = torch.compile(model=diffusion_model, backend=backend)
#print("Compiled diffusion model object:", compiled_model)
return diffusion_model
def patch(self, model, backend, mode, fullgraph, single_blocks, double_blocks):
single_block_list = self.parse_blocks(single_blocks)
double_block_list = self.parse_blocks(double_blocks)
m = model.clone()
diffusion_model = m.get_model_object("diffusion_model")
#self.compile_diffusion_model(diffusion_model, backend, mode, fullgraph, single_block_list, double_block_list)
m.add_object_patch("diffusion_model", self.compile_diffusion_model(diffusion_model, backend, mode, fullgraph, single_block_list, double_block_list))
return (m, )