diff --git a/__init__.py b/__init__.py index 09f0957..dfbf668 100644 --- a/__init__.py +++ b/__init__.py @@ -102,6 +102,7 @@ NODE_CLASS_MAPPINGS = { "LoadResAdapterNormalization": LoadResAdapterNormalization, "Superprompt": Superprompt, "GLIGENTextBoxApplyBatch": GLIGENTextBoxApplyBatch, + "GLIGENTextBoxApplyBatchCoords": GLIGENTextBoxApplyBatchCoords, "Intrinsic_lora_sampling": Intrinsic_lora_sampling, } diff --git a/nodes/curve_nodes.py b/nodes/curve_nodes.py index f673081..8cac723 100644 --- a/nodes/curve_nodes.py +++ b/nodes/curve_nodes.py @@ -55,7 +55,8 @@ class SplineEditor: } } - RETURN_TYPES = ("MASK", "STRING", "FLOAT") + RETURN_TYPES = ("MASK", "STRING", "FLOAT", "INT") + RETURN_NAMES = ("mask", "string", "float", "count") FUNCTION = "splinedata" CATEGORY = "KJNodes/experimental" DESCRIPTION = """ @@ -126,7 +127,7 @@ output types: masks_out = torch.stack(mask_tensors) masks_out = masks_out.repeat(repeat_output, 1, 1, 1) masks_out = masks_out.mean(dim=-1) - return (masks_out, str(coordinates), out_floats,) + return (masks_out, str(coordinates), out_floats, len(out_floats)) class CreateShapeMaskOnPath: @@ -485,4 +486,115 @@ Creates a sigmas tensor from list of float values. """ def customsigmas(self, float_list): - return torch.tensor(float_list, dtype=torch.float32), \ No newline at end of file + return torch.tensor(float_list, dtype=torch.float32), + +class GLIGENTextBoxApplyBatchCoords: + @classmethod + def INPUT_TYPES(s): + return {"required": {"conditioning_to": ("CONDITIONING", ), + "latents": ("LATENT", ), + "clip": ("CLIP", ), + "gligen_textbox_model": ("GLIGEN", ), + "coordinates": ("STRING", {"forceInput": True}), + "text": ("STRING", {"multiline": True}), + "width": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 8}), + "height": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 8}), + }, + } + RETURN_TYPES = ("CONDITIONING", "IMAGE", ) + FUNCTION = "append" + CATEGORY = "KJNodes/experimental" + DESCRIPTION = """ +Experimental, does not function yet as ComfyUI base changes are needed +""" + + def append(self, latents, coordinates, conditioning_to, clip, gligen_textbox_model, text, width, height): + coordinates = json.loads(coordinates.replace("'", '"')) + coordinates = [(coord['x'], coord['y']) for coord in coordinates] + + batch_size = sum(tensor.size(0) for tensor in latents.values()) + assert len(coordinates) == batch_size, "The number of coordinates does not match the number of latents" + c = [] + cond, cond_pooled = clip.encode_from_tokens(clip.tokenize(text), return_pooled=True) + + image_height = latents['samples'].shape[-1] * 8 + image_width = latents['samples'].shape[-2] * 8 + plot_image_tensor = self.plot_coordinates_to_tensor(coordinates, image_height, image_width, height) + + for t in conditioning_to: + n = [t[0], t[1].copy()] + + position_params_batch = [[] for _ in range(batch_size)] # Initialize a list of empty lists for each batch item + + for i in range(batch_size): + x_position, y_position = coordinates[i] + position_param = (cond_pooled, height // 8, width // 8, y_position // 8, x_position // 8) + position_params_batch[i].append(position_param) # Append position_param to the correct sublist + + prev = [] + if "gligen" in n[1]: + prev = n[1]['gligen'][2] + else: + prev = [[] for _ in range(batch_size)] + # Concatenate prev and position_params_batch, ensuring both are lists of lists + # and each sublist corresponds to a batch item + combined_position_params = [prev_item + batch_item for prev_item, batch_item in zip(prev, position_params_batch)] + n[1]['gligen'] = ("position", gligen_textbox_model, combined_position_params) + c.append(n) + + return (c, plot_image_tensor,) + + def plot_coordinates_to_tensor(self, coordinates, height, width, box_size): + import matplotlib + matplotlib.use('Agg') + from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas + + # Convert coordinates to separate x and y lists + #x_coords, y_coords = zip(*coordinates) + + fig, ax = matplotlib.pyplot.subplots(figsize=(width/100, height/100), dpi=100) + #ax.scatter(x_coords, y_coords, color='yellow', label='_nolegend_') + + # Draw a box at each coordinate + for x, y in coordinates: + rect = matplotlib.patches.Rectangle((x - box_size/2, y - box_size/2), box_size, box_size, + linewidth=1, edgecolor='green', facecolor='none', alpha=0.5) + ax.add_patch(rect) + + # Draw arrows from one point to another to indicate direction + for i in range(len(coordinates) - 1): + x1, y1 = coordinates[i] + x2, y2 = coordinates[i + 1] + ax.annotate("", xy=(x2, y2), xytext=(x1, y1), + arrowprops=dict(arrowstyle="->", + linestyle="-", + lw=1, + color='orange', + mutation_scale=10)) + matplotlib.pyplot.rcParams['text.color'] = '#999999' + fig.patch.set_facecolor('#353535') + ax.set_facecolor('#353535') + ax.grid(color='#999999', linestyle='-', linewidth=0.5) + ax.set_xlabel('x', color='#999999') + ax.set_ylabel('y', color='#999999') + for text in ax.get_xticklabels() + ax.get_yticklabels(): + text.set_color('#999999') + ax.set_title('Gligen positions') + ax.set_xlabel('X Coordinate') + ax.set_ylabel('Y Coordinate') + ax.legend().remove() + ax.set_xlim(0, width) # Set the x-axis to match the input latent width + ax.set_ylim(height, 0) # Set the y-axis to match the input latent height, with (0,0) at top-left + # Adjust the margins of the subplot + matplotlib.pyplot.subplots_adjust(left=0.08, right=0.95, bottom=0.05, top=0.95, wspace=0.2, hspace=0.2) + canvas = FigureCanvas(fig) + canvas.draw() + matplotlib.pyplot.close(fig) + + width, height = fig.get_size_inches() * fig.get_dpi() + + image_np = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape(int(height), int(width), 3) + image_tensor = torch.from_numpy(image_np).float() / 255.0 + image_tensor = image_tensor.unsqueeze(0) + + return image_tensor \ No newline at end of file diff --git a/web/js/spline_editor.js b/web/js/spline_editor.js index 27a5e46..5aa9a0e 100644 --- a/web/js/spline_editor.js +++ b/web/js/spline_editor.js @@ -181,7 +181,7 @@ app.registerExtension({ } }); - this.setSize([550, 900]); + this.setSize([550, 920]); this.resizable = false; this.splineEditor.parentEl = document.createElement("div"); this.splineEditor.parentEl.className = "spline-editor"; @@ -190,7 +190,7 @@ app.registerExtension({ chainCallback(this, "onGraphConfigured", function() { createSplineEditor(this); - this.setSize([550, 900]); + this.setSize([550, 920]); }); }); // onAfterGraphConfigured