rename CFGZeroStar to avoid conflict

This commit is contained in:
kijai 2025-03-26 12:29:56 +02:00
parent f77c0dc5d8
commit 52c2e31a90
2 changed files with 16 additions and 16 deletions

View File

@ -189,7 +189,7 @@ NODE_CONFIG = {
"SkipLayerGuidanceWanVideo": {"class": SkipLayerGuidanceWanVideo, "name": "Skip Layer Guidance WanVideo"},
"TimerNodeKJ": {"class": TimerNodeKJ, "name": "Timer Node KJ"},
"HunyuanVideoEncodeKeyframesToCond": {"class": HunyuanVideoEncodeKeyframesToCond, "name": "HunyuanVideo Encode Keyframes To Cond"},
"CFGZeroStar": {"class": CFGZeroStar, "name": "CFG Zero Star"},
"CFGZeroStarAndInit": {"class": CFGZeroStarAndInit, "name": "CFG Zero Star/Init"},
#instance diffusion
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},

View File

@ -1173,7 +1173,7 @@ class SkipLayerGuidanceWanVideo:
return (m, )
class CFGZeroStar:
class CFGZeroStarAndInit:
@classmethod
def INPUT_TYPES(s):
return {"required": {
@ -1195,17 +1195,6 @@ class CFGZeroStar:
timestep = args["timestep"]
sigmas = args["model_options"]["transformer_options"]["sample_sigmas"]
batch_size = cond.shape[0]
noise_pred_text = cond
positive_flat = noise_pred_text.view(batch_size, -1)
negative_flat = uncond.view(batch_size, -1)
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
alpha = dot_product / squared_norm
alpha = alpha.view(batch_size, *([1] * (len(cond.shape) - 1)))
matched_step_index = (sigmas == timestep[0] ).nonzero()
if len(matched_step_index) > 0:
current_step_index = matched_step_index.item()
@ -1219,9 +1208,20 @@ class CFGZeroStar:
current_step_index = 0
if (current_step_index <= zero_star_steps) and use_zero_init:
noise_pred = noise_pred_text * 0
else:
noise_pred = uncond * alpha + cond_scale * (noise_pred_text - uncond * alpha)
return cond * 0
batch_size = cond.shape[0]
noise_pred_text = cond
positive_flat = noise_pred_text.view(batch_size, -1)
negative_flat = uncond.view(batch_size, -1)
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
alpha = dot_product / squared_norm
alpha = alpha.view(batch_size, *([1] * (len(cond.shape) - 1)))
noise_pred = uncond * alpha + cond_scale * (noise_pred_text - uncond * alpha)
return noise_pred
m = model.clone()