fix zero_init

This commit is contained in:
kijai 2025-03-29 11:28:08 +02:00
parent fc59fff1b5
commit 916461c432

View File

@ -1179,7 +1179,7 @@ class CFGZeroStarAndInit:
return {"required": {
"model": ("MODEL",),
"use_zero_init": ("BOOLEAN", {"default": True}),
"zero_star_steps": ("INT", {"default": 0, "min": 0, "tooltip": "for zero init, starts from 0 so first step is always zeroed out if use_zero_init enabled"}),
"zero_init_steps": ("INT", {"default": 0, "min": 0, "tooltip": "for zero init, starts from 0 so first step is always zeroed out if use_zero_init enabled"}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
@ -1187,12 +1187,13 @@ class CFGZeroStarAndInit:
CATEGORY = "KJNodes/experimental"
EXPERIMENTAL = True
def patch(self, model, use_zero_init, zero_star_steps):
def patch(self, model, use_zero_init, zero_init_steps):
def cfg_zerostar(args):
#zero init
cond = args["cond"]
timestep = args["timestep"]
sigmas = args["model_options"]["transformer_options"]["sample_sigmas"]
matched_step_index = (sigmas == timestep[0] ).nonzero()
matched_step_index = (sigmas == timestep[0]).nonzero()
if len(matched_step_index) > 0:
current_step_index = matched_step_index.item()
else:
@ -1203,17 +1204,15 @@ class CFGZeroStarAndInit:
else:
current_step_index = 0
if (current_step_index <= zero_star_steps) and use_zero_init:
if (current_step_index <= zero_init_steps) and use_zero_init:
return cond * 0
cond = args["cond"]
uncond = args["uncond"]
cond_scale = args["cond_scale"]
batch_size = cond.shape[0]
noise_pred_text = cond
positive_flat = noise_pred_text.view(batch_size, -1)
positive_flat = cond.view(batch_size, -1)
negative_flat = uncond.view(batch_size, -1)
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
@ -1221,7 +1220,7 @@ class CFGZeroStarAndInit:
alpha = dot_product / squared_norm
alpha = alpha.view(batch_size, *([1] * (len(cond.shape) - 1)))
noise_pred = uncond * alpha + cond_scale * (noise_pred_text - uncond * alpha)
noise_pred = uncond * alpha + cond_scale * (cond - uncond * alpha)
return noise_pred
m = model.clone()