From 7fe4716f2decbbe3fec8803152f014e1f6513fbc Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Fri, 25 Oct 2024 03:16:33 +0300 Subject: [PATCH] tile encode fix --- mz_enable_vae_encode_tiling.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/mz_enable_vae_encode_tiling.py b/mz_enable_vae_encode_tiling.py index 90b1d7d..a038bec 100644 --- a/mz_enable_vae_encode_tiling.py +++ b/mz_enable_vae_encode_tiling.py @@ -79,11 +79,15 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: i: i + self.tile_sample_min_height, j: j + self.tile_sample_min_width, ] + tile = self.encoder(tile) if self.quant_conv is not None: tile = self.quant_conv(tile) - time.append(tile) - self._clear_fake_context_parallel_cache() + time.append(tile[0]) + try: + self._clear_fake_context_parallel_cache() + except: + pass row.append(torch.cat(time, dim=2)) rows.append(row) result_rows = [] @@ -130,7 +134,10 @@ def _encode( if self.quant_conv is not None: z_intermediate = self.quant_conv(z_intermediate) h.append(z_intermediate) - self._clear_fake_context_parallel_cache() + try: + self._clear_fake_context_parallel_cache() + except: + pass h = torch.cat(h, dim=2) return h