mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-09 04:44:30 +08:00
Merge branch 'pr/402'
This commit is contained in:
commit
3aff68488f
@ -2476,6 +2476,7 @@ class ImageResizeKJv2:
|
||||
"optional" : {
|
||||
"mask": ("MASK",),
|
||||
"device": (["cpu", "gpu"],),
|
||||
"per_batch": ("INT", { "default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1, "tooltip": "Process images in sub-batches to reduce memory usage. 0 disables sub-batching."}),
|
||||
},
|
||||
"hidden": {
|
||||
"unique_id": "UNIQUE_ID",
|
||||
@ -2494,7 +2495,7 @@ Keep proportions keeps the aspect ratio of the image, by
|
||||
highest dimension.
|
||||
"""
|
||||
|
||||
def resize(self, image, width, height, keep_proportion, upscale_method, divisible_by, pad_color, crop_position, unique_id, device="cpu", mask=None):
|
||||
def resize(self, image, width, height, keep_proportion, upscale_method, divisible_by, pad_color, crop_position, unique_id, device="cpu", mask=None, per_batch=0):
|
||||
B, H, W, C = image.shape
|
||||
|
||||
if device == "gpu":
|
||||
@ -2510,6 +2511,10 @@ highest dimension.
|
||||
height = H
|
||||
|
||||
pillarbox_blur = keep_proportion == "pillarbox_blur"
|
||||
|
||||
# Initialize padding variables
|
||||
pad_left = pad_right = pad_top = pad_bottom = 0
|
||||
|
||||
if keep_proportion == "resize" or keep_proportion.startswith("pad") or pillarbox_blur:
|
||||
# If one of the dimensions is zero, calculate it to maintain the aspect ratio
|
||||
if width == 0 and height != 0:
|
||||
@ -2528,7 +2533,6 @@ highest dimension.
|
||||
new_width = width
|
||||
new_height = height
|
||||
|
||||
pad_left = pad_right = pad_top = pad_bottom = 0
|
||||
if keep_proportion.startswith("pad") or pillarbox_blur:
|
||||
# Calculate padding based on position
|
||||
if crop_position == "center":
|
||||
@ -2564,71 +2568,126 @@ highest dimension.
|
||||
width = width - (width % divisible_by)
|
||||
height = height - (height % divisible_by)
|
||||
|
||||
out_image = image.clone().to(device)
|
||||
if mask is not None:
|
||||
out_mask = mask.clone().to(device)
|
||||
# Preflight estimate (log-only when batching is active)
|
||||
if per_batch != 0 and B > per_batch:
|
||||
try:
|
||||
bytes_per_elem = image.element_size() # typically 4 for float32
|
||||
est_total_bytes = B * height * width * C * bytes_per_elem
|
||||
est_mb = est_total_bytes / (1024 * 1024)
|
||||
msg = f"<tr><td>Resize v2</td><td>estimated output ~{est_mb:.2f} MB; batching {per_batch}/{B}</td></tr>"
|
||||
if unique_id and PromptServer is not None:
|
||||
try:
|
||||
PromptServer.instance.send_progress_text(msg, unique_id)
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
print(f"[ImageResizeKJv2] estimated output ~{est_mb:.2f} MB; batching {per_batch}/{B}")
|
||||
except:
|
||||
pass
|
||||
|
||||
def _process_subbatch(in_image, in_mask, pad_left, pad_right, pad_top, pad_bottom):
|
||||
# Avoid unnecessary clones; only move if needed
|
||||
out_image = in_image if in_image.device == device else in_image.to(device)
|
||||
out_mask = None if in_mask is None else (in_mask if in_mask.device == device else in_mask.to(device))
|
||||
|
||||
# Crop logic
|
||||
if keep_proportion == "crop":
|
||||
old_height = out_image.shape[-3]
|
||||
old_width = out_image.shape[-2]
|
||||
old_aspect = old_width / old_height
|
||||
new_aspect = width / height
|
||||
if old_aspect > new_aspect:
|
||||
crop_w = round(old_height * new_aspect)
|
||||
crop_h = old_height
|
||||
else:
|
||||
crop_w = old_width
|
||||
crop_h = round(old_width / new_aspect)
|
||||
if crop_position == "center":
|
||||
x = (old_width - crop_w) // 2
|
||||
y = (old_height - crop_h) // 2
|
||||
elif crop_position == "top":
|
||||
x = (old_width - crop_w) // 2
|
||||
y = 0
|
||||
elif crop_position == "bottom":
|
||||
x = (old_width - crop_w) // 2
|
||||
y = old_height - crop_h
|
||||
elif crop_position == "left":
|
||||
x = 0
|
||||
y = (old_height - crop_h) // 2
|
||||
elif crop_position == "right":
|
||||
x = old_width - crop_w
|
||||
y = (old_height - crop_h) // 2
|
||||
out_image = out_image.narrow(-2, x, crop_w).narrow(-3, y, crop_h)
|
||||
if out_mask is not None:
|
||||
out_mask = out_mask.narrow(-1, x, crop_w).narrow(-2, y, crop_h)
|
||||
|
||||
out_image = common_upscale(out_image.movedim(-1,1), width, height, upscale_method, crop="disabled").movedim(1,-1)
|
||||
if out_mask is not None:
|
||||
if upscale_method == "lanczos":
|
||||
out_mask = common_upscale(out_mask.unsqueeze(1).repeat(1, 3, 1, 1), width, height, upscale_method, crop="disabled").movedim(1,-1)[:, :, :, 0]
|
||||
else:
|
||||
out_mask = common_upscale(out_mask.unsqueeze(1), width, height, upscale_method, crop="disabled").squeeze(1)
|
||||
|
||||
# Pad logic
|
||||
if (keep_proportion.startswith("pad") or pillarbox_blur) and (pad_left > 0 or pad_right > 0 or pad_top > 0 or pad_bottom > 0):
|
||||
padded_width = width + pad_left + pad_right
|
||||
padded_height = height + pad_top + pad_bottom
|
||||
if divisible_by > 1:
|
||||
width_remainder = padded_width % divisible_by
|
||||
height_remainder = padded_height % divisible_by
|
||||
if width_remainder > 0:
|
||||
extra_width = divisible_by - width_remainder
|
||||
pad_right += extra_width
|
||||
if height_remainder > 0:
|
||||
extra_height = divisible_by - height_remainder
|
||||
pad_bottom += extra_height
|
||||
|
||||
pad_mode = (
|
||||
"pillarbox_blur" if pillarbox_blur else
|
||||
"edge" if keep_proportion == "pad_edge" else
|
||||
"edge_pixel" if keep_proportion == "pad_edge_pixel" else
|
||||
"color"
|
||||
)
|
||||
out_image, out_mask = ImagePadKJ.pad(self, out_image, pad_left, pad_right, pad_top, pad_bottom, 0, pad_color, pad_mode, mask=out_mask)
|
||||
|
||||
return out_image, out_mask
|
||||
|
||||
# If batching disabled (per_batch==0) or batch fits, process whole batch
|
||||
if per_batch == 0 or B <= per_batch:
|
||||
out_image, out_mask = _process_subbatch(image, mask, pad_left, pad_right, pad_top, pad_bottom)
|
||||
else:
|
||||
out_mask = None
|
||||
|
||||
# Crop logic
|
||||
if keep_proportion == "crop":
|
||||
old_width = W
|
||||
old_height = H
|
||||
old_aspect = old_width / old_height
|
||||
new_aspect = width / height
|
||||
if old_aspect > new_aspect:
|
||||
crop_w = round(old_height * new_aspect)
|
||||
crop_h = old_height
|
||||
chunks = []
|
||||
mask_chunks = [] if mask is not None else None
|
||||
total_batches = (B + per_batch - 1) // per_batch
|
||||
current_batch = 0
|
||||
for start_idx in range(0, B, per_batch):
|
||||
current_batch += 1
|
||||
end_idx = min(start_idx + per_batch, B)
|
||||
sub_img = image[start_idx:end_idx]
|
||||
sub_mask = mask[start_idx:end_idx] if mask is not None else None
|
||||
sub_out_img, sub_out_mask = _process_subbatch(sub_img, sub_mask, pad_left, pad_right, pad_top, pad_bottom)
|
||||
chunks.append(sub_out_img.cpu())
|
||||
if mask is not None:
|
||||
mask_chunks.append(sub_out_mask.cpu() if sub_out_mask is not None else None)
|
||||
# Per-batch progress update
|
||||
if unique_id and PromptServer is not None:
|
||||
try:
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"<tr><td>Resize v2</td><td>batch {current_batch}/{total_batches} · images {end_idx}/{B}</td></tr>",
|
||||
unique_id
|
||||
)
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
print(f"[ImageResizeKJv2] batch {current_batch}/{total_batches} · images {end_idx}/{B}")
|
||||
except:
|
||||
pass
|
||||
out_image = torch.cat(chunks, dim=0)
|
||||
if mask is not None and any(m is not None for m in mask_chunks):
|
||||
out_mask = torch.cat([m for m in mask_chunks if m is not None], dim=0)
|
||||
else:
|
||||
crop_w = old_width
|
||||
crop_h = round(old_width / new_aspect)
|
||||
if crop_position == "center":
|
||||
x = (old_width - crop_w) // 2
|
||||
y = (old_height - crop_h) // 2
|
||||
elif crop_position == "top":
|
||||
x = (old_width - crop_w) // 2
|
||||
y = 0
|
||||
elif crop_position == "bottom":
|
||||
x = (old_width - crop_w) // 2
|
||||
y = old_height - crop_h
|
||||
elif crop_position == "left":
|
||||
x = 0
|
||||
y = (old_height - crop_h) // 2
|
||||
elif crop_position == "right":
|
||||
x = old_width - crop_w
|
||||
y = (old_height - crop_h) // 2
|
||||
out_image = out_image.narrow(-2, x, crop_w).narrow(-3, y, crop_h)
|
||||
if mask is not None:
|
||||
out_mask = out_mask.narrow(-1, x, crop_w).narrow(-2, y, crop_h)
|
||||
|
||||
out_image = common_upscale(out_image.movedim(-1,1), width, height, upscale_method, crop="disabled").movedim(1,-1)
|
||||
if mask is not None:
|
||||
if upscale_method == "lanczos":
|
||||
out_mask = common_upscale(out_mask.unsqueeze(1).repeat(1, 3, 1, 1), width, height, upscale_method, crop="disabled").movedim(1,-1)[:, :, :, 0]
|
||||
else:
|
||||
out_mask = common_upscale(out_mask.unsqueeze(1), width, height, upscale_method, crop="disabled").squeeze(1)
|
||||
|
||||
# Pad logic
|
||||
if (keep_proportion.startswith("pad") or pillarbox_blur) and (pad_left > 0 or pad_right > 0 or pad_top > 0 or pad_bottom > 0):
|
||||
padded_width = width + pad_left + pad_right
|
||||
padded_height = height + pad_top + pad_bottom
|
||||
if divisible_by > 1:
|
||||
width_remainder = padded_width % divisible_by
|
||||
height_remainder = padded_height % divisible_by
|
||||
if width_remainder > 0:
|
||||
extra_width = divisible_by - width_remainder
|
||||
pad_right += extra_width
|
||||
if height_remainder > 0:
|
||||
extra_height = divisible_by - height_remainder
|
||||
pad_bottom += extra_height
|
||||
|
||||
pad_mode = (
|
||||
"pillarbox_blur" if pillarbox_blur else
|
||||
"edge" if keep_proportion == "pad_edge" else
|
||||
"edge_pixel" if keep_proportion == "pad_edge_pixel" else
|
||||
"color"
|
||||
)
|
||||
out_image, out_mask = ImagePadKJ.pad(self, out_image, pad_left, pad_right, pad_top, pad_bottom, 0, pad_color, pad_mode, mask=out_mask)
|
||||
out_mask = None
|
||||
|
||||
# Progress UI
|
||||
if unique_id and PromptServer is not None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user