mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-09 04:44:30 +08:00
Add SkipLayerGuidanceWanVideo
This commit is contained in:
parent
6b7eeebe44
commit
7c488a16ef
@ -186,6 +186,7 @@ NODE_CONFIG = {
|
||||
"ApplyRifleXRoPE_WanVideo": {"class": ApplyRifleXRoPE_WanVideo, "name": "Apply RifleXRoPE WanVideo"},
|
||||
"WanVideoTeaCacheKJ": {"class": WanVideoTeaCacheKJ, "name": "WanVideo Tea Cache (native)"},
|
||||
"WanVideoEnhanceAVideoKJ": {"class": WanVideoEnhanceAVideoKJ, "name": "WanVideo Enhance A Video (native)"},
|
||||
"SkipLayerGuidanceWanVideo": {"class": SkipLayerGuidanceWanVideo, "name": "Skip Layer Guidance WanVideo"},
|
||||
"TimerNodeKJ": {"class": TimerNodeKJ, "name": "Timer Node KJ"},
|
||||
"HunyuanVideoEncodeKeyframesToCond": {"class": HunyuanVideoEncodeKeyframesToCond, "name": "HunyuanVideo Encode Keyframes To Cond"},
|
||||
|
||||
|
||||
@ -791,7 +791,7 @@ def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=Non
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap, "transformer_options": transformer_options})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(x, e=e0, freqs=freqs, context=context)
|
||||
@ -853,7 +853,7 @@ Official recommended values https://github.com/ali-vilab/TeaCache/tree/main/TeaC
|
||||
|
||||
if coefficients == "disabled" and rel_l1_thresh > 0.1:
|
||||
logging.warning("Threshold value is too high for TeaCache without coefficients, consider using coefficients for better results.")
|
||||
if coefficients != "disabled" and rel_l1_thresh < 0.1:
|
||||
if coefficients != "disabled" and rel_l1_thresh < 0.1 and "1.3B" not in coefficients:
|
||||
logging.warning("Threshold value is too low for TeaCache with coefficients, consider using higher threshold value for better results.")
|
||||
|
||||
# type_str = str(type(model.model.model_config).__name__)
|
||||
@ -914,15 +914,15 @@ Official recommended values https://github.com/ali-vilab/TeaCache/tree/main/TeaC
|
||||
logging.info("\nResetting TeaCache state")
|
||||
|
||||
current_percent = current_step_index / (len(sigmas) - 1)
|
||||
c["transformer_options"]["current_percent"] = current_percent
|
||||
if start_percent <= current_percent <= end_percent:
|
||||
c["transformer_options"]["teacache_enabled"] = True
|
||||
|
||||
context = patch.multiple(
|
||||
diffusion_model,
|
||||
forward_orig=teacache_wanvideo_forward_orig.__get__(diffusion_model, diffusion_model.__class__)
|
||||
)
|
||||
else:
|
||||
context = nullcontext()
|
||||
|
||||
context = patch.multiple(
|
||||
diffusion_model,
|
||||
forward_orig=teacache_wanvideo_forward_orig.__get__(diffusion_model, diffusion_model.__class__)
|
||||
)
|
||||
|
||||
with context:
|
||||
out = model_function(input, timestep, **c)
|
||||
if current_step_index+1 == last_step and hasattr(diffusion_model, "teacache_state"):
|
||||
@ -1072,4 +1072,62 @@ class WanVideoEnhanceAVideoKJ:
|
||||
self_attn = WanAttentionPatch(num_frames, weight).__get__(block.self_attn, block.__class__)
|
||||
model_clone.add_object_patch(f"diffusion_model.blocks.{idx}.self_attn.forward", self_attn)
|
||||
|
||||
return (model_clone,)
|
||||
return (model_clone,)
|
||||
|
||||
class SkipLayerGuidanceWanVideo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"model": ("MODEL", ),
|
||||
"blocks": ("STRING", {"default": "10", "multiline": False}),
|
||||
"start_percent": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "slg"
|
||||
EXPERIMENTAL = True
|
||||
DESCRIPTION = "Simplified skip layer guidance that only skips the uncond on selected blocks"
|
||||
|
||||
CATEGORY = "advanced/guidance"
|
||||
|
||||
def slg(self, model, start_percent, end_percent, blocks):
|
||||
def skip(args, extra_args):
|
||||
transformer_options = extra_args.get("transformer_options", {})
|
||||
if not transformer_options:
|
||||
raise ValueError("transformer_options not found in extra_args, currrently SkipLayerGuidanceWanVideo only works with TeaCacheKJ")
|
||||
if start_percent <= transformer_options["current_percent"] <= end_percent:
|
||||
if args["img"].shape[0] == 2:
|
||||
prev_img_uncond = args["img"][0].unsqueeze(0)
|
||||
|
||||
new_args = {
|
||||
"img": args["img"][1],
|
||||
"txt": args["txt"][1],
|
||||
"vec": args["vec"][1],
|
||||
"pe": args["pe"][1]
|
||||
}
|
||||
block_out = extra_args["original_block"](new_args)
|
||||
|
||||
out = {
|
||||
"img": torch.cat([prev_img_uncond, block_out["img"]], dim=0),
|
||||
"txt": args["txt"],
|
||||
"vec": args["vec"],
|
||||
"pe": args["pe"]
|
||||
}
|
||||
else:
|
||||
if transformer_options.get("cond_or_uncond") == [0]:
|
||||
out = extra_args["original_block"](args)
|
||||
else:
|
||||
out = args
|
||||
else:
|
||||
out = extra_args["original_block"](args)
|
||||
return out
|
||||
|
||||
block_list = [int(x.strip()) for x in blocks.split(",")]
|
||||
double_layers = [int(i) for i in block_list]
|
||||
logging.info(f"Selected blocks to skip uncond on: {double_layers}")
|
||||
|
||||
m = model.clone()
|
||||
|
||||
for layer in double_layers:
|
||||
m.set_model_patch_replace(skip, "dit", "double_block", layer)
|
||||
|
||||
return (m, )
|
||||
Loading…
x
Reference in New Issue
Block a user