fix dtype selection
This commit is contained in:
parent
256f552526
commit
34e029bacc
Binary file not shown.
@ -107,6 +107,7 @@ class T2VSynthMochiModel:
|
||||
device_id: int,
|
||||
vae_stats_path: str,
|
||||
dit_checkpoint_path: str,
|
||||
weight_dtype: torch.dtype = torch.float8_e4m3fn,
|
||||
):
|
||||
super().__init__()
|
||||
t = Timer()
|
||||
@ -146,7 +147,7 @@ class T2VSynthMochiModel:
|
||||
for name, param in self.dit.named_parameters():
|
||||
params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"}
|
||||
if not any(keyword in name for keyword in params_to_keep):
|
||||
param.data = param.data.to(torch.float8_e4m3fn)
|
||||
param.data = param.data.to(weight_dtype)
|
||||
else:
|
||||
param.data = param.data.to(torch.bfloat16)
|
||||
|
||||
|
||||
3
nodes.py
3
nodes.py
@ -47,7 +47,7 @@ class DownloadAndLoadMochiModel:
|
||||
],
|
||||
),
|
||||
"precision": (["fp8_e4m3fn","fp16", "fp32", "bf16"],
|
||||
{"default": "bf16", "tooltip": "official recommendation is that 2b model should be fp16, 5b model should be bf16"}
|
||||
{"default": "fp8_e4m3fn", }
|
||||
),
|
||||
},
|
||||
}
|
||||
@ -99,6 +99,7 @@ class DownloadAndLoadMochiModel:
|
||||
device_id=0,
|
||||
vae_stats_path=os.path.join(script_directory, "configs", "vae_stats.json"),
|
||||
dit_checkpoint_path=model_path,
|
||||
weight_dtype=dtype,
|
||||
)
|
||||
vae = Decoder(
|
||||
out_channels=3,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user