Some fixes/adjustments

This commit is contained in:
kijai 2025-10-06 15:07:08 +03:00
parent 629a2bf423
commit ead0cce23b

View File

@ -2476,7 +2476,7 @@ class ImageResizeKJv2:
"optional" : {
"mask": ("MASK",),
"device": (["cpu", "gpu"],),
"per_batch": ("INT", { "default": 16, "min": 0, "max": 4096, "step": 1, "tooltip": "Process images in sub-batches to reduce memory usage. 0 disables sub-batching."}),
"per_batch": ("INT", { "default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1, "tooltip": "Process images in sub-batches to reduce memory usage. 0 disables sub-batching."}),
},
"hidden": {
"unique_id": "UNIQUE_ID",
@ -2495,7 +2495,7 @@ Keep proportions keeps the aspect ratio of the image, by
highest dimension.
"""
def resize(self, image, width, height, keep_proportion, upscale_method, divisible_by, pad_color, crop_position, unique_id, device="cpu", mask=None, per_batch=16):
def resize(self, image, width, height, keep_proportion, upscale_method, divisible_by, pad_color, crop_position, unique_id, device="cpu", mask=None, per_batch=0):
B, H, W, C = image.shape
if device == "gpu":
@ -2511,6 +2511,10 @@ highest dimension.
height = H
pillarbox_blur = keep_proportion == "pillarbox_blur"
# Initialize padding variables
pad_left = pad_right = pad_top = pad_bottom = 0
if keep_proportion == "resize" or keep_proportion.startswith("pad") or pillarbox_blur:
# If one of the dimensions is zero, calculate it to maintain the aspect ratio
if width == 0 and height != 0:
@ -2529,7 +2533,6 @@ highest dimension.
new_width = width
new_height = height
pad_left = pad_right = pad_top = pad_bottom = 0
if keep_proportion.startswith("pad") or pillarbox_blur:
# Calculate padding based on position
if crop_position == "center":
@ -2566,7 +2569,7 @@ highest dimension.
height = height - (height % divisible_by)
# Preflight estimate (log-only when batching is active)
if per_batch and B > per_batch:
if per_batch != 0 and B > per_batch:
try:
bytes_per_elem = image.element_size() # typically 4 for float32
est_total_bytes = B * height * width * C * bytes_per_elem
@ -2582,7 +2585,7 @@ highest dimension.
except:
pass
def _process_subbatch(in_image, in_mask):
def _process_subbatch(in_image, in_mask, pad_left, pad_right, pad_top, pad_bottom):
# Avoid unnecessary clones; only move if needed
out_image = in_image if in_image.device == device else in_image.to(device)
out_mask = None if in_mask is None else (in_mask if in_mask.device == device else in_mask.to(device))
@ -2650,8 +2653,8 @@ highest dimension.
return out_image, out_mask
# If batching disabled (per_batch==0) or batch fits, process whole batch
if per_batch is None or per_batch == 0 or B <= per_batch:
out_image, out_mask = _process_subbatch(image, mask)
if per_batch == 0 or B <= per_batch:
out_image, out_mask = _process_subbatch(image, mask, pad_left, pad_right, pad_top, pad_bottom)
else:
chunks = []
mask_chunks = [] if mask is not None else None
@ -2662,7 +2665,7 @@ highest dimension.
end_idx = min(start_idx + per_batch, B)
sub_img = image[start_idx:end_idx]
sub_mask = mask[start_idx:end_idx] if mask is not None else None
sub_out_img, sub_out_mask = _process_subbatch(sub_img, sub_mask)
sub_out_img, sub_out_mask = _process_subbatch(sub_img, sub_mask, pad_left, pad_right, pad_top, pad_bottom)
chunks.append(sub_out_img.cpu())
if mask is not None:
mask_chunks.append(sub_out_mask.cpu() if sub_out_mask is not None else None)