diff --git a/__init__.py b/__init__.py index b570161..3f248bd 100644 --- a/__init__.py +++ b/__init__.py @@ -114,6 +114,7 @@ NODE_CONFIG = { "Superprompt": {"class": Superprompt, "name": "Superprompt"}, "GLIGENTextBoxApplyBatchCoords": {"class": GLIGENTextBoxApplyBatchCoords}, "Intrinsic_lora_sampling": {"class": Intrinsic_lora_sampling, "name": "Intrinsic Lora Sampling"}, + "LoadICLightUnet": {"class": LoadICLightUnet}, #instance diffusion "CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking}, "AppendInstanceDiffusionTracking": {"class": AppendInstanceDiffusionTracking}, diff --git a/nodes/nodes.py b/nodes/nodes.py index 5d87e5b..ae4181a 100644 --- a/nodes/nodes.py +++ b/nodes/nodes.py @@ -1629,4 +1629,66 @@ If no image is provided, mode is set to text-to-image raise Exception(f"Server error: {error_data}") except json.JSONDecodeError: # If the response is not valid JSON, raise a different exception - raise Exception(f"Server error: {response.text}") \ No newline at end of file + raise Exception(f"Server error: {response.text}") + +class LoadICLightUnet: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL",), + "model_path": (folder_paths.get_filename_list("unet"), ) + } + } + + RETURN_TYPES = ("MODEL",) + FUNCTION = "load" + CATEGORY = "KJNodes/experimental" + + def load(self, model, model_path): + print("LoadICLightUnet: Checking ResAdapter path") + model_full_path = folder_paths.get_full_path("unet", model_path) + if not os.path.exists(model_full_path): + raise Exception("Invalid model path") + else: + print("LoadICLightUnet: Loading LoadICLightUnet weights") + from comfy.utils import load_torch_file + model_clone = model.clone() + new_state_dict = load_torch_file(model_full_path) + prefix_to_remove = 'model.' + new_keys_dict = {key[len(prefix_to_remove):]: new_state_dict[key] for key in new_state_dict if key.startswith(prefix_to_remove)} + + print("LoadICLightUnet: Attempting to add patches with LoadICLightUnet weights") + try: + for key in new_keys_dict: + model_clone.add_patches({key: (new_keys_dict[key],)}, 1.0, 1.0) + #print(f"Added patch for: {key}") + except: + raise Exception("Could not patch model") + print("LoadICLightUnet: Added LoadICLightUnet patches") + #model_clone.model.diffusion_model.in_channels = 8 + + return (model_clone, ) + + +# Change UNet + +# with torch.no_grad(): +# new_conv_in = torch.nn.Conv2d(8, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding) +# new_conv_in.weight.zero_() +# new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight) +# new_conv_in.bias = unet.conv_in.bias +# unet.conv_in = new_conv_in + +# unet_original_forward = unet.forward + + +# def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs): +# c_concat = kwargs['cross_attention_kwargs']['concat_conds'].to(sample) +# c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0) +# new_sample = torch.cat([sample, c_concat], dim=1) +# kwargs['cross_attention_kwargs'] = {} +# return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs) + + +# unet.forward = hooked_unet_forward \ No newline at end of file