Add ScheduledCFGGuidance

This commit is contained in:
kijai 2025-02-22 15:43:58 +02:00
parent 8950c5fe67
commit f3d931a630
2 changed files with 58 additions and 1 deletions

View File

@ -178,6 +178,7 @@ NODE_CONFIG = {
"PathchSageAttentionKJ": {"class": PathchSageAttentionKJ, "name": "Patch Sage Attention KJ"},
"LeapfusionHunyuanI2VPatcher": {"class": LeapfusionHunyuanI2V, "name": "Leapfusion Hunyuan I2V Patcher"},
"VAELoaderKJ": {"class": VAELoaderKJ, "name": "VAELoader KJ"},
"ScheduledCFGGuidance": {"class": ScheduledCFGGuidance, "name": "Scheduled CFG Guidance"},
#instance diffusion
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},

View File

@ -2399,4 +2399,60 @@ class VAELoaderKJ:
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
sd = load_torch_file(vae_path)
vae = VAE(sd=sd, device=device, dtype=dtype)
return (vae,)
return (vae,)
from comfy.samplers import sampling_function, CFGGuider
class Guider_ScheduledCFG(CFGGuider):
def set_cfg(self, cfg, start_percent, end_percent):
self.cfg = cfg
self.start_percent = start_percent
self.end_percent = end_percent
def predict_noise(self, x, timestep, model_options={}, seed=None):
steps = model_options["transformer_options"]["sample_sigmas"]
matched_step_index = (steps == timestep).nonzero()
if len(matched_step_index) > 0:
current_step_index = matched_step_index.item()
else:
for i in range(len(steps) - 1):
# walk from beginning of steps until crossing the timestep
if (steps[i] - timestep) * (steps[i + 1] - timestep) <= 0:
current_step_index = i
break
else:
current_step_index = 0
current_percent = current_step_index / (len(steps) - 1)
if self.start_percent <= current_percent <= self.end_percent:
if isinstance(self.cfg, list):
cfg = self.cfg[current_step_index]
else:
cfg = self.cfg
uncond = self.conds.get("negative", None)
else:
uncond = None
cfg = 1.0
return sampling_function(self.inner_model, x, timestep, uncond, self.conds.get("positive", None), cfg, model_options=model_options, seed=seed)
class ScheduledCFGGuidance:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"model": ("MODEL",),
"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 100.0, "step": 0.01}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step":0.01}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step":0.01}),
},
}
RETURN_TYPES = ("GUIDER",)
FUNCTION = "get_guider"
CATEGORY = "KJNodes/experimental"
def get_guider(self, model, cfg, positive, negative, start_percent, end_percent):
guider = Guider_ScheduledCFG(model)
guider.set_conds(positive, negative)
guider.set_cfg(cfg, start_percent, end_percent)
return (guider, )