diff --git a/nodes.py b/nodes.py index 5e0200d..00110ad 100644 --- a/nodes.py +++ b/nodes.py @@ -3248,9 +3248,11 @@ class AddLabel: ], { "default": 'up' - }), }, + "optional":{ + "caption": ("STRING", {"default": "", "forceInput": True}), + } } RETURN_TYPES = ("IMAGE",) @@ -3264,7 +3266,7 @@ Fonts are loaded from this folder: ComfyUI/custom_nodes/ComfyUI-KJNodes/fonts """ - def addlabel(self, image, text_x, text_y, text, height, font_size, font_color, label_color, font, direction): + def addlabel(self, image, text_x, text_y, text, height, font_size, font_color, label_color, font, direction, caption=""): batch_size = image.shape[0] width = image.shape[2] @@ -3272,18 +3274,37 @@ ComfyUI/custom_nodes/ComfyUI-KJNodes/fonts font_path = os.path.join(script_directory, "fonts", "TTNorms-Black.otf") else: font_path = folder_paths.get_full_path("kjnodes_fonts", font) - label_image = Image.new("RGB", (width, height), label_color) - draw = ImageDraw.Draw(label_image) - font = ImageFont.truetype(font_path, font_size) - try: - draw.text((text_x, text_y), text, font=font, fill=font_color, features=['-liga']) - except: - draw.text((text_x, text_y), text, font=font, fill=font_color) + + if caption == "": + label_image = Image.new("RGB", (width, height), label_color) + draw = ImageDraw.Draw(label_image) + font = ImageFont.truetype(font_path, font_size) + try: + draw.text((text_x, text_y), text, font=font, fill=font_color, features=['-liga']) + except: + draw.text((text_x, text_y), text, font=font, fill=font_color) - label_image = np.array(label_image).astype(np.float32) / 255.0 - label_image = torch.from_numpy(label_image)[None, :, :, :] - # Duplicate the label image for the entire batch - label_batch = label_image.repeat(batch_size, 1, 1, 1) + label_image = np.array(label_image).astype(np.float32) / 255.0 + label_image = torch.from_numpy(label_image)[None, :, :, :] + # Duplicate the label image for the entire batch + label_batch = label_image.repeat(batch_size, 1, 1, 1) + else: + label_list = [] + assert len(caption) == batch_size, "Number of captions does not match number of images" + for cap in caption: + label_image = Image.new("RGB", (width, height), label_color) + draw = ImageDraw.Draw(label_image) + font = ImageFont.truetype(font_path, font_size) + try: + draw.text((text_x, text_y), cap, font=font, fill=font_color, features=['-liga']) + except: + draw.text((text_x, text_y), cap, font=font, fill=font_color) + + label_image = np.array(label_image).astype(np.float32) / 255.0 + label_image = torch.from_numpy(label_image) + label_list.append(label_image) + label_batch = torch.stack(label_list) + print(label_batch.shape) if direction == 'down': combined_images = torch.cat((image, label_batch), dim=1)