diff --git a/__init__.py b/__init__.py index 1f5de7b..9b5d6bf 100644 --- a/__init__.py +++ b/__init__.py @@ -127,6 +127,7 @@ NODE_CONFIG = { "Superprompt": {"class": Superprompt, "name": "Superprompt"}, "GLIGENTextBoxApplyBatchCoords": {"class": GLIGENTextBoxApplyBatchCoords}, "Intrinsic_lora_sampling": {"class": Intrinsic_lora_sampling, "name": "Intrinsic Lora Sampling"}, + "CheckpointPerturbWeights": {"class": CheckpointPerturbWeights, "name": "CheckpointPerturbWeights"}, #instance diffusion "CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking}, diff --git a/nodes/nodes.py b/nodes/nodes.py index 3afb26c..a14a093 100644 --- a/nodes/nodes.py +++ b/nodes/nodes.py @@ -9,7 +9,7 @@ import json, re, os, io, time import model_management import folder_paths from nodes import MAX_RESOLUTION -from comfy.utils import common_upscale +from comfy.utils import common_upscale, ProgressBar script_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) folder_paths.add_model_folder_path("kjnodes_fonts", os.path.join(script_directory, "fonts")) @@ -1657,4 +1657,44 @@ If no image is provided, mode is set to text-to-image raise Exception(f"Server error: {error_data}") except json.JSONDecodeError: # If the response is not valid JSON, raise a different exception - raise Exception(f"Server error: {response.text}") \ No newline at end of file + raise Exception(f"Server error: {response.text}") + +class CheckpointPerturbWeights: + + @classmethod + def INPUT_TYPES(s): + return {"required": { + "model": ("MODEL",), + "joint_blocks": ("FLOAT", {"default": 0.02, "min": 0.001, "max": 10.0, "step": 0.001}), + "rest_of_the_blocks": ("FLOAT", {"default": 0.02, "min": 0.001, "max": 10.0, "step": 0.001}), + } + } + RETURN_TYPES = ("MODEL",) + FUNCTION = "mod" + OUTPUT_NODE = True + + CATEGORY = "KJNodes/experimental" + + def mod(self, model, joint_blocks, rest_of_the_blocks): + import copy + device = model_management.get_torch_device() + model_copy = copy.deepcopy(model) + model_copy.model.to(device) + keys = model_copy.model.diffusion_model.state_dict().keys() + + dict = {} + for key in keys: + dict[key] = model_copy.model.diffusion_model.state_dict()[key] + + pbar = ProgressBar(len(keys)) + for k in keys: + v = dict[k] # This is common to both conditions, so we can take it out + print(f'{k}: {v.std()}') + if k.startswith('joint_blocks'): + multiplier = joint_blocks + else: + multiplier = rest_of_the_blocks + dict[k] += torch.normal(torch.zeros_like(v) * v.mean(), torch.ones_like(v) * v.std() * multiplier).to(device) + pbar.update(1) + model_copy.model.diffusion_model.load_state_dict(dict) + return model_copy, \ No newline at end of file