Fix Scheduled CFG guider for batched timesteps

This commit is contained in:
nolbert82 2025-12-08 19:32:44 +01:00
parent 50e7dd34d3
commit b44992b7e0

View File

@ -2325,14 +2325,18 @@ class Guider_ScheduledCFG(CFGGuider):
def predict_noise(self, x, timestep, model_options={}, seed=None): def predict_noise(self, x, timestep, model_options={}, seed=None):
steps = model_options["transformer_options"]["sample_sigmas"] steps = model_options["transformer_options"]["sample_sigmas"]
matched_step_index = (steps == timestep).nonzero() if isinstance(timestep, torch.Tensor):
timestep_value = timestep.reshape(-1)[0].to(steps)
else:
timestep_value = torch.tensor(timestep, device=steps.device, dtype=steps.dtype)
matched_step_index = torch.isclose(steps, timestep_value).nonzero()
assert not (isinstance(self.cfg, list) and len(self.cfg) != (len(steps) - 1)), "cfg list length must match step count" assert not (isinstance(self.cfg, list) and len(self.cfg) != (len(steps) - 1)), "cfg list length must match step count"
if len(matched_step_index) > 0: if len(matched_step_index) > 0:
current_step_index = matched_step_index.item() current_step_index = matched_step_index.item()
else: else:
for i in range(len(steps) - 1): for i in range(len(steps) - 1):
# walk from beginning of steps until crossing the timestep # walk from beginning of steps until crossing the timestep
if (steps[i] - timestep[0]) * (steps[i + 1] - timestep[0]) <= 0: if (steps[i] - timestep_value) * (steps[i + 1] - timestep_value) <= 0:
current_step_index = i current_step_index = i
break break
else: else:
@ -2699,4 +2703,4 @@ class LatentInpaintTTM:
def patch(self, model, steps, mask=None): def patch(self, model, steps, mask=None):
m = model.clone() m = model.clone()
m.add_wrapper_with_key(WrappersMP.SAMPLER_SAMPLE, "TTM_SampleWrapper", TTM_SampleWrapper(mask, steps)) m.add_wrapper_with_key(WrappersMP.SAMPLER_SAMPLE, "TTM_SampleWrapper", TTM_SampleWrapper(mask, steps))
return (m, ) return (m, )