From 83097a6b6360fc5aab77d059783f6abd1d061893 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Thu, 24 Oct 2024 00:40:52 +0300 Subject: [PATCH] Update t2v_synth_mochi.py --- mochi_preview/t2v_synth_mochi.py | 40 ++++++++++++++++---------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index 2550991..2d7c318 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -125,26 +125,26 @@ class T2VSynthMochiModel: self.offload_device = offload_device print("Initializing model...") - model: nn.Module = torch.nn.utils.skip_init( - AsymmDiTJoint, - depth=48, - patch_size=2, - num_heads=24, - hidden_size_x=3072, - hidden_size_y=1536, - mlp_ratio_x=4.0, - mlp_ratio_y=4.0, - in_channels=12, - qk_norm=True, - qkv_bias=False, - out_bias=True, - patch_embed_bias=True, - timestep_mlp_bias=True, - timestep_scale=1000.0, - t5_feat_dim=4096, - t5_token_length=256, - rope_theta=10000.0, - ) + with (init_empty_weights() if is_accelerate_available else nullcontext()): + model: nn.Module = AsymmDiTJoint( + depth=48, + patch_size=2, + num_heads=24, + hidden_size_x=3072, + hidden_size_y=1536, + mlp_ratio_x=4.0, + mlp_ratio_y=4.0, + in_channels=12, + qk_norm=True, + qkv_bias=False, + out_bias=True, + patch_embed_bias=True, + timestep_mlp_bias=True, + timestep_scale=1000.0, + t5_feat_dim=4096, + t5_token_length=256, + rope_theta=10000.0, + ) params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"} print(f"Loading model state_dict from {dit_checkpoint_path}...")