Better compile compatibility with various patches

Shouldn't drop compile when changing slg or enhance-a-video settings anymore
This commit is contained in:
kijai 2025-03-17 14:11:30 +02:00
parent bb154eb71f
commit 393ec896f7

View File

@ -486,26 +486,32 @@ class TorchCompileModelWanVideo:
m = model.clone()
diffusion_model = m.get_model_object("diffusion_model")
torch._dynamo.config.cache_size_limit = dynamo_cache_size_limit
if not self._compiled:
try:
if compile_transformer_blocks_only:
for i, block in enumerate(diffusion_model.blocks):
is_compiled = hasattr(model.model.diffusion_model.blocks[0], "_orig_mod")
if is_compiled:
logging.info(f"Already compiled, not reapplying")
else:
logging.info(f"Not compiled, applying")
try:
if compile_transformer_blocks_only:
for i, block in enumerate(diffusion_model.blocks):
if is_compiled:
compiled_block = torch.compile(block._orig_mod, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode)
else:
compiled_block = torch.compile(block, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode)
m.add_object_patch(f"diffusion_model.blocks.{i}", compiled_block)
else:
compiled_model = torch.compile(diffusion_model, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode)
m.add_object_patch("diffusion_model", compiled_model)
m.add_object_patch(f"diffusion_model.blocks.{i}", compiled_block)
else:
compiled_model = torch.compile(diffusion_model, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode)
m.add_object_patch("diffusion_model", compiled_model)
self._compiled = True
compile_settings = {
"backend": backend,
"mode": mode,
"fullgraph": fullgraph,
"dynamic": dynamic,
}
setattr(m.model, "compile_settings", compile_settings)
except:
raise RuntimeError("Failed to compile model")
compile_settings = {
"backend": backend,
"mode": mode,
"fullgraph": fullgraph,
"dynamic": dynamic,
}
setattr(m.model, "compile_settings", compile_settings)
except:
raise RuntimeError("Failed to compile model")
return (m, )
class TorchCompileVAE:
@ -1074,9 +1080,14 @@ class WanVideoEnhanceAVideoKJ:
model_clone.model_options['transformer_options'] = {}
model_clone.model_options["transformer_options"]["enhance_weight"] = weight
diffusion_model = model_clone.get_model_object("diffusion_model")
compile_settings = getattr(model.model, "compile_settings", None)
for idx, block in enumerate(diffusion_model.blocks):
self_attn = WanAttentionPatch(num_frames, weight).__get__(block.self_attn, block.__class__)
model_clone.add_object_patch(f"diffusion_model.blocks.{idx}.self_attn.forward", self_attn)
patched_attn = WanAttentionPatch(num_frames, weight).__get__(block.self_attn, block.__class__)
if compile_settings is not None:
patched_attn = torch.compile(patched_attn, mode=compile_settings["mode"], dynamic=compile_settings["dynamic"], fullgraph=compile_settings["fullgraph"], backend=compile_settings["backend"])
model_clone.add_object_patch(f"diffusion_model.blocks.{idx}.self_attn.forward", patched_attn)
return (model_clone,)
@ -1098,8 +1109,10 @@ class SkipLayerGuidanceWanVideo:
def slg(self, model, start_percent, end_percent, blocks):
def skip(args, extra_args):
transformer_options = extra_args.get("transformer_options", {})
original_block = extra_args["original_block"]
if not transformer_options:
raise ValueError("transformer_options not found in extra_args, currrently SkipLayerGuidanceWanVideo only works with TeaCacheKJ")
raise ValueError("transformer_options not found in extra_args, currently SkipLayerGuidanceWanVideo only works with TeaCacheKJ")
if start_percent <= transformer_options["current_percent"] <= end_percent:
if args["img"].shape[0] == 2:
prev_img_uncond = args["img"][0].unsqueeze(0)
@ -1110,7 +1123,8 @@ class SkipLayerGuidanceWanVideo:
"vec": args["vec"][1],
"pe": args["pe"][1]
}
block_out = extra_args["original_block"](new_args)
block_out = original_block(new_args)
out = {
"img": torch.cat([prev_img_uncond, block_out["img"]], dim=0),
@ -1120,20 +1134,36 @@ class SkipLayerGuidanceWanVideo:
}
else:
if transformer_options.get("cond_or_uncond") == [0]:
out = extra_args["original_block"](args)
out = original_block(args)
else:
out = args
else:
out = extra_args["original_block"](args)
out = original_block(args)
return out
block_list = [int(x.strip()) for x in blocks.split(",")]
double_layers = [int(i) for i in block_list]
logging.info(f"Selected blocks to skip uncond on: {double_layers}")
blocks = [int(i) for i in block_list]
logging.info(f"Selected blocks to skip uncond on: {blocks}")
m = model.clone()
for layer in double_layers:
m.set_model_patch_replace(skip, "dit", "double_block", layer)
for b in blocks:
#m.set_model_patch_replace(skip, "dit", "double_block", b)
model_options = m.model_options["transformer_options"].copy()
if "patches_replace" not in model_options:
model_options["patches_replace"] = {}
else:
model_options["patches_replace"] = model_options["patches_replace"].copy()
if "dit" not in model_options["patches_replace"]:
model_options["patches_replace"]["dit"] = {}
else:
model_options["patches_replace"]["dit"] = model_options["patches_replace"]["dit"].copy()
block = ("double_block", b)
model_options["patches_replace"]["dit"][block] = skip
m.model_options["transformer_options"] = model_options
return (m, )