Add CFGZeroStar

https://github.com/WeichenFan/CFG-Zero-star/
This commit is contained in:
kijai 2025-03-26 01:05:20 +02:00
parent a5bd3c86c8
commit b1ec996ba3
2 changed files with 64 additions and 13 deletions

View File

@ -189,6 +189,7 @@ NODE_CONFIG = {
"SkipLayerGuidanceWanVideo": {"class": SkipLayerGuidanceWanVideo, "name": "Skip Layer Guidance WanVideo"},
"TimerNodeKJ": {"class": TimerNodeKJ, "name": "Timer Node KJ"},
"HunyuanVideoEncodeKeyframesToCond": {"class": HunyuanVideoEncodeKeyframesToCond, "name": "HunyuanVideo Encode Keyframes To Cond"},
"CFGZeroStar": {"class": CFGZeroStar, "name": "CFG Zero Star"},
#instance diffusion
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},

View File

@ -485,19 +485,13 @@ class TorchCompileModelWanVideo:
def patch(self, model, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit, compile_transformer_blocks_only):
m = model.clone()
diffusion_model = m.get_model_object("diffusion_model")
torch._dynamo.config.cache_size_limit = dynamo_cache_size_limit
is_compiled = hasattr(model.model.diffusion_model.blocks[0], "_orig_mod")
if is_compiled:
logging.info(f"Already compiled, not reapplying")
else:
logging.info(f"Not compiled, applying")
torch._dynamo.config.cache_size_limit = dynamo_cache_size_limit
try:
if compile_transformer_blocks_only:
for i, block in enumerate(diffusion_model.blocks):
if is_compiled:
compiled_block = torch.compile(block._orig_mod, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode)
else:
compiled_block = torch.compile(block, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode)
if hasattr(block, "_orig_mod"):
block = block._orig_mod
compiled_block = torch.compile(block, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode)
m.add_object_patch(f"diffusion_model.blocks.{i}", compiled_block)
else:
compiled_model = torch.compile(diffusion_model, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode)
@ -924,9 +918,10 @@ Official recommended values https://github.com/ali-vilab/TeaCache/tree/main/TeaC
current_step_index = 0
if current_step_index == 0:
if hasattr(diffusion_model, "teacache_state"):
delattr(diffusion_model, "teacache_state")
logging.info("\nResetting TeaCache state")
if (len(cond_or_uncond) == 1 and cond_or_uncond[0] == 1) or len(cond_or_uncond) == 2:
if hasattr(diffusion_model, "teacache_state"):
delattr(diffusion_model, "teacache_state")
logging.info("\nResetting TeaCache state")
current_percent = current_step_index / (len(sigmas) - 1)
c["transformer_options"]["current_percent"] = current_percent
@ -1176,4 +1171,59 @@ class SkipLayerGuidanceWanVideo:
m.model_options["transformer_options"] = model_options
return (m, )
class CFGZeroStar:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"model": ("MODEL",),
"use_zero_init": ("BOOLEAN", {"default": True}),
"zero_star_steps": ("INT", {"default": 0, "min": 0, "tooltip": "for zero init, starts from 0 so first step is always zeroed out if use_zero_init enabled"}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
DESCRIPTION = "https://github.com/WeichenFan/CFG-Zero-star"
CATEGORY = "KJNodes/experimental"
EXPERIMENTAL = True
def patch(self, model, use_zero_init, zero_star_steps):
def cfg_zerostar(args):
cond = args["cond"]
uncond = args["uncond"]
cond_scale = args["cond_scale"]
timestep = args["timestep"]
sigmas = args["model_options"]["transformer_options"]["sample_sigmas"]
batch_size = cond.shape[0]
noise_pred_text = cond
positive_flat = noise_pred_text.view(batch_size, -1)
negative_flat = uncond.view(batch_size, -1)
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
alpha = dot_product / squared_norm
alpha = alpha.view(batch_size, 1, 1, 1, 1)
matched_step_index = (sigmas == timestep[0] ).nonzero()
if len(matched_step_index) > 0:
current_step_index = matched_step_index.item()
else:
for i in range(len(sigmas) - 1):
# walk from beginning of steps until crossing the timestep
if (sigmas[i] - timestep[0]) * (sigmas[i + 1] - timestep[0]) <= 0:
current_step_index = i
break
else:
current_step_index = 0
if (current_step_index <= zero_star_steps) and use_zero_init:
noise_pred = noise_pred_text * 0
else:
noise_pred = uncond * alpha + cond_scale * (noise_pred_text - uncond * alpha)
return noise_pred
m = model.clone()
m.set_model_sampler_cfg_function(cfg_zerostar)
return (m, )