fix dtype selection
This commit is contained in:
parent
256f552526
commit
34e029bacc
Binary file not shown.
@ -107,6 +107,7 @@ class T2VSynthMochiModel:
|
|||||||
device_id: int,
|
device_id: int,
|
||||||
vae_stats_path: str,
|
vae_stats_path: str,
|
||||||
dit_checkpoint_path: str,
|
dit_checkpoint_path: str,
|
||||||
|
weight_dtype: torch.dtype = torch.float8_e4m3fn,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
t = Timer()
|
t = Timer()
|
||||||
@ -146,7 +147,7 @@ class T2VSynthMochiModel:
|
|||||||
for name, param in self.dit.named_parameters():
|
for name, param in self.dit.named_parameters():
|
||||||
params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"}
|
params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"}
|
||||||
if not any(keyword in name for keyword in params_to_keep):
|
if not any(keyword in name for keyword in params_to_keep):
|
||||||
param.data = param.data.to(torch.float8_e4m3fn)
|
param.data = param.data.to(weight_dtype)
|
||||||
else:
|
else:
|
||||||
param.data = param.data.to(torch.bfloat16)
|
param.data = param.data.to(torch.bfloat16)
|
||||||
|
|
||||||
|
|||||||
3
nodes.py
3
nodes.py
@ -47,7 +47,7 @@ class DownloadAndLoadMochiModel:
|
|||||||
],
|
],
|
||||||
),
|
),
|
||||||
"precision": (["fp8_e4m3fn","fp16", "fp32", "bf16"],
|
"precision": (["fp8_e4m3fn","fp16", "fp32", "bf16"],
|
||||||
{"default": "bf16", "tooltip": "official recommendation is that 2b model should be fp16, 5b model should be bf16"}
|
{"default": "fp8_e4m3fn", }
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -99,6 +99,7 @@ class DownloadAndLoadMochiModel:
|
|||||||
device_id=0,
|
device_id=0,
|
||||||
vae_stats_path=os.path.join(script_directory, "configs", "vae_stats.json"),
|
vae_stats_path=os.path.join(script_directory, "configs", "vae_stats.json"),
|
||||||
dit_checkpoint_path=model_path,
|
dit_checkpoint_path=model_path,
|
||||||
|
weight_dtype=dtype,
|
||||||
)
|
)
|
||||||
vae = Decoder(
|
vae = Decoder(
|
||||||
out_channels=3,
|
out_channels=3,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user