mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-10 05:15:05 +08:00
Fix Scheduled CFG guider for batched timesteps
This commit is contained in:
parent
50e7dd34d3
commit
b44992b7e0
@ -2325,14 +2325,18 @@ class Guider_ScheduledCFG(CFGGuider):
|
||||
|
||||
def predict_noise(self, x, timestep, model_options={}, seed=None):
|
||||
steps = model_options["transformer_options"]["sample_sigmas"]
|
||||
matched_step_index = (steps == timestep).nonzero()
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep_value = timestep.reshape(-1)[0].to(steps)
|
||||
else:
|
||||
timestep_value = torch.tensor(timestep, device=steps.device, dtype=steps.dtype)
|
||||
matched_step_index = torch.isclose(steps, timestep_value).nonzero()
|
||||
assert not (isinstance(self.cfg, list) and len(self.cfg) != (len(steps) - 1)), "cfg list length must match step count"
|
||||
if len(matched_step_index) > 0:
|
||||
current_step_index = matched_step_index.item()
|
||||
else:
|
||||
for i in range(len(steps) - 1):
|
||||
# walk from beginning of steps until crossing the timestep
|
||||
if (steps[i] - timestep[0]) * (steps[i + 1] - timestep[0]) <= 0:
|
||||
if (steps[i] - timestep_value) * (steps[i + 1] - timestep_value) <= 0:
|
||||
current_step_index = i
|
||||
break
|
||||
else:
|
||||
@ -2699,4 +2703,4 @@ class LatentInpaintTTM:
|
||||
def patch(self, model, steps, mask=None):
|
||||
m = model.clone()
|
||||
m.add_wrapper_with_key(WrappersMP.SAMPLER_SAMPLE, "TTM_SampleWrapper", TTM_SampleWrapper(mask, steps))
|
||||
return (m, )
|
||||
return (m, )
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user