tiled encoding
This commit is contained in:
parent
69ab797b8c
commit
a5b06b02ad
@ -323,6 +323,5 @@ class T2VSynthMochiModel:
|
||||
comfy_pbar.update(1)
|
||||
|
||||
self.dit.to(self.offload_device)
|
||||
#samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std)
|
||||
logging.info(f"samples shape: {z.shape}")
|
||||
return z
|
||||
|
||||
@ -899,6 +899,11 @@ class Encoder(nn.Module):
|
||||
assert logvar.shape == means.shape
|
||||
assert means.size(1) == self.latent_dim
|
||||
|
||||
noise = torch.randn(means.shape, device=means.device, dtype=means.dtype, generator=None)
|
||||
|
||||
# Just Gaussian sample with no scaling of variance.
|
||||
return noise * torch.exp(logvar * 0.5) + means
|
||||
|
||||
return LatentDistribution(means, logvar)
|
||||
|
||||
|
||||
|
||||
33
nodes.py
33
nodes.py
@ -43,7 +43,8 @@ def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None):
|
||||
sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0]
|
||||
sigma_schedule = [1.0 - x for x in sigma_schedule]
|
||||
return sigma_schedule
|
||||
|
||||
|
||||
#region ModelLoading
|
||||
class DownloadAndLoadMochiModel:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -358,7 +359,8 @@ class MochiVAEEncoderLoader:
|
||||
encoder = torch.compile(encoder, fullgraph=torch_compile_args["fullgraph"], mode=torch_compile_args["mode"], dynamic=False, backend=torch_compile_args["backend"])
|
||||
|
||||
return (encoder,)
|
||||
|
||||
#endregion
|
||||
|
||||
class MochiTextEncode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -412,7 +414,7 @@ class MochiTextEncode:
|
||||
}
|
||||
return (t5_embeds, clip,)
|
||||
|
||||
|
||||
#region Sampler
|
||||
class MochiSampler:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -427,7 +429,6 @@ class MochiSampler:
|
||||
"steps": ("INT", {"default": 50, "min": 2}),
|
||||
"cfg": ("FLOAT", {"default": 4.5, "min": 0.0, "max": 30.0, "step": 0.01}),
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||
#"batch_cfg": ("BOOLEAN", {"default": False, "tooltip": "Enable batched cfg"}),
|
||||
},
|
||||
"optional": {
|
||||
"cfg_schedule": ("FLOAT", {"forceInput": True, "tooltip": "Override cfg schedule with a list of ints"}),
|
||||
@ -489,7 +490,9 @@ class MochiSampler:
|
||||
mm.soft_empty_cache()
|
||||
|
||||
return ({"samples": latents},)
|
||||
|
||||
#endregion
|
||||
#region Latents
|
||||
|
||||
class MochiDecode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -688,12 +691,18 @@ class MochiDecodeSpatialTiling:
|
||||
|
||||
return (frames,)
|
||||
|
||||
|
||||
class MochiImageEncode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"encoder": ("MOCHIVAE",),
|
||||
"images": ("IMAGE", ),
|
||||
"enable_vae_tiling": ("BOOLEAN", {"default": False, "tooltip": "Drastically reduces memory use but may introduce seams"}),
|
||||
"num_tiles_w": ("INT", {"default": 4, "min": 2, "max": 64, "step": 2, "tooltip": "Number of horizontal tiles"}),
|
||||
"num_tiles_h": ("INT", {"default": 4, "min": 2, "max": 64, "step": 2, "tooltip": "Number of vertical tiles"}),
|
||||
"overlap": ("INT", {"default": 16, "min": 0, "max": 256, "step": 1, "tooltip": "Number of pixel of overlap between adjacent tiles"}),
|
||||
"min_block_size": ("INT", {"default": 1, "min": 1, "max": 256, "step": 1, "tooltip": "Minimum number of pixels in each dimension when subdividing"}),
|
||||
},
|
||||
}
|
||||
|
||||
@ -702,23 +711,25 @@ class MochiImageEncode:
|
||||
FUNCTION = "decode"
|
||||
CATEGORY = "MochiWrapper"
|
||||
|
||||
def decode(self, encoder, images):
|
||||
def decode(self, encoder, images, enable_vae_tiling, num_tiles_w, num_tiles_h, overlap, min_block_size):
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
intermediate_device = mm.intermediate_device()
|
||||
|
||||
from .mochi_preview.vae.model import apply_tiled
|
||||
B, H, W, C = images.shape
|
||||
|
||||
images = images.unsqueeze(0) * 2 - 1
|
||||
images = rearrange(images, "t b h w c -> t c b h w")
|
||||
images = images.to(encoder.dtype).to(device)
|
||||
print(images.shape)
|
||||
|
||||
encoder.to(device)
|
||||
print("images before encoding", images.shape)
|
||||
with torch.autocast(mm.get_autocast_device(device), dtype=encoder.dtype):
|
||||
video = add_fourier_features(images)
|
||||
latents = encoder(video).sample()
|
||||
video = add_fourier_features(images)
|
||||
if enable_vae_tiling:
|
||||
latents = apply_tiled(encoder, video, num_tiles_w = num_tiles_w, num_tiles_h = num_tiles_h, overlap=overlap, min_block_size=min_block_size)
|
||||
else:
|
||||
latents = encoder(video)
|
||||
latents = vae_latents_to_dit_latents(latents)
|
||||
print("encoder output",latents.shape)
|
||||
|
||||
@ -784,6 +795,8 @@ class MochiLatentPreview:
|
||||
|
||||
return (latent_images.float().cpu(),)
|
||||
|
||||
#endregion
|
||||
#region NodeMappings
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"DownloadAndLoadMochiModel": DownloadAndLoadMochiModel,
|
||||
"MochiSampler": MochiSampler,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user