Fix GLIGEN bbox origin

This commit is contained in:
kijai 2024-05-02 23:54:12 +03:00
parent 90f639c6f5
commit 9fa6a26689

View File

@ -554,7 +554,7 @@ bounding boxes.
for i in range(batch_size):
x_position, y_position = coordinates[i]
position_param = (cond_pooled, int((height // 8) * size_multiplier[i]), int((width // 8) * size_multiplier[i]), y_position // 8, x_position // 8)
position_param = (cond_pooled, int((height // 8) * size_multiplier[i]), int((width // 8) * size_multiplier[i]), (y_position - height / 2) // 8, (x_position - width / 2) // 8)
position_params_batch[i].append(position_param) # Append position_param to the correct sublist
prev = []
@ -570,11 +570,11 @@ bounding boxes.
image_height = latents['samples'].shape[-2] * 8
image_width = latents['samples'].shape[-1] * 8
plot_image_tensor = self.plot_coordinates_to_tensor(coordinates, image_height, image_width, height, size_multiplier, text)
plot_image_tensor = self.plot_coordinates_to_tensor(coordinates, image_height, image_width, height, width, size_multiplier, text)
return (c, plot_image_tensor,)
def plot_coordinates_to_tensor(self, coordinates, height, width, bbox_height, size_multiplier, prompt):
def plot_coordinates_to_tensor(self, coordinates, height, width, bbox_height, bbox_width, size_multiplier, prompt):
import matplotlib
matplotlib.use('Agg')
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
@ -606,8 +606,9 @@ bounding boxes.
for i, ((x, y), size) in enumerate(zip(coordinates, size_multiplier)):
color_index = i / (len(coordinates) - 1)
color = cmap(color_index)
box_size = bbox_height * size
rect = matplotlib.patches.Rectangle((x - box_size/2, y - box_size/2), box_size, box_size,
draw_height = bbox_height * size
draw_width = bbox_width * size
rect = matplotlib.patches.Rectangle((x - draw_width/2, y - draw_height/2), draw_width, draw_height,
linewidth=1, edgecolor=color, facecolor='none', alpha=0.5)
ax.add_patch(rect)
@ -620,7 +621,7 @@ bounding boxes.
linestyle="-",
lw=1,
color=color,
mutation_scale=10))
mutation_scale=20))
canvas.draw()
image_np = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape(int(height), int(width), 3).copy()
image_tensor = torch.from_numpy(image_np).float() / 255.0