Add choice of device for imageresize

This commit is contained in:
kijai 2025-05-20 16:50:55 +03:00
parent aca7916352
commit 44565e9bff

View File

@ -809,12 +809,19 @@ with repeats 2 becomes batch of 10 images: 0, 0, 1, 1, 2, 2, 3, 3, 4, 4
} }
def repeat(self, images, repeats, mask=None): def repeat(self, images, repeats, mask=None):
original_count = images.shape[0]
total_count = original_count * repeats
repeated_images = torch.repeat_interleave(images, repeats=repeats, dim=0) repeated_images = torch.repeat_interleave(images, repeats=repeats, dim=0)
if mask is not None: if mask is not None:
mask = torch.repeat_interleave(mask, repeats=repeats, dim=0) mask = torch.repeat_interleave(mask, repeats=repeats, dim=0)
else: else:
mask = torch.zeros_like(repeated_images[:, 0:1, :, :]) mask = torch.zeros((total_count, images.shape[1], images.shape[2]),
device=images.device, dtype=images.dtype)
for i in range(original_count):
mask[i * repeats] = 1.0
print("mask shape", mask.shape)
return (repeated_images, mask) return (repeated_images, mask)
class ImageUpscaleWithModelBatched: class ImageUpscaleWithModelBatched:
@ -2381,6 +2388,9 @@ class ImageResizeKJv2:
"crop_position": (["center", "top", "bottom", "left", "right"], { "default": "center" }), "crop_position": (["center", "top", "bottom", "left", "right"], { "default": "center" }),
"divisible_by": ("INT", { "default": 2, "min": 0, "max": 512, "step": 1, }), "divisible_by": ("INT", { "default": 2, "min": 0, "max": 512, "step": 1, }),
}, },
"optional" : {
"device": (["cpu", "gpu"],),
}
} }
RETURN_TYPES = ("IMAGE", "INT", "INT",) RETURN_TYPES = ("IMAGE", "INT", "INT",)
@ -2395,9 +2405,16 @@ Keep proportions keeps the aspect ratio of the image, by
highest dimension. highest dimension.
""" """
def resize(self, image, width, height, keep_proportion, upscale_method, divisible_by, pad_color, crop_position): def resize(self, image, width, height, keep_proportion, upscale_method, divisible_by, pad_color, crop_position, device="cpu"):
B, H, W, C = image.shape B, H, W, C = image.shape
if device == "gpu":
if upscale_method == "lanczos":
raise Exception("Lanczos is not supported on the GPU")
device = model_management.get_torch_device()
else:
device = torch.device("cpu")
if width == 0: if width == 0:
width = W width = W
if height == 0: if height == 0:
@ -2430,7 +2447,7 @@ highest dimension.
width = width - (width % divisible_by) width = width - (width % divisible_by)
height = height - (height % divisible_by) height = height - (height % divisible_by)
out_image = image.clone() out_image = image.clone().to(device)
if keep_proportion == "crop": if keep_proportion == "crop":
old_width = W old_width = W
@ -2483,7 +2500,7 @@ highest dimension.
out_image, _ = ImagePadKJ.pad(self, out_image, pad_left, pad_right, pad_top, pad_bottom, 0, pad_color, "edge" if keep_proportion == "pad_edge" else "color") out_image, _ = ImagePadKJ.pad(self, out_image, pad_left, pad_right, pad_top, pad_bottom, 0, pad_color, "edge" if keep_proportion == "pad_edge" else "color")
return(out_image, out_image.shape[2], out_image.shape[1],) return(out_image.cpu(), out_image.shape[2], out_image.shape[1],)
import pathlib import pathlib
class LoadAndResizeImage: class LoadAndResizeImage: