PointsEditor updates

This commit is contained in:
kijai 2024-08-04 20:08:13 +03:00
parent 2e7129fdb9
commit c3f6dcd850
2 changed files with 118 additions and 47 deletions

View File

@ -5,6 +5,7 @@ from PIL import Image, ImageDraw, ImageFont, ImageColor, ImageFilter
import numpy as np import numpy as np
from ..utility.utility import pil2tensor from ..utility.utility import pil2tensor
import folder_paths import folder_paths
from comfy.utils import common_upscale
def plot_coordinates_to_tensor(coordinates, height, width, bbox_height, bbox_width, size_multiplier, prompt): def plot_coordinates_to_tensor(coordinates, height, width, bbox_height, bbox_width, size_multiplier, prompt):
import matplotlib import matplotlib
@ -1265,16 +1266,15 @@ class PointsEditor:
), ),
"width": ("INT", {"default": 512, "min": 8, "max": 4096, "step": 8}), "width": ("INT", {"default": 512, "min": 8, "max": 4096, "step": 8}),
"height": ("INT", {"default": 512, "min": 8, "max": 4096, "step": 8}), "height": ("INT", {"default": 512, "min": 8, "max": 4096, "step": 8}),
"normalize": ("BOOLEAN", {"default": True}), "normalize": ("BOOLEAN", {"default": False}),
}, },
"optional": { "optional": {
"bg_image": ("IMAGE", ), "bg_image": ("IMAGE", ),
}, },
} }
RETURN_TYPES = ("STRING", "STRING", "BBOX", "MASK") RETURN_TYPES = ("STRING", "STRING", "BBOX", "MASK", "IMAGE")
RETURN_NAMES = ("positive_coords", "negative_coords", "bbox", "bbox_mask") RETURN_NAMES = ("positive_coords", "negative_coords", "bbox", "bbox_mask", "cropped_image")
FUNCTION = "pointdata" FUNCTION = "pointdata"
CATEGORY = "KJNodes/weights" CATEGORY = "KJNodes/weights"
DESCRIPTION = """ DESCRIPTION = """
@ -1332,28 +1332,48 @@ you can clear the image from the context menu by right clicking on the canvas
mask = np.zeros((height, width), dtype=np.uint8) mask = np.zeros((height, width), dtype=np.uint8)
bboxes = json.loads(bboxes) bboxes = json.loads(bboxes)
print(bboxes) print(bboxes)
if bboxes["x"] is None or bboxes["y"] is None or bboxes["width"] is None or bboxes["height"] is None: valid_bboxes = []
bboxes = [] for bbox in bboxes:
else: if (bbox.get("startX") is None or
bboxes = [(int(bboxes["x"]), int(bboxes["y"]), int(bboxes["width"]), int(bboxes["height"]))] bbox.get("startY") is None or
bbox.get("endX") is None or
bbox.get("endY") is None):
continue # Skip this bounding box if any value is None
else:
# Ensure that endX and endY are greater than startX and startY
x_min = min(int(bbox["startX"]), int(bbox["endX"]))
y_min = min(int(bbox["startY"]), int(bbox["endY"]))
x_max = max(int(bbox["startX"]), int(bbox["endX"]))
y_max = max(int(bbox["startY"]), int(bbox["endY"]))
valid_bboxes.append((x_min, y_min, x_max, y_max))
bboxes_xyxy = [] bboxes_xyxy = []
# Draw the bounding box on the mask for bbox in valid_bboxes:
for bbox in bboxes: x_min, y_min, x_max, y_max = bbox
x_min, y_min, w, h = bbox
x_max = x_min + w
y_max = y_min + h
bboxes_xyxy.append((x_min, y_min, x_max, y_max)) bboxes_xyxy.append((x_min, y_min, x_max, y_max))
mask[y_min:y_max, x_min:x_max] = 1 # Fill the bounding box area with 1s mask[y_min:y_max, x_min:x_max] = 1 # Fill the bounding box area with 1s
if bbox_format == "xyxy": if bbox_format == "xywh":
bboxes = bboxes_xyxy bboxes_xywh = []
for bbox in valid_bboxes:
x_min, y_min, x_max, y_max = bbox
width = x_max - x_min
height = y_max - y_min
bboxes_xywh.append((x_min, y_min, width, height))
bboxes = bboxes_xywh
else:
bboxes = bboxes_xyxy
mask_tensor = torch.from_numpy(mask) mask_tensor = torch.from_numpy(mask)
mask_tensor = mask_tensor.unsqueeze(0).float().cpu() mask_tensor = mask_tensor.unsqueeze(0).float().cpu()
#mask_tensor = mask_tensor[:,:,0]
print(mask_tensor.shape) if bg_image is not None and len(valid_bboxes) > 0:
x_min, y_min, x_max, y_max = bboxes[0]
cropped_image = bg_image[:, y_min:y_max, x_min:x_max, :]
elif bg_image is not None:
cropped_image = bg_image
if bg_image is None: if bg_image is None:
return (json.dumps(pos_coordinates), json.dumps(neg_coordinates), bboxes, mask_tensor) return (json.dumps(pos_coordinates), json.dumps(neg_coordinates), bboxes, mask_tensor)
@ -1369,5 +1389,5 @@ you can clear the image from the context menu by right clicking on the canvas
return { return {
"ui": {"bg_image": [img_base64]}, "ui": {"bg_image": [img_base64]},
"result": (json.dumps(coordinates), json.dumps(neg_coordinates), json.dumps(bboxes), bboxes, mask_tensor) "result": (json.dumps(pos_coordinates), json.dumps(neg_coordinates), bboxes, mask_tensor, cropped_image)
} }

View File

@ -214,6 +214,7 @@ class PointsEditor {
constructor(context, reset = false) { constructor(context, reset = false) {
this.node = context; this.node = context;
this.reset = reset; this.reset = reset;
const self = this; // Keep a reference to the main class context
console.log("creatingPointEditor") console.log("creatingPointEditor")
@ -248,7 +249,6 @@ class PointsEditor {
// context menu // context menu
this.createContextMenu(); this.createContextMenu();
if (reset && context.pointsEditor.element) { if (reset && context.pointsEditor.element) {
context.pointsEditor.element.innerHTML = ''; // Clear the container context.pointsEditor.element.innerHTML = ''; // Clear the container
} }
@ -280,13 +280,17 @@ class PointsEditor {
this.neg_points = JSON.parse(pointsStoreWidget.value).negative; this.neg_points = JSON.parse(pointsStoreWidget.value).negative;
this.updateData(); this.updateData();
} }
this.bboxStoreWidget.callback = () => {
this.bbox = JSON.parse(bboxStoreWidget.value)
this.updateData();
}
var w = this.widthWidget.value; var w = this.widthWidget.value;
var h = this.heightWidget.value; var h = this.heightWidget.value;
var i = 3; var i = 3;
this.points = []; this.points = [];
this.neg_points = []; this.neg_points = [];
this.bbox = []; this.bbox = [{}];
var drawing = false; var drawing = false;
// Initialize or reset points array // Initialize or reset points array
@ -294,6 +298,7 @@ class PointsEditor {
this.points = JSON.parse(this.pointsStoreWidget.value).positive; this.points = JSON.parse(this.pointsStoreWidget.value).positive;
this.neg_points = JSON.parse(this.pointsStoreWidget.value).negative; this.neg_points = JSON.parse(this.pointsStoreWidget.value).negative;
this.bbox = JSON.parse(this.bboxStoreWidget.value); this.bbox = JSON.parse(this.bboxStoreWidget.value);
console.log(this.bbox)
} else { } else {
this.points = [ this.points = [
{ {
@ -314,7 +319,8 @@ class PointsEditor {
this.pointsStoreWidget.value = JSON.stringify(combinedPoints); this.pointsStoreWidget.value = JSON.stringify(combinedPoints);
this.bboxStoreWidget.value = JSON.stringify(this.bbox); this.bboxStoreWidget.value = JSON.stringify(this.bbox);
} }
const self = this; // Keep a reference to the main class context
//create main canvas panel
this.vis = new pv.Panel() this.vis = new pv.Panel()
.width(w) .width(w)
.height(h) .height(h)
@ -323,8 +329,7 @@ class PointsEditor {
.lineWidth(2) .lineWidth(2)
.antialias(false) .antialias(false)
.margin(10) .margin(10)
.event("mousedown", function () { .event("mousedown", function () {
if (pv.event.shiftKey && pv.event.button === 2) { // Use pv.event to access the event object if (pv.event.shiftKey && pv.event.button === 2) { // Use pv.event to access the event object
let scaledMouse = { let scaledMouse = {
x: this.mouse().x / app.canvas.ds.scale, x: this.mouse().x / app.canvas.ds.scale,
@ -334,7 +339,7 @@ class PointsEditor {
self.updateData(); self.updateData();
return this; return this;
} }
else if (pv.event.shiftKey) { // Use pv.event to access the event object else if (pv.event.shiftKey) {
let scaledMouse = { let scaledMouse = {
x: this.mouse().x / app.canvas.ds.scale, x: this.mouse().x / app.canvas.ds.scale,
y: this.mouse().y / app.canvas.ds.scale y: this.mouse().y / app.canvas.ds.scale
@ -346,8 +351,8 @@ class PointsEditor {
else if (pv.event.ctrlKey) { else if (pv.event.ctrlKey) {
console.log("start drawing at " + this.mouse().x / app.canvas.ds.scale + ", " + this.mouse().y / app.canvas.ds.scale); console.log("start drawing at " + this.mouse().x / app.canvas.ds.scale + ", " + this.mouse().y / app.canvas.ds.scale);
drawing = true; drawing = true;
self.box_startX = this.mouse().x / app.canvas.ds.scale; self.bbox[0].startX = this.mouse().x / app.canvas.ds.scale;
self.box_startY = this.mouse().y / app.canvas.ds.scale; self.bbox[0].startY = this.mouse().y / app.canvas.ds.scale;
} }
else if (pv.event.button === 2) { else if (pv.event.button === 2) {
self.node.contextMenu.style.display = 'block'; self.node.contextMenu.style.display = 'block';
@ -357,8 +362,8 @@ class PointsEditor {
}) })
.event("mousemove", function () { .event("mousemove", function () {
if (drawing) { if (drawing) {
self.box_endX = this.mouse().x / app.canvas.ds.scale; self.bbox[0].endX = this.mouse().x / app.canvas.ds.scale;
self.box_endY = this.mouse().y / app.canvas.ds.scale; self.bbox[0].endY = this.mouse().y / app.canvas.ds.scale;
self.vis.render(); self.vis.render();
} }
}) })
@ -369,15 +374,66 @@ class PointsEditor {
}); });
this.backgroundImage = this.vis.add(pv.Image).visible(false) this.backgroundImage = this.vis.add(pv.Image).visible(false)
//create bounding box
this.vis.add(pv.Area) this.vis.add(pv.Area)
.data(function () {return drawing || self.bbox ? [self.box_startX, self.box_endX] : []; }) .data(function () {
.bottom(function () {return h - Math.max(self.box_startY, self.box_endY); }) if (drawing || (self.bbox && self.bbox[0] && Object.keys(self.bbox[0]).length > 0)) {
return [self.bbox[0].startX, self.bbox[0].endX];
} else {
return [];
}
})
.bottom(function () {return h - Math.max(self.bbox[0].startY, self.bbox[0].endY); })
.left(function (d) {return d; }) .left(function (d) {return d; })
.height(function () {return Math.abs(self.box_startY - self.box_endY);}) .height(function () {return Math.abs(self.bbox[0].startY - self.bbox[0].endY);})
.fillStyle("rgba(70, 130, 180, 0.5)") .fillStyle("rgba(70, 130, 180, 0.5)")
.strokeStyle("steelblue"); .strokeStyle("steelblue")
.visible(function () {return drawing || Object.keys(self.bbox[0]).length > 0; })
.add(pv.Dot)
.visible(function () {return drawing || Object.keys(self.bbox[0]).length > 0; })
.data(() => {
if (self.bbox && Object.keys(self.bbox[0]).length > 0) {
return [{
x: self.bbox[0].endX,
y: self.bbox[0].endY
}];
} else {
return [];
}
})
.left(d => d.x)
.top(d => d.y)
.radius(Math.log(Math.min(w, h)) * 1)
.shape("square")
.cursor("move")
.strokeStyle("steelblue")
.lineWidth(2)
.fillStyle(function () { return "rgba(100, 100, 100, 0.6)"; })
.event("mousedown", pv.Behavior.drag())
.event("dragstart", function () {
i = this.index;
})
.event("mousedown", pv.Behavior.drag())
.event("dragstart", function () {
i = this.index;
})
.event("drag", function () {
let adjustedX = this.mouse().x / app.canvas.ds.scale; // Adjust the new position by the inverse of the scale factor
let adjustedY = this.mouse().y / app.canvas.ds.scale;
// Adjust the new position if it would place the dot outside the bounds of the vis.Panel
adjustedX = Math.max(0, Math.min(self.vis.width(), adjustedX));
adjustedY = Math.max(0, Math.min(self.vis.height(), adjustedY));
self.bbox[0].endX = this.mouse().x / app.canvas.ds.scale;
self.bbox[0].endY = this.mouse().y / app.canvas.ds.scale;
self.vis.render();
})
.event("dragend", function () {
self.updateData();
});
//create positive points
this.vis.add(pv.Dot) this.vis.add(pv.Dot)
.data(() => this.points) .data(() => this.points)
.left(d => d.x) .left(d => d.x)
@ -431,6 +487,7 @@ class PointsEditor {
.fillStyle("red") // Color for the center point .fillStyle("red") // Color for the center point
.lineWidth(1); // Stroke thickness for the center point .lineWidth(1); // Stroke thickness for the center point
//create negative points
this.vis.add(pv.Dot) this.vis.add(pv.Dot)
.data(() => this.neg_points) .data(() => this.neg_points)
.left(d => d.x) .left(d => d.x)
@ -466,7 +523,6 @@ class PointsEditor {
self.neg_points[this.index] = { x: adjustedX, y: adjustedY }; // Update the point's position self.neg_points[this.index] = { x: adjustedX, y: adjustedY }; // Update the point's position
self.vis.render(); // Re-render the visualization to reflect the new position self.vis.render(); // Re-render the visualization to reflect the new position
}) })
.anchor("center") .anchor("center")
.add(pv.Label) .add(pv.Label)
.left(d => d.x < w / 2 ? d.x + 30 : d.x - 35) // Shift label to right if on left half, otherwise shift to left .left(d => d.x < w / 2 ? d.x + 30 : d.x - 35) // Shift label to right if on left half, otherwise shift to left
@ -507,17 +563,20 @@ class PointsEditor {
console.log("no points") console.log("no points")
return return
} }
let bbox = calculateBBox(this.box_startX, this.box_startY, this.box_endX, this.box_endY);
let bboxString = JSON.stringify(bbox);
const combinedPoints = { const combinedPoints = {
positive: this.points, positive: this.points,
negative: this.neg_points, negative: this.neg_points,
}; };
this.pointsStoreWidget.value = JSON.stringify(combinedPoints); this.pointsStoreWidget.value = JSON.stringify(combinedPoints);
this.bboxStoreWidget.value = JSON.stringify(bboxString);
this.pos_coordWidget.value = JSON.stringify(this.points); this.pos_coordWidget.value = JSON.stringify(this.points);
this.neg_coordWidget.value = JSON.stringify(this.neg_points); this.neg_coordWidget.value = JSON.stringify(this.neg_points);
this.bboxWidget.value = bboxString;
if (this.bbox.length != 0) {
let bboxString = JSON.stringify(this.bbox);
this.bboxStoreWidget.value = bboxString;
this.bboxWidget.value = bboxString;
}
this.vis.render(); this.vis.render();
}; };
@ -673,12 +732,4 @@ export function hideWidgetForGood(node, widget, suffix = '') {
hideWidgetForGood(node, w, ':' + widget.name) hideWidgetForGood(node, w, ':' + widget.name)
} }
} }
}
function calculateBBox(x1, y1, x2, y2) {
var x = Math.min(x1, x2);
var y = Math.min(y1, y2);
var width = Math.abs(x2 - x1);
var height = Math.abs(y2 - y1);
return { x: x, y: y, width: width, height: height };
} }