bbox to point editor

This commit is contained in:
kijai 2024-08-02 18:52:05 +03:00
parent 5ccd427bd8
commit dee4e8f1eb
3 changed files with 114 additions and 41 deletions

View File

@ -710,28 +710,48 @@ Visualizes the specified bbox on the image.
image_list = []
for image, bbox in zip(images, bboxes):
x_min, y_min, width, height = bbox
# Ensure bbox coordinates are integers
x_min = int(x_min)
y_min = int(y_min)
width = int(width)
height = int(height)
# Permute the image dimensions
image = image.permute(2, 0, 1)
# Clone the image to draw bounding boxes
img_with_bbox = image.clone()
# Define the color for the bbox, e.g., red
color = torch.tensor([1, 0, 0], dtype=torch.float32)
# Ensure color tensor matches the image channels
if color.shape[0] != img_with_bbox.shape[0]:
color = color.unsqueeze(1).expand(-1, line_width)
# Draw lines for each side of the bbox with the specified line width
for lw in range(line_width):
# Top horizontal line
if y_min + lw < img_with_bbox.shape[1]:
img_with_bbox[:, y_min + lw, x_min:x_min + width] = color[:, None]
# Bottom horizontal line
if y_min + height - lw < img_with_bbox.shape[1]:
img_with_bbox[:, y_min + height - lw, x_min:x_min + width] = color[:, None]
# Left vertical line
if x_min + lw < img_with_bbox.shape[2]:
img_with_bbox[:, y_min:y_min + height, x_min + lw] = color[:, None]
# Right vertical line
if x_min + width - lw < img_with_bbox.shape[2]:
img_with_bbox[:, y_min:y_min + height, x_min + width - lw] = color[:, None]
# Permute the image dimensions back
img_with_bbox = img_with_bbox.permute(1, 2, 0).unsqueeze(0)
image_list.append(img_with_bbox)
return (torch.cat(image_list, dim=0),)
return (torch.cat(image_list, dim=0),)

View File

@ -1254,13 +1254,15 @@ class PointsEditor:
"required": {
"points_store": ("STRING", {"multiline": False}),
"coordinates": ("STRING", {"multiline": False}),
"bbox_store": ("STRING", {"multiline": False}),
"bboxes": ("STRING", {"multiline": False}),
"width": ("INT", {"default": 512, "min": 8, "max": 4096, "step": 8}),
"height": ("INT", {"default": 512, "min": 8, "max": 4096, "step": 8}),
},
}
RETURN_TYPES = ("STRING", "STRING",)
RETURN_NAMES = ("coord_str", "normalized_str",)
RETURN_TYPES = ("STRING", "STRING", "BBOX", "MASK")
RETURN_NAMES = ("coord_str", "normalized_str", "bbox", "bbox_mask")
FUNCTION = "splinedata"
CATEGORY = "KJNodes/weights"
DESCRIPTION = """
@ -1293,7 +1295,7 @@ output types:
example compatible nodes: unknown
"""
def splinedata(self, points_store, width, height, coordinates):
def splinedata(self, points_store, bbox_store, width, height, coordinates, bboxes):
coordinates = json.loads(coordinates)
normalized = []
@ -1306,5 +1308,22 @@ output types:
normalized_y_values.append(norm_y)
normalized.append({'x':norm_x, 'y':norm_y})
bboxes = json.loads(bboxes)
bboxes = [(int(bboxes["x"]), int(bboxes["y"]), int(bboxes["width"]), int(bboxes["height"]))]
return (json.dumps(coordinates), json.dumps(normalized))
# Create a blank mask
mask = np.zeros((height, width), dtype=np.uint8)
# Draw the bounding box on the mask
for bbox in bboxes:
x_min, y_min, w, h = bbox
x_max = x_min + w
y_max = y_min + h
mask[y_min:y_max, x_min:x_max] = 1 # Fill the bounding box area with 1s
mask_tensor = torch.from_numpy(mask)
mask_tensor = mask_tensor.unsqueeze(0).float().cpu()
#mask_tensor = mask_tensor[:,:,0]
print(mask_tensor.shape)
return (json.dumps(coordinates), json.dumps(normalized), bboxes, mask_tensor)

View File

@ -106,6 +106,7 @@ app.registerExtension({
chainCallback(nodeType.prototype, "onNodeCreated", function () {
hideWidgetForGood(this, this.widgets.find(w => w.name === "coordinates"))
hideWidgetForGood(this, this.widgets.find(w => w.name === "bboxes"))
var element = document.createElement("div");
this.uuid = makeUUID()
@ -277,7 +278,7 @@ function createPointsEditor(context, reset=false) {
context.setSize([context.size[0], h + 230]);
vis.width(w);
vis.height(h);
updatePath();
updateData();
backgroundImage.url(imageUrl).visible(true).root.render();
};
}
@ -319,7 +320,7 @@ function createPointsEditor(context, reset=false) {
context.setSize([context.size[0], this.height + 230]);
vis.width(w);
vis.height(h);
updatePath();
updateData();
backgroundImage.url(imageUrl).visible(true).root.render();
};
}
@ -341,33 +342,23 @@ function createPointsEditor(context, reset=false) {
}
createContextMenu();
function updatePath() {
function updateData() {
if (points.length == 0) {
console.log("no points")
return
}
let coords = points
// if (pointsLayer) {
// // Update the data of the existing points layer
// pointsLayer.data(coords);
// } else {
// // Create the points layer if it doesn't exist
// pointsLayer = vis.add(pv.Dot)
// .data(coords)
// .left(function(d) { return d.x; })
// .top(function(d) { return d.y; })
// .radius(3) // Adjust the radius as needed
// .fillStyle("red") // Change the color as needed
// .strokeStyle("black") // Change the stroke color as needed
// .lineWidth(1); // Adjust the line width as needed
// }
let coordsString = JSON.stringify(coords);
let bbox = calculateBBox(box_startX, box_startY, box_endX, box_endY)
let bboxString = JSON.stringify(bbox);
pointsStoreWidget.value = JSON.stringify(points);
bboxStoreWidget.value = JSON.stringify(bboxString);
if (coordWidget) {
coordWidget.value = coordsString;
}
if (bboxWidget) {
bboxWidget.value = bboxString;
}
vis.render();
}
@ -378,8 +369,8 @@ function createPointsEditor(context, reset=false) {
const pointsStoreWidget = context.widgets.find(w => w.name === "points_store");
const widthWidget = context.widgets.find(w => w.name === "width");
const heightWidget = context.widgets.find(w => w.name === "height");
var pointsLayer = null;
const bboxStoreWidget = context.widgets.find(w => w.name === "bbox_store");
const bboxWidget = context.widgets.find(w => w.name === "bboxes");
widthWidget.callback = () => {
w = widthWidget.value;
@ -387,17 +378,17 @@ function createPointsEditor(context, reset=false) {
context.setSize([w + 45, context.size[1]]);
}
vis.width(w);
updatePath();
updateData();
}
heightWidget.callback = () => {
h = heightWidget.value
vis.height(h)
context.setSize([context.size[0], h + 430]);
updatePath();
updateData();
}
pointsStoreWidget.callback = () => {
points = JSON.parse(pointsStoreWidget.value);
updatePath();
updateData();
}
// Initialize or reset points array
@ -407,6 +398,9 @@ function createPointsEditor(context, reset=false) {
var h = heightWidget.value;
var i = 3;
let points = [];
let bbox = [];
var box_startX, box_startY, box_endX, box_endY;
var drawing = false;
if (!reset && pointsStoreWidget.value != "") {
points = JSON.parse(pointsStoreWidget.value);
@ -436,11 +430,14 @@ function createPointsEditor(context, reset=false) {
y: this.mouse().y / app.canvas.ds.scale
};
i = points.push(scaledMouse) - 1;
updatePath();
updateData();
return this;
}
else if (pv.event.ctrlKey) {
console.log("start drawing at " + this.mouse().x / app.canvas.ds.scale + ", " + this.mouse().y / app.canvas.ds.scale);
drawing = true;
box_startX = this.mouse().x / app.canvas.ds.scale;
box_startY = this.mouse().y / app.canvas.ds.scale;
}
else if (pv.event.button === 2) {
context.contextMenu.style.display = 'block';
@ -448,9 +445,38 @@ function createPointsEditor(context, reset=false) {
context.contextMenu.style.top = `${pv.event.clientY}px`;
}
})
.event("mousemove", function() {
if (drawing) {
box_endX = this.mouse().x / app.canvas.ds.scale;
box_endY = this.mouse().y / app.canvas.ds.scale;
vis.render();
}
})
.event("mouseup", function() {
console.log("end drawing at " + this.mouse().x / app.canvas.ds.scale + ", " + this.mouse().y / app.canvas.ds.scale);
drawing = false;
updateData();
});
var backgroundImage = vis.add(pv.Image)
.visible(false)
vis.add(pv.Area)
.data(function() {
return drawing || bbox ? [box_startX, box_endX] : [];
})
.bottom(function() {
return h - Math.max(box_startY, box_endY);
})
.left(function(d) {
return d;
})
.height(function() {
return Math.abs(box_startY - box_endY);
})
.fillStyle("rgba(70, 130, 180, 0.5)")
.strokeStyle("steelblue");
vis.add(pv.Dot)
.data(() => points)
.left(d => d.x)
@ -475,7 +501,7 @@ function createPointsEditor(context, reset=false) {
points.splice(i--, 1);
}
updatePath();
updateData();
isDragging = false;
})
@ -539,7 +565,7 @@ function createPointsEditor(context, reset=false) {
context.setSize([w + 45, context.size[1]]);
}
context.setSize([context.size[0], h + 430]);
updatePath();
updateData();
if (context.properties.imgData && context.properties.imgData.base64) {
const base64String = context.properties.imgData.base64;
@ -573,3 +599,11 @@ export function hideWidgetForGood(node, widget, suffix = '') {
}
}
}
function calculateBBox(x1, y1, x2, y2) {
var x = Math.min(x1, x2);
var y = Math.min(y1, y2);
var width = Math.abs(x2 - x1);
var height = Math.abs(y2 - y1);
return { x: x, y: y, width: width, height: height };
}