From 58c4c755f7caafb6a335eb410f1e60345618a1aa Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Wed, 31 Jan 2024 22:32:58 +0200 Subject: [PATCH] Update nodes.py --- nodes.py | 122 +++++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 106 insertions(+), 16 deletions(-) diff --git a/nodes.py b/nodes.py index 0141122..e3e5e35 100644 --- a/nodes.py +++ b/nodes.py @@ -842,6 +842,25 @@ class ConditioningMultiCombine: cond = cond_combine_node.combine(new_cond, cond)[0] return (cond, inputcount,) +class CondPassThrough: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + }, + + } + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING",) + RETURN_NAMES = ("positive", "negative") + FUNCTION = "passthrough" + CATEGORY = "KJNodes/misc" + + def passthrough(self, positive, negative): + return (positive, negative,) + def append_helper(t, mask, c, set_area_to_bounds, strength): n = [t[0], t[1].copy()] _, h, w = mask.shape @@ -3313,7 +3332,74 @@ class ImageTransformByNormalizedAmplitude: transformed_batch = torch.stack(transformed_images) return (transformed_batch,) - + +def parse_coordinates(coordinates_str): + coordinates = {} + pattern = r'(\d+):\((\d+),(\d+)\)' + matches = re.findall(pattern, coordinates_str) + for match in matches: + index, x, y = map(int, match) + coordinates[index] = (x, y) + return coordinates + +def interpolate_coordinates(coordinates_dict, batch_size): + sorted_coords = sorted(coordinates_dict.items()) + interpolated = {} + + for i, ((index1, (x1, y1)), (index2, (x2, y2))) in enumerate(zip(sorted_coords, sorted_coords[1:])): + distance = index2 - index1 + x_step = (x2 - x1) / distance + y_step = (y2 - y1) / distance + + for j in range(distance): + interpolated_x = round(x1 + j * x_step) + interpolated_y = round(y1 + j * y_step) + interpolated[index1 + j] = (interpolated_x, interpolated_y) + interpolated[sorted_coords[-1][0]] = sorted_coords[-1][1] + + # Ensure we have coordinates for all indices in the batch + last_index, last_coords = sorted_coords[-1] + for i in range(last_index + 1, batch_size): + interpolated[i] = last_coords + + return interpolated + +def plot_to_tensor(coordinates_dict, interpolated_dict, height, width, box_size): + from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas + import matplotlib.patches as patches + + original_x, original_y = zip(*coordinates_dict.values()) + interpolated_x, interpolated_y = zip(*interpolated_dict.values()) + + fig, ax = plt.subplots(figsize=(width/100, height/100), dpi=100) + ax.scatter(original_x, original_y, color='blue', label='Original Points') + ax.scatter(interpolated_x, interpolated_y, color='red', alpha=0.5, label='Interpolated Points') + ax.plot(interpolated_x, interpolated_y, color='grey', linestyle='--', linewidth=0.5) + # Draw a box at each interpolated coordinate + for x, y in interpolated_dict.values(): + rect = patches.Rectangle((x - box_size/2, y - box_size/2), box_size, box_size, + linewidth=1, edgecolor='green', facecolor='none') + ax.add_patch(rect) + ax.set_title('Interpolated Coordinates') + ax.set_xlabel('X Coordinate') + ax.set_ylabel('Y Coordinate') + ax.legend() + ax.set_xlim(0, width) # Set the x-axis to match the input latent width + ax.set_ylim(height, 0) # Set the y-axis to match the input latent height, with (0,0) at top-left + + canvas = FigureCanvas(fig) + canvas.draw() + + width, height = fig.get_size_inches() * fig.get_dpi() + image_np = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape(int(height), int(width), 3) + + image_tensor = torch.from_numpy(image_np).float() / 255.0 + image_tensor = image_tensor.unsqueeze(0) + + plt.close(fig) + + return image_tensor + class GLIGENTextBoxApplyBatch: @classmethod def INPUT_TYPES(s): @@ -3324,33 +3410,34 @@ class GLIGENTextBoxApplyBatch: "text": ("STRING", {"multiline": True}), "width": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}), "height": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}), - "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), - "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), - "x_increment": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION, "step": 8}), - "y_increment": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION, "step": 8}), + "coordinates": ("STRING", {"multiline": True}), }} - RETURN_TYPES = ("CONDITIONING",) + RETURN_TYPES = ("CONDITIONING", "IMAGE",) FUNCTION = "append" CATEGORY = "conditioning/gligen" - def append(self, latents, conditioning_to, clip, gligen_textbox_model, text, width, height, x, y, x_increment, y_increment): + + + def append(self, latents, conditioning_to, clip, gligen_textbox_model, text, width, height, coordinates): + + coordinates_dict = parse_coordinates(coordinates) batch_size = sum(tensor.size(0) for tensor in latents.values()) - print(batch_size) c = [] cond, cond_pooled = clip.encode_from_tokens(clip.tokenize(text), return_pooled=True) - + + # Interpolate coordinates for the entire batch + interpolated_coords = interpolate_coordinates(coordinates_dict, batch_size) + plot_image_tensor = plot_to_tensor(coordinates_dict, interpolated_coords, 512, 512, height) for t in conditioning_to: n = [t[0], t[1].copy()] position_params_batch = [[] for _ in range(batch_size)] # Initialize a list of empty lists for each batch item for i in range(batch_size): - x_position = x + i * x_increment - y_position = y + i * y_increment + x_position, y_position = interpolated_coords[i] position_param = (cond_pooled, height // 8, width // 8, y_position // 8, x_position // 8) position_params_batch[i].append(position_param) # Append position_param to the correct sublist - prev = [] if "gligen" in n[1]: prev = n[1]['gligen'][2] @@ -3362,8 +3449,9 @@ class GLIGENTextBoxApplyBatch: n[1]['gligen'] = ("position", gligen_textbox_model, combined_position_params) c.append(n) - return (c, ) - + return (c, plot_image_tensor,) + + NODE_CLASS_MAPPINGS = { "INTConstant": INTConstant, "FloatConstant": FloatConstant, @@ -3426,7 +3514,8 @@ NODE_CLASS_MAPPINGS = { "ImageTransformByNormalizedAmplitude": ImageTransformByNormalizedAmplitude, "GetLatentsFromBatchIndexed": GetLatentsFromBatchIndexed, "StringConstant": StringConstant, - "GLIGENTextBoxApplyBatch": GLIGENTextBoxApplyBatch + "GLIGENTextBoxApplyBatch": GLIGENTextBoxApplyBatch, + "CondPassThrough": CondPassThrough } NODE_DISPLAY_NAME_MAPPINGS = { "INTConstant": "INT Constant", @@ -3489,5 +3578,6 @@ NODE_DISPLAY_NAME_MAPPINGS = { "ImageTransformByNormalizedAmplitude": "ImageTransformByNormalizedAmplitude", "GetLatentsFromBatchIndexed": "GetLatentsFromBatchIndexed", "StringConstant": "StringConstant", - "GLIGENTextBoxApplyBatch": "GLIGENTextBoxApplyBatch" + "GLIGENTextBoxApplyBatch": "GLIGENTextBoxApplyBatch", + "CondPassThrough": "CondPassThrough" } \ No newline at end of file