Update curve_nodes.py

This commit is contained in:
kijai 2024-05-02 09:37:01 +03:00
parent d61894be21
commit e914839605

View File

@ -154,15 +154,17 @@ Grow value is the amount to grow the shape on each frame, creating animated mask
"default": 'circle' "default": 'circle'
}), }),
"coordinates": ("STRING", {"forceInput": True}), "coordinates": ("STRING", {"forceInput": True}),
"grow": ("INT", {"default": 0, "min": -512, "max": 512, "step": 1}),
"frame_width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}), "frame_width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
"frame_height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}), "frame_height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
"shape_width": ("INT", {"default": 128,"min": 8, "max": 4096, "step": 1}), "shape_width": ("INT", {"default": 128,"min": 8, "max": 4096, "step": 1}),
"shape_height": ("INT", {"default": 128,"min": 8, "max": 4096, "step": 1}), "shape_height": ("INT", {"default": 128,"min": 8, "max": 4096, "step": 1}),
}, },
"optional": {
"size_multiplier": ("FLOAT", {"default": [1.0], "forceInput": True}),
}
} }
def createshapemask(self, coordinates, frame_width, frame_height, shape_width, shape_height, grow, shape): def createshapemask(self, coordinates, frame_width, frame_height, shape_width, shape_height, shape, size_multiplier=[1.0]):
# Define the number of images in the batch # Define the number of images in the batch
coordinates = coordinates.replace("'", '"') coordinates = coordinates.replace("'", '"')
coordinates = json.loads(coordinates) coordinates = json.loads(coordinates)
@ -170,14 +172,15 @@ Grow value is the amount to grow the shape on each frame, creating animated mask
batch_size = len(coordinates) batch_size = len(coordinates)
out = [] out = []
color = "white" color = "white"
if len(size_multiplier) != batch_size:
size_multiplier = size_multiplier * (batch_size // len(size_multiplier)) + size_multiplier[:batch_size % len(size_multiplier)]
for i, coord in enumerate(coordinates): for i, coord in enumerate(coordinates):
image = Image.new("RGB", (frame_width, frame_height), "black") image = Image.new("RGB", (frame_width, frame_height), "black")
draw = ImageDraw.Draw(image) draw = ImageDraw.Draw(image)
# Calculate the size for this frame and ensure it's not less than 0 # Calculate the size for this frame and ensure it's not less than 0
current_width = max(0, shape_width + i*grow) current_width = max(0, shape_width + i * size_multiplier[i])
current_height = max(0, shape_height + i*grow) current_height = max(0, shape_height + i * size_multiplier[i])
location_x = coord['x'] location_x = coord['x']
location_y = coord['y'] location_y = coord['y']
@ -575,7 +578,9 @@ bounding boxes.
import matplotlib import matplotlib
matplotlib.use('Agg') matplotlib.use('Agg')
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
text_color = '#999999'
bg_color = '#353535'
matplotlib.pyplot.rcParams['text.color'] = text_color
fig, ax = matplotlib.pyplot.subplots(figsize=(width/100, height/100), dpi=100) fig, ax = matplotlib.pyplot.subplots(figsize=(width/100, height/100), dpi=100)
cmap = matplotlib.pyplot.get_cmap('rainbow') cmap = matplotlib.pyplot.get_cmap('rainbow')
@ -588,10 +593,8 @@ bounding boxes.
linewidth=1, edgecolor=color, facecolor='none', alpha=0.5) linewidth=1, edgecolor=color, facecolor='none', alpha=0.5)
ax.add_patch(rect) ax.add_patch(rect)
# Draw arrows from one point to another to indicate direction # Check if there is a next coordinate to draw an arrow to
for i in range(len(coordinates) - 1): if i < len(coordinates) - 1:
color_index = i / (len(coordinates) - 1)
color = cmap(color_index)
x1, y1 = coordinates[i] x1, y1 = coordinates[i]
x2, y2 = coordinates[i + 1] x2, y2 = coordinates[i + 1]
ax.annotate("", xy=(x2, y2), xytext=(x1, y1), ax.annotate("", xy=(x2, y2), xytext=(x1, y1),
@ -600,14 +603,14 @@ bounding boxes.
lw=1, lw=1,
color=color, color=color,
mutation_scale=10)) mutation_scale=10))
matplotlib.pyplot.rcParams['text.color'] = '#999999'
fig.patch.set_facecolor('#353535') fig.patch.set_facecolor(bg_color)
ax.set_facecolor('#353535') ax.set_facecolor(bg_color)
ax.grid(color='#999999', linestyle='-', linewidth=0.5) ax.grid(color=text_color, linestyle='-', linewidth=0.5)
ax.set_xlabel('x', color='#999999') ax.set_xlabel('x', color=text_color)
ax.set_ylabel('y', color='#999999') ax.set_ylabel('y', color=text_color)
for text in ax.get_xticklabels() + ax.get_yticklabels(): for text in ax.get_xticklabels() + ax.get_yticklabels():
text.set_color('#999999') text.set_color(text_color)
ax.set_title('Gligen pos for: ' + prompt) ax.set_title('Gligen pos for: ' + prompt)
ax.set_xlabel('X Coordinate') ax.set_xlabel('X Coordinate')
ax.set_ylabel('Y Coordinate') ax.set_ylabel('Y Coordinate')