diff --git a/nodes.py b/nodes.py index 50a68fd..db6af0b 100644 --- a/nodes.py +++ b/nodes.py @@ -1272,7 +1272,7 @@ class ImageBatchTestPattern: #based on nodes from mtb https://github.com/melMass/comfy_mtb from .utility import tensor2pil, pil2tensor, tensor2np, np2tensor -from torchvision.transforms import Resize +from torchvision.transforms import Resize, CenterCrop class BatchCropFromMask: @@ -1381,10 +1381,15 @@ class BatchCropFromMask: cropped_img = img[min_y:max_y, min_x:max_x, :] # Resize the cropped image to a fixed size - resize_transform = Resize((self.max_bbox_size, self.max_bbox_size)) - resized_img = resize_transform(cropped_img.permute(2, 0, 1)).permute(1, 2, 0) - - cropped_images.append(resized_img) + new_size = max(cropped_img.shape[0], cropped_img.shape[1]) + resize_transform = Resize(new_size) + resized_img = resize_transform(cropped_img.permute(2, 0, 1)) + + # Perform the center crop to the desired size + crop_transform = CenterCrop((self.max_bbox_size, self.max_bbox_size)) + cropped_resized_img = crop_transform(resized_img) + + cropped_images.append(cropped_resized_img.permute(1, 2, 0)) cropped_out = torch.stack(cropped_images, dim=0) @@ -1491,6 +1496,75 @@ class BatchUncrop: return (pil2tensor(out_images),) +from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation + +class BatchCLIPSeg: + + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(s): + + return {"required": + { + "images": ("IMAGE",), + "text": ("STRING", {"multiline": False}), + "threshold": ("FLOAT", {"default": 0.4,"min": 0.0, "max": 10.0, "step": 0.01}), + "binary_mask": ("BOOLEAN", {"default": True}), + }, + } + + CATEGORY = "KJNodes/masking" + RETURN_TYPES = ("MASK",) + RETURN_NAMES = ("Mask",) + + FUNCTION = "segment_image" + + def segment_image(self, images, text, threshold, binary_mask): + + out = [] + height, width, _ = images[0].shape + print(height) + print(width) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") + model.to(device) # Ensure the model is on the correct device + images = images.to(device) + processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") + + for image in images: + image = (image* 255).type(torch.uint8) + prompt = text + input_prc = processor(text=prompt, images=image, padding="max_length", return_tensors="pt") + # Move the processed input to the device + for key in input_prc: + input_prc[key] = input_prc[key].to(device) + + outputs = model(**input_prc) + tensor = torch.sigmoid(outputs[0]) + + tensor_thresholded = torch.where(tensor > threshold, tensor, torch.tensor(0, dtype=torch.float)) + tensor_normalized = (tensor_thresholded - tensor_thresholded.min()) / (tensor_thresholded.max() - tensor_thresholded.min()) + + tensor = tensor_normalized + + # Add extra dimensions to the mask for batch and channel + tensor = tensor[None, None, :, :] + + # Resize the mask + resized_tensor = F.interpolate(tensor, size=(height, width), mode='bilinear', align_corners=False) + + # Remove the extra dimensions + resized_tensor = resized_tensor[0, 0, :, :] + + out.append(resized_tensor) + + results = torch.stack(out).cpu() + if binary_mask: + results = results.round() + + return results, NODE_CLASS_MAPPINGS = { "INTConstant": INTConstant, @@ -1522,6 +1596,7 @@ NODE_CLASS_MAPPINGS = { "ReplaceImagesInBatch": ReplaceImagesInBatch, "BatchCropFromMask": BatchCropFromMask, "BatchUncrop": BatchUncrop, + "BatchCLIPSeg": BatchCLIPSeg, } NODE_DISPLAY_NAME_MAPPINGS = { "INTConstant": "INT Constant", @@ -1552,4 +1627,5 @@ NODE_DISPLAY_NAME_MAPPINGS = { "ReplaceImagesInBatch": "ReplaceImagesInBatch", "BatchCropFromMask": "BatchCropFromMask", "BatchUncrop": "BatchUncrop", + "BatchCLIPSeg": "BatchCLIPSeg", } \ No newline at end of file