diff --git a/nodes/image_nodes.py b/nodes/image_nodes.py index 77344d0..4caac5b 100644 --- a/nodes/image_nodes.py +++ b/nodes/image_nodes.py @@ -1676,7 +1676,8 @@ class LoadAndResizeImage: "repeat": ("INT", { "default": 1, "min": 1, "max": 4096, "step": 1, }), "keep_proportion": ("BOOLEAN", { "default": False }), "divisible_by": ("INT", { "default": 2, "min": 0, "max": 512, "step": 1, }), - "mask_channel": (s._color_channels, ), + "mask_channel": (s._color_channels, {"tooltip": "Channel to use for the mask output"}), + "background_color": ("STRING", { "default": "white", "tooltip": "Color to fill the alpha channel with. Enter a comma-separated RGB value. E.g. 255, 255, 255 for white."}), }, } @@ -1685,11 +1686,25 @@ class LoadAndResizeImage: RETURN_NAMES = ("image", "mask", "width", "height","image_path",) FUNCTION = "load_image" - def load_image(self, image, resize, width, height, repeat, keep_proportion, divisible_by, mask_channel): + def load_image(self, image, resize, width, height, repeat, keep_proportion, divisible_by, mask_channel, background_color): + from PIL import ImageColor image_path = folder_paths.get_annotated_filepath(image) import node_helpers img = node_helpers.pillow(Image.open, image_path) + + # Process the background_color + try: + # Try to parse as RGB tuple + bg_color_rgba = tuple(int(x.strip()) for x in background_color.split(',')) + except ValueError: + # If parsing fails, it might be a hex color or named color + if background_color.startswith('#') or background_color.lower() in ImageColor.colormap: + bg_color_rgba = ImageColor.getrgb(background_color) + else: + raise ValueError(f"Invalid background color: {background_color}") + + bg_color_rgba += (255,) # Add alpha channel output_images = [] output_masks = [] @@ -1715,12 +1730,28 @@ class LoadAndResizeImage: else: width, height = W, H - for i in ImageSequence.Iterator(img): - i = node_helpers.pillow(ImageOps.exif_transpose, i) + for frame in ImageSequence.Iterator(img): + frame = node_helpers.pillow(ImageOps.exif_transpose, frame) - if i.mode == 'I': - i = i.point(lambda i: i * (1 / 255)) - image = i.convert("RGB") + if frame.mode == 'I': + frame = frame.point(lambda i: i * (1 / 255)) + + if frame.mode == 'P': + frame = frame.convert("RGBA") + elif 'A' in frame.getbands(): + frame = frame.convert("RGBA") + + # Extract alpha channel if it exists + if 'A' in frame.getbands(): + alpha_mask = np.array(frame.getchannel('A')).astype(np.float32) / 255.0 + alpha_mask = 1. - torch.from_numpy(alpha_mask) + bg_image = Image.new("RGBA", frame.size, bg_color_rgba) + # Composite the frame onto the background + frame = Image.alpha_composite(bg_image, frame) + else: + alpha_mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu") + + image = frame.convert("RGB") if len(output_images) == 0: w = image.size[0] @@ -1733,17 +1764,17 @@ class LoadAndResizeImage: image = np.array(image).astype(np.float32) / 255.0 image = torch.from_numpy(image)[None,] - mask = None + c = mask_channel[0].upper() - if c in i.getbands(): + if c in frame.getbands(): if resize: - i = i.resize((width, height), Image.Resampling.BILINEAR) - mask = np.array(i.getchannel(c)).astype(np.float32) / 255.0 + frame = frame.resize((width, height), Image.Resampling.BILINEAR) + mask = np.array(frame.getchannel(c)).astype(np.float32) / 255.0 mask = torch.from_numpy(mask) if c == 'A': - mask = 1. - mask + mask = alpha_mask else: - mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") + mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu") output_images.append(image) output_masks.append(mask.unsqueeze(0)) @@ -1758,7 +1789,6 @@ class LoadAndResizeImage: output_image = output_image.repeat(repeat, 1, 1, 1) output_mask = output_mask.repeat(repeat, 1, 1) - return (output_image, output_mask, width, height, image_path)