diff --git a/nodes/image_nodes.py b/nodes/image_nodes.py index b89f418..cd858c4 100644 --- a/nodes/image_nodes.py +++ b/nodes/image_nodes.py @@ -3209,9 +3209,11 @@ class ImagePadKJ: if mask is not None: - out_masks = torch.zeros((BM, padded_height, padded_width), dtype=mask.dtype, device=mask.device) - for m in range(BM): - out_masks[m, pad_top:pad_top+H, pad_left:pad_left+W] = mask[m] + out_masks = torch.nn.functional.pad( + mask, + (pad_left, pad_right, pad_top, pad_bottom), + mode='replicate' + ) else: out_masks = torch.ones((B, padded_height, padded_width), dtype=image.dtype, device=image.device) for m in range(B):