diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index 5cd8f01..6507700 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -23,6 +23,8 @@ except: is_accelerate_available = False pass +from .dit.joint_model.asymm_models_joint import AsymmDiTJoint + MAX_T5_TOKEN_LENGTH = 256 def unnormalize_latents( @@ -125,9 +127,7 @@ class T2VSynthMochiModel: self.offload_device = offload_device with t("construct_dit"): - from .dit.joint_model.asymm_models_joint import ( - AsymmDiTJoint, - ) + model: nn.Module = torch.nn.utils.skip_init( AsymmDiTJoint, depth=48,