diff --git a/nodes/image_nodes.py b/nodes/image_nodes.py index ae6541d..aebd49d 100644 --- a/nodes/image_nodes.py +++ b/nodes/image_nodes.py @@ -2476,6 +2476,7 @@ class ImageResizeKJv2: "optional" : { "mask": ("MASK",), "device": (["cpu", "gpu"],), + "per_batch": ("INT", { "default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1, "tooltip": "Process images in sub-batches to reduce memory usage. 0 disables sub-batching."}), }, "hidden": { "unique_id": "UNIQUE_ID", @@ -2494,7 +2495,7 @@ Keep proportions keeps the aspect ratio of the image, by highest dimension. """ - def resize(self, image, width, height, keep_proportion, upscale_method, divisible_by, pad_color, crop_position, unique_id, device="cpu", mask=None): + def resize(self, image, width, height, keep_proportion, upscale_method, divisible_by, pad_color, crop_position, unique_id, device="cpu", mask=None, per_batch=0): B, H, W, C = image.shape if device == "gpu": @@ -2510,6 +2511,10 @@ highest dimension. height = H pillarbox_blur = keep_proportion == "pillarbox_blur" + + # Initialize padding variables + pad_left = pad_right = pad_top = pad_bottom = 0 + if keep_proportion == "resize" or keep_proportion.startswith("pad") or pillarbox_blur: # If one of the dimensions is zero, calculate it to maintain the aspect ratio if width == 0 and height != 0: @@ -2528,7 +2533,6 @@ highest dimension. new_width = width new_height = height - pad_left = pad_right = pad_top = pad_bottom = 0 if keep_proportion.startswith("pad") or pillarbox_blur: # Calculate padding based on position if crop_position == "center": @@ -2564,71 +2568,126 @@ highest dimension. width = width - (width % divisible_by) height = height - (height % divisible_by) - out_image = image.clone().to(device) - if mask is not None: - out_mask = mask.clone().to(device) + # Preflight estimate (log-only when batching is active) + if per_batch != 0 and B > per_batch: + try: + bytes_per_elem = image.element_size() # typically 4 for float32 + est_total_bytes = B * height * width * C * bytes_per_elem + est_mb = est_total_bytes / (1024 * 1024) + msg = f"Resize v2estimated output ~{est_mb:.2f} MB; batching {per_batch}/{B}" + if unique_id and PromptServer is not None: + try: + PromptServer.instance.send_progress_text(msg, unique_id) + except: + pass + else: + print(f"[ImageResizeKJv2] estimated output ~{est_mb:.2f} MB; batching {per_batch}/{B}") + except: + pass + + def _process_subbatch(in_image, in_mask, pad_left, pad_right, pad_top, pad_bottom): + # Avoid unnecessary clones; only move if needed + out_image = in_image if in_image.device == device else in_image.to(device) + out_mask = None if in_mask is None else (in_mask if in_mask.device == device else in_mask.to(device)) + + # Crop logic + if keep_proportion == "crop": + old_height = out_image.shape[-3] + old_width = out_image.shape[-2] + old_aspect = old_width / old_height + new_aspect = width / height + if old_aspect > new_aspect: + crop_w = round(old_height * new_aspect) + crop_h = old_height + else: + crop_w = old_width + crop_h = round(old_width / new_aspect) + if crop_position == "center": + x = (old_width - crop_w) // 2 + y = (old_height - crop_h) // 2 + elif crop_position == "top": + x = (old_width - crop_w) // 2 + y = 0 + elif crop_position == "bottom": + x = (old_width - crop_w) // 2 + y = old_height - crop_h + elif crop_position == "left": + x = 0 + y = (old_height - crop_h) // 2 + elif crop_position == "right": + x = old_width - crop_w + y = (old_height - crop_h) // 2 + out_image = out_image.narrow(-2, x, crop_w).narrow(-3, y, crop_h) + if out_mask is not None: + out_mask = out_mask.narrow(-1, x, crop_w).narrow(-2, y, crop_h) + + out_image = common_upscale(out_image.movedim(-1,1), width, height, upscale_method, crop="disabled").movedim(1,-1) + if out_mask is not None: + if upscale_method == "lanczos": + out_mask = common_upscale(out_mask.unsqueeze(1).repeat(1, 3, 1, 1), width, height, upscale_method, crop="disabled").movedim(1,-1)[:, :, :, 0] + else: + out_mask = common_upscale(out_mask.unsqueeze(1), width, height, upscale_method, crop="disabled").squeeze(1) + + # Pad logic + if (keep_proportion.startswith("pad") or pillarbox_blur) and (pad_left > 0 or pad_right > 0 or pad_top > 0 or pad_bottom > 0): + padded_width = width + pad_left + pad_right + padded_height = height + pad_top + pad_bottom + if divisible_by > 1: + width_remainder = padded_width % divisible_by + height_remainder = padded_height % divisible_by + if width_remainder > 0: + extra_width = divisible_by - width_remainder + pad_right += extra_width + if height_remainder > 0: + extra_height = divisible_by - height_remainder + pad_bottom += extra_height + + pad_mode = ( + "pillarbox_blur" if pillarbox_blur else + "edge" if keep_proportion == "pad_edge" else + "edge_pixel" if keep_proportion == "pad_edge_pixel" else + "color" + ) + out_image, out_mask = ImagePadKJ.pad(self, out_image, pad_left, pad_right, pad_top, pad_bottom, 0, pad_color, pad_mode, mask=out_mask) + + return out_image, out_mask + + # If batching disabled (per_batch==0) or batch fits, process whole batch + if per_batch == 0 or B <= per_batch: + out_image, out_mask = _process_subbatch(image, mask, pad_left, pad_right, pad_top, pad_bottom) else: - out_mask = None - - # Crop logic - if keep_proportion == "crop": - old_width = W - old_height = H - old_aspect = old_width / old_height - new_aspect = width / height - if old_aspect > new_aspect: - crop_w = round(old_height * new_aspect) - crop_h = old_height + chunks = [] + mask_chunks = [] if mask is not None else None + total_batches = (B + per_batch - 1) // per_batch + current_batch = 0 + for start_idx in range(0, B, per_batch): + current_batch += 1 + end_idx = min(start_idx + per_batch, B) + sub_img = image[start_idx:end_idx] + sub_mask = mask[start_idx:end_idx] if mask is not None else None + sub_out_img, sub_out_mask = _process_subbatch(sub_img, sub_mask, pad_left, pad_right, pad_top, pad_bottom) + chunks.append(sub_out_img.cpu()) + if mask is not None: + mask_chunks.append(sub_out_mask.cpu() if sub_out_mask is not None else None) + # Per-batch progress update + if unique_id and PromptServer is not None: + try: + PromptServer.instance.send_progress_text( + f"Resize v2batch {current_batch}/{total_batches} · images {end_idx}/{B}", + unique_id + ) + except: + pass + else: + try: + print(f"[ImageResizeKJv2] batch {current_batch}/{total_batches} · images {end_idx}/{B}") + except: + pass + out_image = torch.cat(chunks, dim=0) + if mask is not None and any(m is not None for m in mask_chunks): + out_mask = torch.cat([m for m in mask_chunks if m is not None], dim=0) else: - crop_w = old_width - crop_h = round(old_width / new_aspect) - if crop_position == "center": - x = (old_width - crop_w) // 2 - y = (old_height - crop_h) // 2 - elif crop_position == "top": - x = (old_width - crop_w) // 2 - y = 0 - elif crop_position == "bottom": - x = (old_width - crop_w) // 2 - y = old_height - crop_h - elif crop_position == "left": - x = 0 - y = (old_height - crop_h) // 2 - elif crop_position == "right": - x = old_width - crop_w - y = (old_height - crop_h) // 2 - out_image = out_image.narrow(-2, x, crop_w).narrow(-3, y, crop_h) - if mask is not None: - out_mask = out_mask.narrow(-1, x, crop_w).narrow(-2, y, crop_h) - - out_image = common_upscale(out_image.movedim(-1,1), width, height, upscale_method, crop="disabled").movedim(1,-1) - if mask is not None: - if upscale_method == "lanczos": - out_mask = common_upscale(out_mask.unsqueeze(1).repeat(1, 3, 1, 1), width, height, upscale_method, crop="disabled").movedim(1,-1)[:, :, :, 0] - else: - out_mask = common_upscale(out_mask.unsqueeze(1), width, height, upscale_method, crop="disabled").squeeze(1) - - # Pad logic - if (keep_proportion.startswith("pad") or pillarbox_blur) and (pad_left > 0 or pad_right > 0 or pad_top > 0 or pad_bottom > 0): - padded_width = width + pad_left + pad_right - padded_height = height + pad_top + pad_bottom - if divisible_by > 1: - width_remainder = padded_width % divisible_by - height_remainder = padded_height % divisible_by - if width_remainder > 0: - extra_width = divisible_by - width_remainder - pad_right += extra_width - if height_remainder > 0: - extra_height = divisible_by - height_remainder - pad_bottom += extra_height - - pad_mode = ( - "pillarbox_blur" if pillarbox_blur else - "edge" if keep_proportion == "pad_edge" else - "edge_pixel" if keep_proportion == "pad_edge_pixel" else - "color" - ) - out_image, out_mask = ImagePadKJ.pad(self, out_image, pad_left, pad_right, pad_top, pad_bottom, 0, pad_color, pad_mode, mask=out_mask) + out_mask = None # Progress UI if unique_id and PromptServer is not None: