Merge branch 'main' into develop

This commit is contained in:
kijai 2024-05-05 23:57:57 +03:00
commit 62a704fc14
28 changed files with 6912 additions and 5179 deletions

3
.gitignore vendored
View File

@ -4,4 +4,5 @@ __pycache__
*.ckpt
*.pth
types
models
models
jsconfig.json

View File

@ -2,21 +2,35 @@
Various quality of life and masking related -nodes and scripts made by combining functionality of existing nodes for ComfyUI.
I know I'm bad at documentation, especially this project that has grown from random practice nodes to... too many lines in one file.
I have however started to add descriptions to the nodes themselves, there's a small ? you can click for info what the node does.
This is still work in progress, like everything else.
# Installation
1. Clone this repo into `custom_nodes` folder.
2. Install dependencies: pip install -r requirements.txt
2. Install dependencies: `pip install -r requirements.txt`
or if you use the portable install, run this in ComfyUI_windows_portable -folder:
`python_embeded\python.exe -m pip install -r ComfyUI\custom_nodes\ComfyUI-KJNodes\requirements.txt`
## Javascript
### browserstatus.js
Sets the favicon to green circle when not processing anything, sets it to red when processing and shows progress percentage and the lenghth of your queue. Might clash with other scripts that affect the page title, delete this file to disable (until I figure out how to add options).
Sets the favicon to green circle when not processing anything, sets it to red when processing and shows progress percentage and the lenghth of your queue.
Default off, needs to be enabled from options, overrides Custom-Scripts favicon when enabled.
## Nodes:
### Set/Get
Javascript nodes to set and get constants to reduce unnecessary lines. Takes in and returns anything, purely visual nodes.
Could still be buggy, especially when loading workflow with missing nodes, use with precaution.
On the right click menu of these nodes there's now an options to visualize the paths, as well as option to jump to the corresponding node on the other end.
**Known limitations**:
- Will not work with any node that dynamically sets it's outpute, such as reroute or other Set/Get node
- Will not work when directly connected to a bypassed node
- Other possible conflicts with javascript based nodes.
### ColorToMask
@ -34,14 +48,6 @@ Mask and combine two sets of conditions, saves space.
Grows or shrinks (with negative values) mask, option to invert input, returns mask and inverted mask. Additionally Blurs the mask, this is a slow operation especially with big batches.
### CreateFadeMask
This node creates batch of single color images by interpolating between white/black levels. Useful to control mask strengths or QR code controlnet input weight when combined with MaskComposite node.
### CreateAudioMask
Work in progress, currently creates a sphere that's size is synced with audio input.
### RoundMask
![image](https://github.com/kijai/ComfyUI-KJNodes/assets/40791699/52c85202-f74e-4b96-9dac-c8bda5ddcc40)

View File

@ -1,4 +1,136 @@
from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
from .nodes.nodes import *
from .nodes.curve_nodes import *
from .nodes.batchcrop_nodes import *
from .nodes.audioscheduler_nodes import *
from .nodes.image_nodes import *
from .nodes.intrinsic_lora_nodes import *
from .nodes.mask_nodes import *
NODE_CONFIG = {
#constants
"INTConstant": {"class": INTConstant, "name": "INT Constant"},
"FloatConstant": {"class": FloatConstant, "name": "Float Constant"},
"StringConstant": {"class": StringConstant, "name": "String Constant"},
"StringConstantMultiline": {"class": StringConstantMultiline, "name": "String Constant Multiline"},
#conditioning
"ConditioningMultiCombine": {"class": ConditioningMultiCombine, "name": "Conditioning Multi Combine"},
"ConditioningSetMaskAndCombine": {"class": ConditioningSetMaskAndCombine, "name": "ConditioningSetMaskAndCombine"},
"ConditioningSetMaskAndCombine3": {"class": ConditioningSetMaskAndCombine3, "name": "ConditioningSetMaskAndCombine3"},
"ConditioningSetMaskAndCombine4": {"class": ConditioningSetMaskAndCombine4, "name": "ConditioningSetMaskAndCombine4"},
"ConditioningSetMaskAndCombine5": {"class": ConditioningSetMaskAndCombine5, "name": "ConditioningSetMaskAndCombine5"},
"CondPassThrough": {"class": CondPassThrough},
#masking
"BatchCLIPSeg": {"class": BatchCLIPSeg, "name": "Batch CLIPSeg"},
"ColorToMask": {"class": ColorToMask, "name": "Color To Mask"},
"CreateGradientMask": {"class": CreateGradientMask, "name": "Create Gradient Mask"},
"CreateTextMask": {"class": CreateTextMask, "name": "Create Text Mask"},
"CreateAudioMask": {"class": CreateAudioMask, "name": "Create Audio Mask"},
"CreateFadeMask": {"class": CreateFadeMask, "name": "Create Fade Mask"},
"CreateFadeMaskAdvanced": {"class": CreateFadeMaskAdvanced, "name": "Create Fade Mask Advanced"},
"CreateFluidMask": {"class": CreateFluidMask, "name": "Create Fluid Mask"},
"CreateShapeMask": {"class": CreateShapeMask, "name": "Create Shape Mask"},
"CreateVoronoiMask": {"class": CreateVoronoiMask, "name": "Create Voronoi Mask"},
"CreateMagicMask": {"class": CreateMagicMask, "name": "Create Magic Mask"},
"GetMaskSizeAndCount": {"class": GetMaskSizeAndCount, "name": "Get Mask Size & Count"},
"GrowMaskWithBlur": {"class": GrowMaskWithBlur, "name": "Grow Mask With Blur"},
"MaskBatchMulti": {"class": MaskBatchMulti, "name": "Mask Batch Multi"},
"OffsetMask": {"class": OffsetMask, "name": "Offset Mask"},
"RemapMaskRange": {"class": RemapMaskRange, "name": "Remap Mask Range"},
"ResizeMask": {"class": ResizeMask, "name": "Resize Mask"},
"RoundMask": {"class": RoundMask, "name": "Round Mask"},
#images
"AddLabel": {"class": AddLabel, "name": "Add Label"},
"ColorMatch": {"class": ColorMatch, "name": "Color Match"},
"CrossFadeImages": {"class": CrossFadeImages, "name": "Cross Fade Images"},
"GetImageRangeFromBatch": {"class": GetImageRangeFromBatch, "name": "Get Image Range From Batch"},
"GetImageSizeAndCount": {"class": GetImageSizeAndCount, "name": "Get Image Size & Count"},
"ImageAndMaskPreview": {"class": ImageAndMaskPreview},
"ImageBatchMulti": {"class": ImageBatchMulti, "name": "Image Batch Multi"},
"ImageBatchRepeatInterleaving": {"class": ImageBatchRepeatInterleaving},
"ImageBatchTestPattern": {"class": ImageBatchTestPattern, "name": "Image Batch Test Pattern"},
"ImageConcanate": {"class": ImageConcanate, "name": "Image Concatenate"},
"ImageGrabPIL": {"class": ImageGrabPIL, "name": "Image Grab PIL"},
"ImageGridComposite2x2": {"class": ImageGridComposite2x2, "name": "Image Grid Composite 2x2"},
"ImageGridComposite3x3": {"class": ImageGridComposite3x3, "name": "Image Grid Composite 3x3"},
"ImageNormalize_Neg1_To_1": {"class": ImageNormalize_Neg1_To_1, "name": "Image Normalize -1 to 1"},
"ImagePass": {"class": ImagePass},
"ImagePadForOutpaintMasked": {"class": ImagePadForOutpaintMasked, "name": "Image Pad For Outpaint Masked"},
"ImageUpscaleWithModelBatched": {"class": ImageUpscaleWithModelBatched, "name": "Image Upscale With Model Batched"},
"InsertImagesToBatchIndexed": {"class": InsertImagesToBatchIndexed, "name": "Insert Images To Batch Indexed"},
"MergeImageChannels": {"class": MergeImageChannels, "name": "Merge Image Channels"},
"RemapImageRange": {"class": RemapImageRange, "name": "Remap Image Range"},
"ReverseImageBatch": {"class": ReverseImageBatch, "name": "Reverse Image Batch"},
"ReplaceImagesInBatch": {"class": ReplaceImagesInBatch, "name": "Replace Images In Batch"},
"SaveImageWithAlpha": {"class": SaveImageWithAlpha, "name": "Save Image With Alpha"},
"SplitImageChannels": {"class": SplitImageChannels, "name": "Split Image Channels"},
#batch cropping
"BatchCropFromMask": {"class": BatchCropFromMask, "name": "Batch Crop From Mask"},
"BatchCropFromMaskAdvanced": {"class": BatchCropFromMaskAdvanced, "name": "Batch Crop From Mask Advanced"},
"FilterZeroMasksAndCorrespondingImages": {"class": FilterZeroMasksAndCorrespondingImages},
"InsertImageBatchByIndexes": {"class": InsertImageBatchByIndexes, "name": "Insert Image Batch By Indexes"},
"BatchUncrop": {"class": BatchUncrop, "name": "Batch Uncrop"},
"BatchUncropAdvanced": {"class": BatchUncropAdvanced, "name": "Batch Uncrop Advanced"},
"SplitBboxes": {"class": SplitBboxes, "name": "Split Bboxes"},
"BboxToInt": {"class": BboxToInt, "name": "Bbox To Int"},
"BboxVisualize": {"class": BboxVisualize, "name": "Bbox Visualize"},
#noise
"GenerateNoise": {"class": GenerateNoise, "name": "Generate Noise"},
"FlipSigmasAdjusted": {"class": FlipSigmasAdjusted, "name": "Flip Sigmas Adjusted"},
"InjectNoiseToLatent": {"class": InjectNoiseToLatent, "name": "Inject Noise To Latent"},
"CustomSigmas": {"class": CustomSigmas, "name": "Custom Sigmas"},
#utility
"WidgetToString": {"class": WidgetToString, "name": "Widget To String"},
"DummyLatentOut": {"class": DummyLatentOut, "name": "Dummy Latent Out"},
"GetLatentsFromBatchIndexed": {"class": GetLatentsFromBatchIndexed, "name": "Get Latents From Batch Indexed"},
"ScaleBatchPromptSchedule": {"class": ScaleBatchPromptSchedule, "name": "Scale Batch Prompt Schedule"},
"CameraPoseVisualizer": {"class": CameraPoseVisualizer, "name": "Camera Pose Visualizer"},
"JoinStrings": {"class": JoinStrings, "name": "Join Strings"},
"JoinStringMulti": {"class": JoinStringMulti, "name": "Join String Multi"},
"Sleep": {"class": Sleep, "name": "Sleep"},
"VRAM_Debug": {"class": VRAM_Debug, "name": "VRAM Debug"},
"SomethingToString": {"class": SomethingToString, "name": "Something To String"},
"EmptyLatentImagePresets": {"class": EmptyLatentImagePresets, "name": "Empty Latent Image Presets"},
#audioscheduler stuff
"NormalizedAmplitudeToMask": {"class": NormalizedAmplitudeToMask},
"NormalizedAmplitudeToFloatList": {"class": NormalizedAmplitudeToFloatList},
"OffsetMaskByNormalizedAmplitude": {"class": OffsetMaskByNormalizedAmplitude},
"ImageTransformByNormalizedAmplitude": {"class": ImageTransformByNormalizedAmplitude},
#curve nodes
"SplineEditor": {"class": SplineEditor, "name": "Spline Editor"},
"CreateShapeMaskOnPath": {"class": CreateShapeMaskOnPath, "name": "Create Shape Mask On Path"},
"WeightScheduleExtend": {"class": WeightScheduleExtend, "name": "Weight Schedule Extend"},
"MaskOrImageToWeight": {"class": MaskOrImageToWeight, "name": "Mask Or Image To Weight"},
"WeightScheduleConvert": {"class": WeightScheduleConvert, "name": "Weight Schedule Convert"},
"FloatToMask": {"class": FloatToMask, "name": "Float To Mask"},
"FloatToSigmas": {"class": FloatToSigmas, "name": "Float To Sigmas"},
"PlotCoordinates": {"class": PlotCoordinates, "name": "Plot Coordinates"},
"InterpolateCoords": {"class": InterpolateCoords, "name": "Interpolate Coords"},
#experimental
"StabilityAPI_SD3": {"class": StabilityAPI_SD3, "name": "Stability API SD3"},
"SoundReactive": {"class": SoundReactive, "name": "Sound Reactive"},
"StableZero123_BatchSchedule": {"class": StableZero123_BatchSchedule, "name": "Stable Zero123 Batch Schedule"},
"SV3D_BatchSchedule": {"class": SV3D_BatchSchedule, "name": "SV3D Batch Schedule"},
"LoadResAdapterNormalization": {"class": LoadResAdapterNormalization},
"Superprompt": {"class": Superprompt, "name": "Superprompt"},
"GLIGENTextBoxApplyBatchCoords": {"class": GLIGENTextBoxApplyBatchCoords},
"Intrinsic_lora_sampling": {"class": Intrinsic_lora_sampling, "name": "Intrinsic Lora Sampling"},
#instance diffusion
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},
"AppendInstanceDiffusionTracking": {"class": AppendInstanceDiffusionTracking},
"DrawInstanceDiffusionTracking": {"class": DrawInstanceDiffusionTracking},
}
def generate_node_mappings(node_config):
node_class_mappings = {}
node_display_name_mappings = {}
for node_name, node_info in node_config.items():
node_class_mappings[node_name] = node_info["class"]
node_display_name_mappings[node_name] = node_info.get("name", node_info["class"].__name__)
return node_class_mappings, node_display_name_mappings
NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS = generate_node_mappings(NODE_CONFIG)
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS", "WEB_DIRECTORY"]
WEB_DIRECTORY = "./web"

BIN
audio.wav

Binary file not shown.

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.0 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1006 B

File diff suppressed because one or more lines are too long

5038
nodes.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,251 @@
# to be used with https://github.com/a1lazydog/ComfyUI-AudioScheduler
import torch
from torchvision.transforms import functional as TF
from PIL import Image, ImageDraw
import numpy as np
from ..utility.utility import pil2tensor
from nodes import MAX_RESOLUTION
class NormalizedAmplitudeToMask:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"normalized_amp": ("NORMALIZED_AMPLITUDE",),
"width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
"height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
"frame_offset": ("INT", {"default": 0,"min": -255, "max": 255, "step": 1}),
"location_x": ("INT", {"default": 256,"min": 0, "max": 4096, "step": 1}),
"location_y": ("INT", {"default": 256,"min": 0, "max": 4096, "step": 1}),
"size": ("INT", {"default": 128,"min": 8, "max": 4096, "step": 1}),
"shape": (
[
'none',
'circle',
'square',
'triangle',
],
{
"default": 'none'
}),
"color": (
[
'white',
'amplitude',
],
{
"default": 'amplitude'
}),
},}
CATEGORY = "KJNodes/audio"
RETURN_TYPES = ("MASK",)
FUNCTION = "convert"
DESCRIPTION = """
Works as a bridge to the AudioScheduler -nodes:
https://github.com/a1lazydog/ComfyUI-AudioScheduler
Creates masks based on the normalized amplitude.
"""
def convert(self, normalized_amp, width, height, frame_offset, shape, location_x, location_y, size, color):
# Ensure normalized_amp is an array and within the range [0, 1]
normalized_amp = np.clip(normalized_amp, 0.0, 1.0)
# Offset the amplitude values by rolling the array
normalized_amp = np.roll(normalized_amp, frame_offset)
# Initialize an empty list to hold the image tensors
out = []
# Iterate over each amplitude value to create an image
for amp in normalized_amp:
# Scale the amplitude value to cover the full range of grayscale values
if color == 'amplitude':
grayscale_value = int(amp * 255)
elif color == 'white':
grayscale_value = 255
# Convert the grayscale value to an RGB format
gray_color = (grayscale_value, grayscale_value, grayscale_value)
finalsize = size * amp
if shape == 'none':
shapeimage = Image.new("RGB", (width, height), gray_color)
else:
shapeimage = Image.new("RGB", (width, height), "black")
draw = ImageDraw.Draw(shapeimage)
if shape == 'circle' or shape == 'square':
# Define the bounding box for the shape
left_up_point = (location_x - finalsize, location_y - finalsize)
right_down_point = (location_x + finalsize,location_y + finalsize)
two_points = [left_up_point, right_down_point]
if shape == 'circle':
draw.ellipse(two_points, fill=gray_color)
elif shape == 'square':
draw.rectangle(two_points, fill=gray_color)
elif shape == 'triangle':
# Define the points for the triangle
left_up_point = (location_x - finalsize, location_y + finalsize) # bottom left
right_down_point = (location_x + finalsize, location_y + finalsize) # bottom right
top_point = (location_x, location_y) # top point
draw.polygon([top_point, left_up_point, right_down_point], fill=gray_color)
shapeimage = pil2tensor(shapeimage)
mask = shapeimage[:, :, :, 0]
out.append(mask)
return (torch.cat(out, dim=0),)
class NormalizedAmplitudeToFloatList:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"normalized_amp": ("NORMALIZED_AMPLITUDE",),
},}
CATEGORY = "KJNodes/audio"
RETURN_TYPES = ("FLOAT",)
FUNCTION = "convert"
DESCRIPTION = """
Works as a bridge to the AudioScheduler -nodes:
https://github.com/a1lazydog/ComfyUI-AudioScheduler
Creates a list of floats from the normalized amplitude.
"""
def convert(self, normalized_amp):
# Ensure normalized_amp is an array and within the range [0, 1]
normalized_amp = np.clip(normalized_amp, 0.0, 1.0)
return (normalized_amp.tolist(),)
class OffsetMaskByNormalizedAmplitude:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"normalized_amp": ("NORMALIZED_AMPLITUDE",),
"mask": ("MASK",),
"x": ("INT", { "default": 0, "min": -4096, "max": MAX_RESOLUTION, "step": 1, "display": "number" }),
"y": ("INT", { "default": 0, "min": -4096, "max": MAX_RESOLUTION, "step": 1, "display": "number" }),
"rotate": ("BOOLEAN", { "default": False }),
"angle_multiplier": ("FLOAT", { "default": 0.0, "min": -1.0, "max": 1.0, "step": 0.001, "display": "number" }),
}
}
RETURN_TYPES = ("MASK",)
RETURN_NAMES = ("mask",)
FUNCTION = "offset"
CATEGORY = "KJNodes/audio"
DESCRIPTION = """
Works as a bridge to the AudioScheduler -nodes:
https://github.com/a1lazydog/ComfyUI-AudioScheduler
Offsets masks based on the normalized amplitude.
"""
def offset(self, mask, x, y, angle_multiplier, rotate, normalized_amp):
# Ensure normalized_amp is an array and within the range [0, 1]
offsetmask = mask.clone()
normalized_amp = np.clip(normalized_amp, 0.0, 1.0)
batch_size, height, width = mask.shape
if rotate:
for i in range(batch_size):
rotation_amp = int(normalized_amp[i] * (360 * angle_multiplier))
rotation_angle = rotation_amp
offsetmask[i] = TF.rotate(offsetmask[i].unsqueeze(0), rotation_angle).squeeze(0)
if x != 0 or y != 0:
for i in range(batch_size):
offset_amp = normalized_amp[i] * 10
shift_x = min(x*offset_amp, width-1)
shift_y = min(y*offset_amp, height-1)
if shift_x != 0:
offsetmask[i] = torch.roll(offsetmask[i], shifts=int(shift_x), dims=1)
if shift_y != 0:
offsetmask[i] = torch.roll(offsetmask[i], shifts=int(shift_y), dims=0)
return offsetmask,
class ImageTransformByNormalizedAmplitude:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"normalized_amp": ("NORMALIZED_AMPLITUDE",),
"zoom_scale": ("FLOAT", { "default": 0.0, "min": -1.0, "max": 1.0, "step": 0.001, "display": "number" }),
"x_offset": ("INT", { "default": 0, "min": (1 -MAX_RESOLUTION), "max": MAX_RESOLUTION, "step": 1, "display": "number" }),
"y_offset": ("INT", { "default": 0, "min": (1 -MAX_RESOLUTION), "max": MAX_RESOLUTION, "step": 1, "display": "number" }),
"cumulative": ("BOOLEAN", { "default": False }),
"image": ("IMAGE",),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "amptransform"
CATEGORY = "KJNodes/audio"
DESCRIPTION = """
Works as a bridge to the AudioScheduler -nodes:
https://github.com/a1lazydog/ComfyUI-AudioScheduler
Transforms image based on the normalized amplitude.
"""
def amptransform(self, image, normalized_amp, zoom_scale, cumulative, x_offset, y_offset):
# Ensure normalized_amp is an array and within the range [0, 1]
normalized_amp = np.clip(normalized_amp, 0.0, 1.0)
transformed_images = []
# Initialize the cumulative zoom factor
prev_amp = 0.0
for i in range(image.shape[0]):
img = image[i] # Get the i-th image in the batch
amp = normalized_amp[i] # Get the corresponding amplitude value
# Incrementally increase the cumulative zoom factor
if cumulative:
prev_amp += amp
amp += prev_amp
# Convert the image tensor from BxHxWxC to CxHxW format expected by torchvision
img = img.permute(2, 0, 1)
# Convert PyTorch tensor to PIL Image for processing
pil_img = TF.to_pil_image(img)
# Calculate the crop size based on the amplitude
width, height = pil_img.size
crop_size = int(min(width, height) * (1 - amp * zoom_scale))
crop_size = max(crop_size, 1)
# Calculate the crop box coordinates (centered crop)
left = (width - crop_size) // 2
top = (height - crop_size) // 2
right = (width + crop_size) // 2
bottom = (height + crop_size) // 2
# Crop and resize back to original size
cropped_img = TF.crop(pil_img, top, left, crop_size, crop_size)
resized_img = TF.resize(cropped_img, (height, width))
# Convert back to tensor in CxHxW format
tensor_img = TF.to_tensor(resized_img)
# Convert the tensor back to BxHxWxC format
tensor_img = tensor_img.permute(1, 2, 0)
# Offset the image based on the amplitude
offset_amp = amp * 10 # Calculate the offset magnitude based on the amplitude
shift_x = min(x_offset * offset_amp, img.shape[1] - 1) # Calculate the shift in x direction
shift_y = min(y_offset * offset_amp, img.shape[0] - 1) # Calculate the shift in y direction
# Apply the offset to the image tensor
if shift_x != 0:
tensor_img = torch.roll(tensor_img, shifts=int(shift_x), dims=1)
if shift_y != 0:
tensor_img = torch.roll(tensor_img, shifts=int(shift_y), dims=0)
# Add to the list
transformed_images.append(tensor_img)
# Stack all transformed images into a batch
transformed_batch = torch.stack(transformed_images)
return (transformed_batch,)

737
nodes/batchcrop_nodes.py Normal file
View File

@ -0,0 +1,737 @@
from ..utility.utility import tensor2pil, pil2tensor
from PIL import Image, ImageDraw, ImageFilter
import numpy as np
import torch
from torchvision.transforms import Resize, CenterCrop, InterpolationMode
import math
#based on nodes from mtb https://github.com/melMass/comfy_mtb
def bbox_to_region(bbox, target_size=None):
bbox = bbox_check(bbox, target_size)
return (bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3])
def bbox_check(bbox, target_size=None):
if not target_size:
return bbox
new_bbox = (
bbox[0],
bbox[1],
min(target_size[0] - bbox[0], bbox[2]),
min(target_size[1] - bbox[1], bbox[3]),
)
return new_bbox
class BatchCropFromMask:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"original_images": ("IMAGE",),
"masks": ("MASK",),
"crop_size_mult": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}),
"bbox_smooth_alpha": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
},
}
RETURN_TYPES = (
"IMAGE",
"IMAGE",
"BBOX",
"INT",
"INT",
)
RETURN_NAMES = (
"original_images",
"cropped_images",
"bboxes",
"width",
"height",
)
FUNCTION = "crop"
CATEGORY = "KJNodes/masking"
def smooth_bbox_size(self, prev_bbox_size, curr_bbox_size, alpha):
if alpha == 0:
return prev_bbox_size
return round(alpha * curr_bbox_size + (1 - alpha) * prev_bbox_size)
def smooth_center(self, prev_center, curr_center, alpha=0.5):
if alpha == 0:
return prev_center
return (
round(alpha * curr_center[0] + (1 - alpha) * prev_center[0]),
round(alpha * curr_center[1] + (1 - alpha) * prev_center[1])
)
def crop(self, masks, original_images, crop_size_mult, bbox_smooth_alpha):
bounding_boxes = []
cropped_images = []
self.max_bbox_width = 0
self.max_bbox_height = 0
# First, calculate the maximum bounding box size across all masks
curr_max_bbox_width = 0
curr_max_bbox_height = 0
for mask in masks:
_mask = tensor2pil(mask)[0]
non_zero_indices = np.nonzero(np.array(_mask))
min_x, max_x = np.min(non_zero_indices[1]), np.max(non_zero_indices[1])
min_y, max_y = np.min(non_zero_indices[0]), np.max(non_zero_indices[0])
width = max_x - min_x
height = max_y - min_y
curr_max_bbox_width = max(curr_max_bbox_width, width)
curr_max_bbox_height = max(curr_max_bbox_height, height)
# Smooth the changes in the bounding box size
self.max_bbox_width = self.smooth_bbox_size(self.max_bbox_width, curr_max_bbox_width, bbox_smooth_alpha)
self.max_bbox_height = self.smooth_bbox_size(self.max_bbox_height, curr_max_bbox_height, bbox_smooth_alpha)
# Apply the crop size multiplier
self.max_bbox_width = round(self.max_bbox_width * crop_size_mult)
self.max_bbox_height = round(self.max_bbox_height * crop_size_mult)
bbox_aspect_ratio = self.max_bbox_width / self.max_bbox_height
# Then, for each mask and corresponding image...
for i, (mask, img) in enumerate(zip(masks, original_images)):
_mask = tensor2pil(mask)[0]
non_zero_indices = np.nonzero(np.array(_mask))
min_x, max_x = np.min(non_zero_indices[1]), np.max(non_zero_indices[1])
min_y, max_y = np.min(non_zero_indices[0]), np.max(non_zero_indices[0])
# Calculate center of bounding box
center_x = np.mean(non_zero_indices[1])
center_y = np.mean(non_zero_indices[0])
curr_center = (round(center_x), round(center_y))
# If this is the first frame, initialize prev_center with curr_center
if not hasattr(self, 'prev_center'):
self.prev_center = curr_center
# Smooth the changes in the center coordinates from the second frame onwards
if i > 0:
center = self.smooth_center(self.prev_center, curr_center, bbox_smooth_alpha)
else:
center = curr_center
# Update prev_center for the next frame
self.prev_center = center
# Create bounding box using max_bbox_width and max_bbox_height
half_box_width = round(self.max_bbox_width / 2)
half_box_height = round(self.max_bbox_height / 2)
min_x = max(0, center[0] - half_box_width)
max_x = min(img.shape[1], center[0] + half_box_width)
min_y = max(0, center[1] - half_box_height)
max_y = min(img.shape[0], center[1] + half_box_height)
# Append bounding box coordinates
bounding_boxes.append((min_x, min_y, max_x - min_x, max_y - min_y))
# Crop the image from the bounding box
cropped_img = img[min_y:max_y, min_x:max_x, :]
# Calculate the new dimensions while maintaining the aspect ratio
new_height = min(cropped_img.shape[0], self.max_bbox_height)
new_width = round(new_height * bbox_aspect_ratio)
# Resize the image
resize_transform = Resize((new_height, new_width))
resized_img = resize_transform(cropped_img.permute(2, 0, 1))
# Perform the center crop to the desired size
crop_transform = CenterCrop((self.max_bbox_height, self.max_bbox_width)) # swap the order here if necessary
cropped_resized_img = crop_transform(resized_img)
cropped_images.append(cropped_resized_img.permute(1, 2, 0))
cropped_out = torch.stack(cropped_images, dim=0)
return (original_images, cropped_out, bounding_boxes, self.max_bbox_width, self.max_bbox_height, )
class BatchUncrop:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"original_images": ("IMAGE",),
"cropped_images": ("IMAGE",),
"bboxes": ("BBOX",),
"border_blending": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01}, ),
"crop_rescale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"border_top": ("BOOLEAN", {"default": True}),
"border_bottom": ("BOOLEAN", {"default": True}),
"border_left": ("BOOLEAN", {"default": True}),
"border_right": ("BOOLEAN", {"default": True}),
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "uncrop"
CATEGORY = "KJNodes/masking"
def uncrop(self, original_images, cropped_images, bboxes, border_blending, crop_rescale, border_top, border_bottom, border_left, border_right):
def inset_border(image, border_width, border_color, border_top, border_bottom, border_left, border_right):
draw = ImageDraw.Draw(image)
width, height = image.size
if border_top:
draw.rectangle((0, 0, width, border_width), fill=border_color)
if border_bottom:
draw.rectangle((0, height - border_width, width, height), fill=border_color)
if border_left:
draw.rectangle((0, 0, border_width, height), fill=border_color)
if border_right:
draw.rectangle((width - border_width, 0, width, height), fill=border_color)
return image
if len(original_images) != len(cropped_images):
raise ValueError(f"The number of original_images ({len(original_images)}) and cropped_images ({len(cropped_images)}) should be the same")
# Ensure there are enough bboxes, but drop the excess if there are more bboxes than images
if len(bboxes) > len(original_images):
print(f"Warning: Dropping excess bounding boxes. Expected {len(original_images)}, but got {len(bboxes)}")
bboxes = bboxes[:len(original_images)]
elif len(bboxes) < len(original_images):
raise ValueError("There should be at least as many bboxes as there are original and cropped images")
input_images = tensor2pil(original_images)
crop_imgs = tensor2pil(cropped_images)
out_images = []
for i in range(len(input_images)):
img = input_images[i]
crop = crop_imgs[i]
bbox = bboxes[i]
# uncrop the image based on the bounding box
bb_x, bb_y, bb_width, bb_height = bbox
paste_region = bbox_to_region((bb_x, bb_y, bb_width, bb_height), img.size)
# scale factors
scale_x = crop_rescale
scale_y = crop_rescale
# scaled paste_region
paste_region = (round(paste_region[0]*scale_x), round(paste_region[1]*scale_y), round(paste_region[2]*scale_x), round(paste_region[3]*scale_y))
# rescale the crop image to fit the paste_region
crop = crop.resize((round(paste_region[2]-paste_region[0]), round(paste_region[3]-paste_region[1])))
crop_img = crop.convert("RGB")
if border_blending > 1.0:
border_blending = 1.0
elif border_blending < 0.0:
border_blending = 0.0
blend_ratio = (max(crop_img.size) / 2) * float(border_blending)
blend = img.convert("RGBA")
mask = Image.new("L", img.size, 0)
mask_block = Image.new("L", (paste_region[2]-paste_region[0], paste_region[3]-paste_region[1]), 255)
mask_block = inset_border(mask_block, round(blend_ratio / 2), (0), border_top, border_bottom, border_left, border_right)
mask.paste(mask_block, paste_region)
blend.paste(crop_img, paste_region)
mask = mask.filter(ImageFilter.BoxBlur(radius=blend_ratio / 4))
mask = mask.filter(ImageFilter.GaussianBlur(radius=blend_ratio / 4))
blend.putalpha(mask)
img = Image.alpha_composite(img.convert("RGBA"), blend)
out_images.append(img.convert("RGB"))
return (pil2tensor(out_images),)
class BatchCropFromMaskAdvanced:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"original_images": ("IMAGE",),
"masks": ("MASK",),
"crop_size_mult": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"bbox_smooth_alpha": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
},
}
RETURN_TYPES = (
"IMAGE",
"IMAGE",
"MASK",
"IMAGE",
"MASK",
"BBOX",
"BBOX",
"INT",
"INT",
)
RETURN_NAMES = (
"original_images",
"cropped_images",
"cropped_masks",
"combined_crop_image",
"combined_crop_masks",
"bboxes",
"combined_bounding_box",
"bbox_width",
"bbox_height",
)
FUNCTION = "crop"
CATEGORY = "KJNodes/masking"
def smooth_bbox_size(self, prev_bbox_size, curr_bbox_size, alpha):
return round(alpha * curr_bbox_size + (1 - alpha) * prev_bbox_size)
def smooth_center(self, prev_center, curr_center, alpha=0.5):
return (round(alpha * curr_center[0] + (1 - alpha) * prev_center[0]),
round(alpha * curr_center[1] + (1 - alpha) * prev_center[1]))
def crop(self, masks, original_images, crop_size_mult, bbox_smooth_alpha):
bounding_boxes = []
combined_bounding_box = []
cropped_images = []
cropped_masks = []
cropped_masks_out = []
combined_crop_out = []
combined_cropped_images = []
combined_cropped_masks = []
def calculate_bbox(mask):
non_zero_indices = np.nonzero(np.array(mask))
# handle empty masks
min_x, max_x, min_y, max_y = 0, 0, 0, 0
if len(non_zero_indices[1]) > 0 and len(non_zero_indices[0]) > 0:
min_x, max_x = np.min(non_zero_indices[1]), np.max(non_zero_indices[1])
min_y, max_y = np.min(non_zero_indices[0]), np.max(non_zero_indices[0])
width = max_x - min_x
height = max_y - min_y
bbox_size = max(width, height)
return min_x, max_x, min_y, max_y, bbox_size
combined_mask = torch.max(masks, dim=0)[0]
_mask = tensor2pil(combined_mask)[0]
new_min_x, new_max_x, new_min_y, new_max_y, combined_bbox_size = calculate_bbox(_mask)
center_x = (new_min_x + new_max_x) / 2
center_y = (new_min_y + new_max_y) / 2
half_box_size = round(combined_bbox_size // 2)
new_min_x = max(0, round(center_x - half_box_size))
new_max_x = min(original_images[0].shape[1], round(center_x + half_box_size))
new_min_y = max(0, round(center_y - half_box_size))
new_max_y = min(original_images[0].shape[0], round(center_y + half_box_size))
combined_bounding_box.append((new_min_x, new_min_y, new_max_x - new_min_x, new_max_y - new_min_y))
self.max_bbox_size = 0
# First, calculate the maximum bounding box size across all masks
curr_max_bbox_size = max(calculate_bbox(tensor2pil(mask)[0])[-1] for mask in masks)
# Smooth the changes in the bounding box size
self.max_bbox_size = self.smooth_bbox_size(self.max_bbox_size, curr_max_bbox_size, bbox_smooth_alpha)
# Apply the crop size multiplier
self.max_bbox_size = round(self.max_bbox_size * crop_size_mult)
# Make sure max_bbox_size is divisible by 16, if not, round it upwards so it is
self.max_bbox_size = math.ceil(self.max_bbox_size / 16) * 16
if self.max_bbox_size > original_images[0].shape[0] or self.max_bbox_size > original_images[0].shape[1]:
# max_bbox_size can only be as big as our input's width or height, and it has to be even
self.max_bbox_size = math.floor(min(original_images[0].shape[0], original_images[0].shape[1]) / 2) * 2
# Then, for each mask and corresponding image...
for i, (mask, img) in enumerate(zip(masks, original_images)):
_mask = tensor2pil(mask)[0]
non_zero_indices = np.nonzero(np.array(_mask))
# check for empty masks
if len(non_zero_indices[0]) > 0 and len(non_zero_indices[1]) > 0:
min_x, max_x = np.min(non_zero_indices[1]), np.max(non_zero_indices[1])
min_y, max_y = np.min(non_zero_indices[0]), np.max(non_zero_indices[0])
# Calculate center of bounding box
center_x = np.mean(non_zero_indices[1])
center_y = np.mean(non_zero_indices[0])
curr_center = (round(center_x), round(center_y))
# If this is the first frame, initialize prev_center with curr_center
if not hasattr(self, 'prev_center'):
self.prev_center = curr_center
# Smooth the changes in the center coordinates from the second frame onwards
if i > 0:
center = self.smooth_center(self.prev_center, curr_center, bbox_smooth_alpha)
else:
center = curr_center
# Update prev_center for the next frame
self.prev_center = center
# Create bounding box using max_bbox_size
half_box_size = self.max_bbox_size // 2
min_x = max(0, center[0] - half_box_size)
max_x = min(img.shape[1], center[0] + half_box_size)
min_y = max(0, center[1] - half_box_size)
max_y = min(img.shape[0], center[1] + half_box_size)
# Append bounding box coordinates
bounding_boxes.append((min_x, min_y, max_x - min_x, max_y - min_y))
# Crop the image from the bounding box
cropped_img = img[min_y:max_y, min_x:max_x, :]
cropped_mask = mask[min_y:max_y, min_x:max_x]
# Resize the cropped image to a fixed size
new_size = max(cropped_img.shape[0], cropped_img.shape[1])
resize_transform = Resize(new_size, interpolation=InterpolationMode.NEAREST, max_size=max(img.shape[0], img.shape[1]))
resized_mask = resize_transform(cropped_mask.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0)
resized_img = resize_transform(cropped_img.permute(2, 0, 1))
# Perform the center crop to the desired size
# Constrain the crop to the smaller of our bbox or our image so we don't expand past the image dimensions.
crop_transform = CenterCrop((min(self.max_bbox_size, resized_img.shape[1]), min(self.max_bbox_size, resized_img.shape[2])))
cropped_resized_img = crop_transform(resized_img)
cropped_images.append(cropped_resized_img.permute(1, 2, 0))
cropped_resized_mask = crop_transform(resized_mask)
cropped_masks.append(cropped_resized_mask)
combined_cropped_img = original_images[i][new_min_y:new_max_y, new_min_x:new_max_x, :]
combined_cropped_images.append(combined_cropped_img)
combined_cropped_mask = masks[i][new_min_y:new_max_y, new_min_x:new_max_x]
combined_cropped_masks.append(combined_cropped_mask)
else:
bounding_boxes.append((0, 0, img.shape[1], img.shape[0]))
cropped_images.append(img)
cropped_masks.append(mask)
combined_cropped_images.append(img)
combined_cropped_masks.append(mask)
cropped_out = torch.stack(cropped_images, dim=0)
combined_crop_out = torch.stack(combined_cropped_images, dim=0)
cropped_masks_out = torch.stack(cropped_masks, dim=0)
combined_crop_mask_out = torch.stack(combined_cropped_masks, dim=0)
return (original_images, cropped_out, cropped_masks_out, combined_crop_out, combined_crop_mask_out, bounding_boxes, combined_bounding_box, self.max_bbox_size, self.max_bbox_size)
class FilterZeroMasksAndCorrespondingImages:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"masks": ("MASK",),
},
"optional": {
"original_images": ("IMAGE",),
},
}
RETURN_TYPES = ("MASK", "IMAGE", "IMAGE", "INDEXES",)
RETURN_NAMES = ("non_zero_masks_out", "non_zero_mask_images_out", "zero_mask_images_out", "zero_mask_images_out_indexes",)
FUNCTION = "filter"
CATEGORY = "KJNodes/masking"
DESCRIPTION = """
Filter out all the empty (i.e. all zero) mask in masks
Also filter out all the corresponding images in original_images by indexes if provide
original_images (optional): If provided, need have same length as masks.
"""
def filter(self, masks, original_images=None):
non_zero_masks = []
non_zero_mask_images = []
zero_mask_images = []
zero_mask_images_indexes = []
masks_num = len(masks)
also_process_images = False
if original_images is not None:
imgs_num = len(original_images)
if len(original_images) == masks_num:
also_process_images = True
else:
print(f"[WARNING] ignore input: original_images, due to number of original_images ({imgs_num}) is not equal to number of masks ({masks_num})")
for i in range(masks_num):
non_zero_num = np.count_nonzero(np.array(masks[i]))
if non_zero_num > 0:
non_zero_masks.append(masks[i])
if also_process_images:
non_zero_mask_images.append(original_images[i])
else:
zero_mask_images.append(original_images[i])
zero_mask_images_indexes.append(i)
non_zero_masks_out = torch.stack(non_zero_masks, dim=0)
non_zero_mask_images_out = zero_mask_images_out = zero_mask_images_out_indexes = None
if also_process_images:
non_zero_mask_images_out = torch.stack(non_zero_mask_images, dim=0)
if len(zero_mask_images) > 0:
zero_mask_images_out = torch.stack(zero_mask_images, dim=0)
zero_mask_images_out_indexes = zero_mask_images_indexes
return (non_zero_masks_out, non_zero_mask_images_out, zero_mask_images_out, zero_mask_images_out_indexes)
class InsertImageBatchByIndexes:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
"images_to_insert": ("IMAGE",),
"insert_indexes": ("INDEXES",),
},
}
RETURN_TYPES = ("IMAGE", )
RETURN_NAMES = ("images_after_insert", )
FUNCTION = "insert"
CATEGORY = "KJNodes/image"
DESCRIPTION = """
This node is designed to be use with node FilterZeroMasksAndCorrespondingImages
It inserts the images_to_insert into images according to insert_indexes
Returns:
images_after_insert: updated original images with origonal sequence order
"""
def insert(self, images, images_to_insert, insert_indexes):
images_after_insert = images
if images_to_insert is not None and insert_indexes is not None:
images_to_insert_num = len(images_to_insert)
insert_indexes_num = len(insert_indexes)
if images_to_insert_num == insert_indexes_num:
images_after_insert = []
i_images = 0
for i in range(len(images) + images_to_insert_num):
if i in insert_indexes:
images_after_insert.append(images_to_insert[insert_indexes.index(i)])
else:
images_after_insert.append(images[i_images])
i_images += 1
images_after_insert = torch.stack(images_after_insert, dim=0)
else:
print(f"[WARNING] skip this node, due to number of images_to_insert ({images_to_insert_num}) is not equal to number of insert_indexes ({insert_indexes_num})")
return (images_after_insert, )
class BatchUncropAdvanced:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"original_images": ("IMAGE",),
"cropped_images": ("IMAGE",),
"cropped_masks": ("MASK",),
"combined_crop_mask": ("MASK",),
"bboxes": ("BBOX",),
"border_blending": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01}, ),
"crop_rescale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"use_combined_mask": ("BOOLEAN", {"default": False}),
"use_square_mask": ("BOOLEAN", {"default": True}),
},
"optional": {
"combined_bounding_box": ("BBOX", {"default": None}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "uncrop"
CATEGORY = "KJNodes/masking"
def uncrop(self, original_images, cropped_images, cropped_masks, combined_crop_mask, bboxes, border_blending, crop_rescale, use_combined_mask, use_square_mask, combined_bounding_box = None):
def inset_border(image, border_width=20, border_color=(0)):
width, height = image.size
bordered_image = Image.new(image.mode, (width, height), border_color)
bordered_image.paste(image, (0, 0))
draw = ImageDraw.Draw(bordered_image)
draw.rectangle((0, 0, width - 1, height - 1), outline=border_color, width=border_width)
return bordered_image
if len(original_images) != len(cropped_images):
raise ValueError(f"The number of original_images ({len(original_images)}) and cropped_images ({len(cropped_images)}) should be the same")
# Ensure there are enough bboxes, but drop the excess if there are more bboxes than images
if len(bboxes) > len(original_images):
print(f"Warning: Dropping excess bounding boxes. Expected {len(original_images)}, but got {len(bboxes)}")
bboxes = bboxes[:len(original_images)]
elif len(bboxes) < len(original_images):
raise ValueError("There should be at least as many bboxes as there are original and cropped images")
crop_imgs = tensor2pil(cropped_images)
input_images = tensor2pil(original_images)
out_images = []
for i in range(len(input_images)):
img = input_images[i]
crop = crop_imgs[i]
bbox = bboxes[i]
if use_combined_mask:
bb_x, bb_y, bb_width, bb_height = combined_bounding_box[0]
paste_region = bbox_to_region((bb_x, bb_y, bb_width, bb_height), img.size)
mask = combined_crop_mask[i]
else:
bb_x, bb_y, bb_width, bb_height = bbox
paste_region = bbox_to_region((bb_x, bb_y, bb_width, bb_height), img.size)
mask = cropped_masks[i]
# scale paste_region
scale_x = scale_y = crop_rescale
paste_region = (round(paste_region[0]*scale_x), round(paste_region[1]*scale_y), round(paste_region[2]*scale_x), round(paste_region[3]*scale_y))
# rescale the crop image to fit the paste_region
crop = crop.resize((round(paste_region[2]-paste_region[0]), round(paste_region[3]-paste_region[1])))
crop_img = crop.convert("RGB")
#border blending
if border_blending > 1.0:
border_blending = 1.0
elif border_blending < 0.0:
border_blending = 0.0
blend_ratio = (max(crop_img.size) / 2) * float(border_blending)
blend = img.convert("RGBA")
if use_square_mask:
mask = Image.new("L", img.size, 0)
mask_block = Image.new("L", (paste_region[2]-paste_region[0], paste_region[3]-paste_region[1]), 255)
mask_block = inset_border(mask_block, round(blend_ratio / 2), (0))
mask.paste(mask_block, paste_region)
else:
original_mask = tensor2pil(mask)[0]
original_mask = original_mask.resize((paste_region[2]-paste_region[0], paste_region[3]-paste_region[1]))
mask = Image.new("L", img.size, 0)
mask.paste(original_mask, paste_region)
mask = mask.filter(ImageFilter.BoxBlur(radius=blend_ratio / 4))
mask = mask.filter(ImageFilter.GaussianBlur(radius=blend_ratio / 4))
blend.paste(crop_img, paste_region)
blend.putalpha(mask)
img = Image.alpha_composite(img.convert("RGBA"), blend)
out_images.append(img.convert("RGB"))
return (pil2tensor(out_images),)
class SplitBboxes:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"bboxes": ("BBOX",),
"index": ("INT", {"default": 0,"min": 0, "max": 99999999, "step": 1}),
},
}
RETURN_TYPES = ("BBOX","BBOX",)
RETURN_NAMES = ("bboxes_a","bboxes_b",)
FUNCTION = "splitbbox"
CATEGORY = "KJNodes/masking"
DESCRIPTION = """
Splits the specified bbox list at the given index into two lists.
"""
def splitbbox(self, bboxes, index):
bboxes_a = bboxes[:index] # Sub-list from the start of bboxes up to (but not including) the index
bboxes_b = bboxes[index:] # Sub-list from the index to the end of bboxes
return (bboxes_a, bboxes_b,)
class BboxToInt:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"bboxes": ("BBOX",),
"index": ("INT", {"default": 0,"min": 0, "max": 99999999, "step": 1}),
},
}
RETURN_TYPES = ("INT","INT","INT","INT","INT","INT",)
RETURN_NAMES = ("x_min","y_min","width","height", "center_x","center_y",)
FUNCTION = "bboxtoint"
CATEGORY = "KJNodes/masking"
DESCRIPTION = """
Returns selected index from bounding box list as integers.
"""
def bboxtoint(self, bboxes, index):
x_min, y_min, width, height = bboxes[index]
center_x = int(x_min + width / 2)
center_y = int(y_min + height / 2)
return (x_min, y_min, width, height, center_x, center_y,)
class BboxVisualize:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
"bboxes": ("BBOX",),
"line_width": ("INT", {"default": 1,"min": 1, "max": 10, "step": 1}),
},
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("images",)
FUNCTION = "visualizebbox"
DESCRIPTION = """
Visualizes the specified bbox on the image.
"""
CATEGORY = "KJNodes/masking"
def visualizebbox(self, bboxes, images, line_width):
image_list = []
for image, bbox in zip(images, bboxes):
x_min, y_min, width, height = bbox
image = image.permute(2, 0, 1)
img_with_bbox = image.clone()
# Define the color for the bbox, e.g., red
color = torch.tensor([1, 0, 0], dtype=torch.float32)
# Draw lines for each side of the bbox with the specified line width
for lw in range(line_width):
# Top horizontal line
img_with_bbox[:, y_min + lw, x_min:x_min + width] = color[:, None]
# Bottom horizontal line
img_with_bbox[:, y_min + height - lw, x_min:x_min + width] = color[:, None]
# Left vertical line
img_with_bbox[:, y_min:y_min + height, x_min + lw] = color[:, None]
# Right vertical line
img_with_bbox[:, y_min:y_min + height, x_min + width - lw] = color[:, None]
img_with_bbox = img_with_bbox.permute(1, 2, 0).unsqueeze(0)
image_list.append(img_with_bbox)
return (torch.cat(image_list, dim=0),)

945
nodes/curve_nodes.py Normal file
View File

@ -0,0 +1,945 @@
import torch
from torchvision import transforms
import json
from PIL import Image, ImageDraw, ImageFont
import numpy as np
from ..utility.utility import pil2tensor
import folder_paths
def plot_coordinates_to_tensor(coordinates, height, width, bbox_height, bbox_width, size_multiplier, prompt):
import matplotlib
matplotlib.use('Agg')
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
text_color = '#999999'
bg_color = '#353535'
matplotlib.pyplot.rcParams['text.color'] = text_color
fig, ax = matplotlib.pyplot.subplots(figsize=(width/100, height/100), dpi=100)
fig.patch.set_facecolor(bg_color)
ax.set_facecolor(bg_color)
ax.grid(color=text_color, linestyle='-', linewidth=0.5)
ax.set_xlabel('x', color=text_color)
ax.set_ylabel('y', color=text_color)
for text in ax.get_xticklabels() + ax.get_yticklabels():
text.set_color(text_color)
ax.set_title('position for: ' + prompt)
ax.set_xlabel('X Coordinate')
ax.set_ylabel('Y Coordinate')
#ax.legend().remove()
ax.set_xlim(0, width) # Set the x-axis to match the input latent width
ax.set_ylim(height, 0) # Set the y-axis to match the input latent height, with (0,0) at top-left
# Adjust the margins of the subplot
matplotlib.pyplot.subplots_adjust(left=0.08, right=0.95, bottom=0.05, top=0.95, wspace=0.2, hspace=0.2)
cmap = matplotlib.pyplot.get_cmap('rainbow')
image_batch = []
canvas = FigureCanvas(fig)
width, height = fig.get_size_inches() * fig.get_dpi()
# Draw a box at each coordinate
for i, ((x, y), size) in enumerate(zip(coordinates, size_multiplier)):
color_index = i / (len(coordinates) - 1)
color = cmap(color_index)
draw_height = bbox_height * size
draw_width = bbox_width * size
rect = matplotlib.patches.Rectangle((x - draw_width/2, y - draw_height/2), draw_width, draw_height,
linewidth=1, edgecolor=color, facecolor='none', alpha=0.5)
ax.add_patch(rect)
# Check if there is a next coordinate to draw an arrow to
if i < len(coordinates) - 1:
x1, y1 = coordinates[i]
x2, y2 = coordinates[i + 1]
ax.annotate("", xy=(x2, y2), xytext=(x1, y1),
arrowprops=dict(arrowstyle="->",
linestyle="-",
lw=1,
color=color,
mutation_scale=20))
canvas.draw()
image_np = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape(int(height), int(width), 3).copy()
image_tensor = torch.from_numpy(image_np).float() / 255.0
image_tensor = image_tensor.unsqueeze(0)
image_batch.append(image_tensor)
matplotlib.pyplot.close(fig)
image_batch_tensor = torch.cat(image_batch, dim=0)
return image_batch_tensor
class PlotCoordinates:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"coordinates": ("STRING", {"forceInput": True}),
"text": ("STRING", {"default": 'title', "multiline": False}),
"width": ("INT", {"default": 512, "min": 8, "max": 4096, "step": 8}),
"height": ("INT", {"default": 512, "min": 8, "max": 4096, "step": 8}),
"bbox_width": ("INT", {"default": 128, "min": 8, "max": 4096, "step": 8}),
"bbox_height": ("INT", {"default": 128, "min": 8, "max": 4096, "step": 8}),
},
"optional": {"size_multiplier": ("FLOAT", {"default": [1.0], "forceInput": True})},
}
RETURN_TYPES = ("IMAGE", "INT", "INT", "INT", "INT",)
RETURN_NAMES = ("images", "width", "height", "bbox_width", "bbox_height",)
FUNCTION = "append"
CATEGORY = "KJNodes/experimental"
DESCRIPTION = """
Plots coordinates to sequence of images using Matplotlib.
"""
def append(self, coordinates, text, width, height, bbox_width, bbox_height, size_multiplier=[1.0]):
coordinates = json.loads(coordinates.replace("'", '"'))
coordinates = [(coord['x'], coord['y']) for coord in coordinates]
batch_size = len(coordinates)
if len(size_multiplier) != batch_size:
size_multiplier = size_multiplier * (batch_size // len(size_multiplier)) + size_multiplier[:batch_size % len(size_multiplier)]
plot_image_tensor = plot_coordinates_to_tensor(coordinates, height, width, bbox_height, bbox_width, size_multiplier, text)
return (plot_image_tensor, width, height, bbox_width, bbox_height)
class SplineEditor:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"points_store": ("STRING", {"multiline": False}),
"coordinates": ("STRING", {"multiline": False}),
"mask_width": ("INT", {"default": 512, "min": 8, "max": 4096, "step": 8}),
"mask_height": ("INT", {"default": 512, "min": 8, "max": 4096, "step": 8}),
"points_to_sample": ("INT", {"default": 16, "min": 2, "max": 1000, "step": 1}),
"sampling_method": (
[
'path',
'time',
],
{
"default": 'time'
}),
"interpolation": (
[
'cardinal',
'monotone',
'basis',
'linear',
'step-before',
'step-after',
'polar',
'polar-reverse',
],
{
"default": 'cardinal'
}),
"tension": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
"repeat_output": ("INT", {"default": 1, "min": 1, "max": 4096, "step": 1}),
"float_output_type": (
[
'list',
'pandas series',
'tensor',
],
{
"default": 'list'
}),
},
"optional": {
"min_value": ("FLOAT", {"default": 0.0, "min": -10000.0, "max": 10000.0, "step": 0.01}),
"max_value": ("FLOAT", {"default": 1.0, "min": -10000.0, "max": 10000.0, "step": 0.01}),
}
}
RETURN_TYPES = ("MASK", "STRING", "FLOAT", "INT")
RETURN_NAMES = ("mask", "coord_str", "float", "count")
FUNCTION = "splinedata"
CATEGORY = "KJNodes/weights"
DESCRIPTION = """
# WORK IN PROGRESS
Do not count on this as part of your workflow yet,
probably contains lots of bugs and stability is not
guaranteed!!
## Graphical editor to create values for various
## schedules and/or mask batches.
**Shift + click** to add control point at end.
**Ctrl + click** to add control point (subdivide) between two points.
**Right click on a point** to delete it.
Note that you can't delete from start/end.
Right click on canvas for context menu:
These are purely visual options, doesn't affect the output:
- Toggle handles visibility
- Display sample points: display the points to be returned.
**points_to_sample** value sets the number of samples
returned from the **drawn spline itself**, this is independent from the
actual control points, so the interpolation type matters.
sampling_method:
- time: samples along the time axis, used for schedules
- path: samples along the path itself, useful for coordinates
output types:
- mask batch
example compatible nodes: anything that takes masks
- list of floats
example compatible nodes: IPAdapter weights
- pandas series
example compatible nodes: anything that takes Fizz'
nodes Batch Value Schedule
- torch tensor
example compatible nodes: unknown
"""
def splinedata(self, mask_width, mask_height, coordinates, float_output_type, interpolation,
points_to_sample, sampling_method, points_store, tension, repeat_output, min_value=0.0, max_value=1.0):
coordinates = json.loads(coordinates)
for coord in coordinates:
coord['x'] = int(round(coord['x']))
coord['y'] = int(round(coord['y']))
normalized_y_values = [
(1.0 - (point['y'] / mask_height) - 0.0) * (max_value - min_value) + min_value
for point in coordinates
]
if float_output_type == 'list':
out_floats = normalized_y_values * repeat_output
elif float_output_type == 'pandas series':
try:
import pandas as pd
except:
raise Exception("MaskOrImageToWeight: pandas is not installed. Please install pandas to use this output_type")
out_floats = pd.Series(normalized_y_values * repeat_output),
elif float_output_type == 'tensor':
out_floats = torch.tensor(normalized_y_values * repeat_output, dtype=torch.float32)
# Create a color map for grayscale intensities
color_map = lambda y: torch.full((mask_height, mask_width, 3), y, dtype=torch.float32)
# Create image tensors for each normalized y value
mask_tensors = [color_map(y) for y in normalized_y_values]
masks_out = torch.stack(mask_tensors)
masks_out = masks_out.repeat(repeat_output, 1, 1, 1)
masks_out = masks_out.mean(dim=-1)
return (masks_out, str(coordinates), out_floats, len(out_floats))
class CreateShapeMaskOnPath:
RETURN_TYPES = ("MASK", "MASK",)
RETURN_NAMES = ("mask", "mask_inverted",)
FUNCTION = "createshapemask"
CATEGORY = "KJNodes/masking/generate"
DESCRIPTION = """
Creates a mask or batch of masks with the specified shape.
Locations are center locations.
Grow value is the amount to grow the shape on each frame, creating animated masks.
"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"shape": (
[ 'circle',
'square',
'triangle',
],
{
"default": 'circle'
}),
"coordinates": ("STRING", {"forceInput": True}),
"frame_width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
"frame_height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
"shape_width": ("INT", {"default": 128,"min": 8, "max": 4096, "step": 1}),
"shape_height": ("INT", {"default": 128,"min": 8, "max": 4096, "step": 1}),
},
"optional": {
"size_multiplier": ("FLOAT", {"default": [1.0], "forceInput": True}),
}
}
def createshapemask(self, coordinates, frame_width, frame_height, shape_width, shape_height, shape, size_multiplier=[1.0]):
# Define the number of images in the batch
coordinates = coordinates.replace("'", '"')
coordinates = json.loads(coordinates)
batch_size = len(coordinates)
out = []
color = "white"
if len(size_multiplier) != batch_size:
size_multiplier = size_multiplier * (batch_size // len(size_multiplier)) + size_multiplier[:batch_size % len(size_multiplier)]
for i, coord in enumerate(coordinates):
image = Image.new("RGB", (frame_width, frame_height), "black")
draw = ImageDraw.Draw(image)
# Calculate the size for this frame and ensure it's not less than 0
current_width = max(0, shape_width + i * size_multiplier[i])
current_height = max(0, shape_height + i * size_multiplier[i])
location_x = coord['x']
location_y = coord['y']
if shape == 'circle' or shape == 'square':
# Define the bounding box for the shape
left_up_point = (location_x - current_width // 2, location_y - current_height // 2)
right_down_point = (location_x + current_width // 2, location_y + current_height // 2)
two_points = [left_up_point, right_down_point]
if shape == 'circle':
draw.ellipse(two_points, fill=color)
elif shape == 'square':
draw.rectangle(two_points, fill=color)
elif shape == 'triangle':
# Define the points for the triangle
left_up_point = (location_x - current_width // 2, location_y + current_height // 2) # bottom left
right_down_point = (location_x + current_width // 2, location_y + current_height // 2) # bottom right
top_point = (location_x, location_y - current_height // 2) # top point
draw.polygon([top_point, left_up_point, right_down_point], fill=color)
image = pil2tensor(image)
mask = image[:, :, :, 0]
out.append(mask)
outstack = torch.cat(out, dim=0)
return (outstack, 1.0 - outstack,)
class MaskOrImageToWeight:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"output_type": (
[
'list',
'pandas series',
'tensor',
'string'
],
{
"default": 'list'
}),
},
"optional": {
"images": ("IMAGE",),
"masks": ("MASK",),
},
}
RETURN_TYPES = ("FLOAT", "STRING",)
FUNCTION = "execute"
CATEGORY = "KJNodes/weights"
DESCRIPTION = """
Gets the mean values from mask or image batch
and returns that as the selected output type.
"""
def execute(self, output_type, images=None, masks=None):
mean_values = []
if masks is not None and images is None:
for mask in masks:
mean_values.append(mask.mean().item())
elif masks is None and images is not None:
for image in images:
mean_values.append(image.mean().item())
elif masks is not None and images is not None:
raise Exception("MaskOrImageToWeight: Use either mask or image input only.")
# Convert mean_values to the specified output_type
if output_type == 'list':
out = mean_values,
elif output_type == 'pandas series':
try:
import pandas as pd
except:
raise Exception("MaskOrImageToWeight: pandas is not installed. Please install pandas to use this output_type")
out = pd.Series(mean_values),
elif output_type == 'tensor':
out = torch.tensor(mean_values, dtype=torch.float32),
return (out, [str(value) for value in mean_values],)
class WeightScheduleConvert:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"input_values": ("FLOAT", {"default": 0.0, "forceInput": True}),
"output_type": (
[
'match_input',
'list',
'pandas series',
'tensor',
],
{
"default": 'list'
}),
"invert": ("BOOLEAN", {"default": False}),
"repeat": ("INT", {"default": 1,"min": 1, "max": 255, "step": 1}),
},
"optional": {
"remap_to_frames": ("INT", {"default": 0}),
"interpolation_curve": ("FLOAT", {"forceInput": True}),
"remap_values": ("BOOLEAN", {"default": False}),
"remap_min": ("FLOAT", {"default": 0.0, "min": -100000, "max": 100000.0, "step": 0.01}),
"remap_max": ("FLOAT", {"default": 1.0, "min": -100000, "max": 100000.0, "step": 0.01}),
},
}
RETURN_TYPES = ("FLOAT", "STRING", "INT",)
FUNCTION = "execute"
CATEGORY = "KJNodes/weights"
DESCRIPTION = """
Converts different value lists/series to another type.
"""
def detect_input_type(self, input_values):
import pandas as pd
if isinstance(input_values, list):
return 'list'
elif isinstance(input_values, pd.Series):
return 'pandas series'
elif isinstance(input_values, torch.Tensor):
return 'tensor'
else:
raise ValueError("Unsupported input type")
def execute(self, input_values, output_type, invert, repeat, remap_to_frames=0, interpolation_curve=None, remap_min=0.0, remap_max=1.0, remap_values=False):
import pandas as pd
input_type = self.detect_input_type(input_values)
if input_type == 'pandas series':
float_values = input_values.tolist()
elif input_type == 'tensor':
float_values = input_values
else:
float_values = input_values
if invert:
float_values = [1 - value for value in float_values]
if interpolation_curve is not None:
interpolated_pattern = []
orig_float_values = float_values
for value in interpolation_curve:
min_val = min(orig_float_values)
max_val = max(orig_float_values)
# Normalize the values to [0, 1]
normalized_values = [(value - min_val) / (max_val - min_val) for value in orig_float_values]
# Interpolate the normalized values to the new frame count
remapped_float_values = np.interp(np.linspace(0, 1, int(remap_to_frames * value)), np.linspace(0, 1, len(normalized_values)), normalized_values).tolist()
interpolated_pattern.extend(remapped_float_values)
float_values = interpolated_pattern
else:
# Remap float_values to match target_frame_amount
if remap_to_frames > 0 and remap_to_frames != len(float_values):
min_val = min(float_values)
max_val = max(float_values)
# Normalize the values to [0, 1]
normalized_values = [(value - min_val) / (max_val - min_val) for value in float_values]
# Interpolate the normalized values to the new frame count
float_values = np.interp(np.linspace(0, 1, remap_to_frames), np.linspace(0, 1, len(normalized_values)), normalized_values).tolist()
float_values = float_values * repeat
if remap_values:
float_values = self.remap_values(float_values, remap_min, remap_max)
if output_type == 'list':
out = float_values,
elif output_type == 'pandas series':
out = pd.Series(float_values),
elif output_type == 'tensor':
if input_type == 'pandas series':
out = torch.tensor(float_values.values, dtype=torch.float32),
else:
out = torch.tensor(float_values, dtype=torch.float32),
elif output_type == 'match_input':
out = float_values,
return (out, [str(value) for value in float_values], [int(value) for value in float_values])
def remap_values(self, values, target_min, target_max):
# Determine the current range
current_min = min(values)
current_max = max(values)
current_range = current_max - current_min
# Determine the target range
target_range = target_max - target_min
# Perform the linear interpolation for each value
remapped_values = [(value - current_min) / current_range * target_range + target_min for value in values]
return remapped_values
class FloatToMask:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"input_values": ("FLOAT", {"forceInput": True, "default": 0}),
"width": ("INT", {"default": 100, "min": 1}),
"height": ("INT", {"default": 100, "min": 1}),
},
}
RETURN_TYPES = ("MASK",)
FUNCTION = "execute"
CATEGORY = "KJNodes/masking/generate"
DESCRIPTION = """
Generates a batch of masks based on the input float values.
The batch size is determined by the length of the input float values.
Each mask is generated with the specified width and height.
"""
def execute(self, input_values, width, height):
import pandas as pd
# Ensure input_values is a list
if isinstance(input_values, (float, int)):
input_values = [input_values]
elif isinstance(input_values, pd.Series):
input_values = input_values.tolist()
elif isinstance(input_values, list) and all(isinstance(item, list) for item in input_values):
input_values = [item for sublist in input_values for item in sublist]
# Generate a batch of masks based on the input_values
masks = []
for value in input_values:
# Assuming value is a float between 0 and 1 representing the mask's intensity
mask = torch.ones((height, width), dtype=torch.float32) * value
masks.append(mask)
masks_out = torch.stack(masks, dim=0)
return(masks_out,)
class WeightScheduleExtend:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"input_values_1": ("FLOAT", {"default": 0.0, "forceInput": True}),
"input_values_2": ("FLOAT", {"default": 0.0, "forceInput": True}),
"output_type": (
[
'match_input',
'list',
'pandas series',
'tensor',
],
{
"default": 'match_input'
}),
},
}
RETURN_TYPES = ("FLOAT",)
FUNCTION = "execute"
CATEGORY = "KJNodes/weights"
DESCRIPTION = """
Extends, and converts if needed, different value lists/series
"""
def detect_input_type(self, input_values):
import pandas as pd
if isinstance(input_values, list):
return 'list'
elif isinstance(input_values, pd.Series):
return 'pandas series'
elif isinstance(input_values, torch.Tensor):
return 'tensor'
else:
raise ValueError("Unsupported input type")
def execute(self, input_values_1, input_values_2, output_type):
import pandas as pd
input_type_1 = self.detect_input_type(input_values_1)
input_type_2 = self.detect_input_type(input_values_2)
# Convert input_values_2 to the same format as input_values_1 if they do not match
if not input_type_1 == input_type_2:
print("Converting input_values_2 to the same format as input_values_1")
if input_type_1 == 'pandas series':
# Convert input_values_2 to a pandas Series
float_values_2 = pd.Series(input_values_2)
elif input_type_1 == 'tensor':
# Convert input_values_2 to a tensor
float_values_2 = torch.tensor(input_values_2, dtype=torch.float32)
else:
print("Input types match, no conversion needed")
# If the types match, no conversion is needed
float_values_2 = input_values_2
float_values = input_values_1 + float_values_2
if output_type == 'list':
return float_values,
elif output_type == 'pandas series':
return pd.Series(float_values),
elif output_type == 'tensor':
if input_type_1 == 'pandas series':
return torch.tensor(float_values.values, dtype=torch.float32),
else:
return torch.tensor(float_values, dtype=torch.float32),
elif output_type == 'match_input':
return float_values,
else:
raise ValueError(f"Unsupported output_type: {output_type}")
class FloatToSigmas:
@classmethod
def INPUT_TYPES(s):
return {"required":
{
"float_list": ("FLOAT", {"default": 0.0, "forceInput": True}),
}
}
RETURN_TYPES = ("SIGMAS",)
RETURN_NAMES = ("SIGMAS",)
CATEGORY = "KJNodes/noise"
FUNCTION = "customsigmas"
DESCRIPTION = """
Creates a sigmas tensor from list of float values.
"""
def customsigmas(self, float_list):
return torch.tensor(float_list, dtype=torch.float32),
class GLIGENTextBoxApplyBatchCoords:
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning_to": ("CONDITIONING", ),
"latents": ("LATENT", ),
"clip": ("CLIP", ),
"gligen_textbox_model": ("GLIGEN", ),
"coordinates": ("STRING", {"forceInput": True}),
"text": ("STRING", {"multiline": True}),
"width": ("INT", {"default": 128, "min": 8, "max": 4096, "step": 8}),
"height": ("INT", {"default": 128, "min": 8, "max": 4096, "step": 8}),
},
"optional": {"size_multiplier": ("FLOAT", {"default": [1.0], "forceInput": True})},
}
RETURN_TYPES = ("CONDITIONING", "IMAGE", )
RETURN_NAMES = ("conditioning", "coord_preview", )
FUNCTION = "append"
CATEGORY = "KJNodes/experimental"
DESCRIPTION = """
This node allows scheduling GLIGEN text box positions in a batch,
to be used with AnimateDiff-Evolved. Intended to pair with the
Spline Editor -node.
GLIGEN model can be downloaded through the Manage's "Install Models" menu.
Or directly from here:
https://huggingface.co/comfyanonymous/GLIGEN_pruned_safetensors/tree/main
Inputs:
- **latents** input is used to calculate batch size
- **clip** is your standard text encoder, use same as for the main prompt
- **gligen_textbox_model** connects to GLIGEN Loader
- **coordinates** takes a json string of points, directly compatible
with the spline editor node.
- **text** is the part of the prompt to set position for
- **width** and **height** are the size of the GLIGEN bounding box
Outputs:
- **conditioning** goes between to clip text encode and the sampler
- **coord_preview** is an optional preview of the coordinates and
bounding boxes.
"""
def append(self, latents, coordinates, conditioning_to, clip, gligen_textbox_model, text, width, height, size_multiplier=[1.0]):
coordinates = json.loads(coordinates.replace("'", '"'))
coordinates = [(coord['x'], coord['y']) for coord in coordinates]
batch_size = sum(tensor.size(0) for tensor in latents.values())
if len(coordinates) != batch_size:
print("GLIGENTextBoxApplyBatchCoords WARNING: The number of coordinates does not match the number of latents")
c = []
_, cond_pooled = clip.encode_from_tokens(clip.tokenize(text), return_pooled=True)
for t in conditioning_to:
n = [t[0], t[1].copy()]
position_params_batch = [[] for _ in range(batch_size)] # Initialize a list of empty lists for each batch item
if len(size_multiplier) != batch_size:
size_multiplier = size_multiplier * (batch_size // len(size_multiplier)) + size_multiplier[:batch_size % len(size_multiplier)]
for i in range(batch_size):
x_position, y_position = coordinates[i]
position_param = (cond_pooled, int((height // 8) * size_multiplier[i]), int((width // 8) * size_multiplier[i]), (y_position - height // 2) // 8, (x_position - width // 2) // 8)
position_params_batch[i].append(position_param) # Append position_param to the correct sublist
prev = []
if "gligen" in n[1]:
prev = n[1]['gligen'][2]
else:
prev = [[] for _ in range(batch_size)]
# Concatenate prev and position_params_batch, ensuring both are lists of lists
# and each sublist corresponds to a batch item
combined_position_params = [prev_item + batch_item for prev_item, batch_item in zip(prev, position_params_batch)]
n[1]['gligen'] = ("position_batched", gligen_textbox_model, combined_position_params)
c.append(n)
image_height = latents['samples'].shape[-2] * 8
image_width = latents['samples'].shape[-1] * 8
plot_image_tensor = plot_coordinates_to_tensor(coordinates, image_height, image_width, height, width, size_multiplier, text)
return (c, plot_image_tensor,)
class CreateInstanceDiffusionTracking:
RETURN_TYPES = ("TRACKING", "STRING", "INT", "INT", "INT", "INT",)
RETURN_NAMES = ("tracking", "prompt", "width", "height", "bbox_width", "bbox_height",)
FUNCTION = "tracking"
CATEGORY = "KJNodes/InstanceDiffusion"
DESCRIPTION = """
Creates tracking data to be used with InstanceDiffusion:
https://github.com/logtd/ComfyUI-InstanceDiffusion
InstanceDiffusion prompt format:
"class_id.class_name": "prompt",
for example:
"1.head": "((head))",
"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"coordinates": ("STRING", {"forceInput": True}),
"width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
"height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
"bbox_width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
"bbox_height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
"class_name": ("STRING", {"default": "class_name"}),
"class_id": ("INT", {"default": 0,"min": 0, "max": 255, "step": 1}),
"prompt": ("STRING", {"default": "prompt", "multiline": True}),
},
"optional": {
"size_multiplier": ("FLOAT", {"default": [1.0], "forceInput": True}),
}
}
def tracking(self, coordinates, class_name, class_id, width, height, bbox_width, bbox_height, prompt, size_multiplier=[1.0]):
# Define the number of images in the batch
coordinates = coordinates.replace("'", '"')
coordinates = json.loads(coordinates)
tracked = {}
tracked[class_name] = {}
batch_size = len(coordinates)
# Initialize a list to hold the coordinates for the current ID
id_coordinates = []
if len(size_multiplier) != batch_size:
size_multiplier = size_multiplier * (batch_size // len(size_multiplier)) + size_multiplier[:batch_size % len(size_multiplier)]
for i, coord in enumerate(coordinates):
x = coord['x']
y = coord['y']
adjusted_bbox_width = bbox_width * size_multiplier[i]
adjusted_bbox_height = bbox_height * size_multiplier[i]
# Calculate the top left and bottom right coordinates
top_left_x = x - adjusted_bbox_width // 2
top_left_y = y - adjusted_bbox_height // 2
bottom_right_x = x + adjusted_bbox_width // 2
bottom_right_y = y + adjusted_bbox_height // 2
# Append the top left and bottom right coordinates to the list for the current ID
id_coordinates.append([top_left_x, top_left_y, bottom_right_x, bottom_right_y, width, height])
class_id = int(class_id)
# Assign the list of coordinates to the specified ID within the class_id dictionary
tracked[class_name][class_id] = id_coordinates
prompt_string = ""
for class_name, class_data in tracked.items():
for class_id in class_data.keys():
class_id_str = str(class_id)
# Use the incoming prompt for each class name and ID
prompt_string += f'"{class_id_str}.{class_name}": "({prompt})",\n'
# Remove the last comma and newline
prompt_string = prompt_string.rstrip(",\n")
return (tracked, prompt_string, width, height, bbox_width, bbox_height)
class AppendInstanceDiffusionTracking:
RETURN_TYPES = ("TRACKING", "STRING",)
RETURN_NAMES = ("tracking", "prompt",)
FUNCTION = "append"
CATEGORY = "KJNodes/InstanceDiffusion"
DESCRIPTION = """
Appends tracking data to be used with InstanceDiffusion:
https://github.com/logtd/ComfyUI-InstanceDiffusion
"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"tracking_1": ("TRACKING", {"forceInput": True}),
"tracking_2": ("TRACKING", {"forceInput": True}),
},
"optional": {
"prompt_1": ("STRING", {"default": "", "forceInput": True}),
"prompt_2": ("STRING", {"default": "", "forceInput": True}),
}
}
def append(self, tracking_1, tracking_2, prompt_1="", prompt_2=""):
tracking_copy = tracking_1.copy()
# Check for existing class names and class IDs, and raise an error if they exist
for class_name, class_data in tracking_2.items():
if class_name not in tracking_copy:
tracking_copy[class_name] = class_data
else:
# If the class name exists, merge the class data from tracking_2 into tracking_copy
# This will add new class IDs under the same class name without raising an error
tracking_copy[class_name].update(class_data)
prompt_string = prompt_1 + "," + prompt_2
return (tracking_copy, prompt_string)
class InterpolateCoords:
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("coordinates",)
FUNCTION = "interpolate"
CATEGORY = "KJNodes/experimental"
DESCRIPTION = """
Interpolates coordinates based on a curve.
"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"coordinates": ("STRING", {"forceInput": True}),
"interpolation_curve": ("FLOAT", {"forceInput": True}),
},
}
def interpolate(self, coordinates, interpolation_curve):
# Parse the JSON string to get the list of coordinates
coordinates = json.loads(coordinates.replace("'", '"'))
# Convert the list of dictionaries to a list of (x, y) tuples for easier processing
coordinates = [(coord['x'], coord['y']) for coord in coordinates]
# Calculate the total length of the original path
path_length = sum(np.linalg.norm(np.array(coordinates[i]) - np.array(coordinates[i-1])) for i in range(1, len(coordinates)))
# Normalize the interpolation curve
normalized_curve = [x / path_length for x in interpolation_curve]
# Initialize variables for interpolation
interpolated_coords = []
current_length = 0
current_index = 1
# Iterate over the normalized curve
for target_length in normalized_curve:
target_length *= path_length # Convert back to the original scale
while current_length < target_length and current_index < len(coordinates):
segment_length = np.linalg.norm(np.array(coordinates[current_index]) - np.array(coordinates[current_index-1]))
current_length += segment_length
current_index += 1
# Interpolate between the last two points
if current_index == 1:
interpolated_coords.append(coordinates[0])
else:
p1, p2 = np.array(coordinates[current_index-2]), np.array(coordinates[current_index-1])
segment_length = np.linalg.norm(p2 - p1)
if segment_length > 0:
t = (target_length - (current_length - segment_length)) / segment_length
interpolated_point = p1 + t * (p2 - p1)
interpolated_coords.append(interpolated_point.tolist())
else:
interpolated_coords.append(p1.tolist())
# Convert back to string format if necessary
interpolated_coords_str = "[" + ", ".join([f"{{'x': {round(coord[0])}, 'y': {round(coord[1])}}}" for coord in interpolated_coords]) + "]"
return (interpolated_coords_str, )
class DrawInstanceDiffusionTracking:
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("image", )
FUNCTION = "draw"
CATEGORY = "KJNodes/InstanceDiffusion"
DESCRIPTION = """
Draws the tracking data from
CreateInstanceDiffusionTracking -node.
"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE", ),
"tracking": ("TRACKING", {"forceInput": True}),
"box_line_width": ("INT", {"default": 2, "min": 1, "max": 10, "step": 1}),
"draw_text": ("BOOLEAN", {"default": True}),
"font": (folder_paths.get_filename_list("kjnodes_fonts"), ),
"font_size": ("INT", {"default": 20}),
},
}
def draw(self, image, tracking, box_line_width, draw_text, font, font_size):
import matplotlib.cm as cm
modified_images = []
colormap = cm.get_cmap('rainbow', len(tracking))
if draw_text:
#font = ImageFont.load_default()
font = ImageFont.truetype("arial.ttf", font_size)
# Iterate over each image in the batch
for i in range(image.shape[0]):
# Extract the current image and convert it to a PIL image
# Adjust the tensor to (C, H, W) for ToPILImage
current_image = image[i, :, :, :].permute(2, 0, 1)
pil_image = transforms.ToPILImage()(current_image)
draw = ImageDraw.Draw(pil_image)
# Iterate over the bounding boxes for the current image
for j, (class_name, class_data) in enumerate(tracking.items()):
for class_id, bbox_list in class_data.items():
# Check if the current index is within the bounds of the bbox_list
if i < len(bbox_list):
bbox = bbox_list[i]
# Ensure bbox is a list or tuple before unpacking
if isinstance(bbox, (list, tuple)):
x1, y1, x2, y2, _, _ = bbox
# Convert coordinates to integers
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
# Generate a color from the rainbow colormap
color = tuple(int(255 * x) for x in colormap(j / len(tracking)))[:3]
# Draw the bounding box on the image with the generated color
draw.rectangle([x1, y1, x2, y2], outline=color, width=box_line_width)
if draw_text:
# Draw the class name and ID as text above the box with the generated color
text = f"{class_id}.{class_name}"
# Calculate the width and height of the text
_, _, text_width, text_height = draw.textbbox((0, 0), text=text, font=font)
# Position the text above the top-left corner of the box
text_position = (x1, y1 - text_height)
draw.text(text_position, text, fill=color, font=font)
else:
print(f"Unexpected data type for bbox: {type(bbox)}")
# Convert the drawn image back to a torch tensor and adjust back to (H, W, C)
modified_image_tensor = transforms.ToTensor()(pil_image).permute(1, 2, 0)
modified_images.append(modified_image_tensor)
# Stack the modified images back into a batch
image_tensor_batch = torch.stack(modified_images).cpu().float()
return image_tensor_batch,

1076
nodes/image_nodes.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,115 @@
import folder_paths
import os
import torch
import torch.nn.functional as F
from comfy.utils import ProgressBar, load_torch_file
import comfy.sample
from nodes import CLIPTextEncode
script_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
folder_paths.add_model_folder_path("intristic_loras", os.path.join(script_directory, "intristic_loras"))
class Intrinsic_lora_sampling:
def __init__(self):
self.loaded_lora = None
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"lora_name": (folder_paths.get_filename_list("intristic_loras"), ),
"task": (
[
'depth map',
'surface normals',
'albedo',
'shading',
],
{
"default": 'depth map'
}),
"text": ("STRING", {"multiline": True, "default": ""}),
"clip": ("CLIP", ),
"vae": ("VAE", ),
"per_batch": ("INT", {"default": 16, "min": 1, "max": 4096, "step": 1}),
},
"optional": {
"image": ("IMAGE",),
"optional_latent": ("LATENT",),
},
}
RETURN_TYPES = ("IMAGE", "LATENT",)
FUNCTION = "onestepsample"
CATEGORY = "KJNodes"
DESCRIPTION = """
Sampler to use the intrinsic loras:
https://github.com/duxiaodan/intrinsic-lora
These LoRAs are tiny and thus included
with this node pack.
"""
def onestepsample(self, model, lora_name, clip, vae, text, task, per_batch, image=None, optional_latent=None):
pbar = ProgressBar(3)
if optional_latent is None:
image_list = []
for start_idx in range(0, image.shape[0], per_batch):
sub_pixels = vae.vae_encode_crop_pixels(image[start_idx:start_idx+per_batch])
image_list.append(vae.encode(sub_pixels[:,:,:,:3]))
sample = torch.cat(image_list, dim=0)
else:
sample = optional_latent["samples"]
noise = torch.zeros(sample.size(), dtype=sample.dtype, layout=sample.layout, device="cpu")
prompt = task + "," + text
positive, = CLIPTextEncode.encode(self, clip, prompt)
negative = positive #negative shouldn't do anything in this scenario
pbar.update(1)
#custom model sampling to pass latent through as it is
class X0_PassThrough(comfy.model_sampling.EPS):
def calculate_denoised(self, sigma, model_output, model_input):
return model_output
def calculate_input(self, sigma, noise):
return noise
sampling_base = comfy.model_sampling.ModelSamplingDiscrete
sampling_type = X0_PassThrough
class ModelSamplingAdvanced(sampling_base, sampling_type):
pass
model_sampling = ModelSamplingAdvanced(model.model.model_config)
#load lora
model_clone = model.clone()
lora_path = folder_paths.get_full_path("intristic_loras", lora_name)
lora = load_torch_file(lora_path, safe_load=True)
self.loaded_lora = (lora_path, lora)
model_clone_with_lora = comfy.sd.load_lora_for_models(model_clone, None, lora, 1.0, 0)[0]
model_clone_with_lora.add_object_patch("model_sampling", model_sampling)
samples = {"samples": comfy.sample.sample(model_clone_with_lora, noise, 1, 1.0, "euler", "simple", positive, negative, sample,
denoise=1.0, disable_noise=True, start_step=0, last_step=1,
force_full_denoise=True, noise_mask=None, callback=None, disable_pbar=True, seed=None)}
pbar.update(1)
decoded = []
for start_idx in range(0, samples["samples"].shape[0], per_batch):
decoded.append(vae.decode(samples["samples"][start_idx:start_idx+per_batch]))
image_out = torch.cat(decoded, dim=0)
pbar.update(1)
if task == 'depth map':
imax = image_out.max()
imin = image_out.min()
image_out = (image_out-imin)/(imax-imin)
image_out = torch.max(image_out, dim=3, keepdim=True)[0].repeat(1, 1, 1, 3)
elif task == 'surface normals':
image_out = F.normalize(image_out * 2 - 1, dim=3) / 2 + 0.5
image_out = 1.0 - image_out
else:
image_out = image_out.clamp(-1.,1.)
return (image_out, samples,)

1166
nodes/mask_nodes.py Normal file

File diff suppressed because it is too large Load Diff

1632
nodes/nodes.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1,27 +1,23 @@
import { app } from "../../../scripts/app.js";
app.registerExtension({
name: "KJNodes.appearance",
nodeCreated(node) {
const title = node.getTitle();
switch (title) {
case "INT Constant":
switch (node.comfyClass) {
case "INTConstant":
node.setSize([200, 58]);
node.color = "#1b4669";
node.bgcolor = "#29699c";
break;
case "Float Constant":
case "FloatConstant":
node.setSize([200, 58]);
node.color = LGraphCanvas.node_colors.green.color;
node.bgcolor = LGraphCanvas.node_colors.green.bgcolor;
break;
case "ConditioningMultiCombine":
node.color = LGraphCanvas.node_colors.brown.color;
node.bgcolor = LGraphCanvas.node_colors.brown.bgcolor;
break;
}
}
});

View File

@ -4,6 +4,9 @@ import { app } from "../../../scripts/app.js";
app.registerExtension({
name: "KJNodes.browserstatus",
setup() {
if (!app.ui.settings.getSettingValue("KJNodes.browserStatus")) {
return;
}
api.addEventListener("status", ({ detail }) => {
let title = "ComfyUI";
let favicon = "green";
@ -11,7 +14,6 @@ app.registerExtension({
if (queueRemaining) {
favicon = "red";
title = `00% - ${queueRemaining} | ${title}`;
}
let link = document.querySelector("link[rel~='icon']");
@ -22,9 +24,8 @@ app.registerExtension({
}
link.href = new URL(`../${favicon}.png`, import.meta.url);
document.title = title;
});
//add progress to the title
//add progress to the title
api.addEventListener("progress", ({ detail }) => {
const { value, max } = detail;
const progress = Math.floor((value / max) * 100);
@ -34,8 +35,19 @@ app.registerExtension({
const paddedProgress = String(progress).padStart(2, '0');
title = `${paddedProgress}% ${title.replace(/^\d+%\s/, '')}`;
}
document.title = title;
});
},
init() {
if (!app.ui.settings.getSettingValue("KJNodes.browserStatus")) {
return;
}
const pythongossFeed = app.extensions.find(
(e) => e.name === 'pysssss.FaviconStatus',
)
if (pythongossFeed) {
console.warn("KJNodes - Overriding pysssss.FaviconStatus")
app.extensions = app.extensions.filter(item => item !== pythongossFeed);
}
},
});

View File

@ -1,7 +1,5 @@
import { app } from "../../../scripts/app.js";
var nodeAutoColor = true
// Adds context menu entries, code partly from pyssssscustom-scripts
function addMenuHandler(nodeType, cb) {
@ -45,13 +43,9 @@ app.registerExtension({
content: "Add SetNode",
callback: () => {addNode("SetNode", this, { side:"right", offset: 30 });
},
});
});
}
},
async setup(app) {
const onChange = (value) => {
@ -144,5 +138,15 @@ app.registerExtension({
{ value: false, text: "Off", selected: value === false },
],
});
app.ui.settings.addSetting({
id: "KJNodes.browserStatus",
name: "🦛 KJNodes: 🟢 Stoplight browser status icon 🔴",
defaultValue: false,
type: "boolean",
options: (value) => [
{ value: true, text: "On", selected: value === true },
{ value: false, text: "Off", selected: value === false },
],
});
}
});

View File

@ -45,6 +45,7 @@ loadScript('/kjweb_async/purify.min.js').catch((e) => {
console.log(e)
})
const categories = ["KJNodes", "SUPIR", "VoiceCraft", "Marigold"];
app.registerExtension({
name: "KJNodes.HelpPopup",
async beforeRegisterNodeDef(nodeType, nodeData) {
@ -52,13 +53,12 @@ app.registerExtension({
if (app.ui.settings.getSettingValue("KJNodes.helpPopup") === false) {
return;
}
const categories = ["KJNodes", "SUPIR", "VoiceCraft", "Marigold"];
try {
categories.forEach(category => {
if (nodeData?.category?.startsWith(category)) {
addDocumentation(nodeData, nodeType);
}
else return
});
} catch (error) {
console.error("Error in registering KJNodes.HelpPopup", error);
@ -182,13 +182,16 @@ const create_documentation_stylesheet = () => {
let startX, startY, startWidth, startHeight
resizeHandle.addEventListener('mousedown', function (e) {
e.preventDefault();
e.stopPropagation();
isResizing = true;
startX = e.clientX;
startY = e.clientY;
startWidth = parseInt(document.defaultView.getComputedStyle(docElement).width, 10);
startHeight = parseInt(document.defaultView.getComputedStyle(docElement).height, 10);
});
},
{ signal: this.docCtrl.signal },
);
// close button
const closeButton = document.createElement('div');
@ -208,19 +211,30 @@ const create_documentation_stylesheet = () => {
this.show_doc = !this.show_doc
docElement.parentNode.removeChild(docElement)
docElement = null
});
if (contentWrapper) {
contentWrapper.remove()
contentWrapper = null
}
},
{ signal: this.docCtrl.signal },
);
document.addEventListener('mousemove', function (e) {
if (!isResizing) return;
const newWidth = startWidth + e.clientX - startX;
const newHeight = startHeight + e.clientY - startY;
const scale = app.canvas.ds.scale;
const newWidth = startWidth + (e.clientX - startX) / scale;
const newHeight = startHeight + (e.clientY - startY) / scale;;
docElement.style.width = `${newWidth}px`;
docElement.style.height = `${newHeight}px`;
});
},
{ signal: this.docCtrl.signal },
);
document.addEventListener('mouseup', function () {
isResizing = false
})
},
{ signal: this.docCtrl.signal },
)
document.body.appendChild(docElement)
}
@ -238,7 +252,7 @@ const create_documentation_stylesheet = () => {
const transform = new DOMMatrix()
.scaleSelf(scaleX, scaleY)
.multiplySelf(ctx.getTransform())
.translateSelf(this.size[0] * scaleX, 0)
.translateSelf(this.size[0] * scaleX * Math.max(1.0,window.devicePixelRatio) , 0)
.translateSelf(10, -32)
const scale = new DOMMatrix()
@ -283,8 +297,29 @@ const create_documentation_stylesheet = () => {
} else {
this.show_doc = !this.show_doc
}
if (this.show_doc) {
this.docCtrl = new AbortController()
} else {
this.docCtrl.abort()
}
return true;
}
return r;
}
const onRem = nodeType.prototype.onRemoved
nodeType.prototype.onRemoved = function () {
const r = onRem ? onRem.apply(this, []) : undefined
if (docElement) {
docElement.remove()
docElement = null
}
if (contentWrapper) {
contentWrapper.remove()
contentWrapper = null
}
return r
}
}

View File

@ -3,10 +3,12 @@ import { app } from "../../../scripts/app.js";
app.registerExtension({
name: "KJNodes.jsnodes",
async beforeRegisterNodeDef(nodeType, nodeData, app) {
if(!nodeData?.category?.startsWith("KJNodes")) {
return;
}
switch (nodeData.name) {
case "ConditioningMultiCombine":
nodeType.prototype.onNodeCreated = function () {
//this.inputs_offset = nodeData.name.includes("selective")?1:0
this.cond_type = "CONDITIONING"
this.inputs_offset = nodeData.name.includes("selective")?1:0
this.addWidget("button", "Update inputs", null, () => {
@ -24,9 +26,133 @@ app.registerExtension({
for(let i = this.inputs.length+1-this.inputs_offset; i <= target_number_of_inputs; ++i)
this.addInput(`conditioning_${i}`, this.cond_type)
}
});
});
}
break;
case "ImageBatchMulti":
nodeType.prototype.onNodeCreated = function () {
this._type = "IMAGE"
this.inputs_offset = nodeData.name.includes("selective")?1:0
this.addWidget("button", "Update inputs", null, () => {
if (!this.inputs) {
this.inputs = [];
}
const target_number_of_inputs = this.widgets.find(w => w.name === "inputcount")["value"];
if(target_number_of_inputs===this.inputs.length)return; // already set, do nothing
if(target_number_of_inputs < this.inputs.length){
for(let i = this.inputs.length; i>=this.inputs_offset+target_number_of_inputs; i--)
this.removeInput(i)
}
else{
for(let i = this.inputs.length+1-this.inputs_offset; i <= target_number_of_inputs; ++i)
this.addInput(`image_${i}`, this._type)
}
});
}
break;
case "MaskBatchMulti":
nodeType.prototype.onNodeCreated = function () {
this._type = "MASK"
this.inputs_offset = nodeData.name.includes("selective")?1:0
this.addWidget("button", "Update inputs", null, () => {
if (!this.inputs) {
this.inputs = [];
}
const target_number_of_inputs = this.widgets.find(w => w.name === "inputcount")["value"];
if(target_number_of_inputs===this.inputs.length)return; // already set, do nothing
if(target_number_of_inputs < this.inputs.length){
for(let i = this.inputs.length; i>=this.inputs_offset+target_number_of_inputs; i--)
this.removeInput(i)
}
else{
for(let i = this.inputs.length+1-this.inputs_offset; i <= target_number_of_inputs; ++i)
this.addInput(`mask_${i}`, this._type)
}
});
}
break;
case "GetMaskSizeAndCount":
const onGetMaskSizeConnectInput = nodeType.prototype.onConnectInput;
nodeType.prototype.onConnectInput = function (targetSlot, type, output, originNode, originSlot) {
const v = onGetMaskSizeConnectInput?.(this, arguments);
targetSlot.outputs[1]["name"] = "width"
targetSlot.outputs[2]["name"] = "height"
targetSlot.outputs[3]["name"] = "count"
return v;
}
const onGetMaskSizeExecuted = nodeType.prototype.onExecuted;
nodeType.prototype.onExecuted = function(message) {
const r = onGetMaskSizeExecuted? onGetMaskSizeExecuted.apply(this,arguments): undefined
let values = message["text"].toString().split('x').map(Number);
this.outputs[1]["name"] = values[1] + " width"
this.outputs[2]["name"] = values[2] + " height"
this.outputs[3]["name"] = values[0] + " count"
return r
}
break;
case "GetImageSizeAndCount":
const onGetImageSizeConnectInput = nodeType.prototype.onConnectInput;
nodeType.prototype.onConnectInput = function (targetSlot, type, output, originNode, originSlot) {
const v = onGetImageSizeConnectInput?.(this, arguments);
targetSlot.outputs[1]["name"] = "width"
targetSlot.outputs[2]["name"] = "height"
targetSlot.outputs[3]["name"] = "count"
return v;
}
const onGetImageSizeExecuted = nodeType.prototype.onExecuted;
nodeType.prototype.onExecuted = function(message) {
const r = onGetImageSizeExecuted? onGetImageSizeExecuted.apply(this,arguments): undefined
let values = message["text"].toString().split('x').map(Number);
this.outputs[1]["name"] = values[1] + " width"
this.outputs[2]["name"] = values[2] + " height"
this.outputs[3]["name"] = values[0] + " count"
return r
}
break;
case "VRAM_Debug":
const onVRAM_DebugConnectInput = nodeType.prototype.onConnectInput;
nodeType.prototype.onConnectInput = function (targetSlot, type, output, originNode, originSlot) {
const v = onVRAM_DebugConnectInput?.(this, arguments);
targetSlot.outputs[3]["name"] = "freemem_before"
targetSlot.outputs[4]["name"] = "freemem_after"
return v;
}
const onVRAM_DebugExecuted = nodeType.prototype.onExecuted;
nodeType.prototype.onExecuted = function(message) {
const r = onVRAM_DebugExecuted? onVRAM_DebugExecuted.apply(this,arguments): undefined
let values = message["text"].toString().split('x');
this.outputs[3]["name"] = values[0] + " freemem_before"
this.outputs[4]["name"] = values[1] + " freemem_after"
return r
}
break;
case "JoinStringMulti":
nodeType.prototype.onNodeCreated = function () {
this._type = "STRING"
this.inputs_offset = nodeData.name.includes("selective")?1:0
this.addWidget("button", "Update inputs", null, () => {
if (!this.inputs) {
this.inputs = [];
}
const target_number_of_inputs = this.widgets.find(w => w.name === "inputcount")["value"];
if(target_number_of_inputs===this.inputs.length)return; // already set, do nothing
if(target_number_of_inputs < this.inputs.length){
for(let i = this.inputs.length; i>=this.inputs_offset+target_number_of_inputs; i--)
this.removeInput(i)
}
else{
for(let i = this.inputs.length+1-this.inputs_offset; i <= target_number_of_inputs; ++i)
this.addInput(`string_${i}`, this._type)
}
});
}
break;
case "SoundReactive":
nodeType.prototype.onNodeCreated = function () {
let audioContext;
@ -130,6 +256,21 @@ app.registerExtension({
};
break;
}
}
},
async setup() {
// to keep Set/Get node virtual connections visible when offscreen
const originalComputeVisibleNodes = LGraphCanvas.prototype.computeVisibleNodes;
LGraphCanvas.prototype.computeVisibleNodes = function () {
const visibleNodesSet = new Set(originalComputeVisibleNodes.apply(this, arguments));
for (const node of this.graph._nodes) {
if ((node.type === "SetNode" || node.type === "GetNode") && node.drawConnection) {
visibleNodesSet.add(node);
}
}
return Array.from(visibleNodesSet);
};
}
});

View File

@ -1,30 +0,0 @@
import { app } from "../../../scripts/app.js";
//WIP doesn't do anything
app.registerExtension({
name: "KJNodes.PlotNode",
async beforeRegisterNodeDef(nodeType, nodeData, app) {
switch (nodeData.name) {
case "PlotNode":
nodeType.prototype.onNodeCreated = function () {
this.addWidget("button", "Update", null, () => {
console.log("start x:" + this.pos[0])
console.log("start y:" +this.pos[1])
console.log(this.graph.links);
const toNode = this.graph._nodes.find((otherNode) => otherNode.id == this.graph.links[1].target_id);
console.log("target x:" + toNode.pos[0])
const a = this.pos[0]
const b = toNode.pos[0]
const distance = Math.abs(a - b);
const maxDistance = 1000
const finalDistance = (distance - 0) / (maxDistance - 0);
this.widgets[0].value = finalDistance;
});
}
break;
}
},
});

View File

@ -1,5 +1,5 @@
import { app } from "../../../scripts/app.js";
import { ComfyWidgets } from '../../../scripts/widgets.js';
//based on diffus3's SetGet: https://github.com/diffus3/ComfyUI-extensions
// Nodes that allow you to tunnel connections for cleaner graphs
@ -21,8 +21,6 @@ function setColorAndBgColor(type) {
if (colors) {
this.color = colors.color;
this.bgcolor = colors.bgcolor;
} else {
// Handle the default case if needed
}
}
let isAlertShown = false;
@ -41,6 +39,12 @@ app.registerExtension({
class SetNode {
defaultVisibility = true;
serialize_widgets = true;
drawConnection = false;
currentGetters = null;
slotColor = "#FFF";
canvas = app.canvas;
menuEntry = "Show connections";
constructor() {
if (!this.properties) {
this.properties = {
@ -201,9 +205,11 @@ app.registerExtension({
return graph._nodes.filter(otherNode => otherNode.type === 'GetNode' && otherNode.widgets[0].value === name && name !== '');
}
// This node is purely frontend and does not impact the resulting prompt so should not be serialized
this.isVirtualNode = true;
}
onRemoved() {
const allGetters = this.graph._nodes.filter((otherNode) => otherNode.type == "GetNode");
@ -213,6 +219,136 @@ app.registerExtension({
}
})
}
getExtraMenuOptions(_, options) {
this.menuEntry = this.drawConnection ? "Hide connections" : "Show connections";
options.unshift(
{
content: this.menuEntry,
callback: () => {
this.currentGetters = this.findGetters(this.graph);
if (this.currentGetters.length == 0) return;
let linkType = (this.currentGetters[0].outputs[0].type);
this.slotColor = this.canvas.default_connection_color_byType[linkType]
this.menuEntry = this.drawConnection ? "Hide connections" : "Show connections";
this.drawConnection = !this.drawConnection;
this.canvas.setDirty(true, true);
},
has_submenu: true,
submenu: {
title: "Color",
options: [
{
content: "Highlight",
callback: () => {
this.slotColor = "orange"
this.canvas.setDirty(true, true);
}
}
],
},
},
{
content: "Hide all connections",
callback: () => {
const allGetters = this.graph._nodes.filter(otherNode => otherNode.type === "GetNode" || otherNode.type === "SetNode");
allGetters.forEach(otherNode => {
otherNode.drawConnection = false;
console.log(otherNode);
});
this.menuEntry = "Show connections";
this.drawConnection = false
this.canvas.setDirty(true, true);
},
},
);
// Dynamically add a submenu for all getters
this.currentGetters = this.findGetters(this.graph);
if (this.currentGetters) {
let gettersSubmenu = this.currentGetters.map(getter => ({
content: `${getter.title} id: ${getter.id}`,
callback: () => {
this.canvas.centerOnNode(getter);
this.canvas.selectNode(getter, false);
this.canvas.setDirty(true, true);
},
}));
options.unshift({
content: "Getters",
has_submenu: true,
submenu: {
title: "GetNodes",
options: gettersSubmenu,
}
});
}
}
onDrawForeground(ctx, lGraphCanvas) {
if (this.drawConnection) {
this._drawVirtualLinks(lGraphCanvas, ctx);
}
}
// onDrawCollapsed(ctx, lGraphCanvas) {
// if (this.drawConnection) {
// this._drawVirtualLinks(lGraphCanvas, ctx);
// }
// }
_drawVirtualLinks(lGraphCanvas, ctx) {
if (!this.currentGetters?.length) return;
var title = this.getTitle ? this.getTitle() : this.title;
var title_width = ctx.measureText(title).width;
if (!this.flags.collapsed) {
var start_node_slotpos = [
this.size[0],
LiteGraph.NODE_TITLE_HEIGHT * 0.5,
];
}
else {
var start_node_slotpos = [
title_width + 55,
-15,
];
}
for (const getter of this.currentGetters) {
if (!this.flags.collapsed) {
var end_node_slotpos = this.getConnectionPos(false, 0);
end_node_slotpos = [
getter.pos[0] - end_node_slotpos[0] + this.size[0],
getter.pos[1] - end_node_slotpos[1]
];
}
else {
var end_node_slotpos = this.getConnectionPos(false, 0);
end_node_slotpos = [
getter.pos[0] - end_node_slotpos[0] + title_width + 50,
getter.pos[1] - end_node_slotpos[1] - 30
];
}
lGraphCanvas.renderLink(
ctx,
start_node_slotpos,
end_node_slotpos,
null,
false,
null,
this.slotColor,
LiteGraph.RIGHT,
LiteGraph.LEFT
);
}
}
}
LiteGraph.registerNodeType(
@ -233,13 +369,16 @@ app.registerExtension({
defaultVisibility = true;
serialize_widgets = true;
drawConnection = false;
slotColor = "#FFF";
currentSetter = null;
canvas = app.canvas;
constructor() {
if (!this.properties) {
this.properties = {};
}
this.properties.showOutputText = GetNode.defaultVisibility;
const node = this;
this.addWidget(
"combo",
@ -266,7 +405,7 @@ app.registerExtension({
) {
this.validateLinks();
}
this.setName = function(name) {
node.widgets[0].value = name;
node.onRename();
@ -315,13 +454,20 @@ app.registerExtension({
this.findSetter = function(graph) {
const name = this.widgets[0].value;
return graph._nodes.find(otherNode => otherNode.type === 'SetNode' && otherNode.widgets[0].value === name && name !== '');
const foundNode = graph._nodes.find(otherNode => otherNode.type === 'SetNode' && otherNode.widgets[0].value === name && name !== '');
return foundNode;
};
this.goToSetter = function() {
const setter = this.findSetter(this.graph);
this.canvas.centerOnNode(setter);
this.canvas.selectNode(setter, false);
};
// This node is purely frontend and does not impact the resulting prompt so should not be serialized
this.isVirtualNode = true;
}
getInputLink(slot) {
const setter = this.findSetter(this.graph);
@ -337,6 +483,60 @@ app.registerExtension({
}
onAdded(graph) {
}
getExtraMenuOptions(_, options) {
let menuEntry = this.drawConnection ? "Hide connections" : "Show connections";
options.unshift(
{
content: "Go to setter",
callback: () => {
this.goToSetter();
},
},
{
content: menuEntry,
callback: () => {
this.currentSetter = this.findSetter(this.graph);
if (this.currentSetter.length == 0) return;
let linkType = (this.currentSetter.inputs[0].type);
this.drawConnection = !this.drawConnection;
this.slotColor = this.canvas.default_connection_color_byType[linkType]
menuEntry = this.drawConnection ? "Hide connections" : "Show connections";
this.canvas.setDirty(true, true);
},
},
);
}
onDrawForeground(ctx, lGraphCanvas) {
if (this.drawConnection) {
this._drawVirtualLink(lGraphCanvas, ctx);
}
}
// onDrawCollapsed(ctx, lGraphCanvas) {
// if (this.drawConnection) {
// this._drawVirtualLink(lGraphCanvas, ctx);
// }
// }
_drawVirtualLink(lGraphCanvas, ctx) {
if (!this.currentSetter) return;
let start_node_slotpos = this.currentSetter.getConnectionPos(false, 0);
start_node_slotpos = [
start_node_slotpos[0] - this.pos[0],
start_node_slotpos[1] - this.pos[1],
];
let end_node_slotpos = [0, -LiteGraph.NODE_TITLE_HEIGHT * 0.5];
lGraphCanvas.renderLink(
ctx,
start_node_slotpos,
end_node_slotpos,
null,
false,
null,
this.slotColor
);
}
}
LiteGraph.registerNodeType(

View File

@ -101,8 +101,9 @@ app.registerExtension({
name: 'KJNodes.SplineEditor',
async beforeRegisterNodeDef(nodeType, nodeData) {
if (nodeData?.name == 'SplineEditor') {
if (nodeData?.name === 'SplineEditor') {
chainCallback(nodeType.prototype, "onNodeCreated", function () {
hideWidgetForGood(this, this.widgets.find(w => w.name === "coordinates"))
var element = document.createElement("div");
@ -113,8 +114,64 @@ app.registerExtension({
serialize: false,
hideOnZoom: false,
});
// context menu
this.contextMenu = document.createElement("div");
this.contextMenu.id = "context-menu";
this.contextMenu.style.display = "none";
this.contextMenu.style.position = "absolute";
this.contextMenu.style.backgroundColor = "#202020";
this.contextMenu.style.minWidth = "100px";
this.contextMenu.style.boxShadow = "0px 8px 16px 0px rgba(0,0,0,0.2)";
this.contextMenu.style.zIndex = "100";
this.contextMenu.style.padding = "5px";
function styleMenuItem(menuItem) {
menuItem.style.display = "block";
menuItem.style.padding = "5px";
menuItem.style.color = "#FFF";
menuItem.style.fontFamily = "Arial, sans-serif";
menuItem.style.fontSize = "16px";
menuItem.style.textDecoration = "none";
menuItem.style.marginBottom = "5px";
}
this.menuItem1 = document.createElement("a");
this.menuItem1.href = "#";
this.menuItem1.id = "menu-item-1";
this.menuItem1.textContent = "Toggle handles";
styleMenuItem(this.menuItem1);
this.menuItem2 = document.createElement("a");
this.menuItem2.href = "#";
this.menuItem2.id = "menu-item-2";
this.menuItem2.textContent = "Display sample points";
styleMenuItem(this.menuItem2);
this.menuItem3 = document.createElement("a");
this.menuItem3.href = "#";
this.menuItem3.id = "menu-item-2";
this.menuItem3.textContent = "Switch point shape";
styleMenuItem(this.menuItem3);
const menuItems = [this.menuItem1, this.menuItem2, this.menuItem3];
menuItems.forEach(menuItem => {
menuItem.addEventListener('mouseover', function() {
this.style.backgroundColor = "gray";
});
menuItem.addEventListener('mouseout', function() {
this.style.backgroundColor = "#202020";
});
});
// Append menu items to the context menu
menuItems.forEach(menuItem => {
this.contextMenu.appendChild(menuItem);
});
document.body.appendChild( this.contextMenu);
this.addWidget("button", "New spline", null, () => {
if (!this.properties || !("points" in this.properties)) {
createSplineEditor(this)
this.addProperty("points", this.constructor.type, "string");
@ -123,22 +180,18 @@ app.registerExtension({
createSplineEditor(this, true)
}
});
this.setSize([550, 800])
this.setSize([550, 920]);
this.resizable = false;
this.splineEditor.parentEl = document.createElement("div");
this.splineEditor.parentEl.className = "spline-editor";
this.splineEditor.parentEl.id = `spline-editor-${this.uuid}`
element.appendChild(this.splineEditor.parentEl);
//disable context menu on right click
document.addEventListener('contextmenu', function(e) {
if (e.button === 2) { // Right mouse button
e.preventDefault();
e.stopPropagation();
}
})
chainCallback(this, "onGraphConfigured", function() {
createSplineEditor(this)
createSplineEditor(this);
});
}); // onAfterGraphConfigured
}//node created
} //before register
@ -147,22 +200,173 @@ app.registerExtension({
function createSplineEditor(context, reset=false) {
console.log("creatingSplineEditor")
document.addEventListener('contextmenu', function(e) {
e.preventDefault();
});
document.addEventListener('click', function(e) {
if (!context.contextMenu.contains(e.target)) {
context.contextMenu.style.display = 'none';
}
});
context.menuItem1.addEventListener('click', function(e) {
e.preventDefault();
if (!drawHandles) {
drawHandles = true
vis.add(pv.Line)
.data(() => points.map((point, index) => ({
start: point,
end: [index]
})))
.left(d => d.start.x)
.top(d => d.start.y)
.interpolate("linear")
.tension(0) // Straight lines
.strokeStyle("#ff7f0e") // Same color as control points
.lineWidth(1)
.visible(() => drawHandles);
vis.render();
} else {
drawHandles = false
vis.render();
}
context.contextMenu.style.display = 'none';
});
context.menuItem2.addEventListener('click', function(e) {
e.preventDefault();
drawSamplePoints = !drawSamplePoints;
updatePath();
});
context.menuItem3.addEventListener('click', function(e) {
e.preventDefault();
if (dotShape == "circle"){
dotShape = "triangle"
}
else {
dotShape = "circle"
}
console.log(dotShape)
updatePath();
});
var dotShape = "circle";
var drawSamplePoints = false;
function updatePath() {
let coords = samplePoints(pathElements[0], points_to_sample, samplingMethod, w);
if (drawSamplePoints) {
if (pointsLayer) {
// Update the data of the existing points layer
pointsLayer.data(coords);
} else {
// Create the points layer if it doesn't exist
pointsLayer = vis.add(pv.Dot)
.data(coords)
.left(function(d) { return d.x; })
.top(function(d) { return d.y; })
.radius(5) // Adjust the radius as needed
.fillStyle("red") // Change the color as needed
.strokeStyle("black") // Change the stroke color as needed
.lineWidth(1); // Adjust the line width as needed
}
} else {
if (pointsLayer) {
// Remove the points layer
pointsLayer.data([]);
vis.render();
}
}
let coordsString = JSON.stringify(coords);
pointsStoreWidget.value = JSON.stringify(points);
if (coordWidget) {
coordWidget.value = coordsString;
}
vis.render();
}
if (reset && context.splineEditor.element) {
context.splineEditor.element.innerHTML = ''; // Clear the container
}
}
const coordWidget = context.widgets.find(w => w.name === "coordinates");
const interpolationWidget = context.widgets.find(w => w.name === "interpolation");
const pointsWidget = context.widgets.find(w => w.name === "points_to_sample");
const pointsStoreWidget = context.widgets.find(w => w.name === "points_store");
const tensionWidget = context.widgets.find(w => w.name === "tension");
const segmentedWidget = context.widgets.find(w => w.name === "segmented");
const minValueWidget = context.widgets.find(w => w.name === "min_value");
const maxValueWidget = context.widgets.find(w => w.name === "max_value");
const samplingMethodWidget = context.widgets.find(w => w.name === "sampling_method");
const widthWidget = context.widgets.find(w => w.name === "mask_width");
const heightWidget = context.widgets.find(w => w.name === "mask_height");
//const segmentedWidget = context.widgets.find(w => w.name === "segmented");
var interpolation = interpolationWidget.value
var tension = tensionWidget.value
var points_to_sample = pointsWidget.value
var rangeMin = minValueWidget.value
var rangeMax = maxValueWidget.value
var pointsLayer = null;
var samplingMethod = samplingMethodWidget.value
if (samplingMethod == "path") {
dotShape = "triangle"
}
interpolationWidget.callback = () => {
interpolation = interpolationWidget.value
updatePath();
}
samplingMethodWidget.callback = () => {
samplingMethod = samplingMethodWidget.value
if (samplingMethod == "path") {
dotShape = "triangle"
}
updatePath();
}
tensionWidget.callback = () => {
tension = tensionWidget.value
updatePath();
}
pointsWidget.callback = () => {
points_to_sample = pointsWidget.value
updatePath();
}
minValueWidget.callback = () => {
rangeMin = minValueWidget.value
updatePath();
}
maxValueWidget.callback = () => {
rangeMax = maxValueWidget.value
updatePath();
}
widthWidget.callback = () => {
w = widthWidget.value
vis.width(w)
context.setSize([w + 45, context.size[1]]);
updatePath();
}
heightWidget.callback = () => {
h = heightWidget.value
vis.height(h)
context.setSize([context.size[0], h + 410]);
updatePath();
}
// Initialize or reset points array
var w = 512
var h = 512
var i = 3
var drawHandles = false;
var hoverIndex = -1;
var isDragging = false;
var w = widthWidget.value;
var h = heightWidget.value;
var i = 3;
let points = [];
if (!reset && pointsStoreWidget.value != "") {
points = JSON.parse(pointsStoreWidget.value);
} else {
@ -187,96 +391,246 @@ function createSplineEditor(context, reset=false) {
var vis = new pv.Panel()
.width(w)
.height(h)
.fillStyle("var(--comfy-menu-bg)")
.fillStyle("#222")
.strokeStyle("gray")
.lineWidth(2)
.antialias(false)
.margin(10)
.event("mousedown", function() {
if (pv.event.shiftKey) { // Use pv.event to access the event object
i = points.push(this.mouse()) - 1;
let scaledMouse = {
x: this.mouse().x / app.canvas.ds.scale,
y: this.mouse().y / app.canvas.ds.scale
};
i = points.push(scaledMouse) - 1;
updatePath();
return this;
}
})
.event("mouseup", function() {
if (this.pathElements !== null) {
let coords = samplePoints(pathElements[0], pointsWidget.value);
let coordsString = JSON.stringify(coords);
pointsStoreWidget.value = JSON.stringify(points);
if (coordWidget) {
coordWidget.value = coordsString;
}
}
});
else if (pv.event.ctrlKey) {
// Capture the clicked location
let clickedPoint = {
x: this.mouse().x / app.canvas.ds.scale,
y: this.mouse().y / app.canvas.ds.scale
};
// Find the two closest points to the clicked location
let { point1Index, point2Index } = findClosestPoints(points, clickedPoint);
// Calculate the midpoint between the two closest points
let midpoint = {
x: (points[point1Index].x + points[point2Index].x) / 2,
y: (points[point1Index].y + points[point2Index].y) / 2
};
// Insert the midpoint into the array
points.splice(point2Index, 0, midpoint);
i = point2Index;
updatePath();
}
else if (pv.event.button === 2) {
context.contextMenu.style.display = 'block';
context.contextMenu.style.left = `${pv.event.clientX}px`;
context.contextMenu.style.top = `${pv.event.clientY}px`;
}
})
vis.add(pv.Rule)
.data(pv.range(0, 8, .5))
.bottom(d => d * 64 + 0)
.data(pv.range(0, h, 64))
.bottom(d => d)
.strokeStyle("gray")
.lineWidth(1)
.lineWidth(3)
// vis.add(pv.Rule)
// .data(pv.range(0, points_to_sample, 1))
// .left(d => d * 512 / (points_to_sample - 1))
// .strokeStyle("gray")
// .lineWidth(2)
vis.add(pv.Line)
.data(() => points)
.left(d => d.x)
.top(d => d.y)
.interpolate(() => interpolationWidget.value)
.tension(() => tensionWidget.value)
.segmented(() => segmentedWidget.value)
.interpolate(() => interpolation)
.tension(() => tension)
.segmented(() => false)
.strokeStyle(pv.Colors.category10().by(pv.index))
.lineWidth(3)
vis.add(pv.Dot)
.data(() => points)
.left(d => d.x)
.top(d => d.y)
.radius(8)
.radius(10)
.shape(function() {
return dotShape;
})
.angle(function() {
const index = this.index;
let angle = 0;
if (dotShape === "triangle") {
let dxNext = 0, dyNext = 0;
if (index < points.length - 1) {
dxNext = points[index + 1].x - points[index].x;
dyNext = points[index + 1].y - points[index].y;
}
let dxPrev = 0, dyPrev = 0;
if (index > 0) {
dxPrev = points[index].x - points[index - 1].x;
dyPrev = points[index].y - points[index - 1].y;
}
const dx = (dxNext + dxPrev) / 2;
const dy = (dyNext + dyPrev) / 2;
angle = Math.atan2(dy, dx);
angle -= Math.PI / 2;
angle = (angle + 2 * Math.PI) % (2 * Math.PI);
}
return angle;
})
.cursor("move")
.strokeStyle(function() { return i == this.index ? "#ff7f0e" : "#1f77b4"; })
.fillStyle(function() { return "rgba(100, 100, 100, 0.2)"; })
.fillStyle(function() { return "rgba(100, 100, 100, 0.3)"; })
.event("mousedown", pv.Behavior.drag())
.event("dragstart", function() {
i = this.index;
if (pv.event.button === 2) {
hoverIndex = this.index;
isDragging = true;
if (pv.event.button === 2 && i !== 0 && i !== points.length - 1) {
points.splice(i--, 1);
vis.render();
}
return this;
})
.event("drag", vis)
.anchor("top").add(pv.Label)
.font(d => Math.sqrt(d[2]) * 32 + "px sans-serif")
//.text(d => `(${Math.round(d.x)}, ${Math.round(d.y)})`)
.text(d => {
// Normalize y to range 0.0 to 1.0, considering the inverted y-axis
var normalizedY = 1.0 - (d.y / h);
return `${normalizedY.toFixed(2)}`;
.event("dragend", function() {
if (this.pathElements !== null) {
updatePath();
}
isDragging = false;
})
.event("drag", function() {
let adjustedX = this.mouse().x / app.canvas.ds.scale; // Adjust the new X position by the inverse of the scale factor
let adjustedY = this.mouse().y / app.canvas.ds.scale; // Adjust the new Y position by the inverse of the scale factor
// Determine the bounds of the vis.Panel
const panelWidth = vis.width();
const panelHeight = vis.height();
// Adjust the new position if it would place the dot outside the bounds of the vis.Panel
adjustedX = Math.max(0, Math.min(panelWidth, adjustedX));
adjustedY = Math.max(0, Math.min(panelHeight, adjustedY));
points[this.index] = { x: adjustedX, y: adjustedY }; // Update the point's position
vis.render(); // Re-render the visualization to reflect the new position
})
.event("mouseover", function() {
hoverIndex = this.index; // Set the hover index to the index of the hovered dot
vis.render(); // Re-render the visualization
})
.event("mouseout", function() {
!isDragging && (hoverIndex = -1); // Reset the hover index when the mouse leaves the dot
vis.render(); // Re-render the visualization
})
.anchor("center")
.add(pv.Label)
.visible(function() {
return hoverIndex === this.index; // Only show the label for the hovered dot
})
.left(d => d.x < w / 2 ? d.x + 80 : d.x - 70) // Shift label to right if on left half, otherwise shift to left
.top(d => d.y < h / 2 ? d.y + 20 : d.y - 20) // Shift label down if on top half, otherwise shift up
.font(12 + "px sans-serif")
.text(d => {
if (samplingMethod == "path") {
return `X: ${Math.round(d.x)}, Y: ${Math.round(d.y)}`;
} else {
let frame = Math.round((d.x / w) * points_to_sample);
let normalizedY = (1.0 - (d.y / h) - 0.0) * (rangeMax - rangeMin) + rangeMin;
let normalizedX = (d.x / w);
return `F: ${frame}, X: ${normalizedX.toFixed(2)}, Y: ${normalizedY.toFixed(2)}`;
}
})
.textStyle("orange")
vis.render();
var svgElement = vis.canvas();
svgElement.style['zIndex'] = "2"
svgElement.style['position'] = "relative"
context.splineEditor.element.appendChild(svgElement);
var pathElements = svgElement.getElementsByTagName('path'); // Get all path elements
var pathElements = svgElement.getElementsByTagName('path'); // Get all path elements
updatePath();
}
function samplePoints(svgPathElement, numSamples) {
var pathLength = svgPathElement.getTotalLength();
var points = [];
for (var i = 0; i < numSamples; i++) {
function samplePoints(svgPathElement, numSamples, samplingMethod, width) {
var svgWidth = width; // Fixed width of the SVG element
var pathLength = svgPathElement.getTotalLength();
var points = [];
for (var i = 0; i < numSamples; i++) {
if (samplingMethod === "time") {
// Calculate the x-coordinate for the current sample based on the SVG's width
var x = (svgWidth / (numSamples - 1)) * i;
// Find the point on the path that intersects the vertical line at the calculated x-coordinate
var point = findPointAtX(svgPathElement, x, pathLength);
}
else if (samplingMethod === "path") {
// Calculate the distance along the path for the current sample
var distance = (pathLength / (numSamples - 1)) * i;
// Get the point at the current distance
var point = svgPathElement.getPointAtLength(distance);
}
// Add the point to the array of points
points.push({ x: point.x, y: point.y });
}
//console.log(points);
return points;
// Add the point to the array of points
points.push({ x: point.x, y: point.y });
}
return points;
}
function findClosestPoints(points, clickedPoint) {
// Calculate distances from clickedPoint to each point in the array
let distances = points.map(point => {
let dx = clickedPoint.x - point.x;
let dy = clickedPoint.y - point.y;
return { index: points.indexOf(point), distance: Math.sqrt(dx * dx + dy * dy) };
});
// Sort distances and get the indices of the two closest points
let sortedDistances = distances.sort((a, b) => a.distance - b.distance);
let closestPoint1Index = sortedDistances[0].index;
let closestPoint2Index = sortedDistances[1].index;
// Ensure point1Index is always the smaller index
if (closestPoint1Index > closestPoint2Index) {
[closestPoint1Index, closestPoint2Index] = [closestPoint2Index, closestPoint1Index];
}
return { point1Index: closestPoint1Index, point2Index: closestPoint2Index };
}
function findPointAtX(svgPathElement, targetX, pathLength) {
let low = 0;
let high = pathLength;
let bestPoint = svgPathElement.getPointAtLength(0);
while (low <= high) {
let mid = low + (high - low) / 2;
let point = svgPathElement.getPointAtLength(mid);
if (Math.abs(point.x - targetX) < 1) {
return point; // The point is close enough to the target
}
if (point.x < targetX) {
low = mid + 1;
} else {
high = mid - 1;
}
// Keep track of the closest point found so far
if (Math.abs(point.x - targetX) < Math.abs(bestPoint.x - targetX)) {
bestPoint = point;
}
}
// Return the closest point found
return bestPoint;
}
//from melmass