mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-10 21:34:43 +08:00
Update nodes.py
This commit is contained in:
parent
2f6e38220c
commit
cb6864c7c8
@ -1651,100 +1651,3 @@ If no image is provided, mode is set to text-to-image
|
|||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
# If the response is not valid JSON, raise a different exception
|
# If the response is not valid JSON, raise a different exception
|
||||||
raise Exception(f"Server error: {response.text}")
|
raise Exception(f"Server error: {response.text}")
|
||||||
|
|
||||||
class LoadICLightUnet:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"model": ("MODEL",),
|
|
||||||
"model_path": (folder_paths.get_filename_list("unet"), )
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("MODEL",)
|
|
||||||
FUNCTION = "load"
|
|
||||||
CATEGORY = "KJNodes/experimental"
|
|
||||||
DESCRIPTION = """
|
|
||||||
LoadICLightUnet: Loads an ICLightUnet model. (Experimental)
|
|
||||||
WORK IN PROGRESS
|
|
||||||
Very hacky (but currently working) way to load the converted IC-Light model available here:
|
|
||||||
https://huggingface.co/Kijai/iclight-comfy/blob/main/iclight_fc_converted.safetensors
|
|
||||||
|
|
||||||
Used with InstructPixToPixConditioning -node
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def load(self, model, model_path):
|
|
||||||
print("LoadICLightUnet: Checking LoadICLightUnet path")
|
|
||||||
model_full_path = folder_paths.get_full_path("unet", model_path)
|
|
||||||
if not os.path.exists(model_full_path):
|
|
||||||
raise Exception("Invalid model path")
|
|
||||||
else:
|
|
||||||
print("LoadICLightUnet: Loading LoadICLightUnet weights")
|
|
||||||
from comfy.utils import load_torch_file
|
|
||||||
model_clone = model.clone()
|
|
||||||
|
|
||||||
conv_layer = model_clone.model.diffusion_model.input_blocks[0][0]
|
|
||||||
print(f"Current number of input channels: {conv_layer.in_channels}")
|
|
||||||
|
|
||||||
# Create a new Conv2d layer with 8 input channels
|
|
||||||
new_conv_layer = torch.nn.Conv2d(8, conv_layer.out_channels, kernel_size=conv_layer.kernel_size, stride=conv_layer.stride, padding=conv_layer.padding)
|
|
||||||
new_conv_layer.weight.zero_()
|
|
||||||
new_conv_layer.weight[:, :4, :, :].copy_(conv_layer.weight)
|
|
||||||
new_conv_layer.bias = conv_layer.bias
|
|
||||||
new_conv_layer = new_conv_layer.to(torch.float16)
|
|
||||||
conv_layer.conv_in = new_conv_layer
|
|
||||||
# Replace the old layer with the new one
|
|
||||||
model_clone.model.diffusion_model.input_blocks[0][0] = new_conv_layer
|
|
||||||
# Verify the change
|
|
||||||
print(f"New number of input channels: {model_clone.model.diffusion_model.input_blocks[0][0].in_channels}")
|
|
||||||
|
|
||||||
# Monkey patch because I don't know what I'm doing
|
|
||||||
from comfy.model_base import IP2P
|
|
||||||
import types
|
|
||||||
base_model_instance = model_clone.model
|
|
||||||
|
|
||||||
# Dynamically add the extra_conds method from IP2P to the instance of BaseModel
|
|
||||||
def bound_extra_conds(self, **kwargs):
|
|
||||||
return IP2P.extra_conds(self, **kwargs)
|
|
||||||
base_model_instance.process_ip2p_image_in = lambda image: image
|
|
||||||
base_model_instance.extra_conds = types.MethodType(bound_extra_conds, base_model_instance)
|
|
||||||
|
|
||||||
# Some Proper patching
|
|
||||||
new_state_dict = load_torch_file(model_full_path)
|
|
||||||
prefix_to_remove = 'model.'
|
|
||||||
new_keys_dict = {key[len(prefix_to_remove):]: new_state_dict[key] for key in new_state_dict if key.startswith(prefix_to_remove)}
|
|
||||||
|
|
||||||
print("LoadICLightUnet: Attempting to add patches with LoadICLightUnet weights")
|
|
||||||
try:
|
|
||||||
for key in new_keys_dict:
|
|
||||||
model_clone.add_patches({key: (new_keys_dict[key],)}, 1.0, 1.0)
|
|
||||||
except:
|
|
||||||
raise Exception("Could not patch model")
|
|
||||||
print("LoadICLightUnet: Added LoadICLightUnet patches")
|
|
||||||
|
|
||||||
return (model_clone, )
|
|
||||||
|
|
||||||
|
|
||||||
# Change UNet
|
|
||||||
|
|
||||||
# with torch.no_grad():
|
|
||||||
# new_conv_in = torch.nn.Conv2d(8, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding)
|
|
||||||
# new_conv_in.weight.zero_()
|
|
||||||
# new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
|
|
||||||
# new_conv_in.bias = unet.conv_in.bias
|
|
||||||
# unet.conv_in = new_conv_in
|
|
||||||
|
|
||||||
# unet_original_forward = unet.forward
|
|
||||||
|
|
||||||
|
|
||||||
# def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
|
|
||||||
# c_concat = kwargs['cross_attention_kwargs']['concat_conds'].to(sample)
|
|
||||||
# c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0)
|
|
||||||
# new_sample = torch.cat([sample, c_concat], dim=1)
|
|
||||||
# kwargs['cross_attention_kwargs'] = {}
|
|
||||||
# return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
# unet.forward = hooked_unet_forward
|
|
||||||
Loading…
x
Reference in New Issue
Block a user