new nodes

This commit is contained in:
kijai 2024-10-15 22:57:52 +03:00
parent 74c335b6ef
commit 3df9f978f7
2 changed files with 159 additions and 1 deletions

View File

@ -56,6 +56,8 @@ NODE_CONFIG = {
"ImageConcanate": {"class": ImageConcanate, "name": "Image Concatenate"},
"ImageConcatFromBatch": {"class": ImageConcatFromBatch, "name": "Image Concatenate From Batch"},
"ImageConcatMulti": {"class": ImageConcatMulti, "name": "Image Concatenate Multi"},
"ImageCropByMaskAndResize": {"class": ImageCropByMaskAndResize, "name": "Image Crop By Mask And Resize"},
"ImageUncropByMask": {"class": ImageUncropByMask, "name": "Image Uncrop By Mask"},
"ImageGrabPIL": {"class": ImageGrabPIL, "name": "Image Grab PIL"},
"ImageGridComposite2x2": {"class": ImageGridComposite2x2, "name": "Image Grid Composite 2x2"},
"ImageGridComposite3x3": {"class": ImageGridComposite3x3, "name": "Image Grid Composite 3x3"},

View File

@ -2326,3 +2326,159 @@ class FastPreview:
"ui": {"bg_image": [img_base64]},
"result": ()
}
class ImageCropByMaskAndResize:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE", ),
"mask": ("MASK", ),
"base_resolution": ("INT", { "default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8, }),
"padding": ("INT", { "default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1, }),
},
}
RETURN_TYPES = ("IMAGE", "MASK", "BBOX", )
RETURN_NAMES = ("images", "masks", "bbox",)
FUNCTION = "crop"
CATEGORY = "KJNodes/image"
def crop_by_mask(self, mask, padding=0):
iy, ix = (mask == 1).nonzero(as_tuple=True)
h0, w0 = mask.shape
if iy.numel() == 0:
x_c = w0 / 2.0
y_c = h0 / 2.0
width = 0
height = 0
else:
x_min = ix.min().item()
x_max = ix.max().item()
y_min = iy.min().item()
y_max = iy.max().item()
width = x_max - x_min
height = y_max - y_min
if width > w0 or height > h0:
raise Exception("Masked area out of bounds")
x_c = (x_min + x_max) / 2.0
y_c = (y_min + y_max) / 2.0
width_half = width / 2.0
height_half = height / 2.0
if w0 <= width:
x0 = 0
w = w0
else:
x0 = max(0, x_c - width_half - padding)
w = width + 2 * padding
if x0 + w > w0:
x0 = w0 - w
if h0 <= height:
y0 = 0
h = h0
else:
y0 = max(0, y_c - height_half - padding)
h = height + 2 * padding
if y0 + h > h0:
y0 = h0 - h
return (int(x0), int(y0), int(w), int(h))
def crop(self, image, mask, base_resolution, padding=0):
image_list = []
mask_list = []
bbox_list = []
for i in range(image.shape[0]):
x0, y0, w, h = self.crop_by_mask(mask[i], padding)
cropped_image = image[i][y0:y0+h, x0:x0+w, :]
cropped_mask = mask[i][y0:y0+h, x0:x0+w]
cropped_image = cropped_image.unsqueeze(0).movedim(-1, 1) # Move C to the second position (B, C, H, W)
aspect_ratio = w / h
if aspect_ratio > 1:
target_width = base_resolution
target_height = int(base_resolution / aspect_ratio)
else:
target_height = base_resolution
target_width = int(base_resolution * aspect_ratio)
cropped_image = F.interpolate(cropped_image, size=(target_height, target_width), mode='bilinear', align_corners=False)
cropped_image = cropped_image.movedim(1, -1).squeeze(0)
cropped_mask = cropped_mask.unsqueeze(0).unsqueeze(0)
cropped_mask = F.interpolate(cropped_mask, size=(target_height, target_width), mode='nearest')
cropped_mask = cropped_mask.squeeze(0).squeeze(0)
image_list.append(cropped_image)
mask_list.append(cropped_mask)
bbox_list.append((x0, y0, x0 + w, y0 + h))
return (torch.stack(image_list), torch.stack(mask_list), bbox_list)
class ImageUncropByMask:
@classmethod
def INPUT_TYPES(s):
return {"required":
{
"destination": ("IMAGE",),
"source": ("IMAGE",),
"mask": ("MASK",),
"bbox": ("BBOX",),
},
}
CATEGORY = "KJNodes/image"
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("image",)
FUNCTION = "uncrop"
def uncrop(self, destination, source, mask, bbox=None):
output_list = []
B, H, W, C = destination.shape
for i in range(source.shape[0]):
x0, y0, x1, y1 = bbox[i]
bbox_height = y1 - y0
bbox_width = x1 - x0
# Resize source image to match the bounding box dimensions
resized_source = F.interpolate(source[i].unsqueeze(0).movedim(-1, 1), size=(bbox_height, bbox_width), mode='bilinear', align_corners=False)
resized_source = resized_source.movedim(1, -1).squeeze(0)
# Resize mask to match the bounding box dimensions
resized_mask = F.interpolate(mask[i].unsqueeze(0).unsqueeze(0), size=(bbox_height, bbox_width), mode='nearest')
resized_mask = resized_mask.squeeze(0).squeeze(0)
# Calculate padding values
pad_left = x0
pad_right = W - x1
pad_top = y0
pad_bottom = H - y1
# Pad the resized source image and mask to fit the destination dimensions
padded_source = F.pad(resized_source, pad=(0, 0, pad_left, pad_right, pad_top, pad_bottom), mode='constant', value=0)
padded_mask = F.pad(resized_mask, pad=(pad_left, pad_right, pad_top, pad_bottom), mode='constant', value=0)
# Ensure the padded mask has the correct shape
padded_mask = padded_mask.unsqueeze(2).expand(-1, -1, destination[i].shape[2])
# Ensure the padded source has the correct shape
padded_source = padded_source.unsqueeze(2).expand(-1, -1, -1, destination[i].shape[2]).squeeze(2)
# Combine the destination and padded source images using the mask
result = destination[i] * (1.0 - padded_mask) + padded_source * padded_mask
output_list.append(result)
return (torch.stack(output_list),)