From a6e545531c28244d792b6f15ab00fa74af25a6cd Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sun, 3 Nov 2024 00:32:51 +0200 Subject: [PATCH] test --- .../dit/joint_model/asymm_models_joint.py | 52 ++++++-- mochi_preview/t2v_synth_mochi.py | 118 +++++++++++++----- 2 files changed, 129 insertions(+), 41 deletions(-) diff --git a/mochi_preview/dit/joint_model/asymm_models_joint.py b/mochi_preview/dit/joint_model/asymm_models_joint.py index f4fdf6a..a1621e7 100644 --- a/mochi_preview/dit/joint_model/asymm_models_joint.py +++ b/mochi_preview/dit/joint_model/asymm_models_joint.py @@ -389,6 +389,9 @@ class AsymmetricJointBlock(nn.Module): self.mod_y = nn.Linear(hidden_size_x, 4 * hidden_size_y, device=device) else: self.mod_y = nn.Linear(hidden_size_x, hidden_size_y, device=device) + + self.cached_x_attention = [None, None] + self.cached_y_attention = [None, None] # Self-attention: self.attn = AsymmetricAttention( @@ -428,6 +431,8 @@ class AsymmetricJointBlock(nn.Module): x: torch.Tensor, c: torch.Tensor, y: torch.Tensor, + fastercache_counter: Optional[int] = 0, + fastercache_start_step: Optional[int] = 15, **attn_kwargs, ): """Forward pass of a block. @@ -453,15 +458,36 @@ class AsymmetricJointBlock(nn.Module): scale_msa_y, gate_msa_y, scale_mlp_y, gate_mlp_y = mod_y.chunk(4, dim=1) else: scale_msa_y = mod_y - - # Self-attention block. - x_attn, y_attn = self.attn( - x, - y, - scale_x=scale_msa_x, - scale_y=scale_msa_y, - **attn_kwargs, - ) + + #fastercache + B = x.shape[0] + #print("x", x.shape) #([1, 9540, 3072]) + if fastercache_counter >= fastercache_start_step + 3 and fastercache_counter%3!=0 and self.cached_x_attention[-1].shape[0] >= B: + x_attn = ( + self.cached_x_attention[1][:B] + + (self.cached_x_attention[1][:B] - self.cached_x_attention[0][:B]) + * 0.3 + ).to(x.device, non_blocking=True) + y_attn = ( + self.cached_y_attention[1][:B] + + (self.cached_y_attention[1][:B] - self.cached_y_attention[0][:B]) + * 0.3 + ).to(x.device, non_blocking=True) + else: + # Self-attention block. + x_attn, y_attn = self.attn( + x, + y, + scale_x=scale_msa_x, + scale_y=scale_msa_y, + **attn_kwargs, + ) + if fastercache_counter == fastercache_start_step: + self.cached_x_attention = [x_attn, x_attn] + self.cached_y_attention = [y_attn, y_attn] + elif fastercache_counter > fastercache_start_step: + self.cached_x_attention[-1].copy_(x_attn) + self.cached_y_attention[-1].copy_(y_attn) assert x_attn.size(1) == N x = residual_tanh_gated_rmsnorm(x, x_attn, gate_msa_x) @@ -674,6 +700,8 @@ class AsymmDiTJoint(nn.Module): self, x: torch.Tensor, sigma: torch.Tensor, + fastercache_counter: int, + fastercache_start_step: int, y_feat: List[torch.Tensor], y_mask: List[torch.Tensor], packed_indices: Dict[str, torch.Tensor] = None, @@ -707,7 +735,9 @@ class AsymmDiTJoint(nn.Module): rope_cos=rope_cos, rope_sin=rope_sin, packed_indices=packed_indices, - ) # (B, M, D), (B, L, D) + fastercache_counter = fastercache_counter, + fastercache_start_step = fastercache_start_step, + ) # (B, M, D), (B, L, D) del y_feat # Final layers don't use dense text features. x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels) @@ -720,6 +750,6 @@ class AsymmDiTJoint(nn.Module): p1=self.patch_size, p2=self.patch_size, c=self.out_channels, - ) + ) return x diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index a1c6fb7..e713141 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -59,6 +59,22 @@ log = logging.getLogger(__name__) MAX_T5_TOKEN_LENGTH = 256 +def fft(tensor): + tensor_fft = torch.fft.fft2(tensor) + tensor_fft_shifted = torch.fft.fftshift(tensor_fft) + B, C, H, W = tensor.size() + radius = min(H, W) // 5 + + Y, X = torch.meshgrid(torch.arange(H), torch.arange(W)) + center_x, center_y = W // 2, H // 2 + mask = (X - center_x) ** 2 + (Y - center_y) ** 2 <= radius ** 2 + low_freq_mask = mask.unsqueeze(0).unsqueeze(0).to(tensor.device) + high_freq_mask = ~low_freq_mask + + low_freq_fft = tensor_fft_shifted * low_freq_mask + high_freq_fft = tensor_fft_shifted * high_freq_mask + + return low_freq_fft, high_freq_fft def unnormalize_latents( z: torch.Tensor, mean: torch.Tensor, @@ -285,42 +301,84 @@ class T2VSynthMochiModel: sample_null["packed_indices"] = self.get_packed_indices( sample_null["y_mask"], **latent_dims ) + self.use_fastercache = True + self.fastercache_counter = 0 + self.fastercache_start_step = 15 + self.fastercache_lf_step = 40 + self.fastercache_hf_step = 30 - def model_fn(*, z, sigma, cfg_scale): - self.dit.to(self.device) - if hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul: - autocast_dtype = torch.float16 - else: - autocast_dtype = torch.bfloat16 - + def model_fn(*, z, sigma, cfg_scale): nonlocal sample, sample_null - with torch.autocast(mm.get_autocast_device(self.device), dtype=autocast_dtype): - if cfg_scale > 1.0: - out_cond = self.dit(z, sigma, **sample) - out_uncond = self.dit(z, sigma, **sample_null) - else: - out_cond = self.dit(z, sigma, **sample) - return out_cond + if self.use_fastercache: + self.fastercache_counter+=1 + if self.fastercache_counter >= self.fastercache_start_step + 3 and self.fastercache_counter % 5 !=0: + out_cond = self.dit(z, sigma,self.fastercache_counter, self.fastercache_start_step, **sample) + + (bb, cc, tt, hh, ww) = out_cond.shape + cond = rearrange(out_cond, "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww) + lf_c, hf_c = fft(cond.float()) + #lf_step = 40 + #hf_step = 30 + if self.fastercache_counter <= self.fastercache_lf_step: + self.delta_lf = self.delta_lf * 1.1 + if self.fastercache_counter >= self.fastercache_hf_step: + self.delta_hf = self.delta_hf * 1.1 - return out_uncond + cfg_scale * (out_cond - out_uncond) + new_hf_uc = self.delta_hf + hf_c + new_lf_uc = self.delta_lf + lf_c + + combine_uc = new_lf_uc + new_hf_uc + combined_fft = torch.fft.ifftshift(combine_uc) + recovered_uncond = torch.fft.ifft2(combined_fft).real + recovered_uncond = rearrange(recovered_uncond.to(out_cond.dtype), "(B T) C H W -> B C T H W", B=bb, C=cc, T=tt, H=hh, W=ww) + + return recovered_uncond + cfg_scale * (out_cond - recovered_uncond) + else: + out_cond = self.dit(z, sigma, self.fastercache_counter, self.fastercache_start_step,**sample) + out_uncond = self.dit(z, sigma, self.fastercache_counter, self.fastercache_start_step,**sample_null) + #print("out_cond.shape",out_cond.shape) #([1, 12, 3, 60, 106]) + + if self.fastercache_counter >= self.fastercache_start_step + 1: + + (bb, cc, tt, hh, ww) = out_cond.shape + cond = rearrange(out_cond.float(), "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww) + uncond = rearrange(out_uncond.float(), "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww) + + lf_c, hf_c = fft(cond) + lf_uc, hf_uc = fft(uncond) + + self.delta_lf = lf_uc - lf_c + self.delta_hf = hf_uc - hf_c + + return out_uncond + cfg_scale * (out_cond - out_uncond) + comfy_pbar = ProgressBar(sample_steps) - for i in tqdm(range(0, sample_steps), desc="Processing Samples", total=sample_steps): - sigma = sigma_schedule[i] - dsigma = sigma - sigma_schedule[i + 1] - # `pred` estimates `z_0 - eps`. - pred = model_fn( - z=z, - sigma=torch.full([B], sigma, device=z.device), - cfg_scale=cfg_schedule[i], - ) - pred = pred.to(z) - z = z + dsigma * pred - if callback is not None: - callback(i, z.detach()[0].permute(1,0,2,3), None, sample_steps) - else: - comfy_pbar.update(1) + if hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul: + autocast_dtype = torch.float16 + else: + autocast_dtype = torch.bfloat16 + + self.dit.to(self.device) + + with torch.autocast(mm.get_autocast_device(self.device), dtype=autocast_dtype): + for i in tqdm(range(0, sample_steps), desc="Processing Samples", total=sample_steps): + sigma = sigma_schedule[i] + dsigma = sigma - sigma_schedule[i + 1] + + # `pred` estimates `z_0 - eps`. + pred = model_fn( + z=z, + sigma=torch.full([B], sigma, device=z.device), + cfg_scale=cfg_schedule[i], + ) + pred = pred.to(z) + z = z + dsigma * pred + if callback is not None: + callback(i, z.detach()[0].permute(1,0,2,3), None, sample_steps) + else: + comfy_pbar.update(1) self.dit.to(self.offload_device) logging.info(f"samples shape: {z.shape}")