mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-15 15:54:38 +08:00
Update mask_nodes.py
Reduced the total time for creating masks in a batch by using a handler built into the processor: processor(images=images)
This commit is contained in:
parent
148c805a15
commit
3652e8eee2
@ -46,9 +46,8 @@ class BatchCLIPSeg:
|
|||||||
Segments an image or batch of images using CLIPSeg.
|
Segments an image or batch of images using CLIPSeg.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def segment_image(self, images, text, threshold, binary_mask, combine_mask, use_cuda):
|
def segment_image(self, images, text, threshold, binary_mask, combine_mask, use_cuda):
|
||||||
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
|
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
|
||||||
out = []
|
|
||||||
height, width, _ = images[0].shape
|
height, width, _ = images[0].shape
|
||||||
if use_cuda and torch.cuda.is_available():
|
if use_cuda and torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
@ -60,43 +59,38 @@ Segments an image or batch of images using CLIPSeg.
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
images = images.to(device)
|
images = images.to(device)
|
||||||
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
|
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
|
||||||
pbar = ProgressBar(images.shape[0])
|
|
||||||
autocast_condition = (dtype != torch.float32) and not model_management.is_device_mps(device)
|
autocast_condition = (dtype != torch.float32) and not model_management.is_device_mps(device)
|
||||||
with torch.autocast(model_management.get_autocast_device(device), dtype=dtype) if autocast_condition else nullcontext():
|
with torch.autocast(model_management.get_autocast_device(device), dtype=dtype) if autocast_condition else nullcontext():
|
||||||
for image in images:
|
|
||||||
image = (image* 255).type(torch.uint8)
|
|
||||||
prompt = text
|
|
||||||
input_prc = processor(text=prompt, images=image, return_tensors="pt")
|
|
||||||
# Move the processed input to the device
|
|
||||||
for key in input_prc:
|
|
||||||
input_prc[key] = input_prc[key].to(device)
|
|
||||||
|
|
||||||
outputs = model(**input_prc)
|
|
||||||
|
|
||||||
tensor = torch.sigmoid(outputs[0])
|
|
||||||
tensor_thresholded = torch.where(tensor > threshold, tensor, torch.tensor(0, dtype=torch.float))
|
|
||||||
tensor_normalized = (tensor_thresholded - tensor_thresholded.min()) / (tensor_thresholded.max() - tensor_thresholded.min())
|
|
||||||
tensor = tensor_normalized
|
|
||||||
|
|
||||||
# Resize the mask
|
images = [Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) for image in images ]
|
||||||
if len(tensor.shape) == 3:
|
prompt = [text] * len(images)
|
||||||
tensor = tensor.unsqueeze(0)
|
input_prc = processor(text=prompt, images=images, return_tensors="pt")
|
||||||
resized_tensor = F.interpolate(tensor, size=(height, width), mode='nearest')
|
# Move the processed input to the device
|
||||||
|
for key in input_prc:
|
||||||
|
input_prc[key] = input_prc[key].to(device)
|
||||||
|
|
||||||
|
outputs = model(**input_prc)
|
||||||
|
|
||||||
|
tensor = torch.sigmoid(outputs.logits)
|
||||||
|
tensor_thresholded = torch.where(tensor > threshold, tensor, torch.tensor(0, dtype=torch.float))
|
||||||
|
tensor_normalized = (tensor_thresholded - tensor_thresholded.min()) / (tensor_thresholded.max() - tensor_thresholded.min())
|
||||||
|
tensor = tensor_normalized
|
||||||
|
|
||||||
|
# Resize the mask
|
||||||
|
resized_tensor = F.interpolate(tensor.unsqueeze(1), size=(height, width), mode='nearest')
|
||||||
|
|
||||||
|
# Remove the extra dimensions
|
||||||
|
resized_tensor = resized_tensor.squeeze(1)
|
||||||
|
|
||||||
|
results = resized_tensor.cpu().float()
|
||||||
|
|
||||||
# Remove the extra dimensions
|
|
||||||
resized_tensor = resized_tensor[0, 0, :, :]
|
|
||||||
pbar.update(1)
|
|
||||||
out.append(resized_tensor)
|
|
||||||
|
|
||||||
results = torch.stack(out).cpu().float()
|
|
||||||
|
|
||||||
if combine_mask:
|
if combine_mask:
|
||||||
combined_results = torch.max(results, dim=0)[0]
|
combined_results = torch.max(results, dim=0)[0]
|
||||||
results = combined_results.unsqueeze(0).repeat(len(images),1,1)
|
results = combined_results.unsqueeze(0).repeat(len(images),1,1)
|
||||||
|
|
||||||
if binary_mask:
|
if binary_mask:
|
||||||
results = results.round()
|
results = results.round()
|
||||||
|
|
||||||
return results,
|
return results,
|
||||||
|
|
||||||
class CreateTextMask:
|
class CreateTextMask:
|
||||||
@ -1163,4 +1157,4 @@ Sets new min and max values for the mask.
|
|||||||
# Clamp the values to ensure they are within [0.0, 1.0]
|
# Clamp the values to ensure they are within [0.0, 1.0]
|
||||||
scaled_mask = torch.clamp(scaled_mask, min=0.0, max=1.0)
|
scaled_mask = torch.clamp(scaled_mask, min=0.0, max=1.0)
|
||||||
|
|
||||||
return (scaled_mask, )
|
return (scaled_mask, )
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user