Update t2v_synth_mochi.py
This commit is contained in:
parent
4b9d43d318
commit
57640ab0f8
@ -151,14 +151,17 @@ class T2VSynthMochiModel:
|
||||
)
|
||||
with t("dit_load_checkpoint"):
|
||||
params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"}
|
||||
print(f"Loading model state_dict from {dit_checkpoint_path}...")
|
||||
dit_sd = load_torch_file(dit_checkpoint_path)
|
||||
if is_accelerate_available:
|
||||
print("Using accelerate to load and assign model weights to device...")
|
||||
for name, param in model.named_parameters():
|
||||
if not any(keyword in name for keyword in params_to_keep):
|
||||
set_module_tensor_to_device(model, name, dtype=weight_dtype, device=self.device, value=dit_sd[name])
|
||||
else:
|
||||
set_module_tensor_to_device(model, name, dtype=torch.bfloat16, device=self.device, value=dit_sd[name])
|
||||
else:
|
||||
print("Loading state_dict without accelerate...")
|
||||
model.load_state_dict(dit_sd)
|
||||
for name, param in self.dit.named_parameters():
|
||||
if not any(keyword in name for keyword in params_to_keep):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user