mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-01-01 19:15:17 +08:00
Add BatchCLIPSeg
supports batches and uses cuda
This commit is contained in:
parent
9af6f33160
commit
74b54f41a5
86
nodes.py
86
nodes.py
@ -1272,7 +1272,7 @@ class ImageBatchTestPattern:
|
||||
#based on nodes from mtb https://github.com/melMass/comfy_mtb
|
||||
|
||||
from .utility import tensor2pil, pil2tensor, tensor2np, np2tensor
|
||||
from torchvision.transforms import Resize
|
||||
from torchvision.transforms import Resize, CenterCrop
|
||||
|
||||
class BatchCropFromMask:
|
||||
|
||||
@ -1381,10 +1381,15 @@ class BatchCropFromMask:
|
||||
cropped_img = img[min_y:max_y, min_x:max_x, :]
|
||||
|
||||
# Resize the cropped image to a fixed size
|
||||
resize_transform = Resize((self.max_bbox_size, self.max_bbox_size))
|
||||
resized_img = resize_transform(cropped_img.permute(2, 0, 1)).permute(1, 2, 0)
|
||||
|
||||
cropped_images.append(resized_img)
|
||||
new_size = max(cropped_img.shape[0], cropped_img.shape[1])
|
||||
resize_transform = Resize(new_size)
|
||||
resized_img = resize_transform(cropped_img.permute(2, 0, 1))
|
||||
|
||||
# Perform the center crop to the desired size
|
||||
crop_transform = CenterCrop((self.max_bbox_size, self.max_bbox_size))
|
||||
cropped_resized_img = crop_transform(resized_img)
|
||||
|
||||
cropped_images.append(cropped_resized_img.permute(1, 2, 0))
|
||||
|
||||
cropped_out = torch.stack(cropped_images, dim=0)
|
||||
|
||||
@ -1491,6 +1496,75 @@ class BatchUncrop:
|
||||
|
||||
return (pil2tensor(out_images),)
|
||||
|
||||
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
|
||||
|
||||
class BatchCLIPSeg:
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
|
||||
return {"required":
|
||||
{
|
||||
"images": ("IMAGE",),
|
||||
"text": ("STRING", {"multiline": False}),
|
||||
"threshold": ("FLOAT", {"default": 0.4,"min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"binary_mask": ("BOOLEAN", {"default": True}),
|
||||
},
|
||||
}
|
||||
|
||||
CATEGORY = "KJNodes/masking"
|
||||
RETURN_TYPES = ("MASK",)
|
||||
RETURN_NAMES = ("Mask",)
|
||||
|
||||
FUNCTION = "segment_image"
|
||||
|
||||
def segment_image(self, images, text, threshold, binary_mask):
|
||||
|
||||
out = []
|
||||
height, width, _ = images[0].shape
|
||||
print(height)
|
||||
print(width)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
|
||||
model.to(device) # Ensure the model is on the correct device
|
||||
images = images.to(device)
|
||||
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
|
||||
|
||||
for image in images:
|
||||
image = (image* 255).type(torch.uint8)
|
||||
prompt = text
|
||||
input_prc = processor(text=prompt, images=image, padding="max_length", return_tensors="pt")
|
||||
# Move the processed input to the device
|
||||
for key in input_prc:
|
||||
input_prc[key] = input_prc[key].to(device)
|
||||
|
||||
outputs = model(**input_prc)
|
||||
tensor = torch.sigmoid(outputs[0])
|
||||
|
||||
tensor_thresholded = torch.where(tensor > threshold, tensor, torch.tensor(0, dtype=torch.float))
|
||||
tensor_normalized = (tensor_thresholded - tensor_thresholded.min()) / (tensor_thresholded.max() - tensor_thresholded.min())
|
||||
|
||||
tensor = tensor_normalized
|
||||
|
||||
# Add extra dimensions to the mask for batch and channel
|
||||
tensor = tensor[None, None, :, :]
|
||||
|
||||
# Resize the mask
|
||||
resized_tensor = F.interpolate(tensor, size=(height, width), mode='bilinear', align_corners=False)
|
||||
|
||||
# Remove the extra dimensions
|
||||
resized_tensor = resized_tensor[0, 0, :, :]
|
||||
|
||||
out.append(resized_tensor)
|
||||
|
||||
results = torch.stack(out).cpu()
|
||||
if binary_mask:
|
||||
results = results.round()
|
||||
|
||||
return results,
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"INTConstant": INTConstant,
|
||||
@ -1522,6 +1596,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"ReplaceImagesInBatch": ReplaceImagesInBatch,
|
||||
"BatchCropFromMask": BatchCropFromMask,
|
||||
"BatchUncrop": BatchUncrop,
|
||||
"BatchCLIPSeg": BatchCLIPSeg,
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"INTConstant": "INT Constant",
|
||||
@ -1552,4 +1627,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"ReplaceImagesInBatch": "ReplaceImagesInBatch",
|
||||
"BatchCropFromMask": "BatchCropFromMask",
|
||||
"BatchUncrop": "BatchUncrop",
|
||||
"BatchCLIPSeg": "BatchCLIPSeg",
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user