Compare commits
No commits in common. "main" and "fastercache" have entirely different histories.
main
...
fastercach
@ -31,28 +31,17 @@ class LatentPreviewer:
|
||||
class Latent2RGBPreviewer(LatentPreviewer):
|
||||
def __init__(self):
|
||||
#latent_rgb_factors = [[0.05389399697934166, 0.025018778505575393, -0.009193515248318657], [0.02318250640590553, -0.026987363837713156, 0.040172639061236956], [0.046035451343323666, -0.02039565868920197, 0.01275569344290342], [-0.015559161155025095, 0.051403973219861246, 0.03179031307996347], [-0.02766167769640129, 0.03749545161530447, 0.003335141009473408], [0.05824598730479011, 0.021744367381243884, -0.01578925627951616], [0.05260929401500947, 0.0560165014956886, -0.027477296572565126], [0.018513891242931686, 0.041961785217662514, 0.004490763489747966], [0.024063060899760215, 0.065082853069653, 0.044343437673514896], [0.05250992323006226, 0.04361117432588933, 0.01030076055524387], [0.0038921710021782366, -0.025299228133723792, 0.019370764014574535], [-0.00011950534333568519, 0.06549370069727675, -0.03436712163379723], [-0.026020578032683626, -0.013341758571090847, -0.009119046570271953], [0.024412451175602937, 0.030135064560817174, -0.008355486384198006], [0.04002209845752687, -0.017341304390739463, 0.02818338690302971], [-0.032575108695213684, -0.009588338926775117, -0.03077312160940468]]
|
||||
#latent_rgb_factors = [[0.1236769792512748, 0.11775175335219157, -0.17700629766423637], [-0.08504104329270078, 0.026605813147523694, -0.006843165704926019], [-0.17093308616366876, 0.027991854696200386, 0.14179146288816308], [-0.17179555328757623, 0.09844317368603078, 0.14470997015982784], [-0.16975067171668484, -0.10739852629856643, -0.1894254942909962], [-0.19315259266769888, -0.011029760569485209, -0.08519702054654255], [-0.08399895091432583, -0.0964246452052032, -0.033622359523655665], [0.08148916330842498, 0.027500645903400067, -0.06593099749891196], [0.0456603103902293, -0.17844808072462398, 0.04204775167149785], [0.001751626383204502, -0.030567890189647867, -0.022078082809772193], [0.05110631095056278, -0.0709677393548804, 0.08963683539504264], [0.010515800868829, -0.18382052841762514, -0.08554553339721907]]
|
||||
latent_rgb_factors =[
|
||||
[-0.0069, -0.0045, 0.0018],
|
||||
[ 0.0154, -0.0692, -0.0274],
|
||||
[ 0.0333, 0.0019, 0.0206],
|
||||
[-0.1390, 0.0628, 0.1678],
|
||||
[-0.0725, 0.0134, -0.1898],
|
||||
[ 0.0074, -0.0270, -0.0209],
|
||||
[-0.0176, -0.0277, -0.0221],
|
||||
[ 0.5294, 0.5204, 0.3852],
|
||||
[-0.0326, -0.0446, -0.0143],
|
||||
[-0.0659, 0.0153, -0.0153],
|
||||
[ 0.0185, -0.0217, 0.0014],
|
||||
[-0.0396, -0.0495, -0.0281]
|
||||
]
|
||||
latent_rgb_factors = [[0.1236769792512748, 0.11775175335219157, -0.17700629766423637], [-0.08504104329270078, 0.026605813147523694, -0.006843165704926019], [-0.17093308616366876, 0.027991854696200386, 0.14179146288816308], [-0.17179555328757623, 0.09844317368603078, 0.14470997015982784], [-0.16975067171668484, -0.10739852629856643, -0.1894254942909962], [-0.19315259266769888, -0.011029760569485209, -0.08519702054654255], [-0.08399895091432583, -0.0964246452052032, -0.033622359523655665], [0.08148916330842498, 0.027500645903400067, -0.06593099749891196], [0.0456603103902293, -0.17844808072462398, 0.04204775167149785], [0.001751626383204502, -0.030567890189647867, -0.022078082809772193], [0.05110631095056278, -0.0709677393548804, 0.08963683539504264], [0.010515800868829, -0.18382052841762514, -0.08554553339721907]]
|
||||
|
||||
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1)
|
||||
self.latent_rgb_factors_bias = [-0.0940, -0.1418, -0.1453]
|
||||
self.latent_rgb_factors_bias = None
|
||||
# if latent_rgb_factors_bias is not None:
|
||||
# self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu")
|
||||
|
||||
def decode_latent_to_preview(self, x0):
|
||||
self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
|
||||
if self.latent_rgb_factors_bias is not None:
|
||||
self.latent_rgb_factors_bias = torch.tensor(self.latent_rgb_factors_bias, device="cpu").to(dtype=x0.dtype, device=x0.device)
|
||||
self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device)
|
||||
|
||||
latent_image = torch.nn.functional.linear(x0[0].permute(1, 2, 0), self.latent_rgb_factors,
|
||||
bias=self.latent_rgb_factors_bias)
|
||||
|
||||
@ -165,7 +165,6 @@ class AsymmetricAttention(nn.Module):
|
||||
raise ImportError("Flash RMSNorm not available.")
|
||||
elif rms_norm_func == "apex":
|
||||
from apex.normalization import FusedRMSNorm as ApexRMSNorm
|
||||
@torch.compiler.disable()
|
||||
class RMSNorm(ApexRMSNorm):
|
||||
pass
|
||||
else:
|
||||
@ -238,6 +237,7 @@ class AsymmetricAttention(nn.Module):
|
||||
skip_reshape=True
|
||||
)
|
||||
return out
|
||||
|
||||
def run_attention(
|
||||
self,
|
||||
q,
|
||||
|
||||
@ -140,7 +140,7 @@ class PatchEmbed(nn.Module):
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
@torch.compiler.disable()
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-5, device=None):
|
||||
super().__init__()
|
||||
|
||||
@ -34,7 +34,6 @@ except:
|
||||
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import torch._dynamo
|
||||
|
||||
from tqdm import tqdm
|
||||
from comfy.utils import ProgressBar, load_torch_file
|
||||
@ -162,8 +161,6 @@ class T2VSynthMochiModel:
|
||||
|
||||
#torch.compile
|
||||
if compile_args is not None:
|
||||
torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"]
|
||||
log.info(f"Set dynamo cache size limit to {torch._dynamo.config.cache_size_limit}")
|
||||
if compile_args["compile_dit"]:
|
||||
for i, block in enumerate(model.blocks):
|
||||
model.blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"])
|
||||
@ -311,7 +308,7 @@ class T2VSynthMochiModel:
|
||||
|
||||
comfy_pbar = ProgressBar(sample_steps)
|
||||
|
||||
if (hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul) or self.device.type == "mps":
|
||||
if hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul:
|
||||
autocast_dtype = torch.float16
|
||||
else:
|
||||
autocast_dtype = torch.bfloat16
|
||||
|
||||
78
nodes.py
78
nodes.py
@ -119,10 +119,6 @@ class DownloadAndLoadMochiModel:
|
||||
mm.soft_empty_cache()
|
||||
|
||||
dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "fp8_e4m3fn_fast": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
|
||||
if "fp8" in precision:
|
||||
vae_dtype = torch.bfloat16
|
||||
else:
|
||||
vae_dtype = dtype
|
||||
|
||||
# Transformer model
|
||||
model_download_path = os.path.join(folder_paths.models_dir, 'diffusion_models', 'mochi')
|
||||
@ -178,15 +174,14 @@ class DownloadAndLoadMochiModel:
|
||||
nonlinearity="silu",
|
||||
output_nonlinearity="silu",
|
||||
causal=True,
|
||||
dtype=vae_dtype,
|
||||
)
|
||||
vae_sd = load_torch_file(vae_path)
|
||||
if is_accelerate_available:
|
||||
for key in vae_sd:
|
||||
set_module_tensor_to_device(vae, key, dtype=vae_dtype, device=offload_device, value=vae_sd[key])
|
||||
set_module_tensor_to_device(vae, key, dtype=torch.bfloat16, device=offload_device, value=vae_sd[key])
|
||||
else:
|
||||
vae.load_state_dict(vae_sd, strict=True)
|
||||
vae.eval().to(vae_dtype).to("cpu")
|
||||
vae.eval().to(torch.bfloat16).to("cpu")
|
||||
del vae_sd
|
||||
|
||||
return (model, vae,)
|
||||
@ -247,7 +242,6 @@ class MochiTorchCompileSettings:
|
||||
"compile_dit": ("BOOLEAN", {"default": True, "tooltip": "Compiles all transformer blocks"}),
|
||||
"compile_final_layer": ("BOOLEAN", {"default": True, "tooltip": "Enable compiling final layer."}),
|
||||
"dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}),
|
||||
"dynamo_cache_size_limit": ("INT", {"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.cache_size_limit"}),
|
||||
},
|
||||
}
|
||||
RETURN_TYPES = ("MOCHICOMPILEARGS",)
|
||||
@ -256,7 +250,7 @@ class MochiTorchCompileSettings:
|
||||
CATEGORY = "MochiWrapper"
|
||||
DESCRIPTION = "torch.compile settings, when connected to the model loader, torch.compile of the selected layers is attempted. Requires Triton and torch 2.5.0 is recommended"
|
||||
|
||||
def loadmodel(self, backend, fullgraph, mode, compile_dit, compile_final_layer, dynamic, dynamo_cache_size_limit):
|
||||
def loadmodel(self, backend, fullgraph, mode, compile_dit, compile_final_layer, dynamic):
|
||||
|
||||
compile_args = {
|
||||
"backend": backend,
|
||||
@ -265,7 +259,6 @@ class MochiTorchCompileSettings:
|
||||
"compile_dit": compile_dit,
|
||||
"compile_final_layer": compile_final_layer,
|
||||
"dynamic": dynamic,
|
||||
"dynamo_cache_size_limit": dynamo_cache_size_limit,
|
||||
}
|
||||
|
||||
return (compile_args, )
|
||||
@ -315,16 +308,6 @@ class MochiVAELoader:
|
||||
dtype=dtype,
|
||||
)
|
||||
vae_sd = load_torch_file(vae_path)
|
||||
|
||||
#support loading from combined VAE
|
||||
if vae_sd.get("decoder.blocks.0.0.bias") is not None:
|
||||
new_vae_sd = {}
|
||||
for k, v in vae_sd.items():
|
||||
if k.startswith("decoder."):
|
||||
new_k = k[len("decoder."):]
|
||||
new_vae_sd[new_k] = v
|
||||
vae_sd = new_vae_sd
|
||||
|
||||
if is_accelerate_available:
|
||||
for name, param in vae.named_parameters():
|
||||
set_module_tensor_to_device(vae, name, dtype=dtype, device=offload_device, value=vae_sd[name])
|
||||
@ -393,16 +376,6 @@ class MochiVAEEncoderLoader:
|
||||
)
|
||||
|
||||
encoder_sd = load_torch_file(vae_path)
|
||||
|
||||
#support loading from combined VAE
|
||||
if encoder_sd.get("encoder.layers.0.bias") is not None:
|
||||
new_vae_sd = {}
|
||||
for k, v in encoder_sd.items():
|
||||
if k.startswith("encoder."):
|
||||
new_k = k[len("encoder."):]
|
||||
new_vae_sd[new_k] = v
|
||||
encoder_sd = new_vae_sd
|
||||
|
||||
if is_accelerate_available:
|
||||
for name, param in encoder.named_parameters():
|
||||
set_module_tensor_to_device(encoder, name, dtype=dtype, device=offload_device, value=encoder_sd[name])
|
||||
@ -598,9 +571,6 @@ class MochiDecode:
|
||||
"tile_overlap_factor_height": ("FLOAT", {"default": 0.1666, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"tile_overlap_factor_width": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
},
|
||||
"optional": {
|
||||
"unnormalize": ("BOOLEAN", {"default": False, "tooltip": "Unnormalize the latents before decoding"}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
@ -609,13 +579,12 @@ class MochiDecode:
|
||||
CATEGORY = "MochiWrapper"
|
||||
|
||||
def decode(self, vae, samples, enable_vae_tiling, tile_sample_min_height, tile_sample_min_width, tile_overlap_factor_height,
|
||||
tile_overlap_factor_width, auto_tile_size, frame_batch_size, unnormalize=False):
|
||||
tile_overlap_factor_width, auto_tile_size, frame_batch_size):
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
intermediate_device = mm.intermediate_device()
|
||||
samples = samples["samples"]
|
||||
if unnormalize:
|
||||
samples = dit_latents_to_vae_latents(samples)
|
||||
samples = dit_latents_to_vae_latents(samples)
|
||||
samples = samples.to(vae.dtype).to(device)
|
||||
|
||||
B, C, T, H, W = samples.shape
|
||||
@ -730,9 +699,6 @@ class MochiDecodeSpatialTiling:
|
||||
"min_block_size": ("INT", {"default": 1, "min": 1, "max": 256, "step": 1, "tooltip": "Minimum number of pixels in each dimension when subdividing"}),
|
||||
"per_batch": ("INT", {"default": 6, "min": 1, "max": 256, "step": 1, "tooltip": "Number of samples per batch, in latent space (6 frames in 1 latent)"}),
|
||||
},
|
||||
"optional": {
|
||||
"unnormalize": ("BOOLEAN", {"default": True, "tooltip": "Unnormalize the latents before decoding"}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
@ -741,13 +707,12 @@ class MochiDecodeSpatialTiling:
|
||||
CATEGORY = "MochiWrapper"
|
||||
|
||||
def decode(self, vae, samples, enable_vae_tiling, num_tiles_w, num_tiles_h, overlap,
|
||||
min_block_size, per_batch, unnormalize=True):
|
||||
min_block_size, per_batch):
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
intermediate_device = mm.intermediate_device()
|
||||
samples = samples["samples"]
|
||||
if unnormalize:
|
||||
samples = dit_latents_to_vae_latents(samples)
|
||||
samples = dit_latents_to_vae_latents(samples)
|
||||
samples = samples.to(vae.dtype).to(device)
|
||||
|
||||
B, C, T, H, W = samples.shape
|
||||
@ -803,17 +768,14 @@ class MochiImageEncode:
|
||||
"overlap": ("INT", {"default": 16, "min": 0, "max": 256, "step": 1, "tooltip": "Number of pixel of overlap between adjacent tiles"}),
|
||||
"min_block_size": ("INT", {"default": 1, "min": 1, "max": 256, "step": 1, "tooltip": "Minimum number of pixels in each dimension when subdividing"}),
|
||||
},
|
||||
"optional": {
|
||||
"normalize": ("BOOLEAN", {"default": True, "tooltip": "Normalize the images before encoding"}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
RETURN_NAMES = ("samples",)
|
||||
FUNCTION = "encode"
|
||||
FUNCTION = "decode"
|
||||
CATEGORY = "MochiWrapper"
|
||||
|
||||
def encode(self, encoder, images, enable_vae_tiling, num_tiles_w, num_tiles_h, overlap, min_block_size, normalize=True):
|
||||
def decode(self, encoder, images, enable_vae_tiling, num_tiles_w, num_tiles_h, overlap, min_block_size):
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
intermediate_device = mm.intermediate_device()
|
||||
@ -838,8 +800,7 @@ class MochiImageEncode:
|
||||
latents = apply_tiled(encoder, video, num_tiles_w = num_tiles_w, num_tiles_h = num_tiles_h, overlap=overlap, min_block_size=min_block_size)
|
||||
else:
|
||||
latents = encoder(video)
|
||||
if normalize:
|
||||
latents = vae_latents_to_dit_latents(latents)
|
||||
latents = vae_latents_to_dit_latents(latents)
|
||||
print("encoder output",latents.shape)
|
||||
|
||||
return ({"samples": latents},)
|
||||
@ -870,21 +831,8 @@ class MochiLatentPreview:
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
|
||||
#latent_rgb_factors = [[0.1236769792512748, 0.11775175335219157, -0.17700629766423637], [-0.08504104329270078, 0.026605813147523694, -0.006843165704926019], [-0.17093308616366876, 0.027991854696200386, 0.14179146288816308], [-0.17179555328757623, 0.09844317368603078, 0.14470997015982784], [-0.16975067171668484, -0.10739852629856643, -0.1894254942909962], [-0.19315259266769888, -0.011029760569485209, -0.08519702054654255], [-0.08399895091432583, -0.0964246452052032, -0.033622359523655665], [0.08148916330842498, 0.027500645903400067, -0.06593099749891196], [0.0456603103902293, -0.17844808072462398, 0.04204775167149785], [0.001751626383204502, -0.030567890189647867, -0.022078082809772193], [0.05110631095056278, -0.0709677393548804, 0.08963683539504264], [0.010515800868829, -0.18382052841762514, -0.08554553339721907]]
|
||||
latent_rgb_factors =[
|
||||
[-0.0069, -0.0045, 0.0018],
|
||||
[ 0.0154, -0.0692, -0.0274],
|
||||
[ 0.0333, 0.0019, 0.0206],
|
||||
[-0.1390, 0.0628, 0.1678],
|
||||
[-0.0725, 0.0134, -0.1898],
|
||||
[ 0.0074, -0.0270, -0.0209],
|
||||
[-0.0176, -0.0277, -0.0221],
|
||||
[ 0.5294, 0.5204, 0.3852],
|
||||
[-0.0326, -0.0446, -0.0143],
|
||||
[-0.0659, 0.0153, -0.0153],
|
||||
[ 0.0185, -0.0217, 0.0014],
|
||||
[-0.0396, -0.0495, -0.0281]
|
||||
]
|
||||
latent_rgb_factors = [[0.1236769792512748, 0.11775175335219157, -0.17700629766423637], [-0.08504104329270078, 0.026605813147523694, -0.006843165704926019], [-0.17093308616366876, 0.027991854696200386, 0.14179146288816308], [-0.17179555328757623, 0.09844317368603078, 0.14470997015982784], [-0.16975067171668484, -0.10739852629856643, -0.1894254942909962], [-0.19315259266769888, -0.011029760569485209, -0.08519702054654255], [-0.08399895091432583, -0.0964246452052032, -0.033622359523655665], [0.08148916330842498, 0.027500645903400067, -0.06593099749891196], [0.0456603103902293, -0.17844808072462398, 0.04204775167149785], [0.001751626383204502, -0.030567890189647867, -0.022078082809772193], [0.05110631095056278, -0.0709677393548804, 0.08963683539504264], [0.010515800868829, -0.18382052841762514, -0.08554553339721907]]
|
||||
|
||||
# import random
|
||||
# random.seed(seed)
|
||||
# latent_rgb_factors = [[random.uniform(min_val, max_val) for _ in range(3)] for _ in range(12)]
|
||||
@ -892,7 +840,7 @@ class MochiLatentPreview:
|
||||
# print(latent_rgb_factors)
|
||||
|
||||
|
||||
latent_rgb_factors_bias = [-0.0940, -0.1418, -0.1453]
|
||||
latent_rgb_factors_bias = [0,0,0]
|
||||
|
||||
latent_rgb_factors = torch.tensor(latent_rgb_factors, device=latents.device, dtype=latents.dtype).transpose(0, 1)
|
||||
latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device=latents.device, dtype=latents.dtype)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user