mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-10 13:24:44 +08:00
Update nodes.py
This commit is contained in:
parent
928e5cc778
commit
46e91987be
@ -1646,7 +1646,7 @@ class LoadICLightUnet:
|
|||||||
CATEGORY = "KJNodes/experimental"
|
CATEGORY = "KJNodes/experimental"
|
||||||
|
|
||||||
def load(self, model, model_path):
|
def load(self, model, model_path):
|
||||||
print("LoadICLightUnet: Checking ResAdapter path")
|
print("LoadICLightUnet: Checking LoadICLightUnet path")
|
||||||
model_full_path = folder_paths.get_full_path("unet", model_path)
|
model_full_path = folder_paths.get_full_path("unet", model_path)
|
||||||
if not os.path.exists(model_full_path):
|
if not os.path.exists(model_full_path):
|
||||||
raise Exception("Invalid model path")
|
raise Exception("Invalid model path")
|
||||||
@ -1657,6 +1657,7 @@ class LoadICLightUnet:
|
|||||||
|
|
||||||
conv_layer = model_clone.model.diffusion_model.input_blocks[0][0]
|
conv_layer = model_clone.model.diffusion_model.input_blocks[0][0]
|
||||||
print(f"Current number of input channels: {conv_layer.in_channels}")
|
print(f"Current number of input channels: {conv_layer.in_channels}")
|
||||||
|
|
||||||
# Create a new Conv2d layer with 8 input channels
|
# Create a new Conv2d layer with 8 input channels
|
||||||
new_conv_layer = torch.nn.Conv2d(8, conv_layer.out_channels, kernel_size=conv_layer.kernel_size, stride=conv_layer.stride, padding=conv_layer.padding)
|
new_conv_layer = torch.nn.Conv2d(8, conv_layer.out_channels, kernel_size=conv_layer.kernel_size, stride=conv_layer.stride, padding=conv_layer.padding)
|
||||||
new_conv_layer.weight.zero_()
|
new_conv_layer.weight.zero_()
|
||||||
@ -1668,7 +1669,19 @@ class LoadICLightUnet:
|
|||||||
model_clone.model.diffusion_model.input_blocks[0][0] = new_conv_layer
|
model_clone.model.diffusion_model.input_blocks[0][0] = new_conv_layer
|
||||||
# Verify the change
|
# Verify the change
|
||||||
print(f"New number of input channels: {model_clone.model.diffusion_model.input_blocks[0][0].in_channels}")
|
print(f"New number of input channels: {model_clone.model.diffusion_model.input_blocks[0][0].in_channels}")
|
||||||
|
|
||||||
|
# Monkey patch because I don't know what I'm doing
|
||||||
|
from comfy.model_base import IP2P
|
||||||
|
import types
|
||||||
|
base_model_instance = model_clone.model
|
||||||
|
|
||||||
|
# Dynamically add the extra_conds method from IP2P to the instance of BaseModel
|
||||||
|
def bound_extra_conds(self, **kwargs):
|
||||||
|
return IP2P.extra_conds(self, **kwargs)
|
||||||
|
base_model_instance.process_ip2p_image_in = lambda image: image
|
||||||
|
base_model_instance.extra_conds = types.MethodType(bound_extra_conds, base_model_instance)
|
||||||
|
|
||||||
|
# Some Proper patching
|
||||||
new_state_dict = load_torch_file(model_full_path)
|
new_state_dict = load_torch_file(model_full_path)
|
||||||
prefix_to_remove = 'model.'
|
prefix_to_remove = 'model.'
|
||||||
new_keys_dict = {key[len(prefix_to_remove):]: new_state_dict[key] for key in new_state_dict if key.startswith(prefix_to_remove)}
|
new_keys_dict = {key[len(prefix_to_remove):]: new_state_dict[key] for key in new_state_dict if key.startswith(prefix_to_remove)}
|
||||||
@ -1677,11 +1690,9 @@ class LoadICLightUnet:
|
|||||||
try:
|
try:
|
||||||
for key in new_keys_dict:
|
for key in new_keys_dict:
|
||||||
model_clone.add_patches({key: (new_keys_dict[key],)}, 1.0, 1.0)
|
model_clone.add_patches({key: (new_keys_dict[key],)}, 1.0, 1.0)
|
||||||
#print(f"Added patch for: {key}")
|
|
||||||
except:
|
except:
|
||||||
raise Exception("Could not patch model")
|
raise Exception("Could not patch model")
|
||||||
print("LoadICLightUnet: Added LoadICLightUnet patches")
|
print("LoadICLightUnet: Added LoadICLightUnet patches")
|
||||||
#model_clone.model.diffusion_model.in_channels = 8
|
|
||||||
|
|
||||||
return (model_clone, )
|
return (model_clone, )
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user