diff --git a/nodes/nodes.py b/nodes/nodes.py index 76a1fa3..0a0b726 100644 --- a/nodes/nodes.py +++ b/nodes/nodes.py @@ -1646,7 +1646,7 @@ class LoadICLightUnet: CATEGORY = "KJNodes/experimental" def load(self, model, model_path): - print("LoadICLightUnet: Checking ResAdapter path") + print("LoadICLightUnet: Checking LoadICLightUnet path") model_full_path = folder_paths.get_full_path("unet", model_path) if not os.path.exists(model_full_path): raise Exception("Invalid model path") @@ -1657,6 +1657,7 @@ class LoadICLightUnet: conv_layer = model_clone.model.diffusion_model.input_blocks[0][0] print(f"Current number of input channels: {conv_layer.in_channels}") + # Create a new Conv2d layer with 8 input channels new_conv_layer = torch.nn.Conv2d(8, conv_layer.out_channels, kernel_size=conv_layer.kernel_size, stride=conv_layer.stride, padding=conv_layer.padding) new_conv_layer.weight.zero_() @@ -1668,7 +1669,19 @@ class LoadICLightUnet: model_clone.model.diffusion_model.input_blocks[0][0] = new_conv_layer # Verify the change print(f"New number of input channels: {model_clone.model.diffusion_model.input_blocks[0][0].in_channels}") + + # Monkey patch because I don't know what I'm doing + from comfy.model_base import IP2P + import types + base_model_instance = model_clone.model + # Dynamically add the extra_conds method from IP2P to the instance of BaseModel + def bound_extra_conds(self, **kwargs): + return IP2P.extra_conds(self, **kwargs) + base_model_instance.process_ip2p_image_in = lambda image: image + base_model_instance.extra_conds = types.MethodType(bound_extra_conds, base_model_instance) + + # Some Proper patching new_state_dict = load_torch_file(model_full_path) prefix_to_remove = 'model.' new_keys_dict = {key[len(prefix_to_remove):]: new_state_dict[key] for key in new_state_dict if key.startswith(prefix_to_remove)} @@ -1677,11 +1690,9 @@ class LoadICLightUnet: try: for key in new_keys_dict: model_clone.add_patches({key: (new_keys_dict[key],)}, 1.0, 1.0) - #print(f"Added patch for: {key}") except: raise Exception("Could not patch model") print("LoadICLightUnet: Added LoadICLightUnet patches") - #model_clone.model.diffusion_model.in_channels = 8 return (model_clone, )