From cb6864c7c83d482dd13d09fbd009d310955ca542 Mon Sep 17 00:00:00 2001 From: Kijai <40791699+kijai@users.noreply.github.com> Date: Thu, 16 May 2024 12:38:48 +0300 Subject: [PATCH] Update nodes.py --- nodes/nodes.py | 99 +------------------------------------------------- 1 file changed, 1 insertion(+), 98 deletions(-) diff --git a/nodes/nodes.py b/nodes/nodes.py index 99f159f..440934b 100644 --- a/nodes/nodes.py +++ b/nodes/nodes.py @@ -1650,101 +1650,4 @@ If no image is provided, mode is set to text-to-image raise Exception(f"Server error: {error_data}") except json.JSONDecodeError: # If the response is not valid JSON, raise a different exception - raise Exception(f"Server error: {response.text}") - -class LoadICLightUnet: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "model": ("MODEL",), - "model_path": (folder_paths.get_filename_list("unet"), ) - } - } - - RETURN_TYPES = ("MODEL",) - FUNCTION = "load" - CATEGORY = "KJNodes/experimental" - DESCRIPTION = """ -LoadICLightUnet: Loads an ICLightUnet model. (Experimental) -WORK IN PROGRESS -Very hacky (but currently working) way to load the converted IC-Light model available here: -https://huggingface.co/Kijai/iclight-comfy/blob/main/iclight_fc_converted.safetensors - -Used with InstructPixToPixConditioning -node - -""" - - def load(self, model, model_path): - print("LoadICLightUnet: Checking LoadICLightUnet path") - model_full_path = folder_paths.get_full_path("unet", model_path) - if not os.path.exists(model_full_path): - raise Exception("Invalid model path") - else: - print("LoadICLightUnet: Loading LoadICLightUnet weights") - from comfy.utils import load_torch_file - model_clone = model.clone() - - conv_layer = model_clone.model.diffusion_model.input_blocks[0][0] - print(f"Current number of input channels: {conv_layer.in_channels}") - - # Create a new Conv2d layer with 8 input channels - new_conv_layer = torch.nn.Conv2d(8, conv_layer.out_channels, kernel_size=conv_layer.kernel_size, stride=conv_layer.stride, padding=conv_layer.padding) - new_conv_layer.weight.zero_() - new_conv_layer.weight[:, :4, :, :].copy_(conv_layer.weight) - new_conv_layer.bias = conv_layer.bias - new_conv_layer = new_conv_layer.to(torch.float16) - conv_layer.conv_in = new_conv_layer - # Replace the old layer with the new one - model_clone.model.diffusion_model.input_blocks[0][0] = new_conv_layer - # Verify the change - print(f"New number of input channels: {model_clone.model.diffusion_model.input_blocks[0][0].in_channels}") - - # Monkey patch because I don't know what I'm doing - from comfy.model_base import IP2P - import types - base_model_instance = model_clone.model - - # Dynamically add the extra_conds method from IP2P to the instance of BaseModel - def bound_extra_conds(self, **kwargs): - return IP2P.extra_conds(self, **kwargs) - base_model_instance.process_ip2p_image_in = lambda image: image - base_model_instance.extra_conds = types.MethodType(bound_extra_conds, base_model_instance) - - # Some Proper patching - new_state_dict = load_torch_file(model_full_path) - prefix_to_remove = 'model.' - new_keys_dict = {key[len(prefix_to_remove):]: new_state_dict[key] for key in new_state_dict if key.startswith(prefix_to_remove)} - - print("LoadICLightUnet: Attempting to add patches with LoadICLightUnet weights") - try: - for key in new_keys_dict: - model_clone.add_patches({key: (new_keys_dict[key],)}, 1.0, 1.0) - except: - raise Exception("Could not patch model") - print("LoadICLightUnet: Added LoadICLightUnet patches") - - return (model_clone, ) - - -# Change UNet - -# with torch.no_grad(): -# new_conv_in = torch.nn.Conv2d(8, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding) -# new_conv_in.weight.zero_() -# new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight) -# new_conv_in.bias = unet.conv_in.bias -# unet.conv_in = new_conv_in - -# unet_original_forward = unet.forward - - -# def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs): -# c_concat = kwargs['cross_attention_kwargs']['concat_conds'].to(sample) -# c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0) -# new_sample = torch.cat([sample, c_concat], dim=1) -# kwargs['cross_attention_kwargs'] = {} -# return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs) - - -# unet.forward = hooked_unet_forward \ No newline at end of file + raise Exception(f"Server error: {response.text}") \ No newline at end of file