tile encode fix

This commit is contained in:
kijai 2024-10-25 03:16:33 +03:00
parent 9e488568b2
commit 7fe4716f2d

View File

@ -79,11 +79,15 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
i: i + self.tile_sample_min_height, i: i + self.tile_sample_min_height,
j: j + self.tile_sample_min_width, j: j + self.tile_sample_min_width,
] ]
tile = self.encoder(tile) tile = self.encoder(tile)
if self.quant_conv is not None: if self.quant_conv is not None:
tile = self.quant_conv(tile) tile = self.quant_conv(tile)
time.append(tile) time.append(tile[0])
try:
self._clear_fake_context_parallel_cache() self._clear_fake_context_parallel_cache()
except:
pass
row.append(torch.cat(time, dim=2)) row.append(torch.cat(time, dim=2))
rows.append(row) rows.append(row)
result_rows = [] result_rows = []
@ -130,7 +134,10 @@ def _encode(
if self.quant_conv is not None: if self.quant_conv is not None:
z_intermediate = self.quant_conv(z_intermediate) z_intermediate = self.quant_conv(z_intermediate)
h.append(z_intermediate) h.append(z_intermediate)
try:
self._clear_fake_context_parallel_cache() self._clear_fake_context_parallel_cache()
except:
pass
h = torch.cat(h, dim=2) h = torch.cat(h, dim=2)
return h return h