Add DrawInstanceDiffusionTracking -node

This commit is contained in:
kijai 2024-05-05 12:20:52 +03:00
parent 5208980fc9
commit a976bac701
2 changed files with 86 additions and 3 deletions

View File

@ -113,8 +113,10 @@ NODE_CONFIG = {
"Superprompt": {"class": Superprompt, "name": "Superprompt"},
"GLIGENTextBoxApplyBatchCoords": {"class": GLIGENTextBoxApplyBatchCoords},
"Intrinsic_lora_sampling": {"class": Intrinsic_lora_sampling, "name": "Intrinsic Lora Sampling"},
#instance diffusion
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},
"AppendInstanceDiffusionTracking": {"class": AppendInstanceDiffusionTracking},
"DrawInstanceDiffusionTracking": {"class": DrawInstanceDiffusionTracking},
}
def generate_node_mappings(node_config):

View File

@ -1,8 +1,10 @@
import torch
from torchvision import transforms
import json
from PIL import Image, ImageDraw
from PIL import Image, ImageDraw, ImageFont
import numpy as np
from ..utility.utility import pil2tensor
import folder_paths
def plot_coordinates_to_tensor(coordinates, height, width, bbox_height, bbox_width, size_multiplier, prompt):
import matplotlib
@ -690,7 +692,7 @@ class CreateInstanceDiffusionTracking:
RETURN_TYPES = ("TRACKING", "STRING", "INT", "INT", "INT", "INT",)
RETURN_NAMES = ("tracking", "prompt", "width", "height", "bbox_width", "bbox_height",)
FUNCTION = "tracking"
CATEGORY = "KJNodes/experimental"
CATEGORY = "KJNodes/InstanceDiffusion"
DESCRIPTION = """
Creates tracking data to be used with InstanceDiffusion:
https://github.com/logtd/ComfyUI-InstanceDiffusion
@ -768,7 +770,7 @@ class AppendInstanceDiffusionTracking:
RETURN_TYPES = ("TRACKING", "STRING",)
RETURN_NAMES = ("tracking", "prompt",)
FUNCTION = "append"
CATEGORY = "KJNodes/experimental"
CATEGORY = "KJNodes/InstanceDiffusion"
DESCRIPTION = """
Appends tracking data to be used with InstanceDiffusion:
https://github.com/logtd/ComfyUI-InstanceDiffusion
@ -865,3 +867,82 @@ Interpolates coordinates based on a curve.
return (interpolated_coords_str, )
class DrawInstanceDiffusionTracking:
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("image", )
FUNCTION = "draw"
CATEGORY = "KJNodes/InstanceDiffusion"
DESCRIPTION = """
Draws the tracking data from
CreateInstanceDiffusionTracking -node.
"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE", ),
"tracking": ("TRACKING", {"forceInput": True}),
"box_line_width": ("INT", {"default": 2, "min": 1, "max": 10, "step": 1}),
"draw_text": ("BOOLEAN", {"default": True}),
"font": (folder_paths.get_filename_list("kjnodes_fonts"), ),
"font_size": ("INT", {"default": 20}),
},
}
def draw(self, image, tracking, box_line_width, draw_text, font, font_size):
import matplotlib.cm as cm
print(image.shape)
modified_images = []
colormap = cm.get_cmap('rainbow', len(tracking))
if draw_text:
#font = ImageFont.load_default()
font = ImageFont.truetype("arial.ttf", font_size)
# Iterate over each image in the batch
for i in range(image.shape[0]):
# Extract the current image and convert it to a PIL image
# Adjust the tensor to (C, H, W) for ToPILImage
current_image = image[i, :, :, :].permute(2, 0, 1)
pil_image = transforms.ToPILImage()(current_image)
draw = ImageDraw.Draw(pil_image)
# Iterate over the bounding boxes for the current image
for j, (class_name, class_data) in enumerate(tracking.items()):
for class_id, bbox_list in class_data.items():
# Check if the current index is within the bounds of the bbox_list
if i < len(bbox_list):
bbox = bbox_list[i]
# Ensure bbox is a list or tuple before unpacking
if isinstance(bbox, (list, tuple)):
x1, y1, x2, y2, _, _ = bbox
# Convert coordinates to integers
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
# Generate a color from the rainbow colormap
color = tuple(int(255 * x) for x in colormap(j / len(tracking)))[:3]
# Draw the bounding box on the image with the generated color
draw.rectangle([x1, y1, x2, y2], outline=color, width=box_line_width)
if draw_text:
# Draw the class name and ID as text above the box with the generated color
text = f"{class_id}.{class_name}"
# Calculate the width and height of the text
_, _, text_width, text_height = draw.textbbox((0, 0), text=text, font=font)
# Position the text above the top-left corner of the box
text_position = (x1, y1 - text_height)
draw.text(text_position, text, fill=color, font=font)
else:
print(f"Unexpected data type for bbox: {type(bbox)}")
# Convert the drawn image back to a torch tensor and adjust back to (H, W, C)
modified_image_tensor = transforms.ToTensor()(pil_image).permute(1, 2, 0)
modified_images.append(modified_image_tensor)
# Stack the modified images back into a batch
image_tensor_batch = torch.stack(modified_images).cpu().float()
return image_tensor_batch,