Instancediffusion updates

This commit is contained in:
kijai 2024-05-04 23:37:19 +03:00
parent 8ccc080bf4
commit 5208980fc9
2 changed files with 50 additions and 28 deletions

View File

@ -687,8 +687,8 @@ bounding boxes.
class CreateInstanceDiffusionTracking:
RETURN_TYPES = ("TRACKING", "INT", "INT", "INT", "INT",)
RETURN_NAMES = ("tracking", "width", "height", "bbox_width", "bbox_height",)
RETURN_TYPES = ("TRACKING", "STRING", "INT", "INT", "INT", "INT",)
RETURN_NAMES = ("tracking", "prompt", "width", "height", "bbox_width", "bbox_height",)
FUNCTION = "tracking"
CATEGORY = "KJNodes/experimental"
DESCRIPTION = """
@ -712,50 +712,61 @@ for example:
"bbox_height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
"class_name": ("STRING", {"default": "class_name"}),
"class_id": ("INT", {"default": 0,"min": 0, "max": 255, "step": 1}),
"prompt": ("STRING", {"default": "prompt", "multiline": True}),
},
"optional": {
"size_multiplier": ("FLOAT", {"default": [1.0], "forceInput": True}),
}
}
def tracking(self, coordinates, class_name, class_id, width, height, bbox_width, bbox_height, size_multiplier=[1.0]):
def tracking(self, coordinates, class_name, class_id, width, height, bbox_width, bbox_height, prompt, size_multiplier=[1.0]):
# Define the number of images in the batch
coordinates = coordinates.replace("'", '"')
coordinates = json.loads(coordinates)
tracked = {}
tracked[class_name] = {}
batch_size = len(coordinates)
# Initialize a list to hold the coordinates for the current ID
id_coordinates = []
for coord in coordinates:
if len(size_multiplier) != batch_size:
size_multiplier = size_multiplier * (batch_size // len(size_multiplier)) + size_multiplier[:batch_size % len(size_multiplier)]
for i, coord in enumerate(coordinates):
x = coord['x']
y = coord['y']
adjusted_bbox_width = bbox_width * size_multiplier[i]
adjusted_bbox_height = bbox_height * size_multiplier[i]
# Calculate the top left and bottom right coordinates
top_left_x = x - bbox_width // 2
top_left_y = y - bbox_height // 2
bottom_right_x = x + bbox_width // 2
bottom_right_y = y + bbox_height // 2
top_left_x = x - adjusted_bbox_width // 2
top_left_y = y - adjusted_bbox_height // 2
bottom_right_x = x + adjusted_bbox_width // 2
bottom_right_y = y + adjusted_bbox_height // 2
# Append the top left and bottom right coordinates to the list for the current ID
id_coordinates.append([top_left_x, top_left_y, bottom_right_x, bottom_right_y, width, height])
# Append the 'x' and 'y' coordinates along with bbox width/height and frame width/height to the list for the current ID
#id_coordinates.append([coord['x'] - bbox_width // 2 , coord['y'] - bbox_height // 2, bbox_width, bbox_height, width, height])
class_id = int(class_id)
print(class_id)
# Assign the list of coordinates to the specified ID within the class_id dictionary
tracked[class_name][class_id] = id_coordinates
print(tracked)
return (tracked, width, height, bbox_width, bbox_height,)
prompt_string = ""
for class_name, class_data in tracked.items():
for class_id in class_data.keys():
class_id_str = str(class_id)
# Use the incoming prompt for each class name and ID
prompt_string += f'"{class_id_str}.{class_name}": "({prompt})",\n'
# Remove the last comma and newline
prompt_string = prompt_string.rstrip(",\n")
print(prompt_string)
return (tracked, prompt_string, width, height, bbox_width, bbox_height)
class AppendInstanceDiffusionTracking:
RETURN_TYPES = ("TRACKING",)
RETURN_NAMES = ("tracking",)
RETURN_TYPES = ("TRACKING", "STRING",)
RETURN_NAMES = ("tracking", "prompt",)
FUNCTION = "append"
CATEGORY = "KJNodes/experimental"
DESCRIPTION = """
@ -771,20 +782,24 @@ https://github.com/logtd/ComfyUI-InstanceDiffusion
"tracking_1": ("TRACKING", {"forceInput": True}),
"tracking_2": ("TRACKING", {"forceInput": True}),
},
"optional": {
"prompt_1": ("STRING", {"default": "", "forceInput": True}),
"prompt_2": ("STRING", {"default": "", "forceInput": True}),
}
}
def append(self, tracking_1, tracking_2):
def append(self, tracking_1, tracking_2, prompt_1="", prompt_2=""):
tracking_copy = tracking_1.copy()
# Check for existing class names and class IDs, and raise an error if they exist
for class_name, class_data in tracking_2.items():
if class_name in tracking_copy:
for class_id in class_data.keys():
if class_id in tracking_copy[class_name]:
raise ValueError(f"Class ID {class_id} already exists for class name {class_name}. Cannot append tracking data.")
# If class name does not exist, add it
tracking_copy[class_name] = class_data
return (tracking_copy, )
if class_name not in tracking_copy:
tracking_copy[class_name] = class_data
else:
# If the class name exists, merge the class data from tracking_2 into tracking_copy
# This will add new class IDs under the same class name without raising an error
tracking_copy[class_name].update(class_data)
prompt_string = prompt_1 + "," + prompt_2
return (tracking_copy, prompt_string)
class InterpolateCoords:

View File

@ -314,12 +314,19 @@ function createSplineEditor(context, reset=false) {
var pointsLayer = null;
var samplingMethod = samplingMethodWidget.value
if (samplingMethod == "path") {
dotShape = "triangle"
}
interpolationWidget.callback = () => {
interpolation = interpolationWidget.value
updatePath();
}
samplingMethodWidget.callback = () => {
samplingMethod = samplingMethodWidget.value
if (samplingMethod == "path") {
dotShape = "triangle"
}
updatePath();
}
tensionWidget.callback = () => {
@ -359,7 +366,7 @@ function createSplineEditor(context, reset=false) {
var h = heightWidget.value;
var i = 3;
let points = [];
if (!reset && pointsStoreWidget.value != "") {
points = JSON.parse(pointsStoreWidget.value);
} else {