Caption input for AddLabel

This commit is contained in:
Kijai 2024-04-23 16:33:15 +03:00
parent 48d5a18bd4
commit 1c3e7d0df7

View File

@ -3248,9 +3248,11 @@ class AddLabel:
], ],
{ {
"default": 'up' "default": 'up'
}), }),
}, },
"optional":{
"caption": ("STRING", {"default": "", "forceInput": True}),
}
} }
RETURN_TYPES = ("IMAGE",) RETURN_TYPES = ("IMAGE",)
@ -3264,7 +3266,7 @@ Fonts are loaded from this folder:
ComfyUI/custom_nodes/ComfyUI-KJNodes/fonts ComfyUI/custom_nodes/ComfyUI-KJNodes/fonts
""" """
def addlabel(self, image, text_x, text_y, text, height, font_size, font_color, label_color, font, direction): def addlabel(self, image, text_x, text_y, text, height, font_size, font_color, label_color, font, direction, caption=""):
batch_size = image.shape[0] batch_size = image.shape[0]
width = image.shape[2] width = image.shape[2]
@ -3272,18 +3274,37 @@ ComfyUI/custom_nodes/ComfyUI-KJNodes/fonts
font_path = os.path.join(script_directory, "fonts", "TTNorms-Black.otf") font_path = os.path.join(script_directory, "fonts", "TTNorms-Black.otf")
else: else:
font_path = folder_paths.get_full_path("kjnodes_fonts", font) font_path = folder_paths.get_full_path("kjnodes_fonts", font)
label_image = Image.new("RGB", (width, height), label_color)
draw = ImageDraw.Draw(label_image) if caption == "":
font = ImageFont.truetype(font_path, font_size) label_image = Image.new("RGB", (width, height), label_color)
try: draw = ImageDraw.Draw(label_image)
draw.text((text_x, text_y), text, font=font, fill=font_color, features=['-liga']) font = ImageFont.truetype(font_path, font_size)
except: try:
draw.text((text_x, text_y), text, font=font, fill=font_color) draw.text((text_x, text_y), text, font=font, fill=font_color, features=['-liga'])
except:
draw.text((text_x, text_y), text, font=font, fill=font_color)
label_image = np.array(label_image).astype(np.float32) / 255.0 label_image = np.array(label_image).astype(np.float32) / 255.0
label_image = torch.from_numpy(label_image)[None, :, :, :] label_image = torch.from_numpy(label_image)[None, :, :, :]
# Duplicate the label image for the entire batch # Duplicate the label image for the entire batch
label_batch = label_image.repeat(batch_size, 1, 1, 1) label_batch = label_image.repeat(batch_size, 1, 1, 1)
else:
label_list = []
assert len(caption) == batch_size, "Number of captions does not match number of images"
for cap in caption:
label_image = Image.new("RGB", (width, height), label_color)
draw = ImageDraw.Draw(label_image)
font = ImageFont.truetype(font_path, font_size)
try:
draw.text((text_x, text_y), cap, font=font, fill=font_color, features=['-liga'])
except:
draw.text((text_x, text_y), cap, font=font, fill=font_color)
label_image = np.array(label_image).astype(np.float32) / 255.0
label_image = torch.from_numpy(label_image)
label_list.append(label_image)
label_batch = torch.stack(label_list)
print(label_batch.shape)
if direction == 'down': if direction == 'down':
combined_images = torch.cat((image, label_batch), dim=1) combined_images = torch.cat((image, label_batch), dim=1)