mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-01-10 11:04:22 +08:00
Batch clip seg improvements
This commit is contained in:
parent
0ef2b86b28
commit
4812eff6e5
@ -36,6 +36,10 @@ class BatchCLIPSeg:
|
||||
"combine_mask": ("BOOLEAN", {"default": False}),
|
||||
"use_cuda": ("BOOLEAN", {"default": True}),
|
||||
},
|
||||
"optional":
|
||||
{
|
||||
"blur_sigma": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step": 0.1}),
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "KJNodes/masking"
|
||||
@ -46,55 +50,62 @@ class BatchCLIPSeg:
|
||||
Segments an image or batch of images using CLIPSeg.
|
||||
"""
|
||||
|
||||
def segment_image(self, images, text, threshold, binary_mask, combine_mask, use_cuda):
|
||||
def segment_image(self, images, text, threshold, binary_mask, combine_mask, use_cuda, blur_sigma=0.0):
|
||||
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
|
||||
height, width, _ = images[0].shape
|
||||
import torchvision.transforms as transforms
|
||||
offload_device = model_management.unet_offload_device()
|
||||
if use_cuda and torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
device = model_management.get_torch_device()
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = model_management.unet_dtype()
|
||||
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
|
||||
model.to(dtype)
|
||||
model.to(device)
|
||||
images = images.to(device)
|
||||
if not hasattr(self, "model"):
|
||||
self.model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
|
||||
|
||||
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
|
||||
|
||||
self.model.to(dtype).to(device)
|
||||
|
||||
B, H, W, C = images.shape
|
||||
images = images.to(device)
|
||||
|
||||
autocast_condition = (dtype != torch.float32) and not model_management.is_device_mps(device)
|
||||
with torch.autocast(model_management.get_autocast_device(device), dtype=dtype) if autocast_condition else nullcontext():
|
||||
|
||||
images = [Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) for image in images ]
|
||||
prompt = [text] * len(images)
|
||||
input_prc = processor(text=prompt, images=images, return_tensors="pt")
|
||||
# Move the processed input to the device
|
||||
|
||||
for key in input_prc:
|
||||
input_prc[key] = input_prc[key].to(device)
|
||||
outputs = self.model(**input_prc)
|
||||
|
||||
outputs = model(**input_prc)
|
||||
tensor = torch.sigmoid(outputs.logits)
|
||||
tensor = torch.where(tensor > (threshold / 10), tensor, torch.tensor(0, dtype=torch.float))
|
||||
tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min())
|
||||
|
||||
tensor = torch.sigmoid(outputs.logits)
|
||||
tensor_thresholded = torch.where(tensor > threshold, tensor, torch.tensor(0, dtype=torch.float))
|
||||
tensor_normalized = (tensor_thresholded - tensor_thresholded.min()) / (tensor_thresholded.max() - tensor_thresholded.min())
|
||||
tensor = tensor_normalized
|
||||
tensor = F.interpolate(tensor.unsqueeze(1), size=(H, W), mode='nearest')
|
||||
tensor = tensor.squeeze(1)
|
||||
|
||||
# Resize the mask
|
||||
resized_tensor = F.interpolate(tensor.unsqueeze(1), size=(height, width), mode='nearest')
|
||||
|
||||
# Remove the extra dimensions
|
||||
resized_tensor = resized_tensor.squeeze(1)
|
||||
|
||||
results = resized_tensor.cpu().float()
|
||||
self.model.to(offload_device)
|
||||
results = tensor.cpu().float()
|
||||
print(results.min(), results.max())
|
||||
|
||||
if binary_mask:
|
||||
tensor = (tensor > 0).float()
|
||||
if blur_sigma > 0:
|
||||
kernel_size = int(6 * blur_sigma + 1)
|
||||
blur = transforms.GaussianBlur(kernel_size=(kernel_size, kernel_size), sigma=(blur_sigma, blur_sigma))
|
||||
tensor = blur(tensor)
|
||||
|
||||
if combine_mask:
|
||||
combined_results = torch.max(results, dim=0)[0]
|
||||
results = combined_results.unsqueeze(0).repeat(len(images),1,1)
|
||||
tensor = torch.max(tensor, dim=0)[0]
|
||||
tensor = tensor.unsqueeze(0).repeat(len(images),1,1)
|
||||
|
||||
if binary_mask:
|
||||
results = results.round()
|
||||
del outputs
|
||||
model_management.soft_empty_cache()
|
||||
|
||||
del outputs, tensor, tensor_thresholded, tensor_normalized, resized_tensor, images
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return results,
|
||||
return tensor,
|
||||
|
||||
class CreateTextMask:
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user