From a3b04de7004cc19dee9364bd71e62bab05475810 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 12 Sep 2025 16:46:46 -0700 Subject: [PATCH] Hunyuan refiner vae now works with tiled. (#9836) --- comfy/ldm/hunyuan_video/vae_refiner.py | 1 - comfy/sd.py | 21 +++++++++++++++------ 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/comfy/ldm/hunyuan_video/vae_refiner.py b/comfy/ldm/hunyuan_video/vae_refiner.py index e3fff9bbe..c6f742710 100644 --- a/comfy/ldm/hunyuan_video/vae_refiner.py +++ b/comfy/ldm/hunyuan_video/vae_refiner.py @@ -185,7 +185,6 @@ class Encoder(nn.Module): self.regul = comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer() def forward(self, x): - x = x.unsqueeze(2) x = self.conv_in(x) for stage in self.down: diff --git a/comfy/sd.py b/comfy/sd.py index 02ddc7239..f8f1a89e8 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -412,9 +412,12 @@ class VAE: self.working_dtypes = [torch.bfloat16, torch.float32] elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32: ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True} - self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1] - self.downscale_ratio = 16 - self.upscale_ratio = 16 + ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1] + self.latent_channels = 64 + self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16) + self.upscale_index_formula = (4, 16, 16) + self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16) + self.downscale_index_formula = (4, 16, 16) self.latent_dim = 3 self.not_video = True self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] @@ -684,8 +687,11 @@ class VAE: self.throw_exception_if_invalid() pixel_samples = self.vae_encode_crop_pixels(pixel_samples) pixel_samples = pixel_samples.movedim(-1, 1) - if not self.not_video and self.latent_dim == 3 and pixel_samples.ndim < 5: - pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) + if self.latent_dim == 3 and pixel_samples.ndim < 5: + if not self.not_video: + pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) + else: + pixel_samples = pixel_samples.unsqueeze(2) try: memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) @@ -719,7 +725,10 @@ class VAE: dims = self.latent_dim pixel_samples = pixel_samples.movedim(-1, 1) if dims == 3: - pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) + if not self.not_video: + pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) + else: + pixel_samples = pixel_samples.unsqueeze(2) memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) # TODO: calculate mem required for tile model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)