Update t2v_synth_mochi.py

This commit is contained in:
kijai 2024-10-23 19:38:01 +03:00
parent 4b9d43d318
commit 57640ab0f8

View File

@ -151,14 +151,17 @@ class T2VSynthMochiModel:
)
with t("dit_load_checkpoint"):
params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"}
print(f"Loading model state_dict from {dit_checkpoint_path}...")
dit_sd = load_torch_file(dit_checkpoint_path)
if is_accelerate_available:
print("Using accelerate to load and assign model weights to device...")
for name, param in model.named_parameters():
if not any(keyword in name for keyword in params_to_keep):
set_module_tensor_to_device(model, name, dtype=weight_dtype, device=self.device, value=dit_sd[name])
else:
set_module_tensor_to_device(model, name, dtype=torch.bfloat16, device=self.device, value=dit_sd[name])
else:
print("Loading state_dict without accelerate...")
model.load_state_dict(dit_sd)
for name, param in self.dit.named_parameters():
if not any(keyword in name for keyword in params_to_keep):