Update image_nodes.py

This commit is contained in:
kijai 2024-11-06 12:59:11 +02:00
parent ba33d02198
commit 3b065864fa

View File

@ -2474,6 +2474,8 @@ class ImageCropByMaskAndResize:
return (int(x0), int(y0), int(w), int(h))
def crop(self, image, mask, base_resolution, padding=0, min_crop_resolution=128, max_crop_resolution=512):
print("mask shape: ",mask.shape)
print("image shape: ",image.shape)
image_list = []
mask_list = []
bbox_list = []
@ -2482,8 +2484,6 @@ class ImageCropByMaskAndResize:
cropped_image = image[i][y0:y0+h, x0:x0+w, :]
cropped_mask = mask[i][y0:y0+h, x0:x0+w]
cropped_image = cropped_image.unsqueeze(0).movedim(-1, 1) # Move C to the second position (B, C, H, W)
aspect_ratio = w / h
if aspect_ratio > 1:
target_width = base_resolution
@ -2496,11 +2496,15 @@ class ImageCropByMaskAndResize:
target_width = (target_width + 7) // 8 * 8
target_height = (target_height + 7) // 8 * 8
cropped_image = cropped_image.unsqueeze(0).movedim(-1, 1) # Move C to the second position (B, C, H, W)
cropped_image = common_upscale(cropped_image, target_width, target_height, "lanczos", "disabled")
cropped_image = cropped_image.movedim(1, -1).squeeze(0)
print("cropped_image shape: ",cropped_image.shape)
print("cropped_mask shape: ",cropped_mask.shape)
cropped_mask = cropped_mask.unsqueeze(0).unsqueeze(0)
cropped_mask = F.interpolate(cropped_mask, size=(target_height, target_width), mode='bilinear')
cropped_mask = common_upscale(cropped_mask, target_width, target_height, 'bilinear', "disabled")
cropped_mask = cropped_mask.squeeze(0).squeeze(0)
image_list.append(cropped_image)