mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-05-03 19:53:35 +08:00
Add DrawInstanceDiffusionTracking -node
This commit is contained in:
parent
5208980fc9
commit
a976bac701
@ -113,8 +113,10 @@ NODE_CONFIG = {
|
|||||||
"Superprompt": {"class": Superprompt, "name": "Superprompt"},
|
"Superprompt": {"class": Superprompt, "name": "Superprompt"},
|
||||||
"GLIGENTextBoxApplyBatchCoords": {"class": GLIGENTextBoxApplyBatchCoords},
|
"GLIGENTextBoxApplyBatchCoords": {"class": GLIGENTextBoxApplyBatchCoords},
|
||||||
"Intrinsic_lora_sampling": {"class": Intrinsic_lora_sampling, "name": "Intrinsic Lora Sampling"},
|
"Intrinsic_lora_sampling": {"class": Intrinsic_lora_sampling, "name": "Intrinsic Lora Sampling"},
|
||||||
|
#instance diffusion
|
||||||
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},
|
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},
|
||||||
"AppendInstanceDiffusionTracking": {"class": AppendInstanceDiffusionTracking},
|
"AppendInstanceDiffusionTracking": {"class": AppendInstanceDiffusionTracking},
|
||||||
|
"DrawInstanceDiffusionTracking": {"class": DrawInstanceDiffusionTracking},
|
||||||
}
|
}
|
||||||
|
|
||||||
def generate_node_mappings(node_config):
|
def generate_node_mappings(node_config):
|
||||||
|
|||||||
@ -1,8 +1,10 @@
|
|||||||
import torch
|
import torch
|
||||||
|
from torchvision import transforms
|
||||||
import json
|
import json
|
||||||
from PIL import Image, ImageDraw
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from ..utility.utility import pil2tensor
|
from ..utility.utility import pil2tensor
|
||||||
|
import folder_paths
|
||||||
|
|
||||||
def plot_coordinates_to_tensor(coordinates, height, width, bbox_height, bbox_width, size_multiplier, prompt):
|
def plot_coordinates_to_tensor(coordinates, height, width, bbox_height, bbox_width, size_multiplier, prompt):
|
||||||
import matplotlib
|
import matplotlib
|
||||||
@ -690,7 +692,7 @@ class CreateInstanceDiffusionTracking:
|
|||||||
RETURN_TYPES = ("TRACKING", "STRING", "INT", "INT", "INT", "INT",)
|
RETURN_TYPES = ("TRACKING", "STRING", "INT", "INT", "INT", "INT",)
|
||||||
RETURN_NAMES = ("tracking", "prompt", "width", "height", "bbox_width", "bbox_height",)
|
RETURN_NAMES = ("tracking", "prompt", "width", "height", "bbox_width", "bbox_height",)
|
||||||
FUNCTION = "tracking"
|
FUNCTION = "tracking"
|
||||||
CATEGORY = "KJNodes/experimental"
|
CATEGORY = "KJNodes/InstanceDiffusion"
|
||||||
DESCRIPTION = """
|
DESCRIPTION = """
|
||||||
Creates tracking data to be used with InstanceDiffusion:
|
Creates tracking data to be used with InstanceDiffusion:
|
||||||
https://github.com/logtd/ComfyUI-InstanceDiffusion
|
https://github.com/logtd/ComfyUI-InstanceDiffusion
|
||||||
@ -768,7 +770,7 @@ class AppendInstanceDiffusionTracking:
|
|||||||
RETURN_TYPES = ("TRACKING", "STRING",)
|
RETURN_TYPES = ("TRACKING", "STRING",)
|
||||||
RETURN_NAMES = ("tracking", "prompt",)
|
RETURN_NAMES = ("tracking", "prompt",)
|
||||||
FUNCTION = "append"
|
FUNCTION = "append"
|
||||||
CATEGORY = "KJNodes/experimental"
|
CATEGORY = "KJNodes/InstanceDiffusion"
|
||||||
DESCRIPTION = """
|
DESCRIPTION = """
|
||||||
Appends tracking data to be used with InstanceDiffusion:
|
Appends tracking data to be used with InstanceDiffusion:
|
||||||
https://github.com/logtd/ComfyUI-InstanceDiffusion
|
https://github.com/logtd/ComfyUI-InstanceDiffusion
|
||||||
@ -865,3 +867,82 @@ Interpolates coordinates based on a curve.
|
|||||||
|
|
||||||
return (interpolated_coords_str, )
|
return (interpolated_coords_str, )
|
||||||
|
|
||||||
|
class DrawInstanceDiffusionTracking:
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
RETURN_NAMES = ("image", )
|
||||||
|
FUNCTION = "draw"
|
||||||
|
CATEGORY = "KJNodes/InstanceDiffusion"
|
||||||
|
DESCRIPTION = """
|
||||||
|
Draws the tracking data from
|
||||||
|
CreateInstanceDiffusionTracking -node.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE", ),
|
||||||
|
"tracking": ("TRACKING", {"forceInput": True}),
|
||||||
|
"box_line_width": ("INT", {"default": 2, "min": 1, "max": 10, "step": 1}),
|
||||||
|
"draw_text": ("BOOLEAN", {"default": True}),
|
||||||
|
"font": (folder_paths.get_filename_list("kjnodes_fonts"), ),
|
||||||
|
"font_size": ("INT", {"default": 20}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def draw(self, image, tracking, box_line_width, draw_text, font, font_size):
|
||||||
|
import matplotlib.cm as cm
|
||||||
|
print(image.shape)
|
||||||
|
|
||||||
|
modified_images = []
|
||||||
|
|
||||||
|
colormap = cm.get_cmap('rainbow', len(tracking))
|
||||||
|
if draw_text:
|
||||||
|
#font = ImageFont.load_default()
|
||||||
|
font = ImageFont.truetype("arial.ttf", font_size)
|
||||||
|
|
||||||
|
# Iterate over each image in the batch
|
||||||
|
for i in range(image.shape[0]):
|
||||||
|
# Extract the current image and convert it to a PIL image
|
||||||
|
# Adjust the tensor to (C, H, W) for ToPILImage
|
||||||
|
current_image = image[i, :, :, :].permute(2, 0, 1)
|
||||||
|
pil_image = transforms.ToPILImage()(current_image)
|
||||||
|
|
||||||
|
draw = ImageDraw.Draw(pil_image)
|
||||||
|
|
||||||
|
# Iterate over the bounding boxes for the current image
|
||||||
|
for j, (class_name, class_data) in enumerate(tracking.items()):
|
||||||
|
for class_id, bbox_list in class_data.items():
|
||||||
|
# Check if the current index is within the bounds of the bbox_list
|
||||||
|
if i < len(bbox_list):
|
||||||
|
bbox = bbox_list[i]
|
||||||
|
# Ensure bbox is a list or tuple before unpacking
|
||||||
|
if isinstance(bbox, (list, tuple)):
|
||||||
|
x1, y1, x2, y2, _, _ = bbox
|
||||||
|
# Convert coordinates to integers
|
||||||
|
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
||||||
|
# Generate a color from the rainbow colormap
|
||||||
|
color = tuple(int(255 * x) for x in colormap(j / len(tracking)))[:3]
|
||||||
|
# Draw the bounding box on the image with the generated color
|
||||||
|
draw.rectangle([x1, y1, x2, y2], outline=color, width=box_line_width)
|
||||||
|
if draw_text:
|
||||||
|
# Draw the class name and ID as text above the box with the generated color
|
||||||
|
text = f"{class_id}.{class_name}"
|
||||||
|
# Calculate the width and height of the text
|
||||||
|
_, _, text_width, text_height = draw.textbbox((0, 0), text=text, font=font)
|
||||||
|
# Position the text above the top-left corner of the box
|
||||||
|
text_position = (x1, y1 - text_height)
|
||||||
|
draw.text(text_position, text, fill=color, font=font)
|
||||||
|
else:
|
||||||
|
print(f"Unexpected data type for bbox: {type(bbox)}")
|
||||||
|
|
||||||
|
# Convert the drawn image back to a torch tensor and adjust back to (H, W, C)
|
||||||
|
modified_image_tensor = transforms.ToTensor()(pil_image).permute(1, 2, 0)
|
||||||
|
modified_images.append(modified_image_tensor)
|
||||||
|
|
||||||
|
# Stack the modified images back into a batch
|
||||||
|
image_tensor_batch = torch.stack(modified_images).cpu().float()
|
||||||
|
|
||||||
|
return image_tensor_batch,
|
||||||
Loading…
x
Reference in New Issue
Block a user