diff --git a/__init__.py b/__init__.py index af75c4b..70eb417 100644 --- a/__init__.py +++ b/__init__.py @@ -113,8 +113,10 @@ NODE_CONFIG = { "Superprompt": {"class": Superprompt, "name": "Superprompt"}, "GLIGENTextBoxApplyBatchCoords": {"class": GLIGENTextBoxApplyBatchCoords}, "Intrinsic_lora_sampling": {"class": Intrinsic_lora_sampling, "name": "Intrinsic Lora Sampling"}, + #instance diffusion "CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking}, "AppendInstanceDiffusionTracking": {"class": AppendInstanceDiffusionTracking}, + "DrawInstanceDiffusionTracking": {"class": DrawInstanceDiffusionTracking}, } def generate_node_mappings(node_config): diff --git a/nodes/curve_nodes.py b/nodes/curve_nodes.py index ab90589..8c8e9b2 100644 --- a/nodes/curve_nodes.py +++ b/nodes/curve_nodes.py @@ -1,8 +1,10 @@ import torch +from torchvision import transforms import json -from PIL import Image, ImageDraw +from PIL import Image, ImageDraw, ImageFont import numpy as np from ..utility.utility import pil2tensor +import folder_paths def plot_coordinates_to_tensor(coordinates, height, width, bbox_height, bbox_width, size_multiplier, prompt): import matplotlib @@ -690,7 +692,7 @@ class CreateInstanceDiffusionTracking: RETURN_TYPES = ("TRACKING", "STRING", "INT", "INT", "INT", "INT",) RETURN_NAMES = ("tracking", "prompt", "width", "height", "bbox_width", "bbox_height",) FUNCTION = "tracking" - CATEGORY = "KJNodes/experimental" + CATEGORY = "KJNodes/InstanceDiffusion" DESCRIPTION = """ Creates tracking data to be used with InstanceDiffusion: https://github.com/logtd/ComfyUI-InstanceDiffusion @@ -768,7 +770,7 @@ class AppendInstanceDiffusionTracking: RETURN_TYPES = ("TRACKING", "STRING",) RETURN_NAMES = ("tracking", "prompt",) FUNCTION = "append" - CATEGORY = "KJNodes/experimental" + CATEGORY = "KJNodes/InstanceDiffusion" DESCRIPTION = """ Appends tracking data to be used with InstanceDiffusion: https://github.com/logtd/ComfyUI-InstanceDiffusion @@ -865,3 +867,82 @@ Interpolates coordinates based on a curve. return (interpolated_coords_str, ) +class DrawInstanceDiffusionTracking: + + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ("image", ) + FUNCTION = "draw" + CATEGORY = "KJNodes/InstanceDiffusion" + DESCRIPTION = """ +Draws the tracking data from +CreateInstanceDiffusionTracking -node. + +""" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE", ), + "tracking": ("TRACKING", {"forceInput": True}), + "box_line_width": ("INT", {"default": 2, "min": 1, "max": 10, "step": 1}), + "draw_text": ("BOOLEAN", {"default": True}), + "font": (folder_paths.get_filename_list("kjnodes_fonts"), ), + "font_size": ("INT", {"default": 20}), + }, + } + + def draw(self, image, tracking, box_line_width, draw_text, font, font_size): + import matplotlib.cm as cm + print(image.shape) + + modified_images = [] + + colormap = cm.get_cmap('rainbow', len(tracking)) + if draw_text: + #font = ImageFont.load_default() + font = ImageFont.truetype("arial.ttf", font_size) + + # Iterate over each image in the batch + for i in range(image.shape[0]): + # Extract the current image and convert it to a PIL image + # Adjust the tensor to (C, H, W) for ToPILImage + current_image = image[i, :, :, :].permute(2, 0, 1) + pil_image = transforms.ToPILImage()(current_image) + + draw = ImageDraw.Draw(pil_image) + + # Iterate over the bounding boxes for the current image + for j, (class_name, class_data) in enumerate(tracking.items()): + for class_id, bbox_list in class_data.items(): + # Check if the current index is within the bounds of the bbox_list + if i < len(bbox_list): + bbox = bbox_list[i] + # Ensure bbox is a list or tuple before unpacking + if isinstance(bbox, (list, tuple)): + x1, y1, x2, y2, _, _ = bbox + # Convert coordinates to integers + x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) + # Generate a color from the rainbow colormap + color = tuple(int(255 * x) for x in colormap(j / len(tracking)))[:3] + # Draw the bounding box on the image with the generated color + draw.rectangle([x1, y1, x2, y2], outline=color, width=box_line_width) + if draw_text: + # Draw the class name and ID as text above the box with the generated color + text = f"{class_id}.{class_name}" + # Calculate the width and height of the text + _, _, text_width, text_height = draw.textbbox((0, 0), text=text, font=font) + # Position the text above the top-left corner of the box + text_position = (x1, y1 - text_height) + draw.text(text_position, text, fill=color, font=font) + else: + print(f"Unexpected data type for bbox: {type(bbox)}") + + # Convert the drawn image back to a torch tensor and adjust back to (H, W, C) + modified_image_tensor = transforms.ToTensor()(pil_image).permute(1, 2, 0) + modified_images.append(modified_image_tensor) + + # Stack the modified images back into a batch + image_tensor_batch = torch.stack(modified_images).cpu().float() + + return image_tensor_batch, \ No newline at end of file