diff --git a/nodes.py b/nodes.py index f13a17d..3870e11 100644 --- a/nodes.py +++ b/nodes.py @@ -1931,6 +1931,7 @@ class ResizeMask: return(outputs, outputs.shape[2], outputs.shape[1],) +from torch.nn.functional import pad class OffsetMask: @classmethod def INPUT_TYPES(s): @@ -1943,6 +1944,15 @@ class OffsetMask: "duplication_factor": ("INT", { "default": 1, "min": 1, "max": 1000, "step": 1, "display": "number" }), "roll": ("BOOLEAN", { "default": False }), "incremental": ("BOOLEAN", { "default": False }), + "padding_mode": ( + [ + 'empty', + 'border', + 'reflection', + + ], { + "default": 'empty' + }), } } @@ -1951,7 +1961,7 @@ class OffsetMask: FUNCTION = "offset" CATEGORY = "KJNodes/masking" - def offset(self, mask, x, y, angle, roll=False, incremental=False, duplication_factor=1): + def offset(self, mask, x, y, angle, roll=False, incremental=False, duplication_factor=1, padding_mode="empty"): # Create duplicates of the mask batch mask = mask.repeat(duplication_factor, 1, 1) @@ -1987,25 +1997,28 @@ class OffsetMask: temp_x = min(x * (i+1), width-1) temp_y = min(y * (i+1), height-1) if temp_x > 0: - mask[i] = torch.cat([torch.zeros((height, temp_x)), mask[i, :, :-temp_x]], dim=1) - elif temp_x < 0: - mask[i] = torch.cat([mask[i, :, -temp_x:], torch.zeros((height, -temp_x))], dim=1) - if temp_y > 0: - mask[i] = torch.cat([torch.zeros((temp_y, width)), mask[i, :-temp_y, :]], dim=0) - elif temp_y < 0: - mask[i] = torch.cat([mask[i, -temp_y:, :], torch.zeros((-temp_y, width))], dim=0) - else: - temp_x = min(x, width-1) - temp_y = min(y, height-1) - if temp_x > 0: - mask = torch.cat([torch.zeros((batch_size, height, temp_x)), mask[:, :, :-temp_x]], dim=2) - elif temp_x < 0: - mask = torch.cat([mask[:, :, -temp_x:], torch.zeros((batch_size, height, -temp_x))], dim=2) - if temp_y > 0: - mask = torch.cat([torch.zeros((batch_size, temp_y, width)), mask[:, :-temp_y, :]], dim=1) - elif temp_y < 0: - mask = torch.cat([mask[:, -temp_y:, :], torch.zeros((batch_size, -temp_y, width))], dim=1) + if padding_mode == 'empty': + mask[i] = torch.cat([torch.zeros((height, temp_x)), mask[i, :, :-temp_x]], dim=1) + elif padding_mode in ['replicate', 'reflect']: + mask[i] = pad(mask[i, :, :-temp_x], (0, temp_x), mode=padding_mode) + elif temp_x < 0: + if padding_mode == 'empty': + mask[i] = torch.cat([mask[i, :, -temp_x:], torch.zeros((height, -temp_x))], dim=1) + elif padding_mode in ['replicate', 'reflect']: + mask[i] = pad(mask[i, :, -temp_x:], (temp_x, 0), mode=padding_mode) + + if temp_y > 0: + if padding_mode == 'empty': + mask[i] = torch.cat([torch.zeros((temp_y, width)), mask[i, :-temp_y, :]], dim=0) + elif padding_mode in ['replicate', 'reflect']: + mask[i] = pad(mask[i, :-temp_y, :], (0, temp_y), mode=padding_mode) + + elif temp_y < 0: + if padding_mode == 'empty': + mask[i] = torch.cat([mask[i, -temp_y:, :], torch.zeros((-temp_y, width))], dim=0) + elif padding_mode in ['replicate', 'reflect']: + mask[i] = pad(mask[i, -temp_y:, :], (temp_y, 0), mode=padding_mode) return mask, class WidgetToString: