Compare commits
No commits in common. "main" and "fastercache" have entirely different histories.
main
...
fastercach
@ -31,28 +31,17 @@ class LatentPreviewer:
|
|||||||
class Latent2RGBPreviewer(LatentPreviewer):
|
class Latent2RGBPreviewer(LatentPreviewer):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
#latent_rgb_factors = [[0.05389399697934166, 0.025018778505575393, -0.009193515248318657], [0.02318250640590553, -0.026987363837713156, 0.040172639061236956], [0.046035451343323666, -0.02039565868920197, 0.01275569344290342], [-0.015559161155025095, 0.051403973219861246, 0.03179031307996347], [-0.02766167769640129, 0.03749545161530447, 0.003335141009473408], [0.05824598730479011, 0.021744367381243884, -0.01578925627951616], [0.05260929401500947, 0.0560165014956886, -0.027477296572565126], [0.018513891242931686, 0.041961785217662514, 0.004490763489747966], [0.024063060899760215, 0.065082853069653, 0.044343437673514896], [0.05250992323006226, 0.04361117432588933, 0.01030076055524387], [0.0038921710021782366, -0.025299228133723792, 0.019370764014574535], [-0.00011950534333568519, 0.06549370069727675, -0.03436712163379723], [-0.026020578032683626, -0.013341758571090847, -0.009119046570271953], [0.024412451175602937, 0.030135064560817174, -0.008355486384198006], [0.04002209845752687, -0.017341304390739463, 0.02818338690302971], [-0.032575108695213684, -0.009588338926775117, -0.03077312160940468]]
|
#latent_rgb_factors = [[0.05389399697934166, 0.025018778505575393, -0.009193515248318657], [0.02318250640590553, -0.026987363837713156, 0.040172639061236956], [0.046035451343323666, -0.02039565868920197, 0.01275569344290342], [-0.015559161155025095, 0.051403973219861246, 0.03179031307996347], [-0.02766167769640129, 0.03749545161530447, 0.003335141009473408], [0.05824598730479011, 0.021744367381243884, -0.01578925627951616], [0.05260929401500947, 0.0560165014956886, -0.027477296572565126], [0.018513891242931686, 0.041961785217662514, 0.004490763489747966], [0.024063060899760215, 0.065082853069653, 0.044343437673514896], [0.05250992323006226, 0.04361117432588933, 0.01030076055524387], [0.0038921710021782366, -0.025299228133723792, 0.019370764014574535], [-0.00011950534333568519, 0.06549370069727675, -0.03436712163379723], [-0.026020578032683626, -0.013341758571090847, -0.009119046570271953], [0.024412451175602937, 0.030135064560817174, -0.008355486384198006], [0.04002209845752687, -0.017341304390739463, 0.02818338690302971], [-0.032575108695213684, -0.009588338926775117, -0.03077312160940468]]
|
||||||
#latent_rgb_factors = [[0.1236769792512748, 0.11775175335219157, -0.17700629766423637], [-0.08504104329270078, 0.026605813147523694, -0.006843165704926019], [-0.17093308616366876, 0.027991854696200386, 0.14179146288816308], [-0.17179555328757623, 0.09844317368603078, 0.14470997015982784], [-0.16975067171668484, -0.10739852629856643, -0.1894254942909962], [-0.19315259266769888, -0.011029760569485209, -0.08519702054654255], [-0.08399895091432583, -0.0964246452052032, -0.033622359523655665], [0.08148916330842498, 0.027500645903400067, -0.06593099749891196], [0.0456603103902293, -0.17844808072462398, 0.04204775167149785], [0.001751626383204502, -0.030567890189647867, -0.022078082809772193], [0.05110631095056278, -0.0709677393548804, 0.08963683539504264], [0.010515800868829, -0.18382052841762514, -0.08554553339721907]]
|
latent_rgb_factors = [[0.1236769792512748, 0.11775175335219157, -0.17700629766423637], [-0.08504104329270078, 0.026605813147523694, -0.006843165704926019], [-0.17093308616366876, 0.027991854696200386, 0.14179146288816308], [-0.17179555328757623, 0.09844317368603078, 0.14470997015982784], [-0.16975067171668484, -0.10739852629856643, -0.1894254942909962], [-0.19315259266769888, -0.011029760569485209, -0.08519702054654255], [-0.08399895091432583, -0.0964246452052032, -0.033622359523655665], [0.08148916330842498, 0.027500645903400067, -0.06593099749891196], [0.0456603103902293, -0.17844808072462398, 0.04204775167149785], [0.001751626383204502, -0.030567890189647867, -0.022078082809772193], [0.05110631095056278, -0.0709677393548804, 0.08963683539504264], [0.010515800868829, -0.18382052841762514, -0.08554553339721907]]
|
||||||
latent_rgb_factors =[
|
|
||||||
[-0.0069, -0.0045, 0.0018],
|
|
||||||
[ 0.0154, -0.0692, -0.0274],
|
|
||||||
[ 0.0333, 0.0019, 0.0206],
|
|
||||||
[-0.1390, 0.0628, 0.1678],
|
|
||||||
[-0.0725, 0.0134, -0.1898],
|
|
||||||
[ 0.0074, -0.0270, -0.0209],
|
|
||||||
[-0.0176, -0.0277, -0.0221],
|
|
||||||
[ 0.5294, 0.5204, 0.3852],
|
|
||||||
[-0.0326, -0.0446, -0.0143],
|
|
||||||
[-0.0659, 0.0153, -0.0153],
|
|
||||||
[ 0.0185, -0.0217, 0.0014],
|
|
||||||
[-0.0396, -0.0495, -0.0281]
|
|
||||||
]
|
|
||||||
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1)
|
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1)
|
||||||
self.latent_rgb_factors_bias = [-0.0940, -0.1418, -0.1453]
|
self.latent_rgb_factors_bias = None
|
||||||
|
# if latent_rgb_factors_bias is not None:
|
||||||
|
# self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu")
|
||||||
|
|
||||||
def decode_latent_to_preview(self, x0):
|
def decode_latent_to_preview(self, x0):
|
||||||
self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
|
self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
|
||||||
if self.latent_rgb_factors_bias is not None:
|
if self.latent_rgb_factors_bias is not None:
|
||||||
self.latent_rgb_factors_bias = torch.tensor(self.latent_rgb_factors_bias, device="cpu").to(dtype=x0.dtype, device=x0.device)
|
self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device)
|
||||||
|
|
||||||
latent_image = torch.nn.functional.linear(x0[0].permute(1, 2, 0), self.latent_rgb_factors,
|
latent_image = torch.nn.functional.linear(x0[0].permute(1, 2, 0), self.latent_rgb_factors,
|
||||||
bias=self.latent_rgb_factors_bias)
|
bias=self.latent_rgb_factors_bias)
|
||||||
|
|||||||
@ -165,7 +165,6 @@ class AsymmetricAttention(nn.Module):
|
|||||||
raise ImportError("Flash RMSNorm not available.")
|
raise ImportError("Flash RMSNorm not available.")
|
||||||
elif rms_norm_func == "apex":
|
elif rms_norm_func == "apex":
|
||||||
from apex.normalization import FusedRMSNorm as ApexRMSNorm
|
from apex.normalization import FusedRMSNorm as ApexRMSNorm
|
||||||
@torch.compiler.disable()
|
|
||||||
class RMSNorm(ApexRMSNorm):
|
class RMSNorm(ApexRMSNorm):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
@ -238,6 +237,7 @@ class AsymmetricAttention(nn.Module):
|
|||||||
skip_reshape=True
|
skip_reshape=True
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def run_attention(
|
def run_attention(
|
||||||
self,
|
self,
|
||||||
q,
|
q,
|
||||||
|
|||||||
@ -140,7 +140,7 @@ class PatchEmbed(nn.Module):
|
|||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@torch.compiler.disable()
|
|
||||||
class RMSNorm(torch.nn.Module):
|
class RMSNorm(torch.nn.Module):
|
||||||
def __init__(self, hidden_size, eps=1e-5, device=None):
|
def __init__(self, hidden_size, eps=1e-5, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@ -34,7 +34,6 @@ except:
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
import torch._dynamo
|
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from comfy.utils import ProgressBar, load_torch_file
|
from comfy.utils import ProgressBar, load_torch_file
|
||||||
@ -162,8 +161,6 @@ class T2VSynthMochiModel:
|
|||||||
|
|
||||||
#torch.compile
|
#torch.compile
|
||||||
if compile_args is not None:
|
if compile_args is not None:
|
||||||
torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"]
|
|
||||||
log.info(f"Set dynamo cache size limit to {torch._dynamo.config.cache_size_limit}")
|
|
||||||
if compile_args["compile_dit"]:
|
if compile_args["compile_dit"]:
|
||||||
for i, block in enumerate(model.blocks):
|
for i, block in enumerate(model.blocks):
|
||||||
model.blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"])
|
model.blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"])
|
||||||
@ -311,7 +308,7 @@ class T2VSynthMochiModel:
|
|||||||
|
|
||||||
comfy_pbar = ProgressBar(sample_steps)
|
comfy_pbar = ProgressBar(sample_steps)
|
||||||
|
|
||||||
if (hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul) or self.device.type == "mps":
|
if hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul:
|
||||||
autocast_dtype = torch.float16
|
autocast_dtype = torch.float16
|
||||||
else:
|
else:
|
||||||
autocast_dtype = torch.bfloat16
|
autocast_dtype = torch.bfloat16
|
||||||
|
|||||||
72
nodes.py
72
nodes.py
@ -119,10 +119,6 @@ class DownloadAndLoadMochiModel:
|
|||||||
mm.soft_empty_cache()
|
mm.soft_empty_cache()
|
||||||
|
|
||||||
dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "fp8_e4m3fn_fast": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
|
dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "fp8_e4m3fn_fast": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
|
||||||
if "fp8" in precision:
|
|
||||||
vae_dtype = torch.bfloat16
|
|
||||||
else:
|
|
||||||
vae_dtype = dtype
|
|
||||||
|
|
||||||
# Transformer model
|
# Transformer model
|
||||||
model_download_path = os.path.join(folder_paths.models_dir, 'diffusion_models', 'mochi')
|
model_download_path = os.path.join(folder_paths.models_dir, 'diffusion_models', 'mochi')
|
||||||
@ -178,15 +174,14 @@ class DownloadAndLoadMochiModel:
|
|||||||
nonlinearity="silu",
|
nonlinearity="silu",
|
||||||
output_nonlinearity="silu",
|
output_nonlinearity="silu",
|
||||||
causal=True,
|
causal=True,
|
||||||
dtype=vae_dtype,
|
|
||||||
)
|
)
|
||||||
vae_sd = load_torch_file(vae_path)
|
vae_sd = load_torch_file(vae_path)
|
||||||
if is_accelerate_available:
|
if is_accelerate_available:
|
||||||
for key in vae_sd:
|
for key in vae_sd:
|
||||||
set_module_tensor_to_device(vae, key, dtype=vae_dtype, device=offload_device, value=vae_sd[key])
|
set_module_tensor_to_device(vae, key, dtype=torch.bfloat16, device=offload_device, value=vae_sd[key])
|
||||||
else:
|
else:
|
||||||
vae.load_state_dict(vae_sd, strict=True)
|
vae.load_state_dict(vae_sd, strict=True)
|
||||||
vae.eval().to(vae_dtype).to("cpu")
|
vae.eval().to(torch.bfloat16).to("cpu")
|
||||||
del vae_sd
|
del vae_sd
|
||||||
|
|
||||||
return (model, vae,)
|
return (model, vae,)
|
||||||
@ -247,7 +242,6 @@ class MochiTorchCompileSettings:
|
|||||||
"compile_dit": ("BOOLEAN", {"default": True, "tooltip": "Compiles all transformer blocks"}),
|
"compile_dit": ("BOOLEAN", {"default": True, "tooltip": "Compiles all transformer blocks"}),
|
||||||
"compile_final_layer": ("BOOLEAN", {"default": True, "tooltip": "Enable compiling final layer."}),
|
"compile_final_layer": ("BOOLEAN", {"default": True, "tooltip": "Enable compiling final layer."}),
|
||||||
"dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}),
|
"dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}),
|
||||||
"dynamo_cache_size_limit": ("INT", {"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.cache_size_limit"}),
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
RETURN_TYPES = ("MOCHICOMPILEARGS",)
|
RETURN_TYPES = ("MOCHICOMPILEARGS",)
|
||||||
@ -256,7 +250,7 @@ class MochiTorchCompileSettings:
|
|||||||
CATEGORY = "MochiWrapper"
|
CATEGORY = "MochiWrapper"
|
||||||
DESCRIPTION = "torch.compile settings, when connected to the model loader, torch.compile of the selected layers is attempted. Requires Triton and torch 2.5.0 is recommended"
|
DESCRIPTION = "torch.compile settings, when connected to the model loader, torch.compile of the selected layers is attempted. Requires Triton and torch 2.5.0 is recommended"
|
||||||
|
|
||||||
def loadmodel(self, backend, fullgraph, mode, compile_dit, compile_final_layer, dynamic, dynamo_cache_size_limit):
|
def loadmodel(self, backend, fullgraph, mode, compile_dit, compile_final_layer, dynamic):
|
||||||
|
|
||||||
compile_args = {
|
compile_args = {
|
||||||
"backend": backend,
|
"backend": backend,
|
||||||
@ -265,7 +259,6 @@ class MochiTorchCompileSettings:
|
|||||||
"compile_dit": compile_dit,
|
"compile_dit": compile_dit,
|
||||||
"compile_final_layer": compile_final_layer,
|
"compile_final_layer": compile_final_layer,
|
||||||
"dynamic": dynamic,
|
"dynamic": dynamic,
|
||||||
"dynamo_cache_size_limit": dynamo_cache_size_limit,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return (compile_args, )
|
return (compile_args, )
|
||||||
@ -315,16 +308,6 @@ class MochiVAELoader:
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
vae_sd = load_torch_file(vae_path)
|
vae_sd = load_torch_file(vae_path)
|
||||||
|
|
||||||
#support loading from combined VAE
|
|
||||||
if vae_sd.get("decoder.blocks.0.0.bias") is not None:
|
|
||||||
new_vae_sd = {}
|
|
||||||
for k, v in vae_sd.items():
|
|
||||||
if k.startswith("decoder."):
|
|
||||||
new_k = k[len("decoder."):]
|
|
||||||
new_vae_sd[new_k] = v
|
|
||||||
vae_sd = new_vae_sd
|
|
||||||
|
|
||||||
if is_accelerate_available:
|
if is_accelerate_available:
|
||||||
for name, param in vae.named_parameters():
|
for name, param in vae.named_parameters():
|
||||||
set_module_tensor_to_device(vae, name, dtype=dtype, device=offload_device, value=vae_sd[name])
|
set_module_tensor_to_device(vae, name, dtype=dtype, device=offload_device, value=vae_sd[name])
|
||||||
@ -393,16 +376,6 @@ class MochiVAEEncoderLoader:
|
|||||||
)
|
)
|
||||||
|
|
||||||
encoder_sd = load_torch_file(vae_path)
|
encoder_sd = load_torch_file(vae_path)
|
||||||
|
|
||||||
#support loading from combined VAE
|
|
||||||
if encoder_sd.get("encoder.layers.0.bias") is not None:
|
|
||||||
new_vae_sd = {}
|
|
||||||
for k, v in encoder_sd.items():
|
|
||||||
if k.startswith("encoder."):
|
|
||||||
new_k = k[len("encoder."):]
|
|
||||||
new_vae_sd[new_k] = v
|
|
||||||
encoder_sd = new_vae_sd
|
|
||||||
|
|
||||||
if is_accelerate_available:
|
if is_accelerate_available:
|
||||||
for name, param in encoder.named_parameters():
|
for name, param in encoder.named_parameters():
|
||||||
set_module_tensor_to_device(encoder, name, dtype=dtype, device=offload_device, value=encoder_sd[name])
|
set_module_tensor_to_device(encoder, name, dtype=dtype, device=offload_device, value=encoder_sd[name])
|
||||||
@ -598,9 +571,6 @@ class MochiDecode:
|
|||||||
"tile_overlap_factor_height": ("FLOAT", {"default": 0.1666, "min": 0.0, "max": 1.0, "step": 0.001}),
|
"tile_overlap_factor_height": ("FLOAT", {"default": 0.1666, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
"tile_overlap_factor_width": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}),
|
"tile_overlap_factor_width": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
},
|
},
|
||||||
"optional": {
|
|
||||||
"unnormalize": ("BOOLEAN", {"default": False, "tooltip": "Unnormalize the latents before decoding"}),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE",)
|
RETURN_TYPES = ("IMAGE",)
|
||||||
@ -609,12 +579,11 @@ class MochiDecode:
|
|||||||
CATEGORY = "MochiWrapper"
|
CATEGORY = "MochiWrapper"
|
||||||
|
|
||||||
def decode(self, vae, samples, enable_vae_tiling, tile_sample_min_height, tile_sample_min_width, tile_overlap_factor_height,
|
def decode(self, vae, samples, enable_vae_tiling, tile_sample_min_height, tile_sample_min_width, tile_overlap_factor_height,
|
||||||
tile_overlap_factor_width, auto_tile_size, frame_batch_size, unnormalize=False):
|
tile_overlap_factor_width, auto_tile_size, frame_batch_size):
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
intermediate_device = mm.intermediate_device()
|
intermediate_device = mm.intermediate_device()
|
||||||
samples = samples["samples"]
|
samples = samples["samples"]
|
||||||
if unnormalize:
|
|
||||||
samples = dit_latents_to_vae_latents(samples)
|
samples = dit_latents_to_vae_latents(samples)
|
||||||
samples = samples.to(vae.dtype).to(device)
|
samples = samples.to(vae.dtype).to(device)
|
||||||
|
|
||||||
@ -730,9 +699,6 @@ class MochiDecodeSpatialTiling:
|
|||||||
"min_block_size": ("INT", {"default": 1, "min": 1, "max": 256, "step": 1, "tooltip": "Minimum number of pixels in each dimension when subdividing"}),
|
"min_block_size": ("INT", {"default": 1, "min": 1, "max": 256, "step": 1, "tooltip": "Minimum number of pixels in each dimension when subdividing"}),
|
||||||
"per_batch": ("INT", {"default": 6, "min": 1, "max": 256, "step": 1, "tooltip": "Number of samples per batch, in latent space (6 frames in 1 latent)"}),
|
"per_batch": ("INT", {"default": 6, "min": 1, "max": 256, "step": 1, "tooltip": "Number of samples per batch, in latent space (6 frames in 1 latent)"}),
|
||||||
},
|
},
|
||||||
"optional": {
|
|
||||||
"unnormalize": ("BOOLEAN", {"default": True, "tooltip": "Unnormalize the latents before decoding"}),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE",)
|
RETURN_TYPES = ("IMAGE",)
|
||||||
@ -741,12 +707,11 @@ class MochiDecodeSpatialTiling:
|
|||||||
CATEGORY = "MochiWrapper"
|
CATEGORY = "MochiWrapper"
|
||||||
|
|
||||||
def decode(self, vae, samples, enable_vae_tiling, num_tiles_w, num_tiles_h, overlap,
|
def decode(self, vae, samples, enable_vae_tiling, num_tiles_w, num_tiles_h, overlap,
|
||||||
min_block_size, per_batch, unnormalize=True):
|
min_block_size, per_batch):
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
intermediate_device = mm.intermediate_device()
|
intermediate_device = mm.intermediate_device()
|
||||||
samples = samples["samples"]
|
samples = samples["samples"]
|
||||||
if unnormalize:
|
|
||||||
samples = dit_latents_to_vae_latents(samples)
|
samples = dit_latents_to_vae_latents(samples)
|
||||||
samples = samples.to(vae.dtype).to(device)
|
samples = samples.to(vae.dtype).to(device)
|
||||||
|
|
||||||
@ -803,17 +768,14 @@ class MochiImageEncode:
|
|||||||
"overlap": ("INT", {"default": 16, "min": 0, "max": 256, "step": 1, "tooltip": "Number of pixel of overlap between adjacent tiles"}),
|
"overlap": ("INT", {"default": 16, "min": 0, "max": 256, "step": 1, "tooltip": "Number of pixel of overlap between adjacent tiles"}),
|
||||||
"min_block_size": ("INT", {"default": 1, "min": 1, "max": 256, "step": 1, "tooltip": "Minimum number of pixels in each dimension when subdividing"}),
|
"min_block_size": ("INT", {"default": 1, "min": 1, "max": 256, "step": 1, "tooltip": "Minimum number of pixels in each dimension when subdividing"}),
|
||||||
},
|
},
|
||||||
"optional": {
|
|
||||||
"normalize": ("BOOLEAN", {"default": True, "tooltip": "Normalize the images before encoding"}),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ("LATENT",)
|
RETURN_TYPES = ("LATENT",)
|
||||||
RETURN_NAMES = ("samples",)
|
RETURN_NAMES = ("samples",)
|
||||||
FUNCTION = "encode"
|
FUNCTION = "decode"
|
||||||
CATEGORY = "MochiWrapper"
|
CATEGORY = "MochiWrapper"
|
||||||
|
|
||||||
def encode(self, encoder, images, enable_vae_tiling, num_tiles_w, num_tiles_h, overlap, min_block_size, normalize=True):
|
def decode(self, encoder, images, enable_vae_tiling, num_tiles_w, num_tiles_h, overlap, min_block_size):
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
intermediate_device = mm.intermediate_device()
|
intermediate_device = mm.intermediate_device()
|
||||||
@ -838,7 +800,6 @@ class MochiImageEncode:
|
|||||||
latents = apply_tiled(encoder, video, num_tiles_w = num_tiles_w, num_tiles_h = num_tiles_h, overlap=overlap, min_block_size=min_block_size)
|
latents = apply_tiled(encoder, video, num_tiles_w = num_tiles_w, num_tiles_h = num_tiles_h, overlap=overlap, min_block_size=min_block_size)
|
||||||
else:
|
else:
|
||||||
latents = encoder(video)
|
latents = encoder(video)
|
||||||
if normalize:
|
|
||||||
latents = vae_latents_to_dit_latents(latents)
|
latents = vae_latents_to_dit_latents(latents)
|
||||||
print("encoder output",latents.shape)
|
print("encoder output",latents.shape)
|
||||||
|
|
||||||
@ -870,21 +831,8 @@ class MochiLatentPreview:
|
|||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
|
|
||||||
#latent_rgb_factors = [[0.1236769792512748, 0.11775175335219157, -0.17700629766423637], [-0.08504104329270078, 0.026605813147523694, -0.006843165704926019], [-0.17093308616366876, 0.027991854696200386, 0.14179146288816308], [-0.17179555328757623, 0.09844317368603078, 0.14470997015982784], [-0.16975067171668484, -0.10739852629856643, -0.1894254942909962], [-0.19315259266769888, -0.011029760569485209, -0.08519702054654255], [-0.08399895091432583, -0.0964246452052032, -0.033622359523655665], [0.08148916330842498, 0.027500645903400067, -0.06593099749891196], [0.0456603103902293, -0.17844808072462398, 0.04204775167149785], [0.001751626383204502, -0.030567890189647867, -0.022078082809772193], [0.05110631095056278, -0.0709677393548804, 0.08963683539504264], [0.010515800868829, -0.18382052841762514, -0.08554553339721907]]
|
latent_rgb_factors = [[0.1236769792512748, 0.11775175335219157, -0.17700629766423637], [-0.08504104329270078, 0.026605813147523694, -0.006843165704926019], [-0.17093308616366876, 0.027991854696200386, 0.14179146288816308], [-0.17179555328757623, 0.09844317368603078, 0.14470997015982784], [-0.16975067171668484, -0.10739852629856643, -0.1894254942909962], [-0.19315259266769888, -0.011029760569485209, -0.08519702054654255], [-0.08399895091432583, -0.0964246452052032, -0.033622359523655665], [0.08148916330842498, 0.027500645903400067, -0.06593099749891196], [0.0456603103902293, -0.17844808072462398, 0.04204775167149785], [0.001751626383204502, -0.030567890189647867, -0.022078082809772193], [0.05110631095056278, -0.0709677393548804, 0.08963683539504264], [0.010515800868829, -0.18382052841762514, -0.08554553339721907]]
|
||||||
latent_rgb_factors =[
|
|
||||||
[-0.0069, -0.0045, 0.0018],
|
|
||||||
[ 0.0154, -0.0692, -0.0274],
|
|
||||||
[ 0.0333, 0.0019, 0.0206],
|
|
||||||
[-0.1390, 0.0628, 0.1678],
|
|
||||||
[-0.0725, 0.0134, -0.1898],
|
|
||||||
[ 0.0074, -0.0270, -0.0209],
|
|
||||||
[-0.0176, -0.0277, -0.0221],
|
|
||||||
[ 0.5294, 0.5204, 0.3852],
|
|
||||||
[-0.0326, -0.0446, -0.0143],
|
|
||||||
[-0.0659, 0.0153, -0.0153],
|
|
||||||
[ 0.0185, -0.0217, 0.0014],
|
|
||||||
[-0.0396, -0.0495, -0.0281]
|
|
||||||
]
|
|
||||||
# import random
|
# import random
|
||||||
# random.seed(seed)
|
# random.seed(seed)
|
||||||
# latent_rgb_factors = [[random.uniform(min_val, max_val) for _ in range(3)] for _ in range(12)]
|
# latent_rgb_factors = [[random.uniform(min_val, max_val) for _ in range(3)] for _ in range(12)]
|
||||||
@ -892,7 +840,7 @@ class MochiLatentPreview:
|
|||||||
# print(latent_rgb_factors)
|
# print(latent_rgb_factors)
|
||||||
|
|
||||||
|
|
||||||
latent_rgb_factors_bias = [-0.0940, -0.1418, -0.1453]
|
latent_rgb_factors_bias = [0,0,0]
|
||||||
|
|
||||||
latent_rgb_factors = torch.tensor(latent_rgb_factors, device=latents.device, dtype=latents.dtype).transpose(0, 1)
|
latent_rgb_factors = torch.tensor(latent_rgb_factors, device=latents.device, dtype=latents.dtype).transpose(0, 1)
|
||||||
latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device=latents.device, dtype=latents.dtype)
|
latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device=latents.device, dtype=latents.dtype)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user