Add pillarbox_blur mode for proportional image resizing with blurred background.

This commit is contained in:
Christopher Anderson 2025-08-10 22:47:27 +10:00
parent 87d0cf42db
commit 268063e317

View File

@ -2436,7 +2436,7 @@ class ImageResizeKJv2:
"width": ("INT", { "default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1, }),
"height": ("INT", { "default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1, }),
"upscale_method": (s.upscale_methods,),
"keep_proportion": (["stretch", "resize", "pad", "pad_edge", "crop"], { "default": False }),
"keep_proportion": (["stretch", "resize", "pad", "pad_edge", "crop", "pillarbox_blur"], { "default": False }),
"pad_color": ("STRING", { "default": "0, 0, 0", "tooltip": "Color to use for padding."}),
"crop_position": (["center", "top", "bottom", "left", "right"], { "default": "center" }),
"divisible_by": ("INT", { "default": 2, "min": 0, "max": 512, "step": 1, }),
@ -2476,7 +2476,140 @@ highest dimension.
width = W
if height == 0:
height = H
# Pillarbox blur path: build blurred background that fills the target canvas,
# overlay resized foreground centered on top.
if keep_proportion == "pillarbox_blur":
# Adjust to divisibility first, since we are producing the final canvas here
if divisible_by > 1:
width = width - (width % divisible_by)
height = height - (height % divisible_by)
image = image.to(device)
if mask is not None:
mask = mask.to(device)
# Compute foreground fit size (keep aspect, fit inside target)
ratio_fit = min(width / W, height / H) if (W > 0 and H > 0) else 1.0
fg_w = max(1, int(round(W * ratio_fit)))
fg_h = max(1, int(round(H * ratio_fit)))
# Foreground placement (center)
pad_left = (width - fg_w) // 2
pad_right = width - fg_w - pad_left
pad_top = (height - fg_h) // 2
pad_bottom = height - fg_h - pad_top
# Prepare output tensors
out_image = torch.zeros((B, height, width, C), dtype=image.dtype, device=device)
out_mask = None
# Helper: separable gaussian blur for NCHW
def _gaussian_blur_nchw(img_nchw, sigma_px):
if sigma_px <= 0:
return img_nchw
radius = max(1, int(3.0 * float(sigma_px)))
k = 2 * radius + 1
x = torch.arange(-radius, radius + 1, device=img_nchw.device, dtype=img_nchw.dtype)
k1 = torch.exp(-(x * x) / (2.0 * float(sigma_px) * float(sigma_px)))
k1 = k1 / k1.sum()
kx = k1.view(1, 1, 1, k)
ky = k1.view(1, 1, k, 1)
c = img_nchw.shape[1]
kx = kx.repeat(c, 1, 1, 1)
ky = ky.repeat(c, 1, 1, 1)
img_nchw = F.conv2d(img_nchw, kx, padding=(0, radius), groups=c)
img_nchw = F.conv2d(img_nchw, ky, padding=(radius, 0), groups=c)
return img_nchw
# Build per-batch frames
for b in range(B):
# Background: scale-to-fill (keep aspect), then center-crop to (height, width)
scale_fill = max(width / float(W), height / float(H)) if (W > 0 and H > 0) else 1.0
bg_w = max(1, int(round(W * scale_fill)))
bg_h = max(1, int(round(H * scale_fill)))
src_b = image[b].movedim(-1, 0).unsqueeze(0) # 1,C,H,W
bg = common_upscale(src_b, bg_w, bg_h, upscale_method, crop="disabled")
# Center crop to canvas
y0 = max(0, (bg_h - height) // 2)
x0 = max(0, (bg_w - width) // 2)
y1 = min(bg_h, y0 + height)
x1 = min(bg_w, x0 + width)
bg = bg[:, :, y0:y1, x0:x1]
# If rounding made it a hair off, pad to exact size
if bg.shape[2] != height or bg.shape[3] != width:
pad_h = height - bg.shape[2]
pad_w = width - bg.shape[3]
pad_top_fix = max(0, pad_h // 2)
pad_bottom_fix = max(0, pad_h - pad_top_fix)
pad_left_fix = max(0, pad_w // 2)
pad_right_fix = max(0, pad_w - pad_left_fix)
bg = F.pad(bg, (pad_left_fix, pad_right_fix, pad_top_fix, pad_bottom_fix), mode="replicate")
# Blur strength scales with output size
sigma = max(1.0, 0.006 * float(min(height, width)))
bg = _gaussian_blur_nchw(bg, sigma_px=sigma)
# 20% saturation reduction (Rec.709 luma)
if C >= 3:
r, g, bch = bg[:, 0:1], bg[:, 1:2], bg[:, 2:3]
luma = 0.2126 * r + 0.7152 * g + 0.0722 * bch
gray = torch.cat([luma, luma, luma], dim=1)
desat = 0.20
rgb = torch.cat([r, g, bch], dim=1)
rgb = rgb * (1.0 - desat) + gray * desat
bg[:, 0:3, :, :] = rgb
# Dim to keep attention on foreground
dim = 0.35
bg = torch.clamp(bg * dim, 0.0, 1.0)
# Write background to canvas
out_image[b] = bg.squeeze(0).movedim(0, -1)
# Resize foreground to fit size and composite at center
fg = common_upscale(image.movedim(-1, 1), fg_w, fg_h, upscale_method, crop="disabled").movedim(1, -1)
out_image[:, pad_top:pad_top+fg_h, pad_left:pad_left+fg_w, :] = fg
# Mask handling
if mask is not None:
# Transform mask like the foreground
if upscale_method == "lanczos":
# Use the same path as elsewhere in this node for lanczos
fg_mask = common_upscale(mask.unsqueeze(1).repeat(1, 3, 1, 1), fg_w, fg_h, upscale_method, crop="disabled").movedim(1, -1)[:, :, :, 0]
else:
fg_mask = common_upscale(mask.unsqueeze(1), fg_w, fg_h, upscale_method, crop="disabled").squeeze(1)
out_mask = torch.ones((B, height, width), dtype=image.dtype, device=device)
out_mask[:, pad_top:pad_top+fg_h, pad_left:pad_left+fg_w] = fg_mask
else:
out_mask = torch.ones((B, height, width), dtype=image.dtype, device=device)
out_mask[:, pad_top:pad_top+fg_h, pad_left:pad_left+fg_w] = 0.0
# Progress UI (kept consistent with existing code)
if unique_id and PromptServer is not None:
try:
num_elements = out_image.numel()
element_size = out_image.element_size()
memory_size_mb = (num_elements * element_size) / (1024 * 1024)
PromptServer.instance.send_progress_text(
f"<tr><td>Output: </td><td><b>{out_image.shape[0]}</b> x <b>{out_image.shape[2]}</b> x <b>{out_image.shape[1]} | {memory_size_mb:.2f}MB</b></td></tr>",
unique_id
)
except:
pass
return (
out_image.cpu(),
out_image.shape[2],
out_image.shape[1],
out_mask.cpu() if out_mask is not None else torch.zeros(64, 64, device=torch.device("cpu"), dtype=torch.float32),
)
# Existing logic for other modes
if keep_proportion == "resize" or keep_proportion.startswith("pad"):
# If one of the dimensions is zero, calculate it to maintain the aspect ratio
if width == 0 and height != 0:
@ -2615,7 +2748,7 @@ highest dimension.
pass
return(out_image.cpu(), out_image.shape[2], out_image.shape[1], out_mask.cpu() if out_mask is not None else torch.zeros(64,64, device=torch.device("cpu"), dtype=torch.float32))
import pathlib
class LoadAndResizeImage:
_color_channels = ["alpha", "red", "green", "blue"]