From 36a8633aff7c958dcb67fbaf2a18c16f6d0f07bc Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sat, 15 Jun 2024 01:19:00 +0300 Subject: [PATCH] Update nodes.py --- nodes/nodes.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/nodes/nodes.py b/nodes/nodes.py index a14a093..07fe0f2 100644 --- a/nodes/nodes.py +++ b/nodes/nodes.py @@ -1666,6 +1666,7 @@ class CheckpointPerturbWeights: return {"required": { "model": ("MODEL",), "joint_blocks": ("FLOAT", {"default": 0.02, "min": 0.001, "max": 10.0, "step": 0.001}), + "final_layer": ("FLOAT", {"default": 0.02, "min": 0.001, "max": 10.0, "step": 0.001}), "rest_of_the_blocks": ("FLOAT", {"default": 0.02, "min": 0.001, "max": 10.0, "step": 0.001}), } } @@ -1675,7 +1676,7 @@ class CheckpointPerturbWeights: CATEGORY = "KJNodes/experimental" - def mod(self, model, joint_blocks, rest_of_the_blocks): + def mod(self, model, joint_blocks, final_layer, rest_of_the_blocks): import copy device = model_management.get_torch_device() model_copy = copy.deepcopy(model) @@ -1688,10 +1689,12 @@ class CheckpointPerturbWeights: pbar = ProgressBar(len(keys)) for k in keys: - v = dict[k] # This is common to both conditions, so we can take it out + v = dict[k] print(f'{k}: {v.std()}') if k.startswith('joint_blocks'): multiplier = joint_blocks + elif k.startswith('final_layer'): + multiplier = final_layer else: multiplier = rest_of_the_blocks dict[k] += torch.normal(torch.zeros_like(v) * v.mean(), torch.ones_like(v) * v.std() * multiplier).to(device)