Update nodes.py

This commit is contained in:
kijai 2023-11-20 10:55:37 +02:00
parent b0938d7908
commit 4a2ae659c1

View File

@ -1992,33 +1992,36 @@ class OffsetMask:
if shift_y != 0:
mask = torch.roll(mask, shifts=shift_y, dims=1)
else:
if incremental:
for i in range(batch_size):
for i in range(batch_size):
if incremental:
temp_x = min(x * (i+1), width-1)
temp_y = min(y * (i+1), height-1)
if temp_x > 0:
if padding_mode == 'empty':
mask[i] = torch.cat([torch.zeros((height, temp_x)), mask[i, :, :-temp_x]], dim=1)
elif padding_mode in ['replicate', 'reflect']:
mask[i] = pad(mask[i, :, :-temp_x], (0, temp_x), mode=padding_mode)
else:
temp_x = min(x, width-1)
temp_y = min(y, height-1)
if temp_x > 0:
if padding_mode == 'empty':
mask[i] = torch.cat([torch.zeros((height, temp_x)), mask[i, :, :-temp_x]], dim=1)
elif padding_mode in ['replicate', 'reflect']:
mask[i] = pad(mask[i, :, :-temp_x], (0, temp_x), mode=padding_mode)
elif temp_x < 0:
if padding_mode == 'empty':
mask[i] = torch.cat([mask[i, :, :temp_x], torch.zeros((height, -temp_x))], dim=1)
elif padding_mode in ['replicate', 'reflect']:
mask[i] = pad(mask[i, :, -temp_x:], (temp_x, 0), mode=padding_mode)
elif temp_x < 0:
if padding_mode == 'empty':
mask[i] = torch.cat([mask[i, :, -temp_x:], torch.zeros((height, -temp_x))], dim=1)
elif padding_mode in ['replicate', 'reflect']:
mask[i] = pad(mask[i, :, -temp_x:], (temp_x, 0), mode=padding_mode)
if temp_y > 0:
if padding_mode == 'empty':
mask[i] = torch.cat([torch.zeros((temp_y, width)), mask[i, :-temp_y, :]], dim=0)
elif padding_mode in ['replicate', 'reflect']:
mask[i] = pad(mask[i, :-temp_y, :], (0, temp_y), mode=padding_mode)
elif temp_y < 0:
if padding_mode == 'empty':
mask[i] = torch.cat([mask[i, -temp_y:, :], torch.zeros((-temp_y, width))], dim=0)
elif padding_mode in ['replicate', 'reflect']:
mask[i] = pad(mask[i, -temp_y:, :], (temp_y, 0), mode=padding_mode)
if temp_y > 0:
if padding_mode == 'empty':
mask[i] = torch.cat([torch.zeros((temp_y, width)), mask[i, :-temp_y, :]], dim=0)
elif padding_mode in ['replicate', 'reflect']:
mask[i] = pad(mask[i, :-temp_y, :], (0, temp_y), mode=padding_mode)
elif temp_y < 0:
if padding_mode == 'empty':
mask[i] = torch.cat([mask[i, :temp_y, :], torch.zeros((-temp_y, width))], dim=0)
elif padding_mode in ['replicate', 'reflect']:
mask[i] = pad(mask[i, -temp_y:, :], (temp_y, 0), mode=padding_mode)
return mask,
class WidgetToString: