diff --git a/mochi_preview/dit/joint_model/asymm_models_joint.py b/mochi_preview/dit/joint_model/asymm_models_joint.py index 5e211b4..a05be43 100644 --- a/mochi_preview/dit/joint_model/asymm_models_joint.py +++ b/mochi_preview/dit/joint_model/asymm_models_joint.py @@ -643,7 +643,7 @@ class AsymmDiTJoint(nn.Module): y_mask: List[torch.Tensor], rope_cos: torch.Tensor = None, rope_sin: torch.Tensor = None, - fastercache: Optional[Dict[str, torch.Tensor]] = None, + fastercache: Optional[Dict] = None, fastercache_counter: Optional[int]=0, ): """Forward pass of DiT. diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index b560f6c..3d650fe 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -224,12 +224,10 @@ class T2VSynthMochiModel: sample = { "y_mask": [args["positive_embeds"]["attention_mask"].to(self.device)], "y_feat": [args["positive_embeds"]["embeds"].to(self.device)], - "fastercache": args["fastercache"] if args["fastercache"] is not None else None } sample_null = { "y_mask": [args["negative_embeds"]["attention_mask"].to(self.device)], "y_feat": [args["negative_embeds"]["embeds"].to(self.device)], - "fastercache": args["fastercache"] if args["fastercache"] is not None else None } print(args["fastercache"]) if args["fastercache"]: @@ -251,6 +249,7 @@ class T2VSynthMochiModel: z, sigma, **sample, + fastercache = args["fastercache"], fastercache_counter=self.fastercache_counter) (bb, cc, tt, hh, ww) = out_cond.shape @@ -275,12 +274,14 @@ class T2VSynthMochiModel: z, sigma, **sample, + fastercache = args["fastercache"], fastercache_counter=self.fastercache_counter) out_uncond = self.dit( z, sigma, **sample_null, + fastercache = args["fastercache"], fastercache_counter=self.fastercache_counter) if self.fastercache_counter >= self.fastercache_start_step + 1: @@ -300,6 +301,7 @@ class T2VSynthMochiModel: z, sigma, **sample, + fastercache = args["fastercache"], fastercache_counter=self.fastercache_counter) return out_cond