Merge ae863719912e96c5c34d9cf6714edbcce9058f34 into 807538fe6c66bca8c91edbad14414fb4e109cbde

This commit is contained in:
Lightje 2025-12-22 01:37:21 +01:00 committed by GitHub
commit da1185f9a1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -976,6 +976,7 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
@torch.inference_mode() @torch.inference_mode()
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", downscale=False, index_formulas=None, pbar=None): def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", downscale=False, index_formulas=None, pbar=None):
dims = len(tile) dims = len(tile)
samples = samples.contiguous() # Ensure input tensor is contiguous
if not (isinstance(upscale_amount, (tuple, list))): if not (isinstance(upscale_amount, (tuple, list))):
upscale_amount = [upscale_amount] * dims upscale_amount = [upscale_amount] * dims
@ -1037,7 +1038,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
# handle entire input fitting in a single tile # handle entire input fitting in a single tile
if all(s.shape[d+2] <= tile[d] for d in range(dims)): if all(s.shape[d+2] <= tile[d] for d in range(dims)):
output[b:b+1] = function(s).to(output_device) output[b:b+1] = function(s.contiguous()).to(output_device) # Ensure single tile is contiguous
if pbar is not None: if pbar is not None:
pbar.update(1) pbar.update(1)
continue continue
@ -1057,7 +1058,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
s_in = s_in.narrow(d + 2, pos, l) s_in = s_in.narrow(d + 2, pos, l)
upscaled.append(round(get_pos(d, pos))) upscaled.append(round(get_pos(d, pos)))
ps = function(s_in).to(output_device) ps = function(s_in.contiguous()).to(output_device) # Ensure tiled segment is contiguous
mask = torch.ones_like(ps) mask = torch.ones_like(ps)
for d in range(2, dims + 2): for d in range(2, dims + 2):
@ -1083,7 +1084,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
output[b:b+1] = out/out_div output[b:b+1] = out/out_div
return output return output
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None): def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar) return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)