From e914839605d3ecb8f8894b8c8594f1e8168fc2aa Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Thu, 2 May 2024 09:37:01 +0300 Subject: [PATCH] Update curve_nodes.py --- nodes/curve_nodes.py | 55 +++++++++++++++++++++++--------------------- 1 file changed, 29 insertions(+), 26 deletions(-) diff --git a/nodes/curve_nodes.py b/nodes/curve_nodes.py index d44ca32..60deb32 100644 --- a/nodes/curve_nodes.py +++ b/nodes/curve_nodes.py @@ -154,15 +154,17 @@ Grow value is the amount to grow the shape on each frame, creating animated mask "default": 'circle' }), "coordinates": ("STRING", {"forceInput": True}), - "grow": ("INT", {"default": 0, "min": -512, "max": 512, "step": 1}), "frame_width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}), "frame_height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}), "shape_width": ("INT", {"default": 128,"min": 8, "max": 4096, "step": 1}), "shape_height": ("INT", {"default": 128,"min": 8, "max": 4096, "step": 1}), }, + "optional": { + "size_multiplier": ("FLOAT", {"default": [1.0], "forceInput": True}), + } } - def createshapemask(self, coordinates, frame_width, frame_height, shape_width, shape_height, grow, shape): + def createshapemask(self, coordinates, frame_width, frame_height, shape_width, shape_height, shape, size_multiplier=[1.0]): # Define the number of images in the batch coordinates = coordinates.replace("'", '"') coordinates = json.loads(coordinates) @@ -170,14 +172,15 @@ Grow value is the amount to grow the shape on each frame, creating animated mask batch_size = len(coordinates) out = [] color = "white" - + if len(size_multiplier) != batch_size: + size_multiplier = size_multiplier * (batch_size // len(size_multiplier)) + size_multiplier[:batch_size % len(size_multiplier)] for i, coord in enumerate(coordinates): image = Image.new("RGB", (frame_width, frame_height), "black") draw = ImageDraw.Draw(image) # Calculate the size for this frame and ensure it's not less than 0 - current_width = max(0, shape_width + i*grow) - current_height = max(0, shape_height + i*grow) + current_width = max(0, shape_width + i * size_multiplier[i]) + current_height = max(0, shape_height + i * size_multiplier[i]) location_x = coord['x'] location_y = coord['y'] @@ -575,7 +578,9 @@ bounding boxes. import matplotlib matplotlib.use('Agg') from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas - + text_color = '#999999' + bg_color = '#353535' + matplotlib.pyplot.rcParams['text.color'] = text_color fig, ax = matplotlib.pyplot.subplots(figsize=(width/100, height/100), dpi=100) cmap = matplotlib.pyplot.get_cmap('rainbow') @@ -585,29 +590,27 @@ bounding boxes. color = cmap(color_index) box_size = bbox_height * size rect = matplotlib.patches.Rectangle((x - box_size/2, y - box_size/2), box_size, box_size, - linewidth=1, edgecolor=color, facecolor='none', alpha=0.5) + linewidth=1, edgecolor=color, facecolor='none', alpha=0.5) ax.add_patch(rect) - # Draw arrows from one point to another to indicate direction - for i in range(len(coordinates) - 1): - color_index = i / (len(coordinates) - 1) - color = cmap(color_index) - x1, y1 = coordinates[i] - x2, y2 = coordinates[i + 1] - ax.annotate("", xy=(x2, y2), xytext=(x1, y1), - arrowprops=dict(arrowstyle="->", - linestyle="-", - lw=1, - color=color, - mutation_scale=10)) - matplotlib.pyplot.rcParams['text.color'] = '#999999' - fig.patch.set_facecolor('#353535') - ax.set_facecolor('#353535') - ax.grid(color='#999999', linestyle='-', linewidth=0.5) - ax.set_xlabel('x', color='#999999') - ax.set_ylabel('y', color='#999999') + # Check if there is a next coordinate to draw an arrow to + if i < len(coordinates) - 1: + x1, y1 = coordinates[i] + x2, y2 = coordinates[i + 1] + ax.annotate("", xy=(x2, y2), xytext=(x1, y1), + arrowprops=dict(arrowstyle="->", + linestyle="-", + lw=1, + color=color, + mutation_scale=10)) + + fig.patch.set_facecolor(bg_color) + ax.set_facecolor(bg_color) + ax.grid(color=text_color, linestyle='-', linewidth=0.5) + ax.set_xlabel('x', color=text_color) + ax.set_ylabel('y', color=text_color) for text in ax.get_xticklabels() + ax.get_yticklabels(): - text.set_color('#999999') + text.set_color(text_color) ax.set_title('Gligen pos for: ' + prompt) ax.set_xlabel('X Coordinate') ax.set_ylabel('Y Coordinate')