mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-09 04:44:30 +08:00
Add ScheduledCFGGuidance
This commit is contained in:
parent
8950c5fe67
commit
f3d931a630
@ -178,6 +178,7 @@ NODE_CONFIG = {
|
||||
"PathchSageAttentionKJ": {"class": PathchSageAttentionKJ, "name": "Patch Sage Attention KJ"},
|
||||
"LeapfusionHunyuanI2VPatcher": {"class": LeapfusionHunyuanI2V, "name": "Leapfusion Hunyuan I2V Patcher"},
|
||||
"VAELoaderKJ": {"class": VAELoaderKJ, "name": "VAELoader KJ"},
|
||||
"ScheduledCFGGuidance": {"class": ScheduledCFGGuidance, "name": "Scheduled CFG Guidance"},
|
||||
|
||||
#instance diffusion
|
||||
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},
|
||||
|
||||
@ -2399,4 +2399,60 @@ class VAELoaderKJ:
|
||||
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
|
||||
sd = load_torch_file(vae_path)
|
||||
vae = VAE(sd=sd, device=device, dtype=dtype)
|
||||
return (vae,)
|
||||
return (vae,)
|
||||
|
||||
from comfy.samplers import sampling_function, CFGGuider
|
||||
class Guider_ScheduledCFG(CFGGuider):
|
||||
|
||||
def set_cfg(self, cfg, start_percent, end_percent):
|
||||
self.cfg = cfg
|
||||
self.start_percent = start_percent
|
||||
self.end_percent = end_percent
|
||||
|
||||
def predict_noise(self, x, timestep, model_options={}, seed=None):
|
||||
steps = model_options["transformer_options"]["sample_sigmas"]
|
||||
matched_step_index = (steps == timestep).nonzero()
|
||||
if len(matched_step_index) > 0:
|
||||
current_step_index = matched_step_index.item()
|
||||
else:
|
||||
for i in range(len(steps) - 1):
|
||||
# walk from beginning of steps until crossing the timestep
|
||||
if (steps[i] - timestep) * (steps[i + 1] - timestep) <= 0:
|
||||
current_step_index = i
|
||||
break
|
||||
else:
|
||||
current_step_index = 0
|
||||
current_percent = current_step_index / (len(steps) - 1)
|
||||
|
||||
if self.start_percent <= current_percent <= self.end_percent:
|
||||
if isinstance(self.cfg, list):
|
||||
cfg = self.cfg[current_step_index]
|
||||
else:
|
||||
cfg = self.cfg
|
||||
uncond = self.conds.get("negative", None)
|
||||
else:
|
||||
uncond = None
|
||||
cfg = 1.0
|
||||
return sampling_function(self.inner_model, x, timestep, uncond, self.conds.get("positive", None), cfg, model_options=model_options, seed=seed)
|
||||
|
||||
class ScheduledCFGGuidance:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"model": ("MODEL",),
|
||||
"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 100.0, "step": 0.01}),
|
||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step":0.01}),
|
||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step":0.01}),
|
||||
},
|
||||
}
|
||||
RETURN_TYPES = ("GUIDER",)
|
||||
FUNCTION = "get_guider"
|
||||
CATEGORY = "KJNodes/experimental"
|
||||
|
||||
def get_guider(self, model, cfg, positive, negative, start_percent, end_percent):
|
||||
guider = Guider_ScheduledCFG(model)
|
||||
guider.set_conds(positive, negative)
|
||||
guider.set_cfg(cfg, start_percent, end_percent)
|
||||
return (guider, )
|
||||
Loading…
x
Reference in New Issue
Block a user