mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-05-28 07:37:05 +08:00
Add BatchCLIPSeg
supports batches and uses cuda
This commit is contained in:
parent
9af6f33160
commit
74b54f41a5
84
nodes.py
84
nodes.py
@ -1272,7 +1272,7 @@ class ImageBatchTestPattern:
|
|||||||
#based on nodes from mtb https://github.com/melMass/comfy_mtb
|
#based on nodes from mtb https://github.com/melMass/comfy_mtb
|
||||||
|
|
||||||
from .utility import tensor2pil, pil2tensor, tensor2np, np2tensor
|
from .utility import tensor2pil, pil2tensor, tensor2np, np2tensor
|
||||||
from torchvision.transforms import Resize
|
from torchvision.transforms import Resize, CenterCrop
|
||||||
|
|
||||||
class BatchCropFromMask:
|
class BatchCropFromMask:
|
||||||
|
|
||||||
@ -1381,10 +1381,15 @@ class BatchCropFromMask:
|
|||||||
cropped_img = img[min_y:max_y, min_x:max_x, :]
|
cropped_img = img[min_y:max_y, min_x:max_x, :]
|
||||||
|
|
||||||
# Resize the cropped image to a fixed size
|
# Resize the cropped image to a fixed size
|
||||||
resize_transform = Resize((self.max_bbox_size, self.max_bbox_size))
|
new_size = max(cropped_img.shape[0], cropped_img.shape[1])
|
||||||
resized_img = resize_transform(cropped_img.permute(2, 0, 1)).permute(1, 2, 0)
|
resize_transform = Resize(new_size)
|
||||||
|
resized_img = resize_transform(cropped_img.permute(2, 0, 1))
|
||||||
|
|
||||||
cropped_images.append(resized_img)
|
# Perform the center crop to the desired size
|
||||||
|
crop_transform = CenterCrop((self.max_bbox_size, self.max_bbox_size))
|
||||||
|
cropped_resized_img = crop_transform(resized_img)
|
||||||
|
|
||||||
|
cropped_images.append(cropped_resized_img.permute(1, 2, 0))
|
||||||
|
|
||||||
cropped_out = torch.stack(cropped_images, dim=0)
|
cropped_out = torch.stack(cropped_images, dim=0)
|
||||||
|
|
||||||
@ -1491,6 +1496,75 @@ class BatchUncrop:
|
|||||||
|
|
||||||
return (pil2tensor(out_images),)
|
return (pil2tensor(out_images),)
|
||||||
|
|
||||||
|
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
|
||||||
|
|
||||||
|
class BatchCLIPSeg:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
|
||||||
|
return {"required":
|
||||||
|
{
|
||||||
|
"images": ("IMAGE",),
|
||||||
|
"text": ("STRING", {"multiline": False}),
|
||||||
|
"threshold": ("FLOAT", {"default": 0.4,"min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
"binary_mask": ("BOOLEAN", {"default": True}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
CATEGORY = "KJNodes/masking"
|
||||||
|
RETURN_TYPES = ("MASK",)
|
||||||
|
RETURN_NAMES = ("Mask",)
|
||||||
|
|
||||||
|
FUNCTION = "segment_image"
|
||||||
|
|
||||||
|
def segment_image(self, images, text, threshold, binary_mask):
|
||||||
|
|
||||||
|
out = []
|
||||||
|
height, width, _ = images[0].shape
|
||||||
|
print(height)
|
||||||
|
print(width)
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
|
||||||
|
model.to(device) # Ensure the model is on the correct device
|
||||||
|
images = images.to(device)
|
||||||
|
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
|
||||||
|
|
||||||
|
for image in images:
|
||||||
|
image = (image* 255).type(torch.uint8)
|
||||||
|
prompt = text
|
||||||
|
input_prc = processor(text=prompt, images=image, padding="max_length", return_tensors="pt")
|
||||||
|
# Move the processed input to the device
|
||||||
|
for key in input_prc:
|
||||||
|
input_prc[key] = input_prc[key].to(device)
|
||||||
|
|
||||||
|
outputs = model(**input_prc)
|
||||||
|
tensor = torch.sigmoid(outputs[0])
|
||||||
|
|
||||||
|
tensor_thresholded = torch.where(tensor > threshold, tensor, torch.tensor(0, dtype=torch.float))
|
||||||
|
tensor_normalized = (tensor_thresholded - tensor_thresholded.min()) / (tensor_thresholded.max() - tensor_thresholded.min())
|
||||||
|
|
||||||
|
tensor = tensor_normalized
|
||||||
|
|
||||||
|
# Add extra dimensions to the mask for batch and channel
|
||||||
|
tensor = tensor[None, None, :, :]
|
||||||
|
|
||||||
|
# Resize the mask
|
||||||
|
resized_tensor = F.interpolate(tensor, size=(height, width), mode='bilinear', align_corners=False)
|
||||||
|
|
||||||
|
# Remove the extra dimensions
|
||||||
|
resized_tensor = resized_tensor[0, 0, :, :]
|
||||||
|
|
||||||
|
out.append(resized_tensor)
|
||||||
|
|
||||||
|
results = torch.stack(out).cpu()
|
||||||
|
if binary_mask:
|
||||||
|
results = results.round()
|
||||||
|
|
||||||
|
return results,
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"INTConstant": INTConstant,
|
"INTConstant": INTConstant,
|
||||||
@ -1522,6 +1596,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"ReplaceImagesInBatch": ReplaceImagesInBatch,
|
"ReplaceImagesInBatch": ReplaceImagesInBatch,
|
||||||
"BatchCropFromMask": BatchCropFromMask,
|
"BatchCropFromMask": BatchCropFromMask,
|
||||||
"BatchUncrop": BatchUncrop,
|
"BatchUncrop": BatchUncrop,
|
||||||
|
"BatchCLIPSeg": BatchCLIPSeg,
|
||||||
}
|
}
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"INTConstant": "INT Constant",
|
"INTConstant": "INT Constant",
|
||||||
@ -1552,4 +1627,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"ReplaceImagesInBatch": "ReplaceImagesInBatch",
|
"ReplaceImagesInBatch": "ReplaceImagesInBatch",
|
||||||
"BatchCropFromMask": "BatchCropFromMask",
|
"BatchCropFromMask": "BatchCropFromMask",
|
||||||
"BatchUncrop": "BatchUncrop",
|
"BatchUncrop": "BatchUncrop",
|
||||||
|
"BatchCLIPSeg": "BatchCLIPSeg",
|
||||||
}
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user