From 095c8d4b526ba3c1f12fd9dd1d7f3540c6a11358 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sun, 16 Feb 2025 19:12:42 +0200 Subject: [PATCH] Update image_nodes.py --- nodes/image_nodes.py | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/nodes/image_nodes.py b/nodes/image_nodes.py index f93c81c..ae52eae 100644 --- a/nodes/image_nodes.py +++ b/nodes/image_nodes.py @@ -1098,13 +1098,14 @@ class ImagePrepForICLora: def INPUT_TYPES(s): return { "required": { - "image": ("IMAGE",), + "reference_image": ("IMAGE",), "output_width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), "output_height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), "border_width": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 1}), }, "optional": { - "mask": ("MASK",), + "latent_image": ("IMAGE",), + "reference_mask": ("MASK",), } } @@ -1113,17 +1114,19 @@ class ImagePrepForICLora: CATEGORY = "image" - def expand_image(self, image, output_width, output_height, border_width, mask=None): - if mask is not None: - if torch.allclose(mask, torch.zeros_like(mask)): + def expand_image(self, reference_image, output_width, output_height, border_width, latent_image=None, reference_mask=None): + + if reference_mask is not None: + if torch.allclose(reference_mask, torch.zeros_like(reference_mask)): print("Warning: The incoming mask is fully black. Handling it as None.") - mask = None + reference_mask = None + image = reference_image B, H, W, C = image.size() # Handle mask - if mask is not None: + if reference_mask is not None: resized_mask = torch.nn.functional.interpolate( - mask.unsqueeze(1), + reference_mask.unsqueeze(1), size=(image.shape[1], image.shape[2]), mode='nearest' ).squeeze(1) @@ -1137,15 +1140,19 @@ class ImagePrepForICLora: resized_image = common_upscale(image.movedim(-1,1), new_width, output_height, "lanczos", "disabled").movedim(1,-1) # Create padded image - empty_image = torch.zeros((B, output_height, output_width, C), device=image.device) + if latent_image is None: + pad_image = torch.zeros((B, output_height, output_width, C), device=image.device) + else: + resized_latent_image = common_upscale(latent_image.movedim(-1,1), output_width, output_height, "lanczos", "disabled").movedim(1,-1) + pad_image = resized_latent_image if border_width > 0: border = torch.zeros((B, output_height, border_width, C), device=image.device) - padded_image = torch.cat((resized_image, border, empty_image), dim=2) + padded_image = torch.cat((resized_image, border, pad_image), dim=2) padded_mask = torch.ones((B, padded_image.shape[1], padded_image.shape[2]), device=image.device) padded_mask[:, :, :new_width + border_width] = 0 else: - padded_image = torch.cat((resized_image, empty_image), dim=2) + padded_image = torch.cat((resized_image, pad_image), dim=2) padded_mask = torch.ones((B, padded_image.shape[1], padded_image.shape[2]), device=image.device) padded_mask[:, :, :new_width] = 0