From 739138927692cc0dbca121af7729d1bb7a9a98a8 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 5 Nov 2024 12:29:53 +0200 Subject: [PATCH] cleanup --- .../dit/joint_model/asymm_models_joint.py | 10 +- mochi_preview/t2v_synth_mochi.py | 132 ++++++++---------- nodes.py | 3 - 3 files changed, 64 insertions(+), 81 deletions(-) diff --git a/mochi_preview/dit/joint_model/asymm_models_joint.py b/mochi_preview/dit/joint_model/asymm_models_joint.py index 4cf9625..5eca377 100644 --- a/mochi_preview/dit/joint_model/asymm_models_joint.py +++ b/mochi_preview/dit/joint_model/asymm_models_joint.py @@ -146,7 +146,7 @@ class AsymmetricAttention(nn.Module): self.qkv_y = nn.Linear(dim_y, 3 * dim_x, bias=qkv_bias, device=device) # Query and key normalization for stability. - assert qk_norm + #assert qk_norm if rms_norm_func == "flash_attn_triton": #use the same rms_norm_func try: from flash_attn.ops.triton.layer_norm import RMSNorm as FlashTritonRMSNorm #slightly faster @@ -338,7 +338,7 @@ class AsymmetricJointBlock(nn.Module): # MLP. mlp_hidden_dim_x = int(hidden_size_x * mlp_ratio_x) - assert mlp_hidden_dim_x == int(1536 * 8) + #assert mlp_hidden_dim_x == int(1536 * 8) self.mlp_x = FeedForward( in_features=hidden_size_x, hidden_size=mlp_hidden_dim_x, @@ -422,7 +422,7 @@ class AsymmetricJointBlock(nn.Module): self.cached_x_attention[-1].copy_(x_attn.to(fastercache_device)) self.cached_y_attention[-1].copy_(y_attn.to(fastercache_device)) - assert x_attn.size(1) == N + #assert x_attn.size(1) == N x = residual_tanh_gated_rmsnorm(x, x_attn, gate_msa_x) if self.update_y: y = residual_tanh_gated_rmsnorm(y, y_attn, gate_msa_y) @@ -606,7 +606,7 @@ class AsymmDiTJoint(nn.Module): T, H, W = x.shape[-3:] pH, pW = H // self.patch_size, W // self.patch_size x = self.embed_x(x) # (B, N, D), where N = T * H * W / patch_size ** 2 - assert x.ndim == 3 + #assert x.ndim == 3 # Construct position array of size [N, 3]. # pos[:, 0] is the frame index for each location, @@ -614,7 +614,7 @@ class AsymmDiTJoint(nn.Module): # pos[:, 2] is the column index for each location. pH, pW = H // self.patch_size, W // self.patch_size N = T * pH * pW - assert x.size(1) == N + #assert x.size(1) == N pos = create_position_matrix(T, pH=pH, pW=pW, device=x.device, dtype=torch.float32) # (N, 3) rope_cos, rope_sin = compute_mixed_rotation(freqs=self.pos_frequencies, pos=pos) # Each are (N, num_heads, dim // 2) diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index 9b0d405..be7a5db 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -72,25 +72,6 @@ def fft(tensor): high_freq_fft = tensor_fft_shifted * high_freq_mask return low_freq_fft, high_freq_fft -def unnormalize_latents( - z: torch.Tensor, - mean: torch.Tensor, - std: torch.Tensor, -) -> torch.Tensor: - """Unnormalize latents. Useful for decoding DiT samples. - - Args: - z (torch.Tensor): [B, C_z, T_z, H_z, W_z], float - - Returns: - torch.Tensor: [B, C_z, T_z, H_z, W_z], float - """ - mean = mean[:, None, None, None] - std = std[:, None, None, None] - - assert z.ndim == 5 - assert z.size(1) == mean.size(0) == std.size(0) - return z * std.to(z) + mean.to(z) class T2VSynthMochiModel: def __init__( @@ -188,10 +169,6 @@ class T2VSynthMochiModel: self.dit = model - def get_packed_indices(self, y_mask, **latent_dims): - # temporary dummy func for compatibility - return [] - def move_to_device_(self, sample): if isinstance(sample, dict): for key in sample.keys(): @@ -265,57 +242,66 @@ class T2VSynthMochiModel: def model_fn(*, z, sigma, cfg_scale): nonlocal sample, sample_null - if args["fastercache"]: - self.fastercache_counter+=1 - if self.fastercache_counter >= self.fastercache_start_step + 3 and self.fastercache_counter % 5 !=0: - out_cond = self.dit( - z, - sigma, - self.fastercache_counter, - **sample) - - (bb, cc, tt, hh, ww) = out_cond.shape - cond = rearrange(out_cond, "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww) - lf_c, hf_c = fft(cond.float()) - if self.fastercache_counter <= self.fastercache_lf_step: - self.delta_lf = self.delta_lf * 1.1 - if self.fastercache_counter >= self.fastercache_hf_step: - self.delta_hf = self.delta_hf * 1.1 - - new_hf_uc = self.delta_hf + hf_c - new_lf_uc = self.delta_lf + lf_c - - combine_uc = new_lf_uc + new_hf_uc - combined_fft = torch.fft.ifftshift(combine_uc) - recovered_uncond = torch.fft.ifft2(combined_fft).real - recovered_uncond = rearrange(recovered_uncond.to(out_cond.dtype), "(B T) C H W -> B C T H W", B=bb, C=cc, T=tt, H=hh, W=ww) - - return recovered_uncond + cfg_scale * (out_cond - recovered_uncond) - else: - out_cond = self.dit( - z, - sigma, - self.fastercache_counter, - **sample) - - out_uncond = self.dit( - z, - sigma, - self.fastercache_counter, - **sample_null) - - if self.fastercache_counter >= self.fastercache_start_step + 1: - (bb, cc, tt, hh, ww) = out_cond.shape - cond = rearrange(out_cond.float(), "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww) - uncond = rearrange(out_uncond.float(), "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww) - - lf_c, hf_c = fft(cond) - lf_uc, hf_uc = fft(uncond) - - self.delta_lf = lf_uc - lf_c - self.delta_hf = hf_uc - hf_c + if cfg_scale != 1.0: + if args["fastercache"]: + self.fastercache_counter+=1 + if self.fastercache_counter >= self.fastercache_start_step + 3 and self.fastercache_counter % 5 !=0: + out_cond = self.dit( + z, + sigma, + self.fastercache_counter, + **sample) - return out_uncond + cfg_scale * (out_cond - out_uncond) + (bb, cc, tt, hh, ww) = out_cond.shape + cond = rearrange(out_cond, "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww) + lf_c, hf_c = fft(cond.float()) + if self.fastercache_counter <= self.fastercache_lf_step: + self.delta_lf = self.delta_lf * 1.1 + if self.fastercache_counter >= self.fastercache_hf_step: + self.delta_hf = self.delta_hf * 1.1 + + new_hf_uc = self.delta_hf + hf_c + new_lf_uc = self.delta_lf + lf_c + + combine_uc = new_lf_uc + new_hf_uc + combined_fft = torch.fft.ifftshift(combine_uc) + recovered_uncond = torch.fft.ifft2(combined_fft).real + recovered_uncond = rearrange(recovered_uncond.to(out_cond.dtype), "(B T) C H W -> B C T H W", B=bb, C=cc, T=tt, H=hh, W=ww) + + return recovered_uncond + cfg_scale * (out_cond - recovered_uncond) + else: + out_cond = self.dit( + z, + sigma, + self.fastercache_counter, + **sample) + + out_uncond = self.dit( + z, + sigma, + self.fastercache_counter, + **sample_null) + + if self.fastercache_counter >= self.fastercache_start_step + 1: + (bb, cc, tt, hh, ww) = out_cond.shape + cond = rearrange(out_cond.float(), "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww) + uncond = rearrange(out_uncond.float(), "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww) + + lf_c, hf_c = fft(cond) + lf_uc, hf_uc = fft(uncond) + + self.delta_lf = lf_uc - lf_c + self.delta_hf = hf_uc - hf_c + + return out_uncond + cfg_scale * (out_cond - out_uncond) + else: #handle cfg 1.0 + out_cond = self.dit( + z, + sigma, + self.fastercache_counter, + **sample) + return out_cond + comfy_pbar = ProgressBar(sample_steps) diff --git a/nodes.py b/nodes.py index c36c720..a6e2de8 100644 --- a/nodes.py +++ b/nodes.py @@ -1,7 +1,4 @@ import os -# import torch._dynamo -# torch._dynamo.config.suppress_errors = True - import torch import folder_paths import comfy.model_management as mm