mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-06-03 10:06:36 +08:00
fix zero_init
This commit is contained in:
parent
fc59fff1b5
commit
916461c432
@ -1179,7 +1179,7 @@ class CFGZeroStarAndInit:
|
|||||||
return {"required": {
|
return {"required": {
|
||||||
"model": ("MODEL",),
|
"model": ("MODEL",),
|
||||||
"use_zero_init": ("BOOLEAN", {"default": True}),
|
"use_zero_init": ("BOOLEAN", {"default": True}),
|
||||||
"zero_star_steps": ("INT", {"default": 0, "min": 0, "tooltip": "for zero init, starts from 0 so first step is always zeroed out if use_zero_init enabled"}),
|
"zero_init_steps": ("INT", {"default": 0, "min": 0, "tooltip": "for zero init, starts from 0 so first step is always zeroed out if use_zero_init enabled"}),
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("MODEL",)
|
RETURN_TYPES = ("MODEL",)
|
||||||
FUNCTION = "patch"
|
FUNCTION = "patch"
|
||||||
@ -1187,12 +1187,13 @@ class CFGZeroStarAndInit:
|
|||||||
CATEGORY = "KJNodes/experimental"
|
CATEGORY = "KJNodes/experimental"
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
def patch(self, model, use_zero_init, zero_star_steps):
|
def patch(self, model, use_zero_init, zero_init_steps):
|
||||||
def cfg_zerostar(args):
|
def cfg_zerostar(args):
|
||||||
#zero init
|
#zero init
|
||||||
|
cond = args["cond"]
|
||||||
timestep = args["timestep"]
|
timestep = args["timestep"]
|
||||||
sigmas = args["model_options"]["transformer_options"]["sample_sigmas"]
|
sigmas = args["model_options"]["transformer_options"]["sample_sigmas"]
|
||||||
matched_step_index = (sigmas == timestep[0] ).nonzero()
|
matched_step_index = (sigmas == timestep[0]).nonzero()
|
||||||
if len(matched_step_index) > 0:
|
if len(matched_step_index) > 0:
|
||||||
current_step_index = matched_step_index.item()
|
current_step_index = matched_step_index.item()
|
||||||
else:
|
else:
|
||||||
@ -1203,17 +1204,15 @@ class CFGZeroStarAndInit:
|
|||||||
else:
|
else:
|
||||||
current_step_index = 0
|
current_step_index = 0
|
||||||
|
|
||||||
if (current_step_index <= zero_star_steps) and use_zero_init:
|
if (current_step_index <= zero_init_steps) and use_zero_init:
|
||||||
return cond * 0
|
return cond * 0
|
||||||
|
|
||||||
cond = args["cond"]
|
|
||||||
uncond = args["uncond"]
|
uncond = args["uncond"]
|
||||||
cond_scale = args["cond_scale"]
|
cond_scale = args["cond_scale"]
|
||||||
|
|
||||||
batch_size = cond.shape[0]
|
batch_size = cond.shape[0]
|
||||||
|
|
||||||
noise_pred_text = cond
|
positive_flat = cond.view(batch_size, -1)
|
||||||
positive_flat = noise_pred_text.view(batch_size, -1)
|
|
||||||
negative_flat = uncond.view(batch_size, -1)
|
negative_flat = uncond.view(batch_size, -1)
|
||||||
|
|
||||||
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
|
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
|
||||||
@ -1221,7 +1220,7 @@ class CFGZeroStarAndInit:
|
|||||||
alpha = dot_product / squared_norm
|
alpha = dot_product / squared_norm
|
||||||
alpha = alpha.view(batch_size, *([1] * (len(cond.shape) - 1)))
|
alpha = alpha.view(batch_size, *([1] * (len(cond.shape) - 1)))
|
||||||
|
|
||||||
noise_pred = uncond * alpha + cond_scale * (noise_pred_text - uncond * alpha)
|
noise_pred = uncond * alpha + cond_scale * (cond - uncond * alpha)
|
||||||
return noise_pred
|
return noise_pred
|
||||||
|
|
||||||
m = model.clone()
|
m = model.clone()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user