diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index 2cf00e6..cc193bc 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -1204,7 +1204,7 @@ class CFGZeroStar: dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 alpha = dot_product / squared_norm - alpha = alpha.view(batch_size, 1, 1, 1, 1) + alpha = alpha.view(batch_size, *([1] * (len(cond.shape) - 1))) matched_step_index = (sigmas == timestep[0] ).nonzero() if len(matched_step_index) > 0: