diff --git a/mz_enable_vae_encode_tiling.py b/mz_enable_vae_encode_tiling.py index 90b1d7d..a038bec 100644 --- a/mz_enable_vae_encode_tiling.py +++ b/mz_enable_vae_encode_tiling.py @@ -79,11 +79,15 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: i: i + self.tile_sample_min_height, j: j + self.tile_sample_min_width, ] + tile = self.encoder(tile) if self.quant_conv is not None: tile = self.quant_conv(tile) - time.append(tile) - self._clear_fake_context_parallel_cache() + time.append(tile[0]) + try: + self._clear_fake_context_parallel_cache() + except: + pass row.append(torch.cat(time, dim=2)) rows.append(row) result_rows = [] @@ -130,7 +134,10 @@ def _encode( if self.quant_conv is not None: z_intermediate = self.quant_conv(z_intermediate) h.append(z_intermediate) - self._clear_fake_context_parallel_cache() + try: + self._clear_fake_context_parallel_cache() + except: + pass h = torch.cat(h, dim=2) return h