mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-01-23 20:24:35 +08:00
Add DrawInstanceDiffusionTracking -node
This commit is contained in:
parent
5208980fc9
commit
a976bac701
@ -113,8 +113,10 @@ NODE_CONFIG = {
|
||||
"Superprompt": {"class": Superprompt, "name": "Superprompt"},
|
||||
"GLIGENTextBoxApplyBatchCoords": {"class": GLIGENTextBoxApplyBatchCoords},
|
||||
"Intrinsic_lora_sampling": {"class": Intrinsic_lora_sampling, "name": "Intrinsic Lora Sampling"},
|
||||
#instance diffusion
|
||||
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},
|
||||
"AppendInstanceDiffusionTracking": {"class": AppendInstanceDiffusionTracking},
|
||||
"DrawInstanceDiffusionTracking": {"class": DrawInstanceDiffusionTracking},
|
||||
}
|
||||
|
||||
def generate_node_mappings(node_config):
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
import json
|
||||
from PIL import Image, ImageDraw
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
import numpy as np
|
||||
from ..utility.utility import pil2tensor
|
||||
import folder_paths
|
||||
|
||||
def plot_coordinates_to_tensor(coordinates, height, width, bbox_height, bbox_width, size_multiplier, prompt):
|
||||
import matplotlib
|
||||
@ -690,7 +692,7 @@ class CreateInstanceDiffusionTracking:
|
||||
RETURN_TYPES = ("TRACKING", "STRING", "INT", "INT", "INT", "INT",)
|
||||
RETURN_NAMES = ("tracking", "prompt", "width", "height", "bbox_width", "bbox_height",)
|
||||
FUNCTION = "tracking"
|
||||
CATEGORY = "KJNodes/experimental"
|
||||
CATEGORY = "KJNodes/InstanceDiffusion"
|
||||
DESCRIPTION = """
|
||||
Creates tracking data to be used with InstanceDiffusion:
|
||||
https://github.com/logtd/ComfyUI-InstanceDiffusion
|
||||
@ -768,7 +770,7 @@ class AppendInstanceDiffusionTracking:
|
||||
RETURN_TYPES = ("TRACKING", "STRING",)
|
||||
RETURN_NAMES = ("tracking", "prompt",)
|
||||
FUNCTION = "append"
|
||||
CATEGORY = "KJNodes/experimental"
|
||||
CATEGORY = "KJNodes/InstanceDiffusion"
|
||||
DESCRIPTION = """
|
||||
Appends tracking data to be used with InstanceDiffusion:
|
||||
https://github.com/logtd/ComfyUI-InstanceDiffusion
|
||||
@ -865,3 +867,82 @@ Interpolates coordinates based on a curve.
|
||||
|
||||
return (interpolated_coords_str, )
|
||||
|
||||
class DrawInstanceDiffusionTracking:
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
RETURN_NAMES = ("image", )
|
||||
FUNCTION = "draw"
|
||||
CATEGORY = "KJNodes/InstanceDiffusion"
|
||||
DESCRIPTION = """
|
||||
Draws the tracking data from
|
||||
CreateInstanceDiffusionTracking -node.
|
||||
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE", ),
|
||||
"tracking": ("TRACKING", {"forceInput": True}),
|
||||
"box_line_width": ("INT", {"default": 2, "min": 1, "max": 10, "step": 1}),
|
||||
"draw_text": ("BOOLEAN", {"default": True}),
|
||||
"font": (folder_paths.get_filename_list("kjnodes_fonts"), ),
|
||||
"font_size": ("INT", {"default": 20}),
|
||||
},
|
||||
}
|
||||
|
||||
def draw(self, image, tracking, box_line_width, draw_text, font, font_size):
|
||||
import matplotlib.cm as cm
|
||||
print(image.shape)
|
||||
|
||||
modified_images = []
|
||||
|
||||
colormap = cm.get_cmap('rainbow', len(tracking))
|
||||
if draw_text:
|
||||
#font = ImageFont.load_default()
|
||||
font = ImageFont.truetype("arial.ttf", font_size)
|
||||
|
||||
# Iterate over each image in the batch
|
||||
for i in range(image.shape[0]):
|
||||
# Extract the current image and convert it to a PIL image
|
||||
# Adjust the tensor to (C, H, W) for ToPILImage
|
||||
current_image = image[i, :, :, :].permute(2, 0, 1)
|
||||
pil_image = transforms.ToPILImage()(current_image)
|
||||
|
||||
draw = ImageDraw.Draw(pil_image)
|
||||
|
||||
# Iterate over the bounding boxes for the current image
|
||||
for j, (class_name, class_data) in enumerate(tracking.items()):
|
||||
for class_id, bbox_list in class_data.items():
|
||||
# Check if the current index is within the bounds of the bbox_list
|
||||
if i < len(bbox_list):
|
||||
bbox = bbox_list[i]
|
||||
# Ensure bbox is a list or tuple before unpacking
|
||||
if isinstance(bbox, (list, tuple)):
|
||||
x1, y1, x2, y2, _, _ = bbox
|
||||
# Convert coordinates to integers
|
||||
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
||||
# Generate a color from the rainbow colormap
|
||||
color = tuple(int(255 * x) for x in colormap(j / len(tracking)))[:3]
|
||||
# Draw the bounding box on the image with the generated color
|
||||
draw.rectangle([x1, y1, x2, y2], outline=color, width=box_line_width)
|
||||
if draw_text:
|
||||
# Draw the class name and ID as text above the box with the generated color
|
||||
text = f"{class_id}.{class_name}"
|
||||
# Calculate the width and height of the text
|
||||
_, _, text_width, text_height = draw.textbbox((0, 0), text=text, font=font)
|
||||
# Position the text above the top-left corner of the box
|
||||
text_position = (x1, y1 - text_height)
|
||||
draw.text(text_position, text, fill=color, font=font)
|
||||
else:
|
||||
print(f"Unexpected data type for bbox: {type(bbox)}")
|
||||
|
||||
# Convert the drawn image back to a torch tensor and adjust back to (H, W, C)
|
||||
modified_image_tensor = transforms.ToTensor()(pil_image).permute(1, 2, 0)
|
||||
modified_images.append(modified_image_tensor)
|
||||
|
||||
# Stack the modified images back into a batch
|
||||
image_tensor_batch = torch.stack(modified_images).cpu().float()
|
||||
|
||||
return image_tensor_batch,
|
||||
Loading…
x
Reference in New Issue
Block a user