mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-01-06 22:36:37 +08:00
Update nodes.py
This commit is contained in:
parent
3ba3ddf0b1
commit
36a8633aff
@ -1666,6 +1666,7 @@ class CheckpointPerturbWeights:
|
||||
return {"required": {
|
||||
"model": ("MODEL",),
|
||||
"joint_blocks": ("FLOAT", {"default": 0.02, "min": 0.001, "max": 10.0, "step": 0.001}),
|
||||
"final_layer": ("FLOAT", {"default": 0.02, "min": 0.001, "max": 10.0, "step": 0.001}),
|
||||
"rest_of_the_blocks": ("FLOAT", {"default": 0.02, "min": 0.001, "max": 10.0, "step": 0.001}),
|
||||
}
|
||||
}
|
||||
@ -1675,7 +1676,7 @@ class CheckpointPerturbWeights:
|
||||
|
||||
CATEGORY = "KJNodes/experimental"
|
||||
|
||||
def mod(self, model, joint_blocks, rest_of_the_blocks):
|
||||
def mod(self, model, joint_blocks, final_layer, rest_of_the_blocks):
|
||||
import copy
|
||||
device = model_management.get_torch_device()
|
||||
model_copy = copy.deepcopy(model)
|
||||
@ -1688,10 +1689,12 @@ class CheckpointPerturbWeights:
|
||||
|
||||
pbar = ProgressBar(len(keys))
|
||||
for k in keys:
|
||||
v = dict[k] # This is common to both conditions, so we can take it out
|
||||
v = dict[k]
|
||||
print(f'{k}: {v.std()}')
|
||||
if k.startswith('joint_blocks'):
|
||||
multiplier = joint_blocks
|
||||
elif k.startswith('final_layer'):
|
||||
multiplier = final_layer
|
||||
else:
|
||||
multiplier = rest_of_the_blocks
|
||||
dict[k] += torch.normal(torch.zeros_like(v) * v.mean(), torch.ones_like(v) * v.std() * multiplier).to(device)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user