From df5e0d49e4fdd9e77452ffd2f4023f7ee0d22bc4 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sun, 4 Aug 2024 16:16:56 +0300 Subject: [PATCH] PointsEditor: add negative points --- nodes/curve_nodes.py | 39 ++++++++++----- web/js/point_editor.js | 109 ++++++++++++++++++++++++++++++++++++----- 2 files changed, 124 insertions(+), 24 deletions(-) diff --git a/nodes/curve_nodes.py b/nodes/curve_nodes.py index 3ba8b7c..b72e559 100644 --- a/nodes/curve_nodes.py +++ b/nodes/curve_nodes.py @@ -1254,6 +1254,7 @@ class PointsEditor: "required": { "points_store": ("STRING", {"multiline": False}), "coordinates": ("STRING", {"multiline": False}), + "neg_coordinates": ("STRING", {"multiline": False}), "bbox_store": ("STRING", {"multiline": False}), "bboxes": ("STRING", {"multiline": False}), "bbox_format": ( @@ -1264,14 +1265,16 @@ class PointsEditor: ), "width": ("INT", {"default": 512, "min": 8, "max": 4096, "step": 8}), "height": ("INT", {"default": 512, "min": 8, "max": 4096, "step": 8}), + "normalize": ("BOOLEAN", {"default": True}), }, "optional": { "bg_image": ("IMAGE", ), + }, } RETURN_TYPES = ("STRING", "STRING", "BBOX", "MASK") - RETURN_NAMES = ("coord_str", "normalized_str", "bbox", "bbox_mask") + RETURN_NAMES = ("positive_coords", "negative_coords", "bbox", "bbox_mask") FUNCTION = "pointdata" CATEGORY = "KJNodes/weights" DESCRIPTION = """ @@ -1290,24 +1293,38 @@ Note that you can't delete from start/end of the points array. To add an image select the node and copy/paste or drag in the image. Or from the bg_image input on queue (first frame of the batch). -**THE IMAGE IS SAVED TO THE NODE AND WORKFLOW METADATA** +**THE IMAGE IS SAVED TO THE NODE AND WORKFLOW METADATA** """ - def pointdata(self, points_store, bbox_store, width, height, coordinates, bboxes, bbox_format="xyxy", bg_image=None): + def pointdata(self, points_store, bbox_store, width, height, coordinates, neg_coordinates, normalize, bboxes, bbox_format="xyxy", bg_image=None): import io import base64 coordinates = json.loads(coordinates) - normalized = [] - normalized_y_values = [] + pos_coordinates = [] for coord in coordinates: coord['x'] = int(round(coord['x'])) coord['y'] = int(round(coord['y'])) - norm_x = (1.0 - (coord['x'] / height) - 0.0) - norm_y = (1.0 - (coord['y'] / height) - 0.0) - normalized_y_values.append(norm_y) - normalized.append({'x':norm_x, 'y':norm_y}) + if normalize: + norm_x = coord['x'] / width + norm_y = coord['y'] / height + pos_coordinates.append({'x': norm_x, 'y': norm_y}) + else: + pos_coordinates.append({'x': coord['x'], 'y': coord['y']}) + + if neg_coordinates: + coordinates = json.loads(neg_coordinates) + neg_coordinates = [] + for coord in coordinates: + coord['x'] = int(round(coord['x'])) + coord['y'] = int(round(coord['y'])) + if normalize: + norm_x = coord['x'] / width + norm_y = coord['y'] / height + neg_coordinates.append({'x': norm_x, 'y': norm_y}) + else: + neg_coordinates.append({'x': coord['x'], 'y': coord['y']}) # Create a blank mask mask = np.zeros((height, width), dtype=np.uint8) @@ -1337,7 +1354,7 @@ Or from the bg_image input on queue (first frame of the batch). print(mask_tensor.shape) if bg_image is None: - return (json.dumps(coordinates), json.dumps(normalized), bboxes, mask_tensor) + return (json.dumps(pos_coordinates), json.dumps(neg_coordinates), bboxes, mask_tensor) else: transform = transforms.ToPILImage() image = transform(bg_image[0].permute(2, 0, 1)) @@ -1350,5 +1367,5 @@ Or from the bg_image input on queue (first frame of the batch). return { "ui": {"bg_image": [img_base64]}, - "result": (json.dumps(coordinates), json.dumps(normalized), bboxes, mask_tensor) + "result": (json.dumps(coordinates), json.dumps(neg_coordinates), json.dumps(bboxes), bboxes, mask_tensor) } \ No newline at end of file diff --git a/web/js/point_editor.js b/web/js/point_editor.js index 9fa2ff3..61583fc 100644 --- a/web/js/point_editor.js +++ b/web/js/point_editor.js @@ -105,6 +105,7 @@ app.registerExtension({ chainCallback(nodeType.prototype, "onNodeCreated", function () { hideWidgetForGood(this, this.widgets.find(w => w.name === "coordinates")) + hideWidgetForGood(this, this.widgets.find(w => w.name === "neg_coordinates")) hideWidgetForGood(this, this.widgets.find(w => w.name === "bboxes")) var element = document.createElement("div"); @@ -177,6 +178,7 @@ app.registerExtension({ if (!this.properties || !("points" in this.properties)) { this.editor = new PointsEditor(this); this.addProperty("points", this.constructor.type, "string"); + this.addProperty("neg_points", this.constructor.type, "string"); } else { @@ -250,7 +252,8 @@ class PointsEditor { if (reset && context.pointsEditor.element) { context.pointsEditor.element.innerHTML = ''; // Clear the container } - this.coordWidget = context.widgets.find(w => w.name === "coordinates"); + this.pos_coordWidget = context.widgets.find(w => w.name === "coordinates"); + this.neg_coordWidget = context.widgets.find(w => w.name === "neg_coordinates"); this.pointsStoreWidget = context.widgets.find(w => w.name === "points_store"); this.widthWidget = context.widgets.find(w => w.name === "width"); this.heightWidget = context.widgets.find(w => w.name === "height"); @@ -273,7 +276,8 @@ class PointsEditor { this.updateData(); } this.pointsStoreWidget.callback = () => { - this.points = JSON.parse(pointsStoreWidget.value); + this.points = JSON.parse(pointsStoreWidget.value).positive; + this.neg_points = JSON.parse(pointsStoreWidget.value).negative; this.updateData(); } @@ -281,12 +285,14 @@ class PointsEditor { var h = this.heightWidget.value; var i = 3; this.points = []; + this.neg_points = []; this.bbox = []; var drawing = false; // Initialize or reset points array if (!reset && this.pointsStoreWidget.value != "") { - this.points = JSON.parse(this.pointsStoreWidget.value); + this.points = JSON.parse(this.pointsStoreWidget.value).positive; + this.neg_points = JSON.parse(this.pointsStoreWidget.value).negative; this.bbox = JSON.parse(this.bboxStoreWidget.value); } else { this.points = [ @@ -295,7 +301,17 @@ class PointsEditor { y: h / 2 // Middle point vertically centered } ]; - this.pointsStoreWidget.value = JSON.stringify(this.points); + this.neg_points = [ + { + x: 0, // Middle point horizontally centered + y: 0 // Middle point vertically centered + } + ]; + const combinedPoints = { + positive: this.points, + negative: this.neg_points, + }; + this.pointsStoreWidget.value = JSON.stringify(combinedPoints); this.bboxStoreWidget.value = JSON.stringify(this.bbox); } const self = this; // Keep a reference to the main class context @@ -308,7 +324,17 @@ class PointsEditor { .antialias(false) .margin(10) .event("mousedown", function () { - if (pv.event.shiftKey) { // Use pv.event to access the event object + + if (pv.event.shiftKey && pv.event.button === 2) { // Use pv.event to access the event object + let scaledMouse = { + x: this.mouse().x / app.canvas.ds.scale, + y: this.mouse().y / app.canvas.ds.scale + }; + i = self.neg_points.push(scaledMouse) - 1; + self.updateData(); + return this; + } + else if (pv.event.shiftKey) { // Use pv.event to access the event object let scaledMouse = { x: this.mouse().x / app.canvas.ds.scale, y: this.mouse().y / app.canvas.ds.scale @@ -359,7 +385,7 @@ class PointsEditor { .radius(Math.log(Math.min(w, h)) * 4) .shape("circle") .cursor("move") - .strokeStyle(function () { return i == this.index ? "#ff7f0e" : "#00FFFF"; }) + .strokeStyle(function () { return i == this.index ? "#07f907" : "#139613"; }) .lineWidth(4) .fillStyle(function () { return "rgba(100, 100, 100, 0.6)"; }) .event("mousedown", pv.Behavior.drag()) @@ -394,16 +420,69 @@ class PointsEditor { .top(d => d.y < h / 2 ? d.y + 25 : d.y - 25) // Shift label down if on top half, otherwise shift up .font(25 + "px sans-serif") .text(d => {return this.points.indexOf(d); }) - .textStyle("cyan") + .textStyle("#139613") .textShadow("2px 2px 2px black") .add(pv.Dot) // Add smaller point in the center .data(() => this.points) + .left(d => d.x) + .top(d => d.y) + .radius(2) // Smaller radius for the center point + .shape("circle") + .fillStyle("red") // Color for the center point + .lineWidth(1); // Stroke thickness for the center point + + this.vis.add(pv.Dot) + .data(() => this.neg_points) .left(d => d.x) .top(d => d.y) - .radius(2) // Smaller radius for the center point + .radius(Math.log(Math.min(w, h)) * 4) .shape("circle") - .fillStyle("red") // Color for the center point - .lineWidth(1); // Stroke thickness for the center point + .cursor("move") + .strokeStyle(function () { return i == this.index ? "#f91111" : "#891616"; }) + .lineWidth(4) + .fillStyle(function () { return "rgba(100, 100, 100, 0.6)"; }) + .event("mousedown", pv.Behavior.drag()) + .event("dragstart", function () { + i = this.index; + }) + .event("dragend", function () { + if (pv.event.button === 2 && i !== 0 && i !== self.neg_points.length - 1) { + this.index = i; + self.neg_points.splice(i--, 1); + } + self.updateData(); + + }) + .event("drag", function () { + let adjustedX = this.mouse().x / app.canvas.ds.scale; // Adjust the new X position by the inverse of the scale factor + let adjustedY = this.mouse().y / app.canvas.ds.scale; // Adjust the new Y position by the inverse of the scale factor + // Determine the bounds of the vis.Panel + const panelWidth = self.vis.width(); + const panelHeight = self.vis.height(); + + // Adjust the new position if it would place the dot outside the bounds of the vis.Panel + adjustedX = Math.max(0, Math.min(panelWidth, adjustedX)); + adjustedY = Math.max(0, Math.min(panelHeight, adjustedY)); + self.neg_points[this.index] = { x: adjustedX, y: adjustedY }; // Update the point's position + self.vis.render(); // Re-render the visualization to reflect the new position + }) + + .anchor("center") + .add(pv.Label) + .left(d => d.x < w / 2 ? d.x + 30 : d.x - 35) // Shift label to right if on left half, otherwise shift to left + .top(d => d.y < h / 2 ? d.y + 25 : d.y - 25) // Shift label down if on top half, otherwise shift up + .font(25 + "px sans-serif") + .text(d => {return this.neg_points.indexOf(d); }) + .textStyle("red") + .textShadow("2px 2px 2px black") + .add(pv.Dot) // Add smaller point in the center + .data(() => this.neg_points) + .left(d => d.x) + .top(d => d.y) + .radius(2) // Smaller radius for the center point + .shape("circle") + .fillStyle("red") // Color for the center point + .lineWidth(1); // Stroke thickness for the center point if (this.points.length != 0) { this.vis.render(); @@ -428,12 +507,16 @@ class PointsEditor { console.log("no points") return } - let coordsString = JSON.stringify(this.points); let bbox = calculateBBox(this.box_startX, this.box_startY, this.box_endX, this.box_endY); let bboxString = JSON.stringify(bbox); - this.pointsStoreWidget.value = JSON.stringify(this.points); + const combinedPoints = { + positive: this.points, + negative: this.neg_points, + }; + this.pointsStoreWidget.value = JSON.stringify(combinedPoints); this.bboxStoreWidget.value = JSON.stringify(bboxString); - this.coordWidget.value = coordsString; + this.pos_coordWidget.value = JSON.stringify(this.points); + this.neg_coordWidget.value = JSON.stringify(this.neg_points); this.bboxWidget.value = bboxString; this.vis.render(); };