mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-06-04 00:51:23 +08:00
Update nodes.py
This commit is contained in:
parent
3ba3ddf0b1
commit
36a8633aff
@ -1666,6 +1666,7 @@ class CheckpointPerturbWeights:
|
|||||||
return {"required": {
|
return {"required": {
|
||||||
"model": ("MODEL",),
|
"model": ("MODEL",),
|
||||||
"joint_blocks": ("FLOAT", {"default": 0.02, "min": 0.001, "max": 10.0, "step": 0.001}),
|
"joint_blocks": ("FLOAT", {"default": 0.02, "min": 0.001, "max": 10.0, "step": 0.001}),
|
||||||
|
"final_layer": ("FLOAT", {"default": 0.02, "min": 0.001, "max": 10.0, "step": 0.001}),
|
||||||
"rest_of_the_blocks": ("FLOAT", {"default": 0.02, "min": 0.001, "max": 10.0, "step": 0.001}),
|
"rest_of_the_blocks": ("FLOAT", {"default": 0.02, "min": 0.001, "max": 10.0, "step": 0.001}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1675,7 +1676,7 @@ class CheckpointPerturbWeights:
|
|||||||
|
|
||||||
CATEGORY = "KJNodes/experimental"
|
CATEGORY = "KJNodes/experimental"
|
||||||
|
|
||||||
def mod(self, model, joint_blocks, rest_of_the_blocks):
|
def mod(self, model, joint_blocks, final_layer, rest_of_the_blocks):
|
||||||
import copy
|
import copy
|
||||||
device = model_management.get_torch_device()
|
device = model_management.get_torch_device()
|
||||||
model_copy = copy.deepcopy(model)
|
model_copy = copy.deepcopy(model)
|
||||||
@ -1688,10 +1689,12 @@ class CheckpointPerturbWeights:
|
|||||||
|
|
||||||
pbar = ProgressBar(len(keys))
|
pbar = ProgressBar(len(keys))
|
||||||
for k in keys:
|
for k in keys:
|
||||||
v = dict[k] # This is common to both conditions, so we can take it out
|
v = dict[k]
|
||||||
print(f'{k}: {v.std()}')
|
print(f'{k}: {v.std()}')
|
||||||
if k.startswith('joint_blocks'):
|
if k.startswith('joint_blocks'):
|
||||||
multiplier = joint_blocks
|
multiplier = joint_blocks
|
||||||
|
elif k.startswith('final_layer'):
|
||||||
|
multiplier = final_layer
|
||||||
else:
|
else:
|
||||||
multiplier = rest_of_the_blocks
|
multiplier = rest_of_the_blocks
|
||||||
dict[k] += torch.normal(torch.zeros_like(v) * v.mean(), torch.ones_like(v) * v.std() * multiplier).to(device)
|
dict[k] += torch.normal(torch.zeros_like(v) * v.mean(), torch.ones_like(v) * v.std() * multiplier).to(device)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user