mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-01-04 23:44:06 +08:00
fix up TorchCompileModelFluxAdvanced
This commit is contained in:
parent
f59e410568
commit
530c5d7eaf
@ -2240,6 +2240,9 @@ class CheckpointLoaderKJ:
|
||||
return model, clip, vae
|
||||
import re
|
||||
class TorchCompileModelFluxAdvanced:
|
||||
def __init__(self):
|
||||
self._compiled = False
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
@ -2267,33 +2270,33 @@ class TorchCompileModelFluxAdvanced:
|
||||
blocks.append(int(part))
|
||||
return blocks
|
||||
|
||||
def compile_diffusion_model(self, diffusion_model, backend, mode, fullgraph, single_block_list, double_block_list):
|
||||
#print("Diffusion model object before compilation:", diffusion_model)
|
||||
for i, block in enumerate(diffusion_model.double_blocks):
|
||||
if i in double_block_list:
|
||||
print("Compiling double block", i)
|
||||
diffusion_model.double_blocks[i] = torch.compile(block, mode=mode, fullgraph=fullgraph, backend=backend)
|
||||
for i, block in enumerate(diffusion_model.single_blocks):
|
||||
if i in single_block_list:
|
||||
print("Compiling single block", i)
|
||||
diffusion_model.single_blocks[i] = torch.compile(block, mode=mode, fullgraph=fullgraph, backend=backend)
|
||||
|
||||
def patch(self, model, backend, mode, fullgraph, single_blocks, double_blocks):
|
||||
single_block_list = self.parse_blocks(single_blocks)
|
||||
double_block_list = self.parse_blocks(double_blocks)
|
||||
m = model.clone()
|
||||
diffusion_model = m.get_model_object("diffusion_model")
|
||||
|
||||
if not self._compiled:
|
||||
try:
|
||||
for i, block in enumerate(diffusion_model.double_blocks):
|
||||
if i in double_block_list:
|
||||
print("Compiling double_block", i)
|
||||
m.add_object_patch(f"diffusion_model.double_blocks.{i}", torch.compile(block, mode=mode, fullgraph=fullgraph, backend=backend))
|
||||
for i, block in enumerate(diffusion_model.single_blocks):
|
||||
if i in single_block_list:
|
||||
print("Compiling single block", i)
|
||||
m.add_object_patch(f"diffusion_model.single_blocks.{i}", torch.compile(block, mode=mode, fullgraph=fullgraph, backend=backend))
|
||||
self._compiled = True
|
||||
except:
|
||||
raise RuntimeError("Failed to compile model")
|
||||
|
||||
return (m, )
|
||||
# rest of the layers that are not patched
|
||||
# diffusion_model.final_layer = torch.compile(diffusion_model.final_layer, mode=mode, fullgraph=fullgraph, backend=backend)
|
||||
# diffusion_model.guidance_in = torch.compile(diffusion_model.guidance_in, mode=mode, fullgraph=fullgraph, backend=backend)
|
||||
# diffusion_model.img_in = torch.compile(diffusion_model.img_in, mode=mode, fullgraph=fullgraph, backend=backend)
|
||||
# diffusion_model.time_in = torch.compile(diffusion_model.time_in, mode=mode, fullgraph=fullgraph, backend=backend)
|
||||
# diffusion_model.txt_in = torch.compile(diffusion_model.txt_in, mode=mode, fullgraph=fullgraph, backend=backend)
|
||||
# diffusion_model.vector_in = torch.compile(diffusion_model.vector_in, mode=mode, fullgraph=fullgraph, backend=backend)
|
||||
|
||||
#compiled_model = torch.compile(model=diffusion_model, backend=backend)
|
||||
#print("Compiled diffusion model object:", compiled_model)
|
||||
return diffusion_model
|
||||
|
||||
def patch(self, model, backend, mode, fullgraph, single_blocks, double_blocks):
|
||||
single_block_list = self.parse_blocks(single_blocks)
|
||||
double_block_list = self.parse_blocks(double_blocks)
|
||||
m = model.clone()
|
||||
diffusion_model = m.get_model_object("diffusion_model")
|
||||
#self.compile_diffusion_model(diffusion_model, backend, mode, fullgraph, single_block_list, double_block_list)
|
||||
m.add_object_patch("diffusion_model", self.compile_diffusion_model(diffusion_model, backend, mode, fullgraph, single_block_list, double_block_list))
|
||||
return (m, )
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user