tiled encoding
This commit is contained in:
parent
69ab797b8c
commit
a5b06b02ad
@ -323,6 +323,5 @@ class T2VSynthMochiModel:
|
|||||||
comfy_pbar.update(1)
|
comfy_pbar.update(1)
|
||||||
|
|
||||||
self.dit.to(self.offload_device)
|
self.dit.to(self.offload_device)
|
||||||
#samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std)
|
|
||||||
logging.info(f"samples shape: {z.shape}")
|
logging.info(f"samples shape: {z.shape}")
|
||||||
return z
|
return z
|
||||||
|
|||||||
@ -899,6 +899,11 @@ class Encoder(nn.Module):
|
|||||||
assert logvar.shape == means.shape
|
assert logvar.shape == means.shape
|
||||||
assert means.size(1) == self.latent_dim
|
assert means.size(1) == self.latent_dim
|
||||||
|
|
||||||
|
noise = torch.randn(means.shape, device=means.device, dtype=means.dtype, generator=None)
|
||||||
|
|
||||||
|
# Just Gaussian sample with no scaling of variance.
|
||||||
|
return noise * torch.exp(logvar * 0.5) + means
|
||||||
|
|
||||||
return LatentDistribution(means, logvar)
|
return LatentDistribution(means, logvar)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
33
nodes.py
33
nodes.py
@ -43,7 +43,8 @@ def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None):
|
|||||||
sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0]
|
sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0]
|
||||||
sigma_schedule = [1.0 - x for x in sigma_schedule]
|
sigma_schedule = [1.0 - x for x in sigma_schedule]
|
||||||
return sigma_schedule
|
return sigma_schedule
|
||||||
|
|
||||||
|
#region ModelLoading
|
||||||
class DownloadAndLoadMochiModel:
|
class DownloadAndLoadMochiModel:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -358,7 +359,8 @@ class MochiVAEEncoderLoader:
|
|||||||
encoder = torch.compile(encoder, fullgraph=torch_compile_args["fullgraph"], mode=torch_compile_args["mode"], dynamic=False, backend=torch_compile_args["backend"])
|
encoder = torch.compile(encoder, fullgraph=torch_compile_args["fullgraph"], mode=torch_compile_args["mode"], dynamic=False, backend=torch_compile_args["backend"])
|
||||||
|
|
||||||
return (encoder,)
|
return (encoder,)
|
||||||
|
#endregion
|
||||||
|
|
||||||
class MochiTextEncode:
|
class MochiTextEncode:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -412,7 +414,7 @@ class MochiTextEncode:
|
|||||||
}
|
}
|
||||||
return (t5_embeds, clip,)
|
return (t5_embeds, clip,)
|
||||||
|
|
||||||
|
#region Sampler
|
||||||
class MochiSampler:
|
class MochiSampler:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -427,7 +429,6 @@ class MochiSampler:
|
|||||||
"steps": ("INT", {"default": 50, "min": 2}),
|
"steps": ("INT", {"default": 50, "min": 2}),
|
||||||
"cfg": ("FLOAT", {"default": 4.5, "min": 0.0, "max": 30.0, "step": 0.01}),
|
"cfg": ("FLOAT", {"default": 4.5, "min": 0.0, "max": 30.0, "step": 0.01}),
|
||||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||||
#"batch_cfg": ("BOOLEAN", {"default": False, "tooltip": "Enable batched cfg"}),
|
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"cfg_schedule": ("FLOAT", {"forceInput": True, "tooltip": "Override cfg schedule with a list of ints"}),
|
"cfg_schedule": ("FLOAT", {"forceInput": True, "tooltip": "Override cfg schedule with a list of ints"}),
|
||||||
@ -489,7 +490,9 @@ class MochiSampler:
|
|||||||
mm.soft_empty_cache()
|
mm.soft_empty_cache()
|
||||||
|
|
||||||
return ({"samples": latents},)
|
return ({"samples": latents},)
|
||||||
|
#endregion
|
||||||
|
#region Latents
|
||||||
|
|
||||||
class MochiDecode:
|
class MochiDecode:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -688,12 +691,18 @@ class MochiDecodeSpatialTiling:
|
|||||||
|
|
||||||
return (frames,)
|
return (frames,)
|
||||||
|
|
||||||
|
|
||||||
class MochiImageEncode:
|
class MochiImageEncode:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": {
|
return {"required": {
|
||||||
"encoder": ("MOCHIVAE",),
|
"encoder": ("MOCHIVAE",),
|
||||||
"images": ("IMAGE", ),
|
"images": ("IMAGE", ),
|
||||||
|
"enable_vae_tiling": ("BOOLEAN", {"default": False, "tooltip": "Drastically reduces memory use but may introduce seams"}),
|
||||||
|
"num_tiles_w": ("INT", {"default": 4, "min": 2, "max": 64, "step": 2, "tooltip": "Number of horizontal tiles"}),
|
||||||
|
"num_tiles_h": ("INT", {"default": 4, "min": 2, "max": 64, "step": 2, "tooltip": "Number of vertical tiles"}),
|
||||||
|
"overlap": ("INT", {"default": 16, "min": 0, "max": 256, "step": 1, "tooltip": "Number of pixel of overlap between adjacent tiles"}),
|
||||||
|
"min_block_size": ("INT", {"default": 1, "min": 1, "max": 256, "step": 1, "tooltip": "Minimum number of pixels in each dimension when subdividing"}),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -702,23 +711,25 @@ class MochiImageEncode:
|
|||||||
FUNCTION = "decode"
|
FUNCTION = "decode"
|
||||||
CATEGORY = "MochiWrapper"
|
CATEGORY = "MochiWrapper"
|
||||||
|
|
||||||
def decode(self, encoder, images):
|
def decode(self, encoder, images, enable_vae_tiling, num_tiles_w, num_tiles_h, overlap, min_block_size):
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
intermediate_device = mm.intermediate_device()
|
intermediate_device = mm.intermediate_device()
|
||||||
|
from .mochi_preview.vae.model import apply_tiled
|
||||||
B, H, W, C = images.shape
|
B, H, W, C = images.shape
|
||||||
|
|
||||||
images = images.unsqueeze(0) * 2 - 1
|
images = images.unsqueeze(0) * 2 - 1
|
||||||
images = rearrange(images, "t b h w c -> t c b h w")
|
images = rearrange(images, "t b h w c -> t c b h w")
|
||||||
images = images.to(encoder.dtype).to(device)
|
images = images.to(encoder.dtype).to(device)
|
||||||
print(images.shape)
|
print(images.shape)
|
||||||
|
|
||||||
encoder.to(device)
|
encoder.to(device)
|
||||||
print("images before encoding", images.shape)
|
print("images before encoding", images.shape)
|
||||||
with torch.autocast(mm.get_autocast_device(device), dtype=encoder.dtype):
|
with torch.autocast(mm.get_autocast_device(device), dtype=encoder.dtype):
|
||||||
video = add_fourier_features(images)
|
video = add_fourier_features(images)
|
||||||
latents = encoder(video).sample()
|
if enable_vae_tiling:
|
||||||
|
latents = apply_tiled(encoder, video, num_tiles_w = num_tiles_w, num_tiles_h = num_tiles_h, overlap=overlap, min_block_size=min_block_size)
|
||||||
|
else:
|
||||||
|
latents = encoder(video)
|
||||||
latents = vae_latents_to_dit_latents(latents)
|
latents = vae_latents_to_dit_latents(latents)
|
||||||
print("encoder output",latents.shape)
|
print("encoder output",latents.shape)
|
||||||
|
|
||||||
@ -784,6 +795,8 @@ class MochiLatentPreview:
|
|||||||
|
|
||||||
return (latent_images.float().cpu(),)
|
return (latent_images.float().cpu(),)
|
||||||
|
|
||||||
|
#endregion
|
||||||
|
#region NodeMappings
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"DownloadAndLoadMochiModel": DownloadAndLoadMochiModel,
|
"DownloadAndLoadMochiModel": DownloadAndLoadMochiModel,
|
||||||
"MochiSampler": MochiSampler,
|
"MochiSampler": MochiSampler,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user