diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index b019fe6..5cd8f01 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -2,7 +2,6 @@ import json import random from typing import Dict, List -from safetensors.torch import load_file import numpy as np import torch import torch.nn as nn @@ -163,7 +162,7 @@ class T2VSynthMochiModel: else: print("Loading state_dict without accelerate...") model.load_state_dict(dit_sd) - for name, param in self.dit.named_parameters(): + for name, param in model.named_parameters(): if not any(keyword in name for keyword in params_to_keep): param.data = param.data.to(weight_dtype) else: @@ -180,7 +179,7 @@ class T2VSynthMochiModel: self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device) self.vae_std = torch.Tensor(vae_stats["std"]).to(self.device) - #t.print_stats() + t.print_stats() def get_conditioning(self, prompts, *, zero_last_n_prompts: int): B = len(prompts)