From 928e5cc778a80481cd1642c1d876975df746421d Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Thu, 9 May 2024 01:48:53 +0300 Subject: [PATCH] Update nodes.py --- nodes/nodes.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/nodes/nodes.py b/nodes/nodes.py index ae4181a..76a1fa3 100644 --- a/nodes/nodes.py +++ b/nodes/nodes.py @@ -1654,6 +1654,21 @@ class LoadICLightUnet: print("LoadICLightUnet: Loading LoadICLightUnet weights") from comfy.utils import load_torch_file model_clone = model.clone() + + conv_layer = model_clone.model.diffusion_model.input_blocks[0][0] + print(f"Current number of input channels: {conv_layer.in_channels}") + # Create a new Conv2d layer with 8 input channels + new_conv_layer = torch.nn.Conv2d(8, conv_layer.out_channels, kernel_size=conv_layer.kernel_size, stride=conv_layer.stride, padding=conv_layer.padding) + new_conv_layer.weight.zero_() + new_conv_layer.weight[:, :4, :, :].copy_(conv_layer.weight) + new_conv_layer.bias = conv_layer.bias + new_conv_layer = new_conv_layer.to(torch.float16) + conv_layer.conv_in = new_conv_layer + # Replace the old layer with the new one + model_clone.model.diffusion_model.input_blocks[0][0] = new_conv_layer + # Verify the change + print(f"New number of input channels: {model_clone.model.diffusion_model.input_blocks[0][0].in_channels}") + new_state_dict = load_torch_file(model_full_path) prefix_to_remove = 'model.' new_keys_dict = {key[len(prefix_to_remove):]: new_state_dict[key] for key in new_state_dict if key.startswith(prefix_to_remove)}