From b44992b7e0cce69c899245435a2188cde41b738c Mon Sep 17 00:00:00 2001 From: nolbert82 Date: Mon, 8 Dec 2025 19:32:44 +0100 Subject: [PATCH] Fix Scheduled CFG guider for batched timesteps --- nodes/nodes.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/nodes/nodes.py b/nodes/nodes.py index 535289a..da6a527 100644 --- a/nodes/nodes.py +++ b/nodes/nodes.py @@ -2325,14 +2325,18 @@ class Guider_ScheduledCFG(CFGGuider): def predict_noise(self, x, timestep, model_options={}, seed=None): steps = model_options["transformer_options"]["sample_sigmas"] - matched_step_index = (steps == timestep).nonzero() + if isinstance(timestep, torch.Tensor): + timestep_value = timestep.reshape(-1)[0].to(steps) + else: + timestep_value = torch.tensor(timestep, device=steps.device, dtype=steps.dtype) + matched_step_index = torch.isclose(steps, timestep_value).nonzero() assert not (isinstance(self.cfg, list) and len(self.cfg) != (len(steps) - 1)), "cfg list length must match step count" if len(matched_step_index) > 0: current_step_index = matched_step_index.item() else: for i in range(len(steps) - 1): # walk from beginning of steps until crossing the timestep - if (steps[i] - timestep[0]) * (steps[i + 1] - timestep[0]) <= 0: + if (steps[i] - timestep_value) * (steps[i + 1] - timestep_value) <= 0: current_step_index = i break else: @@ -2699,4 +2703,4 @@ class LatentInpaintTTM: def patch(self, model, steps, mask=None): m = model.clone() m.add_wrapper_with_key(WrappersMP.SAMPLER_SAMPLE, "TTM_SampleWrapper", TTM_SampleWrapper(mask, steps)) - return (m, ) \ No newline at end of file + return (m, )