diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index b777ba2..9072bcd 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -1179,7 +1179,7 @@ class CFGZeroStarAndInit: return {"required": { "model": ("MODEL",), "use_zero_init": ("BOOLEAN", {"default": True}), - "zero_star_steps": ("INT", {"default": 0, "min": 0, "tooltip": "for zero init, starts from 0 so first step is always zeroed out if use_zero_init enabled"}), + "zero_init_steps": ("INT", {"default": 0, "min": 0, "tooltip": "for zero init, starts from 0 so first step is always zeroed out if use_zero_init enabled"}), }} RETURN_TYPES = ("MODEL",) FUNCTION = "patch" @@ -1187,12 +1187,13 @@ class CFGZeroStarAndInit: CATEGORY = "KJNodes/experimental" EXPERIMENTAL = True - def patch(self, model, use_zero_init, zero_star_steps): + def patch(self, model, use_zero_init, zero_init_steps): def cfg_zerostar(args): #zero init + cond = args["cond"] timestep = args["timestep"] sigmas = args["model_options"]["transformer_options"]["sample_sigmas"] - matched_step_index = (sigmas == timestep[0] ).nonzero() + matched_step_index = (sigmas == timestep[0]).nonzero() if len(matched_step_index) > 0: current_step_index = matched_step_index.item() else: @@ -1203,17 +1204,15 @@ class CFGZeroStarAndInit: else: current_step_index = 0 - if (current_step_index <= zero_star_steps) and use_zero_init: + if (current_step_index <= zero_init_steps) and use_zero_init: return cond * 0 - cond = args["cond"] uncond = args["uncond"] cond_scale = args["cond_scale"] batch_size = cond.shape[0] - noise_pred_text = cond - positive_flat = noise_pred_text.view(batch_size, -1) + positive_flat = cond.view(batch_size, -1) negative_flat = uncond.view(batch_size, -1) dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) @@ -1221,7 +1220,7 @@ class CFGZeroStarAndInit: alpha = dot_product / squared_norm alpha = alpha.view(batch_size, *([1] * (len(cond.shape) - 1))) - noise_pred = uncond * alpha + cond_scale * (noise_pred_text - uncond * alpha) + noise_pred = uncond * alpha + cond_scale * (cond - uncond * alpha) return noise_pred m = model.clone()