diff --git a/nodes/image_nodes.py b/nodes/image_nodes.py index cd858c4..8d201bc 100644 --- a/nodes/image_nodes.py +++ b/nodes/image_nodes.py @@ -3149,9 +3149,11 @@ class ImagePadKJ: "extra_padding": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1, }), "pad_mode": (["edge", "color"],), "color": ("STRING", {"default": "0, 0, 0", "tooltip": "Color as RGB values in range 0-255, separated by commas."}), - } - , "optional": { + }, + "optional": { "mask": ("MASK", ), + "target_width": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1, }), + "target_height": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1, }), } } @@ -3161,7 +3163,7 @@ class ImagePadKJ: CATEGORY = "KJNodes/image" DESCRIPTION = "Pad the input image and optionally mask with the specified padding." - def pad(self, image, left, right, top, bottom, extra_padding, color, pad_mode, mask=None): + def pad(self, image, left, right, top, bottom, extra_padding, color, pad_mode, mask=None, target_width=None, target_height=None): B, H, W, C = image.shape # Resize masks to image dimensions if necessary @@ -3175,15 +3177,23 @@ class ImagePadKJ: if len(bg_color) == 1: bg_color = bg_color * 3 # Grayscale to RGB bg_color = torch.tensor(bg_color, dtype=image.dtype, device=image.device) - + # Calculate padding sizes with extra padding - pad_left = left + extra_padding - pad_right = right + extra_padding - pad_top = top + extra_padding - pad_bottom = bottom + extra_padding + if target_width is not None and target_height is not None: + padded_width = target_width + padded_height = target_height + pad_left = (padded_width - W) // 2 + pad_right = padded_width - W - pad_left + pad_top = (padded_height - H) // 2 + pad_bottom = padded_height - H - pad_top + else: + pad_left = left + extra_padding + pad_right = right + extra_padding + pad_top = top + extra_padding + pad_bottom = bottom + extra_padding - padded_width = W + pad_left + pad_right - padded_height = H + pad_top + pad_bottom + padded_width = W + pad_left + pad_right + padded_height = H + pad_top + pad_bottom out_image = torch.zeros((B, padded_height, padded_width, C), dtype=image.dtype, device=image.device) # Fill padded areas