tile encode fix

This commit is contained in:
kijai 2024-10-25 03:16:33 +03:00
parent 9e488568b2
commit 7fe4716f2d

View File

@ -79,11 +79,15 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
i: i + self.tile_sample_min_height,
j: j + self.tile_sample_min_width,
]
tile = self.encoder(tile)
if self.quant_conv is not None:
tile = self.quant_conv(tile)
time.append(tile)
self._clear_fake_context_parallel_cache()
time.append(tile[0])
try:
self._clear_fake_context_parallel_cache()
except:
pass
row.append(torch.cat(time, dim=2))
rows.append(row)
result_rows = []
@ -130,7 +134,10 @@ def _encode(
if self.quant_conv is not None:
z_intermediate = self.quant_conv(z_intermediate)
h.append(z_intermediate)
self._clear_fake_context_parallel_cache()
try:
self._clear_fake_context_parallel_cache()
except:
pass
h = torch.cat(h, dim=2)
return h