diff --git a/nodes/curve_nodes.py b/nodes/curve_nodes.py index b1568a8..c48e672 100644 --- a/nodes/curve_nodes.py +++ b/nodes/curve_nodes.py @@ -554,7 +554,7 @@ bounding boxes. for i in range(batch_size): x_position, y_position = coordinates[i] - position_param = (cond_pooled, int((height // 8) * size_multiplier[i]), int((width // 8) * size_multiplier[i]), y_position // 8, x_position // 8) + position_param = (cond_pooled, int((height // 8) * size_multiplier[i]), int((width // 8) * size_multiplier[i]), (y_position - height / 2) // 8, (x_position - width / 2) // 8) position_params_batch[i].append(position_param) # Append position_param to the correct sublist prev = [] @@ -570,11 +570,11 @@ bounding boxes. image_height = latents['samples'].shape[-2] * 8 image_width = latents['samples'].shape[-1] * 8 - plot_image_tensor = self.plot_coordinates_to_tensor(coordinates, image_height, image_width, height, size_multiplier, text) + plot_image_tensor = self.plot_coordinates_to_tensor(coordinates, image_height, image_width, height, width, size_multiplier, text) return (c, plot_image_tensor,) - def plot_coordinates_to_tensor(self, coordinates, height, width, bbox_height, size_multiplier, prompt): + def plot_coordinates_to_tensor(self, coordinates, height, width, bbox_height, bbox_width, size_multiplier, prompt): import matplotlib matplotlib.use('Agg') from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas @@ -606,8 +606,9 @@ bounding boxes. for i, ((x, y), size) in enumerate(zip(coordinates, size_multiplier)): color_index = i / (len(coordinates) - 1) color = cmap(color_index) - box_size = bbox_height * size - rect = matplotlib.patches.Rectangle((x - box_size/2, y - box_size/2), box_size, box_size, + draw_height = bbox_height * size + draw_width = bbox_width * size + rect = matplotlib.patches.Rectangle((x - draw_width/2, y - draw_height/2), draw_width, draw_height, linewidth=1, edgecolor=color, facecolor='none', alpha=0.5) ax.add_patch(rect) @@ -620,7 +621,7 @@ bounding boxes. linestyle="-", lw=1, color=color, - mutation_scale=10)) + mutation_scale=20)) canvas.draw() image_np = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape(int(height), int(width), 3).copy() image_tensor = torch.from_numpy(image_np).float() / 255.0