diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index 2550991..2d7c318 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -125,26 +125,26 @@ class T2VSynthMochiModel: self.offload_device = offload_device print("Initializing model...") - model: nn.Module = torch.nn.utils.skip_init( - AsymmDiTJoint, - depth=48, - patch_size=2, - num_heads=24, - hidden_size_x=3072, - hidden_size_y=1536, - mlp_ratio_x=4.0, - mlp_ratio_y=4.0, - in_channels=12, - qk_norm=True, - qkv_bias=False, - out_bias=True, - patch_embed_bias=True, - timestep_mlp_bias=True, - timestep_scale=1000.0, - t5_feat_dim=4096, - t5_token_length=256, - rope_theta=10000.0, - ) + with (init_empty_weights() if is_accelerate_available else nullcontext()): + model: nn.Module = AsymmDiTJoint( + depth=48, + patch_size=2, + num_heads=24, + hidden_size_x=3072, + hidden_size_y=1536, + mlp_ratio_x=4.0, + mlp_ratio_y=4.0, + in_channels=12, + qk_norm=True, + qkv_bias=False, + out_bias=True, + patch_embed_bias=True, + timestep_mlp_bias=True, + timestep_scale=1000.0, + t5_feat_dim=4096, + t5_token_length=256, + rope_theta=10000.0, + ) params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"} print(f"Loading model state_dict from {dit_checkpoint_path}...")