This commit is contained in:
kijai 2024-11-05 12:29:53 +02:00
parent ab3b18a153
commit 7391389276
3 changed files with 64 additions and 81 deletions

View File

@ -146,7 +146,7 @@ class AsymmetricAttention(nn.Module):
self.qkv_y = nn.Linear(dim_y, 3 * dim_x, bias=qkv_bias, device=device)
# Query and key normalization for stability.
assert qk_norm
#assert qk_norm
if rms_norm_func == "flash_attn_triton": #use the same rms_norm_func
try:
from flash_attn.ops.triton.layer_norm import RMSNorm as FlashTritonRMSNorm #slightly faster
@ -338,7 +338,7 @@ class AsymmetricJointBlock(nn.Module):
# MLP.
mlp_hidden_dim_x = int(hidden_size_x * mlp_ratio_x)
assert mlp_hidden_dim_x == int(1536 * 8)
#assert mlp_hidden_dim_x == int(1536 * 8)
self.mlp_x = FeedForward(
in_features=hidden_size_x,
hidden_size=mlp_hidden_dim_x,
@ -422,7 +422,7 @@ class AsymmetricJointBlock(nn.Module):
self.cached_x_attention[-1].copy_(x_attn.to(fastercache_device))
self.cached_y_attention[-1].copy_(y_attn.to(fastercache_device))
assert x_attn.size(1) == N
#assert x_attn.size(1) == N
x = residual_tanh_gated_rmsnorm(x, x_attn, gate_msa_x)
if self.update_y:
y = residual_tanh_gated_rmsnorm(y, y_attn, gate_msa_y)
@ -606,7 +606,7 @@ class AsymmDiTJoint(nn.Module):
T, H, W = x.shape[-3:]
pH, pW = H // self.patch_size, W // self.patch_size
x = self.embed_x(x) # (B, N, D), where N = T * H * W / patch_size ** 2
assert x.ndim == 3
#assert x.ndim == 3
# Construct position array of size [N, 3].
# pos[:, 0] is the frame index for each location,
@ -614,7 +614,7 @@ class AsymmDiTJoint(nn.Module):
# pos[:, 2] is the column index for each location.
pH, pW = H // self.patch_size, W // self.patch_size
N = T * pH * pW
assert x.size(1) == N
#assert x.size(1) == N
pos = create_position_matrix(T, pH=pH, pW=pW, device=x.device, dtype=torch.float32) # (N, 3)
rope_cos, rope_sin = compute_mixed_rotation(freqs=self.pos_frequencies, pos=pos) # Each are (N, num_heads, dim // 2)

View File

@ -72,25 +72,6 @@ def fft(tensor):
high_freq_fft = tensor_fft_shifted * high_freq_mask
return low_freq_fft, high_freq_fft
def unnormalize_latents(
z: torch.Tensor,
mean: torch.Tensor,
std: torch.Tensor,
) -> torch.Tensor:
"""Unnormalize latents. Useful for decoding DiT samples.
Args:
z (torch.Tensor): [B, C_z, T_z, H_z, W_z], float
Returns:
torch.Tensor: [B, C_z, T_z, H_z, W_z], float
"""
mean = mean[:, None, None, None]
std = std[:, None, None, None]
assert z.ndim == 5
assert z.size(1) == mean.size(0) == std.size(0)
return z * std.to(z) + mean.to(z)
class T2VSynthMochiModel:
def __init__(
@ -188,10 +169,6 @@ class T2VSynthMochiModel:
self.dit = model
def get_packed_indices(self, y_mask, **latent_dims):
# temporary dummy func for compatibility
return []
def move_to_device_(self, sample):
if isinstance(sample, dict):
for key in sample.keys():
@ -265,57 +242,66 @@ class T2VSynthMochiModel:
def model_fn(*, z, sigma, cfg_scale):
nonlocal sample, sample_null
if args["fastercache"]:
self.fastercache_counter+=1
if self.fastercache_counter >= self.fastercache_start_step + 3 and self.fastercache_counter % 5 !=0:
out_cond = self.dit(
z,
sigma,
self.fastercache_counter,
**sample)
(bb, cc, tt, hh, ww) = out_cond.shape
cond = rearrange(out_cond, "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
lf_c, hf_c = fft(cond.float())
if self.fastercache_counter <= self.fastercache_lf_step:
self.delta_lf = self.delta_lf * 1.1
if self.fastercache_counter >= self.fastercache_hf_step:
self.delta_hf = self.delta_hf * 1.1
new_hf_uc = self.delta_hf + hf_c
new_lf_uc = self.delta_lf + lf_c
combine_uc = new_lf_uc + new_hf_uc
combined_fft = torch.fft.ifftshift(combine_uc)
recovered_uncond = torch.fft.ifft2(combined_fft).real
recovered_uncond = rearrange(recovered_uncond.to(out_cond.dtype), "(B T) C H W -> B C T H W", B=bb, C=cc, T=tt, H=hh, W=ww)
return recovered_uncond + cfg_scale * (out_cond - recovered_uncond)
else:
out_cond = self.dit(
z,
sigma,
self.fastercache_counter,
**sample)
out_uncond = self.dit(
z,
sigma,
self.fastercache_counter,
**sample_null)
if self.fastercache_counter >= self.fastercache_start_step + 1:
(bb, cc, tt, hh, ww) = out_cond.shape
cond = rearrange(out_cond.float(), "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
uncond = rearrange(out_uncond.float(), "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
lf_c, hf_c = fft(cond)
lf_uc, hf_uc = fft(uncond)
self.delta_lf = lf_uc - lf_c
self.delta_hf = hf_uc - hf_c
if cfg_scale != 1.0:
if args["fastercache"]:
self.fastercache_counter+=1
if self.fastercache_counter >= self.fastercache_start_step + 3 and self.fastercache_counter % 5 !=0:
out_cond = self.dit(
z,
sigma,
self.fastercache_counter,
**sample)
return out_uncond + cfg_scale * (out_cond - out_uncond)
(bb, cc, tt, hh, ww) = out_cond.shape
cond = rearrange(out_cond, "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
lf_c, hf_c = fft(cond.float())
if self.fastercache_counter <= self.fastercache_lf_step:
self.delta_lf = self.delta_lf * 1.1
if self.fastercache_counter >= self.fastercache_hf_step:
self.delta_hf = self.delta_hf * 1.1
new_hf_uc = self.delta_hf + hf_c
new_lf_uc = self.delta_lf + lf_c
combine_uc = new_lf_uc + new_hf_uc
combined_fft = torch.fft.ifftshift(combine_uc)
recovered_uncond = torch.fft.ifft2(combined_fft).real
recovered_uncond = rearrange(recovered_uncond.to(out_cond.dtype), "(B T) C H W -> B C T H W", B=bb, C=cc, T=tt, H=hh, W=ww)
return recovered_uncond + cfg_scale * (out_cond - recovered_uncond)
else:
out_cond = self.dit(
z,
sigma,
self.fastercache_counter,
**sample)
out_uncond = self.dit(
z,
sigma,
self.fastercache_counter,
**sample_null)
if self.fastercache_counter >= self.fastercache_start_step + 1:
(bb, cc, tt, hh, ww) = out_cond.shape
cond = rearrange(out_cond.float(), "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
uncond = rearrange(out_uncond.float(), "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
lf_c, hf_c = fft(cond)
lf_uc, hf_uc = fft(uncond)
self.delta_lf = lf_uc - lf_c
self.delta_hf = hf_uc - hf_c
return out_uncond + cfg_scale * (out_cond - out_uncond)
else: #handle cfg 1.0
out_cond = self.dit(
z,
sigma,
self.fastercache_counter,
**sample)
return out_cond
comfy_pbar = ProgressBar(sample_steps)

View File

@ -1,7 +1,4 @@
import os
# import torch._dynamo
# torch._dynamo.config.suppress_errors = True
import torch
import folder_paths
import comfy.model_management as mm