mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-16 08:14:32 +08:00
Add SkipLayerGuidanceWanVideo
This commit is contained in:
parent
6b7eeebe44
commit
7c488a16ef
@ -186,6 +186,7 @@ NODE_CONFIG = {
|
|||||||
"ApplyRifleXRoPE_WanVideo": {"class": ApplyRifleXRoPE_WanVideo, "name": "Apply RifleXRoPE WanVideo"},
|
"ApplyRifleXRoPE_WanVideo": {"class": ApplyRifleXRoPE_WanVideo, "name": "Apply RifleXRoPE WanVideo"},
|
||||||
"WanVideoTeaCacheKJ": {"class": WanVideoTeaCacheKJ, "name": "WanVideo Tea Cache (native)"},
|
"WanVideoTeaCacheKJ": {"class": WanVideoTeaCacheKJ, "name": "WanVideo Tea Cache (native)"},
|
||||||
"WanVideoEnhanceAVideoKJ": {"class": WanVideoEnhanceAVideoKJ, "name": "WanVideo Enhance A Video (native)"},
|
"WanVideoEnhanceAVideoKJ": {"class": WanVideoEnhanceAVideoKJ, "name": "WanVideo Enhance A Video (native)"},
|
||||||
|
"SkipLayerGuidanceWanVideo": {"class": SkipLayerGuidanceWanVideo, "name": "Skip Layer Guidance WanVideo"},
|
||||||
"TimerNodeKJ": {"class": TimerNodeKJ, "name": "Timer Node KJ"},
|
"TimerNodeKJ": {"class": TimerNodeKJ, "name": "Timer Node KJ"},
|
||||||
"HunyuanVideoEncodeKeyframesToCond": {"class": HunyuanVideoEncodeKeyframesToCond, "name": "HunyuanVideo Encode Keyframes To Cond"},
|
"HunyuanVideoEncodeKeyframesToCond": {"class": HunyuanVideoEncodeKeyframesToCond, "name": "HunyuanVideo Encode Keyframes To Cond"},
|
||||||
|
|
||||||
|
|||||||
@ -791,7 +791,7 @@ def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=Non
|
|||||||
out = {}
|
out = {}
|
||||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
|
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
|
||||||
return out
|
return out
|
||||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap, "transformer_options": transformer_options})
|
||||||
x = out["img"]
|
x = out["img"]
|
||||||
else:
|
else:
|
||||||
x = block(x, e=e0, freqs=freqs, context=context)
|
x = block(x, e=e0, freqs=freqs, context=context)
|
||||||
@ -853,7 +853,7 @@ Official recommended values https://github.com/ali-vilab/TeaCache/tree/main/TeaC
|
|||||||
|
|
||||||
if coefficients == "disabled" and rel_l1_thresh > 0.1:
|
if coefficients == "disabled" and rel_l1_thresh > 0.1:
|
||||||
logging.warning("Threshold value is too high for TeaCache without coefficients, consider using coefficients for better results.")
|
logging.warning("Threshold value is too high for TeaCache without coefficients, consider using coefficients for better results.")
|
||||||
if coefficients != "disabled" and rel_l1_thresh < 0.1:
|
if coefficients != "disabled" and rel_l1_thresh < 0.1 and "1.3B" not in coefficients:
|
||||||
logging.warning("Threshold value is too low for TeaCache with coefficients, consider using higher threshold value for better results.")
|
logging.warning("Threshold value is too low for TeaCache with coefficients, consider using higher threshold value for better results.")
|
||||||
|
|
||||||
# type_str = str(type(model.model.model_config).__name__)
|
# type_str = str(type(model.model.model_config).__name__)
|
||||||
@ -914,6 +914,7 @@ Official recommended values https://github.com/ali-vilab/TeaCache/tree/main/TeaC
|
|||||||
logging.info("\nResetting TeaCache state")
|
logging.info("\nResetting TeaCache state")
|
||||||
|
|
||||||
current_percent = current_step_index / (len(sigmas) - 1)
|
current_percent = current_step_index / (len(sigmas) - 1)
|
||||||
|
c["transformer_options"]["current_percent"] = current_percent
|
||||||
if start_percent <= current_percent <= end_percent:
|
if start_percent <= current_percent <= end_percent:
|
||||||
c["transformer_options"]["teacache_enabled"] = True
|
c["transformer_options"]["teacache_enabled"] = True
|
||||||
|
|
||||||
@ -921,8 +922,7 @@ Official recommended values https://github.com/ali-vilab/TeaCache/tree/main/TeaC
|
|||||||
diffusion_model,
|
diffusion_model,
|
||||||
forward_orig=teacache_wanvideo_forward_orig.__get__(diffusion_model, diffusion_model.__class__)
|
forward_orig=teacache_wanvideo_forward_orig.__get__(diffusion_model, diffusion_model.__class__)
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
context = nullcontext()
|
|
||||||
with context:
|
with context:
|
||||||
out = model_function(input, timestep, **c)
|
out = model_function(input, timestep, **c)
|
||||||
if current_step_index+1 == last_step and hasattr(diffusion_model, "teacache_state"):
|
if current_step_index+1 == last_step and hasattr(diffusion_model, "teacache_state"):
|
||||||
@ -1073,3 +1073,61 @@ class WanVideoEnhanceAVideoKJ:
|
|||||||
model_clone.add_object_patch(f"diffusion_model.blocks.{idx}.self_attn.forward", self_attn)
|
model_clone.add_object_patch(f"diffusion_model.blocks.{idx}.self_attn.forward", self_attn)
|
||||||
|
|
||||||
return (model_clone,)
|
return (model_clone,)
|
||||||
|
|
||||||
|
class SkipLayerGuidanceWanVideo:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"model": ("MODEL", ),
|
||||||
|
"blocks": ("STRING", {"default": "10", "multiline": False}),
|
||||||
|
"start_percent": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
|
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "slg"
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
DESCRIPTION = "Simplified skip layer guidance that only skips the uncond on selected blocks"
|
||||||
|
|
||||||
|
CATEGORY = "advanced/guidance"
|
||||||
|
|
||||||
|
def slg(self, model, start_percent, end_percent, blocks):
|
||||||
|
def skip(args, extra_args):
|
||||||
|
transformer_options = extra_args.get("transformer_options", {})
|
||||||
|
if not transformer_options:
|
||||||
|
raise ValueError("transformer_options not found in extra_args, currrently SkipLayerGuidanceWanVideo only works with TeaCacheKJ")
|
||||||
|
if start_percent <= transformer_options["current_percent"] <= end_percent:
|
||||||
|
if args["img"].shape[0] == 2:
|
||||||
|
prev_img_uncond = args["img"][0].unsqueeze(0)
|
||||||
|
|
||||||
|
new_args = {
|
||||||
|
"img": args["img"][1],
|
||||||
|
"txt": args["txt"][1],
|
||||||
|
"vec": args["vec"][1],
|
||||||
|
"pe": args["pe"][1]
|
||||||
|
}
|
||||||
|
block_out = extra_args["original_block"](new_args)
|
||||||
|
|
||||||
|
out = {
|
||||||
|
"img": torch.cat([prev_img_uncond, block_out["img"]], dim=0),
|
||||||
|
"txt": args["txt"],
|
||||||
|
"vec": args["vec"],
|
||||||
|
"pe": args["pe"]
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
if transformer_options.get("cond_or_uncond") == [0]:
|
||||||
|
out = extra_args["original_block"](args)
|
||||||
|
else:
|
||||||
|
out = args
|
||||||
|
else:
|
||||||
|
out = extra_args["original_block"](args)
|
||||||
|
return out
|
||||||
|
|
||||||
|
block_list = [int(x.strip()) for x in blocks.split(",")]
|
||||||
|
double_layers = [int(i) for i in block_list]
|
||||||
|
logging.info(f"Selected blocks to skip uncond on: {double_layers}")
|
||||||
|
|
||||||
|
m = model.clone()
|
||||||
|
|
||||||
|
for layer in double_layers:
|
||||||
|
m.set_model_patch_replace(skip, "dit", "double_block", layer)
|
||||||
|
|
||||||
|
return (m, )
|
||||||
Loading…
x
Reference in New Issue
Block a user