mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-16 08:14:32 +08:00
Better compile compatibility with various patches
Shouldn't drop compile when changing slg or enhance-a-video settings anymore
This commit is contained in:
parent
bb154eb71f
commit
393ec896f7
@ -486,26 +486,32 @@ class TorchCompileModelWanVideo:
|
||||
m = model.clone()
|
||||
diffusion_model = m.get_model_object("diffusion_model")
|
||||
torch._dynamo.config.cache_size_limit = dynamo_cache_size_limit
|
||||
if not self._compiled:
|
||||
try:
|
||||
if compile_transformer_blocks_only:
|
||||
for i, block in enumerate(diffusion_model.blocks):
|
||||
is_compiled = hasattr(model.model.diffusion_model.blocks[0], "_orig_mod")
|
||||
if is_compiled:
|
||||
logging.info(f"Already compiled, not reapplying")
|
||||
else:
|
||||
logging.info(f"Not compiled, applying")
|
||||
try:
|
||||
if compile_transformer_blocks_only:
|
||||
for i, block in enumerate(diffusion_model.blocks):
|
||||
if is_compiled:
|
||||
compiled_block = torch.compile(block._orig_mod, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode)
|
||||
else:
|
||||
compiled_block = torch.compile(block, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode)
|
||||
m.add_object_patch(f"diffusion_model.blocks.{i}", compiled_block)
|
||||
else:
|
||||
compiled_model = torch.compile(diffusion_model, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode)
|
||||
m.add_object_patch("diffusion_model", compiled_model)
|
||||
m.add_object_patch(f"diffusion_model.blocks.{i}", compiled_block)
|
||||
else:
|
||||
compiled_model = torch.compile(diffusion_model, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode)
|
||||
m.add_object_patch("diffusion_model", compiled_model)
|
||||
|
||||
self._compiled = True
|
||||
compile_settings = {
|
||||
"backend": backend,
|
||||
"mode": mode,
|
||||
"fullgraph": fullgraph,
|
||||
"dynamic": dynamic,
|
||||
}
|
||||
setattr(m.model, "compile_settings", compile_settings)
|
||||
except:
|
||||
raise RuntimeError("Failed to compile model")
|
||||
compile_settings = {
|
||||
"backend": backend,
|
||||
"mode": mode,
|
||||
"fullgraph": fullgraph,
|
||||
"dynamic": dynamic,
|
||||
}
|
||||
setattr(m.model, "compile_settings", compile_settings)
|
||||
except:
|
||||
raise RuntimeError("Failed to compile model")
|
||||
return (m, )
|
||||
|
||||
class TorchCompileVAE:
|
||||
@ -1074,9 +1080,14 @@ class WanVideoEnhanceAVideoKJ:
|
||||
model_clone.model_options['transformer_options'] = {}
|
||||
model_clone.model_options["transformer_options"]["enhance_weight"] = weight
|
||||
diffusion_model = model_clone.get_model_object("diffusion_model")
|
||||
|
||||
compile_settings = getattr(model.model, "compile_settings", None)
|
||||
for idx, block in enumerate(diffusion_model.blocks):
|
||||
self_attn = WanAttentionPatch(num_frames, weight).__get__(block.self_attn, block.__class__)
|
||||
model_clone.add_object_patch(f"diffusion_model.blocks.{idx}.self_attn.forward", self_attn)
|
||||
patched_attn = WanAttentionPatch(num_frames, weight).__get__(block.self_attn, block.__class__)
|
||||
if compile_settings is not None:
|
||||
patched_attn = torch.compile(patched_attn, mode=compile_settings["mode"], dynamic=compile_settings["dynamic"], fullgraph=compile_settings["fullgraph"], backend=compile_settings["backend"])
|
||||
|
||||
model_clone.add_object_patch(f"diffusion_model.blocks.{idx}.self_attn.forward", patched_attn)
|
||||
|
||||
return (model_clone,)
|
||||
|
||||
@ -1098,8 +1109,10 @@ class SkipLayerGuidanceWanVideo:
|
||||
def slg(self, model, start_percent, end_percent, blocks):
|
||||
def skip(args, extra_args):
|
||||
transformer_options = extra_args.get("transformer_options", {})
|
||||
original_block = extra_args["original_block"]
|
||||
|
||||
if not transformer_options:
|
||||
raise ValueError("transformer_options not found in extra_args, currrently SkipLayerGuidanceWanVideo only works with TeaCacheKJ")
|
||||
raise ValueError("transformer_options not found in extra_args, currently SkipLayerGuidanceWanVideo only works with TeaCacheKJ")
|
||||
if start_percent <= transformer_options["current_percent"] <= end_percent:
|
||||
if args["img"].shape[0] == 2:
|
||||
prev_img_uncond = args["img"][0].unsqueeze(0)
|
||||
@ -1110,7 +1123,8 @@ class SkipLayerGuidanceWanVideo:
|
||||
"vec": args["vec"][1],
|
||||
"pe": args["pe"][1]
|
||||
}
|
||||
block_out = extra_args["original_block"](new_args)
|
||||
|
||||
block_out = original_block(new_args)
|
||||
|
||||
out = {
|
||||
"img": torch.cat([prev_img_uncond, block_out["img"]], dim=0),
|
||||
@ -1120,20 +1134,36 @@ class SkipLayerGuidanceWanVideo:
|
||||
}
|
||||
else:
|
||||
if transformer_options.get("cond_or_uncond") == [0]:
|
||||
out = extra_args["original_block"](args)
|
||||
out = original_block(args)
|
||||
else:
|
||||
out = args
|
||||
else:
|
||||
out = extra_args["original_block"](args)
|
||||
out = original_block(args)
|
||||
return out
|
||||
|
||||
block_list = [int(x.strip()) for x in blocks.split(",")]
|
||||
double_layers = [int(i) for i in block_list]
|
||||
logging.info(f"Selected blocks to skip uncond on: {double_layers}")
|
||||
blocks = [int(i) for i in block_list]
|
||||
logging.info(f"Selected blocks to skip uncond on: {blocks}")
|
||||
|
||||
m = model.clone()
|
||||
|
||||
for layer in double_layers:
|
||||
m.set_model_patch_replace(skip, "dit", "double_block", layer)
|
||||
for b in blocks:
|
||||
#m.set_model_patch_replace(skip, "dit", "double_block", b)
|
||||
model_options = m.model_options["transformer_options"].copy()
|
||||
if "patches_replace" not in model_options:
|
||||
model_options["patches_replace"] = {}
|
||||
else:
|
||||
model_options["patches_replace"] = model_options["patches_replace"].copy()
|
||||
|
||||
if "dit" not in model_options["patches_replace"]:
|
||||
model_options["patches_replace"]["dit"] = {}
|
||||
else:
|
||||
model_options["patches_replace"]["dit"] = model_options["patches_replace"]["dit"].copy()
|
||||
|
||||
block = ("double_block", b)
|
||||
|
||||
model_options["patches_replace"]["dit"][block] = skip
|
||||
m.model_options["transformer_options"] = model_options
|
||||
|
||||
|
||||
return (m, )
|
||||
Loading…
x
Reference in New Issue
Block a user