Update image_nodes.py

This commit is contained in:
kijai 2024-11-06 12:59:11 +02:00
parent ba33d02198
commit 3b065864fa

View File

@ -2474,6 +2474,8 @@ class ImageCropByMaskAndResize:
return (int(x0), int(y0), int(w), int(h))
def crop(self, image, mask, base_resolution, padding=0, min_crop_resolution=128, max_crop_resolution=512):
print("mask shape: ",mask.shape)
print("image shape: ",image.shape)
image_list = []
mask_list = []
bbox_list = []
@ -2481,8 +2483,6 @@ class ImageCropByMaskAndResize:
x0, y0, w, h = self.crop_by_mask(mask[i], padding, min_crop_resolution, max_crop_resolution)
cropped_image = image[i][y0:y0+h, x0:x0+w, :]
cropped_mask = mask[i][y0:y0+h, x0:x0+w]
cropped_image = cropped_image.unsqueeze(0).movedim(-1, 1) # Move C to the second position (B, C, H, W)
aspect_ratio = w / h
if aspect_ratio > 1:
@ -2495,12 +2495,16 @@ class ImageCropByMaskAndResize:
# Ensure dimensions are divisible by 8
target_width = (target_width + 7) // 8 * 8
target_height = (target_height + 7) // 8 * 8
cropped_image = cropped_image.unsqueeze(0).movedim(-1, 1) # Move C to the second position (B, C, H, W)
cropped_image = common_upscale(cropped_image, target_width, target_height, "lanczos", "disabled")
cropped_image = cropped_image.movedim(1, -1).squeeze(0)
print("cropped_image shape: ",cropped_image.shape)
print("cropped_mask shape: ",cropped_mask.shape)
cropped_mask = cropped_mask.unsqueeze(0).unsqueeze(0)
cropped_mask = F.interpolate(cropped_mask, size=(target_height, target_width), mode='bilinear')
cropped_mask = common_upscale(cropped_mask, target_width, target_height, 'bilinear', "disabled")
cropped_mask = cropped_mask.squeeze(0).squeeze(0)
image_list.append(cropped_image)