Add DrawInstanceDiffusionTracking -node

This commit is contained in:
kijai 2024-05-05 12:20:52 +03:00
parent 5208980fc9
commit a976bac701
2 changed files with 86 additions and 3 deletions

View File

@ -113,8 +113,10 @@ NODE_CONFIG = {
"Superprompt": {"class": Superprompt, "name": "Superprompt"}, "Superprompt": {"class": Superprompt, "name": "Superprompt"},
"GLIGENTextBoxApplyBatchCoords": {"class": GLIGENTextBoxApplyBatchCoords}, "GLIGENTextBoxApplyBatchCoords": {"class": GLIGENTextBoxApplyBatchCoords},
"Intrinsic_lora_sampling": {"class": Intrinsic_lora_sampling, "name": "Intrinsic Lora Sampling"}, "Intrinsic_lora_sampling": {"class": Intrinsic_lora_sampling, "name": "Intrinsic Lora Sampling"},
#instance diffusion
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking}, "CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},
"AppendInstanceDiffusionTracking": {"class": AppendInstanceDiffusionTracking}, "AppendInstanceDiffusionTracking": {"class": AppendInstanceDiffusionTracking},
"DrawInstanceDiffusionTracking": {"class": DrawInstanceDiffusionTracking},
} }
def generate_node_mappings(node_config): def generate_node_mappings(node_config):

View File

@ -1,8 +1,10 @@
import torch import torch
from torchvision import transforms
import json import json
from PIL import Image, ImageDraw from PIL import Image, ImageDraw, ImageFont
import numpy as np import numpy as np
from ..utility.utility import pil2tensor from ..utility.utility import pil2tensor
import folder_paths
def plot_coordinates_to_tensor(coordinates, height, width, bbox_height, bbox_width, size_multiplier, prompt): def plot_coordinates_to_tensor(coordinates, height, width, bbox_height, bbox_width, size_multiplier, prompt):
import matplotlib import matplotlib
@ -690,7 +692,7 @@ class CreateInstanceDiffusionTracking:
RETURN_TYPES = ("TRACKING", "STRING", "INT", "INT", "INT", "INT",) RETURN_TYPES = ("TRACKING", "STRING", "INT", "INT", "INT", "INT",)
RETURN_NAMES = ("tracking", "prompt", "width", "height", "bbox_width", "bbox_height",) RETURN_NAMES = ("tracking", "prompt", "width", "height", "bbox_width", "bbox_height",)
FUNCTION = "tracking" FUNCTION = "tracking"
CATEGORY = "KJNodes/experimental" CATEGORY = "KJNodes/InstanceDiffusion"
DESCRIPTION = """ DESCRIPTION = """
Creates tracking data to be used with InstanceDiffusion: Creates tracking data to be used with InstanceDiffusion:
https://github.com/logtd/ComfyUI-InstanceDiffusion https://github.com/logtd/ComfyUI-InstanceDiffusion
@ -768,7 +770,7 @@ class AppendInstanceDiffusionTracking:
RETURN_TYPES = ("TRACKING", "STRING",) RETURN_TYPES = ("TRACKING", "STRING",)
RETURN_NAMES = ("tracking", "prompt",) RETURN_NAMES = ("tracking", "prompt",)
FUNCTION = "append" FUNCTION = "append"
CATEGORY = "KJNodes/experimental" CATEGORY = "KJNodes/InstanceDiffusion"
DESCRIPTION = """ DESCRIPTION = """
Appends tracking data to be used with InstanceDiffusion: Appends tracking data to be used with InstanceDiffusion:
https://github.com/logtd/ComfyUI-InstanceDiffusion https://github.com/logtd/ComfyUI-InstanceDiffusion
@ -865,3 +867,82 @@ Interpolates coordinates based on a curve.
return (interpolated_coords_str, ) return (interpolated_coords_str, )
class DrawInstanceDiffusionTracking:
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("image", )
FUNCTION = "draw"
CATEGORY = "KJNodes/InstanceDiffusion"
DESCRIPTION = """
Draws the tracking data from
CreateInstanceDiffusionTracking -node.
"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE", ),
"tracking": ("TRACKING", {"forceInput": True}),
"box_line_width": ("INT", {"default": 2, "min": 1, "max": 10, "step": 1}),
"draw_text": ("BOOLEAN", {"default": True}),
"font": (folder_paths.get_filename_list("kjnodes_fonts"), ),
"font_size": ("INT", {"default": 20}),
},
}
def draw(self, image, tracking, box_line_width, draw_text, font, font_size):
import matplotlib.cm as cm
print(image.shape)
modified_images = []
colormap = cm.get_cmap('rainbow', len(tracking))
if draw_text:
#font = ImageFont.load_default()
font = ImageFont.truetype("arial.ttf", font_size)
# Iterate over each image in the batch
for i in range(image.shape[0]):
# Extract the current image and convert it to a PIL image
# Adjust the tensor to (C, H, W) for ToPILImage
current_image = image[i, :, :, :].permute(2, 0, 1)
pil_image = transforms.ToPILImage()(current_image)
draw = ImageDraw.Draw(pil_image)
# Iterate over the bounding boxes for the current image
for j, (class_name, class_data) in enumerate(tracking.items()):
for class_id, bbox_list in class_data.items():
# Check if the current index is within the bounds of the bbox_list
if i < len(bbox_list):
bbox = bbox_list[i]
# Ensure bbox is a list or tuple before unpacking
if isinstance(bbox, (list, tuple)):
x1, y1, x2, y2, _, _ = bbox
# Convert coordinates to integers
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
# Generate a color from the rainbow colormap
color = tuple(int(255 * x) for x in colormap(j / len(tracking)))[:3]
# Draw the bounding box on the image with the generated color
draw.rectangle([x1, y1, x2, y2], outline=color, width=box_line_width)
if draw_text:
# Draw the class name and ID as text above the box with the generated color
text = f"{class_id}.{class_name}"
# Calculate the width and height of the text
_, _, text_width, text_height = draw.textbbox((0, 0), text=text, font=font)
# Position the text above the top-left corner of the box
text_position = (x1, y1 - text_height)
draw.text(text_position, text, fill=color, font=font)
else:
print(f"Unexpected data type for bbox: {type(bbox)}")
# Convert the drawn image back to a torch tensor and adjust back to (H, W, C)
modified_image_tensor = transforms.ToTensor()(pil_image).permute(1, 2, 0)
modified_images.append(modified_image_tensor)
# Stack the modified images back into a batch
image_tensor_batch = torch.stack(modified_images).cpu().float()
return image_tensor_batch,