diff --git a/mochi_preview/__pycache__/t2v_synth_mochi.cpython-312.pyc b/mochi_preview/__pycache__/t2v_synth_mochi.cpython-312.pyc index b7e593d..1be05b3 100644 Binary files a/mochi_preview/__pycache__/t2v_synth_mochi.cpython-312.pyc and b/mochi_preview/__pycache__/t2v_synth_mochi.cpython-312.pyc differ diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index 7ceb8d9..15fb287 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -107,6 +107,7 @@ class T2VSynthMochiModel: device_id: int, vae_stats_path: str, dit_checkpoint_path: str, + weight_dtype: torch.dtype = torch.float8_e4m3fn, ): super().__init__() t = Timer() @@ -146,7 +147,7 @@ class T2VSynthMochiModel: for name, param in self.dit.named_parameters(): params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"} if not any(keyword in name for keyword in params_to_keep): - param.data = param.data.to(torch.float8_e4m3fn) + param.data = param.data.to(weight_dtype) else: param.data = param.data.to(torch.bfloat16) diff --git a/nodes.py b/nodes.py index ac77a53..04ae75f 100644 --- a/nodes.py +++ b/nodes.py @@ -47,7 +47,7 @@ class DownloadAndLoadMochiModel: ], ), "precision": (["fp8_e4m3fn","fp16", "fp32", "bf16"], - {"default": "bf16", "tooltip": "official recommendation is that 2b model should be fp16, 5b model should be bf16"} + {"default": "fp8_e4m3fn", } ), }, } @@ -99,6 +99,7 @@ class DownloadAndLoadMochiModel: device_id=0, vae_stats_path=os.path.join(script_directory, "configs", "vae_stats.json"), dit_checkpoint_path=model_path, + weight_dtype=dtype, ) vae = Decoder( out_channels=3,