Update nodes.py

This commit is contained in:
Kijai 2024-04-11 15:44:38 +03:00
parent 6f53738fe6
commit 7033cf2dfc

View File

@ -7,6 +7,7 @@ import scipy.ndimage
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
from PIL import ImageFilter, Image, ImageDraw, ImageFont from PIL import ImageFilter, Image, ImageDraw, ImageFont
from contextlib import nullcontext
import json import json
import re import re
@ -2378,7 +2379,7 @@ class BatchCLIPSeg:
{ {
"images": ("IMAGE",), "images": ("IMAGE",),
"text": ("STRING", {"multiline": False}), "text": ("STRING", {"multiline": False}),
"threshold": ("FLOAT", {"default": 0.15,"min": 0.0, "max": 10.0, "step": 0.01}), "threshold": ("FLOAT", {"default": 0.1,"min": 0.0, "max": 10.0, "step": 0.001}),
"binary_mask": ("BOOLEAN", {"default": True}), "binary_mask": ("BOOLEAN", {"default": True}),
"combine_mask": ("BOOLEAN", {"default": False}), "combine_mask": ("BOOLEAN", {"default": False}),
"use_cuda": ("BOOLEAN", {"default": True}), "use_cuda": ("BOOLEAN", {"default": True}),
@ -2401,11 +2402,15 @@ Segments an image or batch of images using CLIPSeg.
device = torch.device("cuda") device = torch.device("cuda")
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = comfy.model_management.unet_dtype()
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
model.to(device) # Ensure the model is on the correct device model.to(dtype)
model.to(device)
images = images.to(device) images = images.to(device)
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
pbar = comfy.utils.ProgressBar(images.shape[0]) pbar = comfy.utils.ProgressBar(images.shape[0])
autocast_condition = (dtype != torch.float32) and not comfy.model_management.is_device_mps(device)
with torch.autocast(comfy.model_management.get_autocast_device(device), dtype=dtype) if autocast_condition else nullcontext():
for image in images: for image in images:
image = (image* 255).type(torch.uint8) image = (image* 255).type(torch.uint8)
prompt = text prompt = text
@ -2415,22 +2420,21 @@ Segments an image or batch of images using CLIPSeg.
input_prc[key] = input_prc[key].to(device) input_prc[key] = input_prc[key].to(device)
outputs = model(**input_prc) outputs = model(**input_prc)
tensor = torch.sigmoid(outputs[0])
tensor = torch.sigmoid(outputs[0])
tensor_thresholded = torch.where(tensor > threshold, tensor, torch.tensor(0, dtype=torch.float)) tensor_thresholded = torch.where(tensor > threshold, tensor, torch.tensor(0, dtype=torch.float))
tensor_normalized = (tensor_thresholded - tensor_thresholded.min()) / (tensor_thresholded.max() - tensor_thresholded.min()) tensor_normalized = (tensor_thresholded - tensor_thresholded.min()) / (tensor_thresholded.max() - tensor_thresholded.min())
tensor = tensor_normalized.unsqueeze(0).unsqueeze(0)
tensor = tensor_normalized
# Resize the mask # Resize the mask
resized_tensor = F.interpolate(tensor.unsqueeze(0), size=(height, width), mode='bilinear', align_corners=False) resized_tensor = F.interpolate(tensor, size=(height, width), mode='nearest')
# Remove the extra dimensions # Remove the extra dimensions
resized_tensor = resized_tensor[0, 0, :, :] resized_tensor = resized_tensor[0, 0, :, :]
pbar.update(1) pbar.update(1)
out.append(resized_tensor) out.append(resized_tensor)
results = torch.stack(out).cpu() results = torch.stack(out).cpu().float()
if combine_mask: if combine_mask:
combined_results = torch.max(results, dim=0)[0] combined_results = torch.max(results, dim=0)[0]