diff --git a/nodes/image_nodes.py b/nodes/image_nodes.py
index ae6541d..aebd49d 100644
--- a/nodes/image_nodes.py
+++ b/nodes/image_nodes.py
@@ -2476,6 +2476,7 @@ class ImageResizeKJv2:
"optional" : {
"mask": ("MASK",),
"device": (["cpu", "gpu"],),
+ "per_batch": ("INT", { "default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1, "tooltip": "Process images in sub-batches to reduce memory usage. 0 disables sub-batching."}),
},
"hidden": {
"unique_id": "UNIQUE_ID",
@@ -2494,7 +2495,7 @@ Keep proportions keeps the aspect ratio of the image, by
highest dimension.
"""
- def resize(self, image, width, height, keep_proportion, upscale_method, divisible_by, pad_color, crop_position, unique_id, device="cpu", mask=None):
+ def resize(self, image, width, height, keep_proportion, upscale_method, divisible_by, pad_color, crop_position, unique_id, device="cpu", mask=None, per_batch=0):
B, H, W, C = image.shape
if device == "gpu":
@@ -2510,6 +2511,10 @@ highest dimension.
height = H
pillarbox_blur = keep_proportion == "pillarbox_blur"
+
+ # Initialize padding variables
+ pad_left = pad_right = pad_top = pad_bottom = 0
+
if keep_proportion == "resize" or keep_proportion.startswith("pad") or pillarbox_blur:
# If one of the dimensions is zero, calculate it to maintain the aspect ratio
if width == 0 and height != 0:
@@ -2528,7 +2533,6 @@ highest dimension.
new_width = width
new_height = height
- pad_left = pad_right = pad_top = pad_bottom = 0
if keep_proportion.startswith("pad") or pillarbox_blur:
# Calculate padding based on position
if crop_position == "center":
@@ -2564,71 +2568,126 @@ highest dimension.
width = width - (width % divisible_by)
height = height - (height % divisible_by)
- out_image = image.clone().to(device)
- if mask is not None:
- out_mask = mask.clone().to(device)
+ # Preflight estimate (log-only when batching is active)
+ if per_batch != 0 and B > per_batch:
+ try:
+ bytes_per_elem = image.element_size() # typically 4 for float32
+ est_total_bytes = B * height * width * C * bytes_per_elem
+ est_mb = est_total_bytes / (1024 * 1024)
+ msg = f"
| Resize v2 | estimated output ~{est_mb:.2f} MB; batching {per_batch}/{B} |
"
+ if unique_id and PromptServer is not None:
+ try:
+ PromptServer.instance.send_progress_text(msg, unique_id)
+ except:
+ pass
+ else:
+ print(f"[ImageResizeKJv2] estimated output ~{est_mb:.2f} MB; batching {per_batch}/{B}")
+ except:
+ pass
+
+ def _process_subbatch(in_image, in_mask, pad_left, pad_right, pad_top, pad_bottom):
+ # Avoid unnecessary clones; only move if needed
+ out_image = in_image if in_image.device == device else in_image.to(device)
+ out_mask = None if in_mask is None else (in_mask if in_mask.device == device else in_mask.to(device))
+
+ # Crop logic
+ if keep_proportion == "crop":
+ old_height = out_image.shape[-3]
+ old_width = out_image.shape[-2]
+ old_aspect = old_width / old_height
+ new_aspect = width / height
+ if old_aspect > new_aspect:
+ crop_w = round(old_height * new_aspect)
+ crop_h = old_height
+ else:
+ crop_w = old_width
+ crop_h = round(old_width / new_aspect)
+ if crop_position == "center":
+ x = (old_width - crop_w) // 2
+ y = (old_height - crop_h) // 2
+ elif crop_position == "top":
+ x = (old_width - crop_w) // 2
+ y = 0
+ elif crop_position == "bottom":
+ x = (old_width - crop_w) // 2
+ y = old_height - crop_h
+ elif crop_position == "left":
+ x = 0
+ y = (old_height - crop_h) // 2
+ elif crop_position == "right":
+ x = old_width - crop_w
+ y = (old_height - crop_h) // 2
+ out_image = out_image.narrow(-2, x, crop_w).narrow(-3, y, crop_h)
+ if out_mask is not None:
+ out_mask = out_mask.narrow(-1, x, crop_w).narrow(-2, y, crop_h)
+
+ out_image = common_upscale(out_image.movedim(-1,1), width, height, upscale_method, crop="disabled").movedim(1,-1)
+ if out_mask is not None:
+ if upscale_method == "lanczos":
+ out_mask = common_upscale(out_mask.unsqueeze(1).repeat(1, 3, 1, 1), width, height, upscale_method, crop="disabled").movedim(1,-1)[:, :, :, 0]
+ else:
+ out_mask = common_upscale(out_mask.unsqueeze(1), width, height, upscale_method, crop="disabled").squeeze(1)
+
+ # Pad logic
+ if (keep_proportion.startswith("pad") or pillarbox_blur) and (pad_left > 0 or pad_right > 0 or pad_top > 0 or pad_bottom > 0):
+ padded_width = width + pad_left + pad_right
+ padded_height = height + pad_top + pad_bottom
+ if divisible_by > 1:
+ width_remainder = padded_width % divisible_by
+ height_remainder = padded_height % divisible_by
+ if width_remainder > 0:
+ extra_width = divisible_by - width_remainder
+ pad_right += extra_width
+ if height_remainder > 0:
+ extra_height = divisible_by - height_remainder
+ pad_bottom += extra_height
+
+ pad_mode = (
+ "pillarbox_blur" if pillarbox_blur else
+ "edge" if keep_proportion == "pad_edge" else
+ "edge_pixel" if keep_proportion == "pad_edge_pixel" else
+ "color"
+ )
+ out_image, out_mask = ImagePadKJ.pad(self, out_image, pad_left, pad_right, pad_top, pad_bottom, 0, pad_color, pad_mode, mask=out_mask)
+
+ return out_image, out_mask
+
+ # If batching disabled (per_batch==0) or batch fits, process whole batch
+ if per_batch == 0 or B <= per_batch:
+ out_image, out_mask = _process_subbatch(image, mask, pad_left, pad_right, pad_top, pad_bottom)
else:
- out_mask = None
-
- # Crop logic
- if keep_proportion == "crop":
- old_width = W
- old_height = H
- old_aspect = old_width / old_height
- new_aspect = width / height
- if old_aspect > new_aspect:
- crop_w = round(old_height * new_aspect)
- crop_h = old_height
+ chunks = []
+ mask_chunks = [] if mask is not None else None
+ total_batches = (B + per_batch - 1) // per_batch
+ current_batch = 0
+ for start_idx in range(0, B, per_batch):
+ current_batch += 1
+ end_idx = min(start_idx + per_batch, B)
+ sub_img = image[start_idx:end_idx]
+ sub_mask = mask[start_idx:end_idx] if mask is not None else None
+ sub_out_img, sub_out_mask = _process_subbatch(sub_img, sub_mask, pad_left, pad_right, pad_top, pad_bottom)
+ chunks.append(sub_out_img.cpu())
+ if mask is not None:
+ mask_chunks.append(sub_out_mask.cpu() if sub_out_mask is not None else None)
+ # Per-batch progress update
+ if unique_id and PromptServer is not None:
+ try:
+ PromptServer.instance.send_progress_text(
+ f"| Resize v2 | batch {current_batch}/{total_batches} · images {end_idx}/{B} |
",
+ unique_id
+ )
+ except:
+ pass
+ else:
+ try:
+ print(f"[ImageResizeKJv2] batch {current_batch}/{total_batches} · images {end_idx}/{B}")
+ except:
+ pass
+ out_image = torch.cat(chunks, dim=0)
+ if mask is not None and any(m is not None for m in mask_chunks):
+ out_mask = torch.cat([m for m in mask_chunks if m is not None], dim=0)
else:
- crop_w = old_width
- crop_h = round(old_width / new_aspect)
- if crop_position == "center":
- x = (old_width - crop_w) // 2
- y = (old_height - crop_h) // 2
- elif crop_position == "top":
- x = (old_width - crop_w) // 2
- y = 0
- elif crop_position == "bottom":
- x = (old_width - crop_w) // 2
- y = old_height - crop_h
- elif crop_position == "left":
- x = 0
- y = (old_height - crop_h) // 2
- elif crop_position == "right":
- x = old_width - crop_w
- y = (old_height - crop_h) // 2
- out_image = out_image.narrow(-2, x, crop_w).narrow(-3, y, crop_h)
- if mask is not None:
- out_mask = out_mask.narrow(-1, x, crop_w).narrow(-2, y, crop_h)
-
- out_image = common_upscale(out_image.movedim(-1,1), width, height, upscale_method, crop="disabled").movedim(1,-1)
- if mask is not None:
- if upscale_method == "lanczos":
- out_mask = common_upscale(out_mask.unsqueeze(1).repeat(1, 3, 1, 1), width, height, upscale_method, crop="disabled").movedim(1,-1)[:, :, :, 0]
- else:
- out_mask = common_upscale(out_mask.unsqueeze(1), width, height, upscale_method, crop="disabled").squeeze(1)
-
- # Pad logic
- if (keep_proportion.startswith("pad") or pillarbox_blur) and (pad_left > 0 or pad_right > 0 or pad_top > 0 or pad_bottom > 0):
- padded_width = width + pad_left + pad_right
- padded_height = height + pad_top + pad_bottom
- if divisible_by > 1:
- width_remainder = padded_width % divisible_by
- height_remainder = padded_height % divisible_by
- if width_remainder > 0:
- extra_width = divisible_by - width_remainder
- pad_right += extra_width
- if height_remainder > 0:
- extra_height = divisible_by - height_remainder
- pad_bottom += extra_height
-
- pad_mode = (
- "pillarbox_blur" if pillarbox_blur else
- "edge" if keep_proportion == "pad_edge" else
- "edge_pixel" if keep_proportion == "pad_edge_pixel" else
- "color"
- )
- out_image, out_mask = ImagePadKJ.pad(self, out_image, pad_left, pad_right, pad_top, pad_bottom, 0, pad_color, pad_mode, mask=out_mask)
+ out_mask = None
# Progress UI
if unique_id and PromptServer is not None: