mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-06-01 02:27:07 +08:00
rename CFGZeroStar to avoid conflict
This commit is contained in:
parent
f77c0dc5d8
commit
52c2e31a90
@ -189,7 +189,7 @@ NODE_CONFIG = {
|
|||||||
"SkipLayerGuidanceWanVideo": {"class": SkipLayerGuidanceWanVideo, "name": "Skip Layer Guidance WanVideo"},
|
"SkipLayerGuidanceWanVideo": {"class": SkipLayerGuidanceWanVideo, "name": "Skip Layer Guidance WanVideo"},
|
||||||
"TimerNodeKJ": {"class": TimerNodeKJ, "name": "Timer Node KJ"},
|
"TimerNodeKJ": {"class": TimerNodeKJ, "name": "Timer Node KJ"},
|
||||||
"HunyuanVideoEncodeKeyframesToCond": {"class": HunyuanVideoEncodeKeyframesToCond, "name": "HunyuanVideo Encode Keyframes To Cond"},
|
"HunyuanVideoEncodeKeyframesToCond": {"class": HunyuanVideoEncodeKeyframesToCond, "name": "HunyuanVideo Encode Keyframes To Cond"},
|
||||||
"CFGZeroStar": {"class": CFGZeroStar, "name": "CFG Zero Star"},
|
"CFGZeroStarAndInit": {"class": CFGZeroStarAndInit, "name": "CFG Zero Star/Init"},
|
||||||
|
|
||||||
#instance diffusion
|
#instance diffusion
|
||||||
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},
|
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},
|
||||||
|
|||||||
@ -1173,7 +1173,7 @@ class SkipLayerGuidanceWanVideo:
|
|||||||
|
|
||||||
return (m, )
|
return (m, )
|
||||||
|
|
||||||
class CFGZeroStar:
|
class CFGZeroStarAndInit:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": {
|
return {"required": {
|
||||||
@ -1195,17 +1195,6 @@ class CFGZeroStar:
|
|||||||
timestep = args["timestep"]
|
timestep = args["timestep"]
|
||||||
sigmas = args["model_options"]["transformer_options"]["sample_sigmas"]
|
sigmas = args["model_options"]["transformer_options"]["sample_sigmas"]
|
||||||
|
|
||||||
batch_size = cond.shape[0]
|
|
||||||
|
|
||||||
noise_pred_text = cond
|
|
||||||
positive_flat = noise_pred_text.view(batch_size, -1)
|
|
||||||
negative_flat = uncond.view(batch_size, -1)
|
|
||||||
|
|
||||||
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
|
|
||||||
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
|
|
||||||
alpha = dot_product / squared_norm
|
|
||||||
alpha = alpha.view(batch_size, *([1] * (len(cond.shape) - 1)))
|
|
||||||
|
|
||||||
matched_step_index = (sigmas == timestep[0] ).nonzero()
|
matched_step_index = (sigmas == timestep[0] ).nonzero()
|
||||||
if len(matched_step_index) > 0:
|
if len(matched_step_index) > 0:
|
||||||
current_step_index = matched_step_index.item()
|
current_step_index = matched_step_index.item()
|
||||||
@ -1219,9 +1208,20 @@ class CFGZeroStar:
|
|||||||
current_step_index = 0
|
current_step_index = 0
|
||||||
|
|
||||||
if (current_step_index <= zero_star_steps) and use_zero_init:
|
if (current_step_index <= zero_star_steps) and use_zero_init:
|
||||||
noise_pred = noise_pred_text * 0
|
return cond * 0
|
||||||
else:
|
|
||||||
noise_pred = uncond * alpha + cond_scale * (noise_pred_text - uncond * alpha)
|
batch_size = cond.shape[0]
|
||||||
|
|
||||||
|
noise_pred_text = cond
|
||||||
|
positive_flat = noise_pred_text.view(batch_size, -1)
|
||||||
|
negative_flat = uncond.view(batch_size, -1)
|
||||||
|
|
||||||
|
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
|
||||||
|
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
|
||||||
|
alpha = dot_product / squared_norm
|
||||||
|
alpha = alpha.view(batch_size, *([1] * (len(cond.shape) - 1)))
|
||||||
|
|
||||||
|
noise_pred = uncond * alpha + cond_scale * (noise_pred_text - uncond * alpha)
|
||||||
return noise_pred
|
return noise_pred
|
||||||
|
|
||||||
m = model.clone()
|
m = model.clone()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user