mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-10 21:34:43 +08:00
Fix Scheduled CFG guider for batched timesteps
This commit is contained in:
parent
50e7dd34d3
commit
b44992b7e0
@ -2325,14 +2325,18 @@ class Guider_ScheduledCFG(CFGGuider):
|
|||||||
|
|
||||||
def predict_noise(self, x, timestep, model_options={}, seed=None):
|
def predict_noise(self, x, timestep, model_options={}, seed=None):
|
||||||
steps = model_options["transformer_options"]["sample_sigmas"]
|
steps = model_options["transformer_options"]["sample_sigmas"]
|
||||||
matched_step_index = (steps == timestep).nonzero()
|
if isinstance(timestep, torch.Tensor):
|
||||||
|
timestep_value = timestep.reshape(-1)[0].to(steps)
|
||||||
|
else:
|
||||||
|
timestep_value = torch.tensor(timestep, device=steps.device, dtype=steps.dtype)
|
||||||
|
matched_step_index = torch.isclose(steps, timestep_value).nonzero()
|
||||||
assert not (isinstance(self.cfg, list) and len(self.cfg) != (len(steps) - 1)), "cfg list length must match step count"
|
assert not (isinstance(self.cfg, list) and len(self.cfg) != (len(steps) - 1)), "cfg list length must match step count"
|
||||||
if len(matched_step_index) > 0:
|
if len(matched_step_index) > 0:
|
||||||
current_step_index = matched_step_index.item()
|
current_step_index = matched_step_index.item()
|
||||||
else:
|
else:
|
||||||
for i in range(len(steps) - 1):
|
for i in range(len(steps) - 1):
|
||||||
# walk from beginning of steps until crossing the timestep
|
# walk from beginning of steps until crossing the timestep
|
||||||
if (steps[i] - timestep[0]) * (steps[i + 1] - timestep[0]) <= 0:
|
if (steps[i] - timestep_value) * (steps[i + 1] - timestep_value) <= 0:
|
||||||
current_step_index = i
|
current_step_index = i
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user