diff --git a/nodes.py b/nodes.py index 34750ba..0141122 100644 --- a/nodes.py +++ b/nodes.py @@ -3314,7 +3314,56 @@ class ImageTransformByNormalizedAmplitude: return (transformed_batch,) - +class GLIGENTextBoxApplyBatch: + @classmethod + def INPUT_TYPES(s): + return {"required": {"conditioning_to": ("CONDITIONING", ), + "latents": ("LATENT", ), + "clip": ("CLIP", ), + "gligen_textbox_model": ("GLIGEN", ), + "text": ("STRING", {"multiline": True}), + "width": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}), + "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + "x_increment": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION, "step": 8}), + "y_increment": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION, "step": 8}), + }} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "append" + + CATEGORY = "conditioning/gligen" + + def append(self, latents, conditioning_to, clip, gligen_textbox_model, text, width, height, x, y, x_increment, y_increment): + batch_size = sum(tensor.size(0) for tensor in latents.values()) + print(batch_size) + c = [] + cond, cond_pooled = clip.encode_from_tokens(clip.tokenize(text), return_pooled=True) + + for t in conditioning_to: + n = [t[0], t[1].copy()] + + position_params_batch = [[] for _ in range(batch_size)] # Initialize a list of empty lists for each batch item + + for i in range(batch_size): + x_position = x + i * x_increment + y_position = y + i * y_increment + position_param = (cond_pooled, height // 8, width // 8, y_position // 8, x_position // 8) + position_params_batch[i].append(position_param) # Append position_param to the correct sublist + + prev = [] + if "gligen" in n[1]: + prev = n[1]['gligen'][2] + else: + prev = [[] for _ in range(batch_size)] + # Concatenate prev and position_params_batch, ensuring both are lists of lists + # and each sublist corresponds to a batch item + combined_position_params = [prev_item + batch_item for prev_item, batch_item in zip(prev, position_params_batch)] + n[1]['gligen'] = ("position", gligen_textbox_model, combined_position_params) + c.append(n) + + return (c, ) + NODE_CLASS_MAPPINGS = { "INTConstant": INTConstant, "FloatConstant": FloatConstant, @@ -3376,7 +3425,8 @@ NODE_CLASS_MAPPINGS = { "OffsetMaskByNormalizedAmplitude": OffsetMaskByNormalizedAmplitude, "ImageTransformByNormalizedAmplitude": ImageTransformByNormalizedAmplitude, "GetLatentsFromBatchIndexed": GetLatentsFromBatchIndexed, - "StringConstant": StringConstant + "StringConstant": StringConstant, + "GLIGENTextBoxApplyBatch": GLIGENTextBoxApplyBatch } NODE_DISPLAY_NAME_MAPPINGS = { "INTConstant": "INT Constant", @@ -3438,5 +3488,6 @@ NODE_DISPLAY_NAME_MAPPINGS = { "OffsetMaskByNormalizedAmplitude": "OffsetMaskByNormalizedAmplitude", "ImageTransformByNormalizedAmplitude": "ImageTransformByNormalizedAmplitude", "GetLatentsFromBatchIndexed": "GetLatentsFromBatchIndexed", - "StringConstant": "StringConstant" + "StringConstant": "StringConstant", + "GLIGENTextBoxApplyBatch": "GLIGENTextBoxApplyBatch" } \ No newline at end of file