fix dtype selection

This commit is contained in:
kijai 2024-10-23 15:49:37 +03:00
parent 256f552526
commit 34e029bacc
3 changed files with 4 additions and 2 deletions

View File

@ -107,6 +107,7 @@ class T2VSynthMochiModel:
device_id: int,
vae_stats_path: str,
dit_checkpoint_path: str,
weight_dtype: torch.dtype = torch.float8_e4m3fn,
):
super().__init__()
t = Timer()
@ -146,7 +147,7 @@ class T2VSynthMochiModel:
for name, param in self.dit.named_parameters():
params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"}
if not any(keyword in name for keyword in params_to_keep):
param.data = param.data.to(torch.float8_e4m3fn)
param.data = param.data.to(weight_dtype)
else:
param.data = param.data.to(torch.bfloat16)

View File

@ -47,7 +47,7 @@ class DownloadAndLoadMochiModel:
],
),
"precision": (["fp8_e4m3fn","fp16", "fp32", "bf16"],
{"default": "bf16", "tooltip": "official recommendation is that 2b model should be fp16, 5b model should be bf16"}
{"default": "fp8_e4m3fn", }
),
},
}
@ -99,6 +99,7 @@ class DownloadAndLoadMochiModel:
device_id=0,
vae_stats_path=os.path.join(script_directory, "configs", "vae_stats.json"),
dit_checkpoint_path=model_path,
weight_dtype=dtype,
)
vae = Decoder(
out_channels=3,