Update model_optimization_nodes.py

This commit is contained in:
kijai 2025-03-26 01:14:16 +02:00
parent b1ec996ba3
commit f77c0dc5d8

View File

@ -1204,7 +1204,7 @@ class CFGZeroStar:
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
alpha = dot_product / squared_norm
alpha = alpha.view(batch_size, 1, 1, 1, 1)
alpha = alpha.view(batch_size, *([1] * (len(cond.shape) - 1)))
matched_step_index = (sigmas == timestep[0] ).nonzero()
if len(matched_step_index) > 0: