diff --git a/__init__.py b/__init__.py index 4bf3951..6d891b6 100644 --- a/__init__.py +++ b/__init__.py @@ -189,7 +189,7 @@ NODE_CONFIG = { "SkipLayerGuidanceWanVideo": {"class": SkipLayerGuidanceWanVideo, "name": "Skip Layer Guidance WanVideo"}, "TimerNodeKJ": {"class": TimerNodeKJ, "name": "Timer Node KJ"}, "HunyuanVideoEncodeKeyframesToCond": {"class": HunyuanVideoEncodeKeyframesToCond, "name": "HunyuanVideo Encode Keyframes To Cond"}, - "CFGZeroStar": {"class": CFGZeroStar, "name": "CFG Zero Star"}, + "CFGZeroStarAndInit": {"class": CFGZeroStarAndInit, "name": "CFG Zero Star/Init"}, #instance diffusion "CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking}, diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index cc193bc..44f8c84 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -1173,7 +1173,7 @@ class SkipLayerGuidanceWanVideo: return (m, ) -class CFGZeroStar: +class CFGZeroStarAndInit: @classmethod def INPUT_TYPES(s): return {"required": { @@ -1195,17 +1195,6 @@ class CFGZeroStar: timestep = args["timestep"] sigmas = args["model_options"]["transformer_options"]["sample_sigmas"] - batch_size = cond.shape[0] - - noise_pred_text = cond - positive_flat = noise_pred_text.view(batch_size, -1) - negative_flat = uncond.view(batch_size, -1) - - dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) - squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 - alpha = dot_product / squared_norm - alpha = alpha.view(batch_size, *([1] * (len(cond.shape) - 1))) - matched_step_index = (sigmas == timestep[0] ).nonzero() if len(matched_step_index) > 0: current_step_index = matched_step_index.item() @@ -1219,9 +1208,20 @@ class CFGZeroStar: current_step_index = 0 if (current_step_index <= zero_star_steps) and use_zero_init: - noise_pred = noise_pred_text * 0 - else: - noise_pred = uncond * alpha + cond_scale * (noise_pred_text - uncond * alpha) + return cond * 0 + + batch_size = cond.shape[0] + + noise_pred_text = cond + positive_flat = noise_pred_text.view(batch_size, -1) + negative_flat = uncond.view(batch_size, -1) + + dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) + squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 + alpha = dot_product / squared_norm + alpha = alpha.view(batch_size, *([1] * (len(cond.shape) - 1))) + + noise_pred = uncond * alpha + cond_scale * (noise_pred_text - uncond * alpha) return noise_pred m = model.clone()