Update nodes.py

This commit is contained in:
kijai 2024-06-15 01:19:00 +03:00
parent 3ba3ddf0b1
commit 36a8633aff

View File

@ -1666,6 +1666,7 @@ class CheckpointPerturbWeights:
return {"required": {
"model": ("MODEL",),
"joint_blocks": ("FLOAT", {"default": 0.02, "min": 0.001, "max": 10.0, "step": 0.001}),
"final_layer": ("FLOAT", {"default": 0.02, "min": 0.001, "max": 10.0, "step": 0.001}),
"rest_of_the_blocks": ("FLOAT", {"default": 0.02, "min": 0.001, "max": 10.0, "step": 0.001}),
}
}
@ -1675,7 +1676,7 @@ class CheckpointPerturbWeights:
CATEGORY = "KJNodes/experimental"
def mod(self, model, joint_blocks, rest_of_the_blocks):
def mod(self, model, joint_blocks, final_layer, rest_of_the_blocks):
import copy
device = model_management.get_torch_device()
model_copy = copy.deepcopy(model)
@ -1688,10 +1689,12 @@ class CheckpointPerturbWeights:
pbar = ProgressBar(len(keys))
for k in keys:
v = dict[k] # This is common to both conditions, so we can take it out
v = dict[k]
print(f'{k}: {v.std()}')
if k.startswith('joint_blocks'):
multiplier = joint_blocks
elif k.startswith('final_layer'):
multiplier = final_layer
else:
multiplier = rest_of_the_blocks
dict[k] += torch.normal(torch.zeros_like(v) * v.mean(), torch.ones_like(v) * v.std() * multiplier).to(device)