Update nodes.py

This commit is contained in:
Kijai 2024-04-11 15:44:38 +03:00
parent 6f53738fe6
commit 7033cf2dfc

View File

@ -7,6 +7,7 @@ import scipy.ndimage
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
from PIL import ImageFilter, Image, ImageDraw, ImageFont from PIL import ImageFilter, Image, ImageDraw, ImageFont
from contextlib import nullcontext
import json import json
import re import re
@ -2378,7 +2379,7 @@ class BatchCLIPSeg:
{ {
"images": ("IMAGE",), "images": ("IMAGE",),
"text": ("STRING", {"multiline": False}), "text": ("STRING", {"multiline": False}),
"threshold": ("FLOAT", {"default": 0.15,"min": 0.0, "max": 10.0, "step": 0.01}), "threshold": ("FLOAT", {"default": 0.1,"min": 0.0, "max": 10.0, "step": 0.001}),
"binary_mask": ("BOOLEAN", {"default": True}), "binary_mask": ("BOOLEAN", {"default": True}),
"combine_mask": ("BOOLEAN", {"default": False}), "combine_mask": ("BOOLEAN", {"default": False}),
"use_cuda": ("BOOLEAN", {"default": True}), "use_cuda": ("BOOLEAN", {"default": True}),
@ -2401,36 +2402,39 @@ Segments an image or batch of images using CLIPSeg.
device = torch.device("cuda") device = torch.device("cuda")
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = comfy.model_management.unet_dtype()
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
model.to(device) # Ensure the model is on the correct device model.to(dtype)
model.to(device)
images = images.to(device) images = images.to(device)
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
pbar = comfy.utils.ProgressBar(images.shape[0]) pbar = comfy.utils.ProgressBar(images.shape[0])
for image in images: autocast_condition = (dtype != torch.float32) and not comfy.model_management.is_device_mps(device)
image = (image* 255).type(torch.uint8) with torch.autocast(comfy.model_management.get_autocast_device(device), dtype=dtype) if autocast_condition else nullcontext():
prompt = text for image in images:
input_prc = processor(text=prompt, images=image, padding="max_length", return_tensors="pt") image = (image* 255).type(torch.uint8)
# Move the processed input to the device prompt = text
for key in input_prc: input_prc = processor(text=prompt, images=image, padding="max_length", return_tensors="pt")
input_prc[key] = input_prc[key].to(device) # Move the processed input to the device
for key in input_prc:
outputs = model(**input_prc) input_prc[key] = input_prc[key].to(device)
tensor = torch.sigmoid(outputs[0])
outputs = model(**input_prc)
tensor_thresholded = torch.where(tensor > threshold, tensor, torch.tensor(0, dtype=torch.float))
tensor_normalized = (tensor_thresholded - tensor_thresholded.min()) / (tensor_thresholded.max() - tensor_thresholded.min()) tensor = torch.sigmoid(outputs[0])
tensor_thresholded = torch.where(tensor > threshold, tensor, torch.tensor(0, dtype=torch.float))
tensor_normalized = (tensor_thresholded - tensor_thresholded.min()) / (tensor_thresholded.max() - tensor_thresholded.min())
tensor = tensor_normalized.unsqueeze(0).unsqueeze(0)
tensor = tensor_normalized # Resize the mask
resized_tensor = F.interpolate(tensor, size=(height, width), mode='nearest')
# Resize the mask # Remove the extra dimensions
resized_tensor = F.interpolate(tensor.unsqueeze(0), size=(height, width), mode='bilinear', align_corners=False) resized_tensor = resized_tensor[0, 0, :, :]
pbar.update(1)
# Remove the extra dimensions out.append(resized_tensor)
resized_tensor = resized_tensor[0, 0, :, :]
pbar.update(1)
out.append(resized_tensor)
results = torch.stack(out).cpu() results = torch.stack(out).cpu().float()
if combine_mask: if combine_mask:
combined_results = torch.max(results, dim=0)[0] combined_results = torch.max(results, dim=0)[0]