Add BatchCLIPSeg

supports batches and uses cuda
This commit is contained in:
kijai 2023-11-09 20:22:32 +02:00
parent 9af6f33160
commit 74b54f41a5

View File

@ -1272,7 +1272,7 @@ class ImageBatchTestPattern:
#based on nodes from mtb https://github.com/melMass/comfy_mtb
from .utility import tensor2pil, pil2tensor, tensor2np, np2tensor
from torchvision.transforms import Resize
from torchvision.transforms import Resize, CenterCrop
class BatchCropFromMask:
@ -1381,10 +1381,15 @@ class BatchCropFromMask:
cropped_img = img[min_y:max_y, min_x:max_x, :]
# Resize the cropped image to a fixed size
resize_transform = Resize((self.max_bbox_size, self.max_bbox_size))
resized_img = resize_transform(cropped_img.permute(2, 0, 1)).permute(1, 2, 0)
cropped_images.append(resized_img)
new_size = max(cropped_img.shape[0], cropped_img.shape[1])
resize_transform = Resize(new_size)
resized_img = resize_transform(cropped_img.permute(2, 0, 1))
# Perform the center crop to the desired size
crop_transform = CenterCrop((self.max_bbox_size, self.max_bbox_size))
cropped_resized_img = crop_transform(resized_img)
cropped_images.append(cropped_resized_img.permute(1, 2, 0))
cropped_out = torch.stack(cropped_images, dim=0)
@ -1491,6 +1496,75 @@ class BatchUncrop:
return (pil2tensor(out_images),)
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
class BatchCLIPSeg:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(s):
return {"required":
{
"images": ("IMAGE",),
"text": ("STRING", {"multiline": False}),
"threshold": ("FLOAT", {"default": 0.4,"min": 0.0, "max": 10.0, "step": 0.01}),
"binary_mask": ("BOOLEAN", {"default": True}),
},
}
CATEGORY = "KJNodes/masking"
RETURN_TYPES = ("MASK",)
RETURN_NAMES = ("Mask",)
FUNCTION = "segment_image"
def segment_image(self, images, text, threshold, binary_mask):
out = []
height, width, _ = images[0].shape
print(height)
print(width)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
model.to(device) # Ensure the model is on the correct device
images = images.to(device)
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
for image in images:
image = (image* 255).type(torch.uint8)
prompt = text
input_prc = processor(text=prompt, images=image, padding="max_length", return_tensors="pt")
# Move the processed input to the device
for key in input_prc:
input_prc[key] = input_prc[key].to(device)
outputs = model(**input_prc)
tensor = torch.sigmoid(outputs[0])
tensor_thresholded = torch.where(tensor > threshold, tensor, torch.tensor(0, dtype=torch.float))
tensor_normalized = (tensor_thresholded - tensor_thresholded.min()) / (tensor_thresholded.max() - tensor_thresholded.min())
tensor = tensor_normalized
# Add extra dimensions to the mask for batch and channel
tensor = tensor[None, None, :, :]
# Resize the mask
resized_tensor = F.interpolate(tensor, size=(height, width), mode='bilinear', align_corners=False)
# Remove the extra dimensions
resized_tensor = resized_tensor[0, 0, :, :]
out.append(resized_tensor)
results = torch.stack(out).cpu()
if binary_mask:
results = results.round()
return results,
NODE_CLASS_MAPPINGS = {
"INTConstant": INTConstant,
@ -1522,6 +1596,7 @@ NODE_CLASS_MAPPINGS = {
"ReplaceImagesInBatch": ReplaceImagesInBatch,
"BatchCropFromMask": BatchCropFromMask,
"BatchUncrop": BatchUncrop,
"BatchCLIPSeg": BatchCLIPSeg,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"INTConstant": "INT Constant",
@ -1552,4 +1627,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"ReplaceImagesInBatch": "ReplaceImagesInBatch",
"BatchCropFromMask": "BatchCropFromMask",
"BatchUncrop": "BatchUncrop",
"BatchCLIPSeg": "BatchCLIPSeg",
}