diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index bfa1b05..b019fe6 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -151,14 +151,17 @@ class T2VSynthMochiModel: ) with t("dit_load_checkpoint"): params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"} + print(f"Loading model state_dict from {dit_checkpoint_path}...") dit_sd = load_torch_file(dit_checkpoint_path) if is_accelerate_available: + print("Using accelerate to load and assign model weights to device...") for name, param in model.named_parameters(): if not any(keyword in name for keyword in params_to_keep): set_module_tensor_to_device(model, name, dtype=weight_dtype, device=self.device, value=dit_sd[name]) else: set_module_tensor_to_device(model, name, dtype=torch.bfloat16, device=self.device, value=dit_sd[name]) else: + print("Loading state_dict without accelerate...") model.load_state_dict(dit_sd) for name, param in self.dit.named_parameters(): if not any(keyword in name for keyword in params_to_keep):