diff --git a/nodes/mask_nodes.py b/nodes/mask_nodes.py index 8b1d9f0..4019dd0 100644 --- a/nodes/mask_nodes.py +++ b/nodes/mask_nodes.py @@ -46,9 +46,8 @@ class BatchCLIPSeg: Segments an image or batch of images using CLIPSeg. """ - def segment_image(self, images, text, threshold, binary_mask, combine_mask, use_cuda): + def segment_image(self, images, text, threshold, binary_mask, combine_mask, use_cuda): from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation - out = [] height, width, _ = images[0].shape if use_cuda and torch.cuda.is_available(): device = torch.device("cuda") @@ -60,43 +59,38 @@ Segments an image or batch of images using CLIPSeg. model.to(device) images = images.to(device) processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") - pbar = ProgressBar(images.shape[0]) autocast_condition = (dtype != torch.float32) and not model_management.is_device_mps(device) with torch.autocast(model_management.get_autocast_device(device), dtype=dtype) if autocast_condition else nullcontext(): - for image in images: - image = (image* 255).type(torch.uint8) - prompt = text - input_prc = processor(text=prompt, images=image, return_tensors="pt") - # Move the processed input to the device - for key in input_prc: - input_prc[key] = input_prc[key].to(device) - - outputs = model(**input_prc) - - tensor = torch.sigmoid(outputs[0]) - tensor_thresholded = torch.where(tensor > threshold, tensor, torch.tensor(0, dtype=torch.float)) - tensor_normalized = (tensor_thresholded - tensor_thresholded.min()) / (tensor_thresholded.max() - tensor_thresholded.min()) - tensor = tensor_normalized - # Resize the mask - if len(tensor.shape) == 3: - tensor = tensor.unsqueeze(0) - resized_tensor = F.interpolate(tensor, size=(height, width), mode='nearest') + images = [Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) for image in images ] + prompt = [text] * len(images) + input_prc = processor(text=prompt, images=images, return_tensors="pt") + # Move the processed input to the device + for key in input_prc: + input_prc[key] = input_prc[key].to(device) + + outputs = model(**input_prc) + + tensor = torch.sigmoid(outputs.logits) + tensor_thresholded = torch.where(tensor > threshold, tensor, torch.tensor(0, dtype=torch.float)) + tensor_normalized = (tensor_thresholded - tensor_thresholded.min()) / (tensor_thresholded.max() - tensor_thresholded.min()) + tensor = tensor_normalized + + # Resize the mask + resized_tensor = F.interpolate(tensor.unsqueeze(1), size=(height, width), mode='nearest') + + # Remove the extra dimensions + resized_tensor = resized_tensor.squeeze(1) + + results = resized_tensor.cpu().float() - # Remove the extra dimensions - resized_tensor = resized_tensor[0, 0, :, :] - pbar.update(1) - out.append(resized_tensor) - - results = torch.stack(out).cpu().float() - if combine_mask: combined_results = torch.max(results, dim=0)[0] results = combined_results.unsqueeze(0).repeat(len(images),1,1) if binary_mask: results = results.round() - + return results, class CreateTextMask: @@ -1163,4 +1157,4 @@ Sets new min and max values for the mask. # Clamp the values to ensure they are within [0.0, 1.0] scaled_mask = torch.clamp(scaled_mask, min=0.0, max=1.0) - return (scaled_mask, ) \ No newline at end of file + return (scaled_mask, )