diff --git a/nodes.py b/nodes.py index 33c540e..0a69b96 100644 --- a/nodes.py +++ b/nodes.py @@ -7,6 +7,7 @@ import scipy.ndimage import matplotlib.pyplot as plt import numpy as np from PIL import ImageFilter, Image, ImageDraw, ImageFont +from contextlib import nullcontext import json import re @@ -2378,7 +2379,7 @@ class BatchCLIPSeg: { "images": ("IMAGE",), "text": ("STRING", {"multiline": False}), - "threshold": ("FLOAT", {"default": 0.15,"min": 0.0, "max": 10.0, "step": 0.01}), + "threshold": ("FLOAT", {"default": 0.1,"min": 0.0, "max": 10.0, "step": 0.001}), "binary_mask": ("BOOLEAN", {"default": True}), "combine_mask": ("BOOLEAN", {"default": False}), "use_cuda": ("BOOLEAN", {"default": True}), @@ -2401,36 +2402,39 @@ Segments an image or batch of images using CLIPSeg. device = torch.device("cuda") else: device = torch.device("cpu") + dtype = comfy.model_management.unet_dtype() model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") - model.to(device) # Ensure the model is on the correct device + model.to(dtype) + model.to(device) images = images.to(device) processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") pbar = comfy.utils.ProgressBar(images.shape[0]) - for image in images: - image = (image* 255).type(torch.uint8) - prompt = text - input_prc = processor(text=prompt, images=image, padding="max_length", return_tensors="pt") - # Move the processed input to the device - for key in input_prc: - input_prc[key] = input_prc[key].to(device) - - outputs = model(**input_prc) - tensor = torch.sigmoid(outputs[0]) - - tensor_thresholded = torch.where(tensor > threshold, tensor, torch.tensor(0, dtype=torch.float)) - tensor_normalized = (tensor_thresholded - tensor_thresholded.min()) / (tensor_thresholded.max() - tensor_thresholded.min()) + autocast_condition = (dtype != torch.float32) and not comfy.model_management.is_device_mps(device) + with torch.autocast(comfy.model_management.get_autocast_device(device), dtype=dtype) if autocast_condition else nullcontext(): + for image in images: + image = (image* 255).type(torch.uint8) + prompt = text + input_prc = processor(text=prompt, images=image, padding="max_length", return_tensors="pt") + # Move the processed input to the device + for key in input_prc: + input_prc[key] = input_prc[key].to(device) + + outputs = model(**input_prc) + + tensor = torch.sigmoid(outputs[0]) + tensor_thresholded = torch.where(tensor > threshold, tensor, torch.tensor(0, dtype=torch.float)) + tensor_normalized = (tensor_thresholded - tensor_thresholded.min()) / (tensor_thresholded.max() - tensor_thresholded.min()) + tensor = tensor_normalized.unsqueeze(0).unsqueeze(0) - tensor = tensor_normalized + # Resize the mask + resized_tensor = F.interpolate(tensor, size=(height, width), mode='nearest') - # Resize the mask - resized_tensor = F.interpolate(tensor.unsqueeze(0), size=(height, width), mode='bilinear', align_corners=False) - - # Remove the extra dimensions - resized_tensor = resized_tensor[0, 0, :, :] - pbar.update(1) - out.append(resized_tensor) + # Remove the extra dimensions + resized_tensor = resized_tensor[0, 0, :, :] + pbar.update(1) + out.append(resized_tensor) - results = torch.stack(out).cpu() + results = torch.stack(out).cpu().float() if combine_mask: combined_results = torch.max(results, dim=0)[0]