From 70ad32621b86c719261ce8225b209e7c3bf4033b Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 5 Nov 2024 20:21:47 +0200 Subject: [PATCH] reorder --- .../dit/joint_model/asymm_models_joint.py | 3 ++- mochi_preview/t2v_synth_mochi.py | 23 ++++++++++--------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/mochi_preview/dit/joint_model/asymm_models_joint.py b/mochi_preview/dit/joint_model/asymm_models_joint.py index fde6041..5e211b4 100644 --- a/mochi_preview/dit/joint_model/asymm_models_joint.py +++ b/mochi_preview/dit/joint_model/asymm_models_joint.py @@ -639,12 +639,12 @@ class AsymmDiTJoint(nn.Module): self, x: torch.Tensor, sigma: torch.Tensor, - fastercache_counter: int, y_feat: List[torch.Tensor], y_mask: List[torch.Tensor], rope_cos: torch.Tensor = None, rope_sin: torch.Tensor = None, fastercache: Optional[Dict[str, torch.Tensor]] = None, + fastercache_counter: Optional[int]=0, ): """Forward pass of DiT. @@ -671,6 +671,7 @@ class AsymmDiTJoint(nn.Module): else: fastercache_start_step = 1000 fastercache_device = None + #print(fastercache_counter) for i, block in enumerate(self.blocks): x, y_feat = block( diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index be7a5db..b560f6c 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -231,8 +231,9 @@ class T2VSynthMochiModel: "y_feat": [args["negative_embeds"]["embeds"].to(self.device)], "fastercache": args["fastercache"] if args["fastercache"] is not None else None } - + print(args["fastercache"]) if args["fastercache"]: + print("Using fastercache") self.fastercache_start_step = args["fastercache"]["start_step"] self.fastercache_lf_step = args["fastercache"]["lf_step"] self.fastercache_hf_step = args["fastercache"]["hf_step"] @@ -249,8 +250,8 @@ class T2VSynthMochiModel: out_cond = self.dit( z, sigma, - self.fastercache_counter, - **sample) + **sample, + fastercache_counter=self.fastercache_counter) (bb, cc, tt, hh, ww) = out_cond.shape cond = rearrange(out_cond, "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww) @@ -273,14 +274,14 @@ class T2VSynthMochiModel: out_cond = self.dit( z, sigma, - self.fastercache_counter, - **sample) + **sample, + fastercache_counter=self.fastercache_counter) out_uncond = self.dit( z, - sigma, - self.fastercache_counter, - **sample_null) + sigma, + **sample_null, + fastercache_counter=self.fastercache_counter) if self.fastercache_counter >= self.fastercache_start_step + 1: (bb, cc, tt, hh, ww) = out_cond.shape @@ -297,9 +298,9 @@ class T2VSynthMochiModel: else: #handle cfg 1.0 out_cond = self.dit( z, - sigma, - self.fastercache_counter, - **sample) + sigma, + **sample, + fastercache_counter=self.fastercache_counter) return out_cond