Update nodes.py

This commit is contained in:
kijai 2024-06-15 01:29:41 +03:00
parent 36a8633aff
commit c98c94125e

View File

@ -1668,6 +1668,7 @@ class CheckpointPerturbWeights:
"joint_blocks": ("FLOAT", {"default": 0.02, "min": 0.001, "max": 10.0, "step": 0.001}),
"final_layer": ("FLOAT", {"default": 0.02, "min": 0.001, "max": 10.0, "step": 0.001}),
"rest_of_the_blocks": ("FLOAT", {"default": 0.02, "min": 0.001, "max": 10.0, "step": 0.001}),
"seed": ("INT", {"default": 123,"min": 0, "max": 0xffffffffffffffff, "step": 1}),
}
}
RETURN_TYPES = ("MODEL",)
@ -1676,8 +1677,10 @@ class CheckpointPerturbWeights:
CATEGORY = "KJNodes/experimental"
def mod(self, model, joint_blocks, final_layer, rest_of_the_blocks):
def mod(self, seed, model, joint_blocks, final_layer, rest_of_the_blocks):
import copy
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
device = model_management.get_torch_device()
model_copy = copy.deepcopy(model)
model_copy.model.to(device)