mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-09 21:04:41 +08:00
Update nodes.py
This commit is contained in:
parent
6f53738fe6
commit
7033cf2dfc
52
nodes.py
52
nodes.py
@ -7,6 +7,7 @@ import scipy.ndimage
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import ImageFilter, Image, ImageDraw, ImageFont
|
from PIL import ImageFilter, Image, ImageDraw, ImageFont
|
||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
@ -2378,7 +2379,7 @@ class BatchCLIPSeg:
|
|||||||
{
|
{
|
||||||
"images": ("IMAGE",),
|
"images": ("IMAGE",),
|
||||||
"text": ("STRING", {"multiline": False}),
|
"text": ("STRING", {"multiline": False}),
|
||||||
"threshold": ("FLOAT", {"default": 0.15,"min": 0.0, "max": 10.0, "step": 0.01}),
|
"threshold": ("FLOAT", {"default": 0.1,"min": 0.0, "max": 10.0, "step": 0.001}),
|
||||||
"binary_mask": ("BOOLEAN", {"default": True}),
|
"binary_mask": ("BOOLEAN", {"default": True}),
|
||||||
"combine_mask": ("BOOLEAN", {"default": False}),
|
"combine_mask": ("BOOLEAN", {"default": False}),
|
||||||
"use_cuda": ("BOOLEAN", {"default": True}),
|
"use_cuda": ("BOOLEAN", {"default": True}),
|
||||||
@ -2401,36 +2402,39 @@ Segments an image or batch of images using CLIPSeg.
|
|||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
|
dtype = comfy.model_management.unet_dtype()
|
||||||
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
|
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
|
||||||
model.to(device) # Ensure the model is on the correct device
|
model.to(dtype)
|
||||||
|
model.to(device)
|
||||||
images = images.to(device)
|
images = images.to(device)
|
||||||
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
|
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
|
||||||
pbar = comfy.utils.ProgressBar(images.shape[0])
|
pbar = comfy.utils.ProgressBar(images.shape[0])
|
||||||
for image in images:
|
autocast_condition = (dtype != torch.float32) and not comfy.model_management.is_device_mps(device)
|
||||||
image = (image* 255).type(torch.uint8)
|
with torch.autocast(comfy.model_management.get_autocast_device(device), dtype=dtype) if autocast_condition else nullcontext():
|
||||||
prompt = text
|
for image in images:
|
||||||
input_prc = processor(text=prompt, images=image, padding="max_length", return_tensors="pt")
|
image = (image* 255).type(torch.uint8)
|
||||||
# Move the processed input to the device
|
prompt = text
|
||||||
for key in input_prc:
|
input_prc = processor(text=prompt, images=image, padding="max_length", return_tensors="pt")
|
||||||
input_prc[key] = input_prc[key].to(device)
|
# Move the processed input to the device
|
||||||
|
for key in input_prc:
|
||||||
outputs = model(**input_prc)
|
input_prc[key] = input_prc[key].to(device)
|
||||||
tensor = torch.sigmoid(outputs[0])
|
|
||||||
|
outputs = model(**input_prc)
|
||||||
tensor_thresholded = torch.where(tensor > threshold, tensor, torch.tensor(0, dtype=torch.float))
|
|
||||||
tensor_normalized = (tensor_thresholded - tensor_thresholded.min()) / (tensor_thresholded.max() - tensor_thresholded.min())
|
tensor = torch.sigmoid(outputs[0])
|
||||||
|
tensor_thresholded = torch.where(tensor > threshold, tensor, torch.tensor(0, dtype=torch.float))
|
||||||
|
tensor_normalized = (tensor_thresholded - tensor_thresholded.min()) / (tensor_thresholded.max() - tensor_thresholded.min())
|
||||||
|
tensor = tensor_normalized.unsqueeze(0).unsqueeze(0)
|
||||||
|
|
||||||
tensor = tensor_normalized
|
# Resize the mask
|
||||||
|
resized_tensor = F.interpolate(tensor, size=(height, width), mode='nearest')
|
||||||
|
|
||||||
# Resize the mask
|
# Remove the extra dimensions
|
||||||
resized_tensor = F.interpolate(tensor.unsqueeze(0), size=(height, width), mode='bilinear', align_corners=False)
|
resized_tensor = resized_tensor[0, 0, :, :]
|
||||||
|
pbar.update(1)
|
||||||
# Remove the extra dimensions
|
out.append(resized_tensor)
|
||||||
resized_tensor = resized_tensor[0, 0, :, :]
|
|
||||||
pbar.update(1)
|
|
||||||
out.append(resized_tensor)
|
|
||||||
|
|
||||||
results = torch.stack(out).cpu()
|
results = torch.stack(out).cpu().float()
|
||||||
|
|
||||||
if combine_mask:
|
if combine_mask:
|
||||||
combined_results = torch.max(results, dim=0)[0]
|
combined_results = torch.max(results, dim=0)[0]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user