mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-10 05:15:05 +08:00
Caption input for AddLabel
This commit is contained in:
parent
48d5a18bd4
commit
1c3e7d0df7
47
nodes.py
47
nodes.py
@ -3248,9 +3248,11 @@ class AddLabel:
|
||||
],
|
||||
{
|
||||
"default": 'up'
|
||||
|
||||
}),
|
||||
},
|
||||
"optional":{
|
||||
"caption": ("STRING", {"default": "", "forceInput": True}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
@ -3264,7 +3266,7 @@ Fonts are loaded from this folder:
|
||||
ComfyUI/custom_nodes/ComfyUI-KJNodes/fonts
|
||||
"""
|
||||
|
||||
def addlabel(self, image, text_x, text_y, text, height, font_size, font_color, label_color, font, direction):
|
||||
def addlabel(self, image, text_x, text_y, text, height, font_size, font_color, label_color, font, direction, caption=""):
|
||||
batch_size = image.shape[0]
|
||||
width = image.shape[2]
|
||||
|
||||
@ -3272,18 +3274,37 @@ ComfyUI/custom_nodes/ComfyUI-KJNodes/fonts
|
||||
font_path = os.path.join(script_directory, "fonts", "TTNorms-Black.otf")
|
||||
else:
|
||||
font_path = folder_paths.get_full_path("kjnodes_fonts", font)
|
||||
label_image = Image.new("RGB", (width, height), label_color)
|
||||
draw = ImageDraw.Draw(label_image)
|
||||
font = ImageFont.truetype(font_path, font_size)
|
||||
try:
|
||||
draw.text((text_x, text_y), text, font=font, fill=font_color, features=['-liga'])
|
||||
except:
|
||||
draw.text((text_x, text_y), text, font=font, fill=font_color)
|
||||
|
||||
if caption == "":
|
||||
label_image = Image.new("RGB", (width, height), label_color)
|
||||
draw = ImageDraw.Draw(label_image)
|
||||
font = ImageFont.truetype(font_path, font_size)
|
||||
try:
|
||||
draw.text((text_x, text_y), text, font=font, fill=font_color, features=['-liga'])
|
||||
except:
|
||||
draw.text((text_x, text_y), text, font=font, fill=font_color)
|
||||
|
||||
label_image = np.array(label_image).astype(np.float32) / 255.0
|
||||
label_image = torch.from_numpy(label_image)[None, :, :, :]
|
||||
# Duplicate the label image for the entire batch
|
||||
label_batch = label_image.repeat(batch_size, 1, 1, 1)
|
||||
label_image = np.array(label_image).astype(np.float32) / 255.0
|
||||
label_image = torch.from_numpy(label_image)[None, :, :, :]
|
||||
# Duplicate the label image for the entire batch
|
||||
label_batch = label_image.repeat(batch_size, 1, 1, 1)
|
||||
else:
|
||||
label_list = []
|
||||
assert len(caption) == batch_size, "Number of captions does not match number of images"
|
||||
for cap in caption:
|
||||
label_image = Image.new("RGB", (width, height), label_color)
|
||||
draw = ImageDraw.Draw(label_image)
|
||||
font = ImageFont.truetype(font_path, font_size)
|
||||
try:
|
||||
draw.text((text_x, text_y), cap, font=font, fill=font_color, features=['-liga'])
|
||||
except:
|
||||
draw.text((text_x, text_y), cap, font=font, fill=font_color)
|
||||
|
||||
label_image = np.array(label_image).astype(np.float32) / 255.0
|
||||
label_image = torch.from_numpy(label_image)
|
||||
label_list.append(label_image)
|
||||
label_batch = torch.stack(label_list)
|
||||
print(label_batch.shape)
|
||||
|
||||
if direction == 'down':
|
||||
combined_images = torch.cat((image, label_batch), dim=1)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user