Merge branch 'pr/402'

This commit is contained in:
kijai 2025-10-06 15:07:36 +03:00
commit 3aff68488f

View File

@ -2476,6 +2476,7 @@ class ImageResizeKJv2:
"optional" : {
"mask": ("MASK",),
"device": (["cpu", "gpu"],),
"per_batch": ("INT", { "default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1, "tooltip": "Process images in sub-batches to reduce memory usage. 0 disables sub-batching."}),
},
"hidden": {
"unique_id": "UNIQUE_ID",
@ -2494,7 +2495,7 @@ Keep proportions keeps the aspect ratio of the image, by
highest dimension.
"""
def resize(self, image, width, height, keep_proportion, upscale_method, divisible_by, pad_color, crop_position, unique_id, device="cpu", mask=None):
def resize(self, image, width, height, keep_proportion, upscale_method, divisible_by, pad_color, crop_position, unique_id, device="cpu", mask=None, per_batch=0):
B, H, W, C = image.shape
if device == "gpu":
@ -2510,6 +2511,10 @@ highest dimension.
height = H
pillarbox_blur = keep_proportion == "pillarbox_blur"
# Initialize padding variables
pad_left = pad_right = pad_top = pad_bottom = 0
if keep_proportion == "resize" or keep_proportion.startswith("pad") or pillarbox_blur:
# If one of the dimensions is zero, calculate it to maintain the aspect ratio
if width == 0 and height != 0:
@ -2528,7 +2533,6 @@ highest dimension.
new_width = width
new_height = height
pad_left = pad_right = pad_top = pad_bottom = 0
if keep_proportion.startswith("pad") or pillarbox_blur:
# Calculate padding based on position
if crop_position == "center":
@ -2564,71 +2568,126 @@ highest dimension.
width = width - (width % divisible_by)
height = height - (height % divisible_by)
out_image = image.clone().to(device)
if mask is not None:
out_mask = mask.clone().to(device)
# Preflight estimate (log-only when batching is active)
if per_batch != 0 and B > per_batch:
try:
bytes_per_elem = image.element_size() # typically 4 for float32
est_total_bytes = B * height * width * C * bytes_per_elem
est_mb = est_total_bytes / (1024 * 1024)
msg = f"<tr><td>Resize v2</td><td>estimated output ~{est_mb:.2f} MB; batching {per_batch}/{B}</td></tr>"
if unique_id and PromptServer is not None:
try:
PromptServer.instance.send_progress_text(msg, unique_id)
except:
pass
else:
print(f"[ImageResizeKJv2] estimated output ~{est_mb:.2f} MB; batching {per_batch}/{B}")
except:
pass
def _process_subbatch(in_image, in_mask, pad_left, pad_right, pad_top, pad_bottom):
# Avoid unnecessary clones; only move if needed
out_image = in_image if in_image.device == device else in_image.to(device)
out_mask = None if in_mask is None else (in_mask if in_mask.device == device else in_mask.to(device))
# Crop logic
if keep_proportion == "crop":
old_height = out_image.shape[-3]
old_width = out_image.shape[-2]
old_aspect = old_width / old_height
new_aspect = width / height
if old_aspect > new_aspect:
crop_w = round(old_height * new_aspect)
crop_h = old_height
else:
crop_w = old_width
crop_h = round(old_width / new_aspect)
if crop_position == "center":
x = (old_width - crop_w) // 2
y = (old_height - crop_h) // 2
elif crop_position == "top":
x = (old_width - crop_w) // 2
y = 0
elif crop_position == "bottom":
x = (old_width - crop_w) // 2
y = old_height - crop_h
elif crop_position == "left":
x = 0
y = (old_height - crop_h) // 2
elif crop_position == "right":
x = old_width - crop_w
y = (old_height - crop_h) // 2
out_image = out_image.narrow(-2, x, crop_w).narrow(-3, y, crop_h)
if out_mask is not None:
out_mask = out_mask.narrow(-1, x, crop_w).narrow(-2, y, crop_h)
out_image = common_upscale(out_image.movedim(-1,1), width, height, upscale_method, crop="disabled").movedim(1,-1)
if out_mask is not None:
if upscale_method == "lanczos":
out_mask = common_upscale(out_mask.unsqueeze(1).repeat(1, 3, 1, 1), width, height, upscale_method, crop="disabled").movedim(1,-1)[:, :, :, 0]
else:
out_mask = common_upscale(out_mask.unsqueeze(1), width, height, upscale_method, crop="disabled").squeeze(1)
# Pad logic
if (keep_proportion.startswith("pad") or pillarbox_blur) and (pad_left > 0 or pad_right > 0 or pad_top > 0 or pad_bottom > 0):
padded_width = width + pad_left + pad_right
padded_height = height + pad_top + pad_bottom
if divisible_by > 1:
width_remainder = padded_width % divisible_by
height_remainder = padded_height % divisible_by
if width_remainder > 0:
extra_width = divisible_by - width_remainder
pad_right += extra_width
if height_remainder > 0:
extra_height = divisible_by - height_remainder
pad_bottom += extra_height
pad_mode = (
"pillarbox_blur" if pillarbox_blur else
"edge" if keep_proportion == "pad_edge" else
"edge_pixel" if keep_proportion == "pad_edge_pixel" else
"color"
)
out_image, out_mask = ImagePadKJ.pad(self, out_image, pad_left, pad_right, pad_top, pad_bottom, 0, pad_color, pad_mode, mask=out_mask)
return out_image, out_mask
# If batching disabled (per_batch==0) or batch fits, process whole batch
if per_batch == 0 or B <= per_batch:
out_image, out_mask = _process_subbatch(image, mask, pad_left, pad_right, pad_top, pad_bottom)
else:
out_mask = None
# Crop logic
if keep_proportion == "crop":
old_width = W
old_height = H
old_aspect = old_width / old_height
new_aspect = width / height
if old_aspect > new_aspect:
crop_w = round(old_height * new_aspect)
crop_h = old_height
chunks = []
mask_chunks = [] if mask is not None else None
total_batches = (B + per_batch - 1) // per_batch
current_batch = 0
for start_idx in range(0, B, per_batch):
current_batch += 1
end_idx = min(start_idx + per_batch, B)
sub_img = image[start_idx:end_idx]
sub_mask = mask[start_idx:end_idx] if mask is not None else None
sub_out_img, sub_out_mask = _process_subbatch(sub_img, sub_mask, pad_left, pad_right, pad_top, pad_bottom)
chunks.append(sub_out_img.cpu())
if mask is not None:
mask_chunks.append(sub_out_mask.cpu() if sub_out_mask is not None else None)
# Per-batch progress update
if unique_id and PromptServer is not None:
try:
PromptServer.instance.send_progress_text(
f"<tr><td>Resize v2</td><td>batch {current_batch}/{total_batches} · images {end_idx}/{B}</td></tr>",
unique_id
)
except:
pass
else:
try:
print(f"[ImageResizeKJv2] batch {current_batch}/{total_batches} · images {end_idx}/{B}")
except:
pass
out_image = torch.cat(chunks, dim=0)
if mask is not None and any(m is not None for m in mask_chunks):
out_mask = torch.cat([m for m in mask_chunks if m is not None], dim=0)
else:
crop_w = old_width
crop_h = round(old_width / new_aspect)
if crop_position == "center":
x = (old_width - crop_w) // 2
y = (old_height - crop_h) // 2
elif crop_position == "top":
x = (old_width - crop_w) // 2
y = 0
elif crop_position == "bottom":
x = (old_width - crop_w) // 2
y = old_height - crop_h
elif crop_position == "left":
x = 0
y = (old_height - crop_h) // 2
elif crop_position == "right":
x = old_width - crop_w
y = (old_height - crop_h) // 2
out_image = out_image.narrow(-2, x, crop_w).narrow(-3, y, crop_h)
if mask is not None:
out_mask = out_mask.narrow(-1, x, crop_w).narrow(-2, y, crop_h)
out_image = common_upscale(out_image.movedim(-1,1), width, height, upscale_method, crop="disabled").movedim(1,-1)
if mask is not None:
if upscale_method == "lanczos":
out_mask = common_upscale(out_mask.unsqueeze(1).repeat(1, 3, 1, 1), width, height, upscale_method, crop="disabled").movedim(1,-1)[:, :, :, 0]
else:
out_mask = common_upscale(out_mask.unsqueeze(1), width, height, upscale_method, crop="disabled").squeeze(1)
# Pad logic
if (keep_proportion.startswith("pad") or pillarbox_blur) and (pad_left > 0 or pad_right > 0 or pad_top > 0 or pad_bottom > 0):
padded_width = width + pad_left + pad_right
padded_height = height + pad_top + pad_bottom
if divisible_by > 1:
width_remainder = padded_width % divisible_by
height_remainder = padded_height % divisible_by
if width_remainder > 0:
extra_width = divisible_by - width_remainder
pad_right += extra_width
if height_remainder > 0:
extra_height = divisible_by - height_remainder
pad_bottom += extra_height
pad_mode = (
"pillarbox_blur" if pillarbox_blur else
"edge" if keep_proportion == "pad_edge" else
"edge_pixel" if keep_proportion == "pad_edge_pixel" else
"color"
)
out_image, out_mask = ImagePadKJ.pad(self, out_image, pad_left, pad_right, pad_top, pad_bottom, 0, pad_color, pad_mode, mask=out_mask)
out_mask = None
# Progress UI
if unique_id and PromptServer is not None: