Update mask_nodes.py

Reduced the total time for creating masks in a batch by using a handler built into the processor:
processor(images=images)
This commit is contained in:
Mokan Alexander 2024-05-09 19:45:13 +03:00 committed by GitHub
parent 148c805a15
commit 3652e8eee2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -48,7 +48,6 @@ Segments an image or batch of images using CLIPSeg.
def segment_image(self, images, text, threshold, binary_mask, combine_mask, use_cuda): def segment_image(self, images, text, threshold, binary_mask, combine_mask, use_cuda):
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
out = []
height, width, _ = images[0].shape height, width, _ = images[0].shape
if use_cuda and torch.cuda.is_available(): if use_cuda and torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
@ -60,35 +59,30 @@ Segments an image or batch of images using CLIPSeg.
model.to(device) model.to(device)
images = images.to(device) images = images.to(device)
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
pbar = ProgressBar(images.shape[0])
autocast_condition = (dtype != torch.float32) and not model_management.is_device_mps(device) autocast_condition = (dtype != torch.float32) and not model_management.is_device_mps(device)
with torch.autocast(model_management.get_autocast_device(device), dtype=dtype) if autocast_condition else nullcontext(): with torch.autocast(model_management.get_autocast_device(device), dtype=dtype) if autocast_condition else nullcontext():
for image in images:
image = (image* 255).type(torch.uint8) images = [Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) for image in images ]
prompt = text prompt = [text] * len(images)
input_prc = processor(text=prompt, images=image, return_tensors="pt") input_prc = processor(text=prompt, images=images, return_tensors="pt")
# Move the processed input to the device # Move the processed input to the device
for key in input_prc: for key in input_prc:
input_prc[key] = input_prc[key].to(device) input_prc[key] = input_prc[key].to(device)
outputs = model(**input_prc) outputs = model(**input_prc)
tensor = torch.sigmoid(outputs[0]) tensor = torch.sigmoid(outputs.logits)
tensor_thresholded = torch.where(tensor > threshold, tensor, torch.tensor(0, dtype=torch.float)) tensor_thresholded = torch.where(tensor > threshold, tensor, torch.tensor(0, dtype=torch.float))
tensor_normalized = (tensor_thresholded - tensor_thresholded.min()) / (tensor_thresholded.max() - tensor_thresholded.min()) tensor_normalized = (tensor_thresholded - tensor_thresholded.min()) / (tensor_thresholded.max() - tensor_thresholded.min())
tensor = tensor_normalized tensor = tensor_normalized
# Resize the mask # Resize the mask
if len(tensor.shape) == 3: resized_tensor = F.interpolate(tensor.unsqueeze(1), size=(height, width), mode='nearest')
tensor = tensor.unsqueeze(0)
resized_tensor = F.interpolate(tensor, size=(height, width), mode='nearest')
# Remove the extra dimensions # Remove the extra dimensions
resized_tensor = resized_tensor[0, 0, :, :] resized_tensor = resized_tensor.squeeze(1)
pbar.update(1)
out.append(resized_tensor)
results = torch.stack(out).cpu().float() results = resized_tensor.cpu().float()
if combine_mask: if combine_mask:
combined_results = torch.max(results, dim=0)[0] combined_results = torch.max(results, dim=0)[0]