tiled encoding

This commit is contained in:
kijai 2024-11-01 18:03:47 +02:00
parent 69ab797b8c
commit a5b06b02ad
3 changed files with 28 additions and 11 deletions

View File

@ -323,6 +323,5 @@ class T2VSynthMochiModel:
comfy_pbar.update(1)
self.dit.to(self.offload_device)
#samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std)
logging.info(f"samples shape: {z.shape}")
return z

View File

@ -899,6 +899,11 @@ class Encoder(nn.Module):
assert logvar.shape == means.shape
assert means.size(1) == self.latent_dim
noise = torch.randn(means.shape, device=means.device, dtype=means.dtype, generator=None)
# Just Gaussian sample with no scaling of variance.
return noise * torch.exp(logvar * 0.5) + means
return LatentDistribution(means, logvar)

View File

@ -43,7 +43,8 @@ def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None):
sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0]
sigma_schedule = [1.0 - x for x in sigma_schedule]
return sigma_schedule
#region ModelLoading
class DownloadAndLoadMochiModel:
@classmethod
def INPUT_TYPES(s):
@ -358,7 +359,8 @@ class MochiVAEEncoderLoader:
encoder = torch.compile(encoder, fullgraph=torch_compile_args["fullgraph"], mode=torch_compile_args["mode"], dynamic=False, backend=torch_compile_args["backend"])
return (encoder,)
#endregion
class MochiTextEncode:
@classmethod
def INPUT_TYPES(s):
@ -412,7 +414,7 @@ class MochiTextEncode:
}
return (t5_embeds, clip,)
#region Sampler
class MochiSampler:
@classmethod
def INPUT_TYPES(s):
@ -427,7 +429,6 @@ class MochiSampler:
"steps": ("INT", {"default": 50, "min": 2}),
"cfg": ("FLOAT", {"default": 4.5, "min": 0.0, "max": 30.0, "step": 0.01}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
#"batch_cfg": ("BOOLEAN", {"default": False, "tooltip": "Enable batched cfg"}),
},
"optional": {
"cfg_schedule": ("FLOAT", {"forceInput": True, "tooltip": "Override cfg schedule with a list of ints"}),
@ -489,7 +490,9 @@ class MochiSampler:
mm.soft_empty_cache()
return ({"samples": latents},)
#endregion
#region Latents
class MochiDecode:
@classmethod
def INPUT_TYPES(s):
@ -688,12 +691,18 @@ class MochiDecodeSpatialTiling:
return (frames,)
class MochiImageEncode:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"encoder": ("MOCHIVAE",),
"images": ("IMAGE", ),
"enable_vae_tiling": ("BOOLEAN", {"default": False, "tooltip": "Drastically reduces memory use but may introduce seams"}),
"num_tiles_w": ("INT", {"default": 4, "min": 2, "max": 64, "step": 2, "tooltip": "Number of horizontal tiles"}),
"num_tiles_h": ("INT", {"default": 4, "min": 2, "max": 64, "step": 2, "tooltip": "Number of vertical tiles"}),
"overlap": ("INT", {"default": 16, "min": 0, "max": 256, "step": 1, "tooltip": "Number of pixel of overlap between adjacent tiles"}),
"min_block_size": ("INT", {"default": 1, "min": 1, "max": 256, "step": 1, "tooltip": "Minimum number of pixels in each dimension when subdividing"}),
},
}
@ -702,23 +711,25 @@ class MochiImageEncode:
FUNCTION = "decode"
CATEGORY = "MochiWrapper"
def decode(self, encoder, images):
def decode(self, encoder, images, enable_vae_tiling, num_tiles_w, num_tiles_h, overlap, min_block_size):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
intermediate_device = mm.intermediate_device()
from .mochi_preview.vae.model import apply_tiled
B, H, W, C = images.shape
images = images.unsqueeze(0) * 2 - 1
images = rearrange(images, "t b h w c -> t c b h w")
images = images.to(encoder.dtype).to(device)
print(images.shape)
encoder.to(device)
print("images before encoding", images.shape)
with torch.autocast(mm.get_autocast_device(device), dtype=encoder.dtype):
video = add_fourier_features(images)
latents = encoder(video).sample()
video = add_fourier_features(images)
if enable_vae_tiling:
latents = apply_tiled(encoder, video, num_tiles_w = num_tiles_w, num_tiles_h = num_tiles_h, overlap=overlap, min_block_size=min_block_size)
else:
latents = encoder(video)
latents = vae_latents_to_dit_latents(latents)
print("encoder output",latents.shape)
@ -784,6 +795,8 @@ class MochiLatentPreview:
return (latent_images.float().cpu(),)
#endregion
#region NodeMappings
NODE_CLASS_MAPPINGS = {
"DownloadAndLoadMochiModel": DownloadAndLoadMochiModel,
"MochiSampler": MochiSampler,