Remake ColorToMask node

that code was atrocious
This commit is contained in:
kijai 2024-04-22 19:50:34 +03:00
parent 76c536d156
commit 171b70bfa5

View File

@ -1000,35 +1000,41 @@ Converts chosen RGB value to a mask
"green": ("INT", {"default": 0,"min": 0, "max": 255, "step": 1}),
"blue": ("INT", {"default": 0,"min": 0, "max": 255, "step": 1}),
"threshold": ("INT", {"default": 10,"min": 0, "max": 255, "step": 1}),
"per_batch": ("INT", {"default": 16, "min": 1, "max": 4096, "step": 1}),
},
}
def clip(self, images, red, green, blue, threshold, invert):
color = np.array([red, green, blue])
images = 255. * images.cpu().numpy()
images = np.clip(images, 0, 255).astype(np.uint8)
images = [Image.fromarray(image) for image in images]
images = [np.array(image) for image in images]
def clip(self, images, red, green, blue, threshold, invert, per_batch):
black = [0, 0, 0]
white = [255, 255, 255]
color = torch.tensor([red, green, blue], dtype=torch.uint8)
black = torch.tensor([0, 0, 0], dtype=torch.uint8)
white = torch.tensor([255, 255, 255], dtype=torch.uint8)
if invert:
black, white = white, black
black, white = white, black
new_images = []
for image in images:
new_image = np.full_like(image, black)
steps = images.shape[0]
pbar = comfy.utils.ProgressBar(steps)
tensors_out = []
for start_idx in range(0, images.shape[0], per_batch):
color_distances = np.linalg.norm(image - color, axis=-1)
complement_indexes = color_distances <= threshold
# Calculate color distances
color_distances = torch.norm(images[start_idx:start_idx+per_batch] * 255 - color, dim=-1)
# Create a mask based on the threshold
mask = color_distances <= threshold
# Apply the mask to create new images
mask_out = torch.where(mask.unsqueeze(-1), white, black).float()
mask_out = mask_out.mean(dim=-1)
new_image[complement_indexes] = white
new_images.append(new_image)
new_images = np.array(new_images).astype(np.float32) / 255.0
new_images = torch.from_numpy(new_images).permute(3, 0, 1, 2)
return new_images
tensors_out.append(mask_out.cpu())
batch_count = mask_out.shape[0]
pbar.update(batch_count)
tensors_out = torch.cat(tensors_out, dim=0)
return tensors_out,
class ConditioningMultiCombine:
@classmethod
@ -5058,7 +5064,7 @@ Each mask is generated with the specified width and height.
mask = torch.ones((height, width), dtype=torch.float32) * value
masks.append(mask)
masks_out = torch.stack(masks, dim=0)
print(masks_out.shape)
return(masks_out,)