cleanup
This commit is contained in:
parent
ab3b18a153
commit
7391389276
@ -146,7 +146,7 @@ class AsymmetricAttention(nn.Module):
|
||||
self.qkv_y = nn.Linear(dim_y, 3 * dim_x, bias=qkv_bias, device=device)
|
||||
|
||||
# Query and key normalization for stability.
|
||||
assert qk_norm
|
||||
#assert qk_norm
|
||||
if rms_norm_func == "flash_attn_triton": #use the same rms_norm_func
|
||||
try:
|
||||
from flash_attn.ops.triton.layer_norm import RMSNorm as FlashTritonRMSNorm #slightly faster
|
||||
@ -338,7 +338,7 @@ class AsymmetricJointBlock(nn.Module):
|
||||
|
||||
# MLP.
|
||||
mlp_hidden_dim_x = int(hidden_size_x * mlp_ratio_x)
|
||||
assert mlp_hidden_dim_x == int(1536 * 8)
|
||||
#assert mlp_hidden_dim_x == int(1536 * 8)
|
||||
self.mlp_x = FeedForward(
|
||||
in_features=hidden_size_x,
|
||||
hidden_size=mlp_hidden_dim_x,
|
||||
@ -422,7 +422,7 @@ class AsymmetricJointBlock(nn.Module):
|
||||
self.cached_x_attention[-1].copy_(x_attn.to(fastercache_device))
|
||||
self.cached_y_attention[-1].copy_(y_attn.to(fastercache_device))
|
||||
|
||||
assert x_attn.size(1) == N
|
||||
#assert x_attn.size(1) == N
|
||||
x = residual_tanh_gated_rmsnorm(x, x_attn, gate_msa_x)
|
||||
if self.update_y:
|
||||
y = residual_tanh_gated_rmsnorm(y, y_attn, gate_msa_y)
|
||||
@ -606,7 +606,7 @@ class AsymmDiTJoint(nn.Module):
|
||||
T, H, W = x.shape[-3:]
|
||||
pH, pW = H // self.patch_size, W // self.patch_size
|
||||
x = self.embed_x(x) # (B, N, D), where N = T * H * W / patch_size ** 2
|
||||
assert x.ndim == 3
|
||||
#assert x.ndim == 3
|
||||
|
||||
# Construct position array of size [N, 3].
|
||||
# pos[:, 0] is the frame index for each location,
|
||||
@ -614,7 +614,7 @@ class AsymmDiTJoint(nn.Module):
|
||||
# pos[:, 2] is the column index for each location.
|
||||
pH, pW = H // self.patch_size, W // self.patch_size
|
||||
N = T * pH * pW
|
||||
assert x.size(1) == N
|
||||
#assert x.size(1) == N
|
||||
pos = create_position_matrix(T, pH=pH, pW=pW, device=x.device, dtype=torch.float32) # (N, 3)
|
||||
rope_cos, rope_sin = compute_mixed_rotation(freqs=self.pos_frequencies, pos=pos) # Each are (N, num_heads, dim // 2)
|
||||
|
||||
|
||||
@ -72,25 +72,6 @@ def fft(tensor):
|
||||
high_freq_fft = tensor_fft_shifted * high_freq_mask
|
||||
|
||||
return low_freq_fft, high_freq_fft
|
||||
def unnormalize_latents(
|
||||
z: torch.Tensor,
|
||||
mean: torch.Tensor,
|
||||
std: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Unnormalize latents. Useful for decoding DiT samples.
|
||||
|
||||
Args:
|
||||
z (torch.Tensor): [B, C_z, T_z, H_z, W_z], float
|
||||
|
||||
Returns:
|
||||
torch.Tensor: [B, C_z, T_z, H_z, W_z], float
|
||||
"""
|
||||
mean = mean[:, None, None, None]
|
||||
std = std[:, None, None, None]
|
||||
|
||||
assert z.ndim == 5
|
||||
assert z.size(1) == mean.size(0) == std.size(0)
|
||||
return z * std.to(z) + mean.to(z)
|
||||
|
||||
class T2VSynthMochiModel:
|
||||
def __init__(
|
||||
@ -188,10 +169,6 @@ class T2VSynthMochiModel:
|
||||
|
||||
self.dit = model
|
||||
|
||||
def get_packed_indices(self, y_mask, **latent_dims):
|
||||
# temporary dummy func for compatibility
|
||||
return []
|
||||
|
||||
def move_to_device_(self, sample):
|
||||
if isinstance(sample, dict):
|
||||
for key in sample.keys():
|
||||
@ -265,57 +242,66 @@ class T2VSynthMochiModel:
|
||||
|
||||
def model_fn(*, z, sigma, cfg_scale):
|
||||
nonlocal sample, sample_null
|
||||
if args["fastercache"]:
|
||||
self.fastercache_counter+=1
|
||||
if self.fastercache_counter >= self.fastercache_start_step + 3 and self.fastercache_counter % 5 !=0:
|
||||
out_cond = self.dit(
|
||||
z,
|
||||
sigma,
|
||||
self.fastercache_counter,
|
||||
**sample)
|
||||
if cfg_scale != 1.0:
|
||||
if args["fastercache"]:
|
||||
self.fastercache_counter+=1
|
||||
if self.fastercache_counter >= self.fastercache_start_step + 3 and self.fastercache_counter % 5 !=0:
|
||||
out_cond = self.dit(
|
||||
z,
|
||||
sigma,
|
||||
self.fastercache_counter,
|
||||
**sample)
|
||||
|
||||
(bb, cc, tt, hh, ww) = out_cond.shape
|
||||
cond = rearrange(out_cond, "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
|
||||
lf_c, hf_c = fft(cond.float())
|
||||
if self.fastercache_counter <= self.fastercache_lf_step:
|
||||
self.delta_lf = self.delta_lf * 1.1
|
||||
if self.fastercache_counter >= self.fastercache_hf_step:
|
||||
self.delta_hf = self.delta_hf * 1.1
|
||||
|
||||
new_hf_uc = self.delta_hf + hf_c
|
||||
new_lf_uc = self.delta_lf + lf_c
|
||||
|
||||
combine_uc = new_lf_uc + new_hf_uc
|
||||
combined_fft = torch.fft.ifftshift(combine_uc)
|
||||
recovered_uncond = torch.fft.ifft2(combined_fft).real
|
||||
recovered_uncond = rearrange(recovered_uncond.to(out_cond.dtype), "(B T) C H W -> B C T H W", B=bb, C=cc, T=tt, H=hh, W=ww)
|
||||
|
||||
return recovered_uncond + cfg_scale * (out_cond - recovered_uncond)
|
||||
else:
|
||||
out_cond = self.dit(
|
||||
z,
|
||||
sigma,
|
||||
self.fastercache_counter,
|
||||
**sample)
|
||||
|
||||
out_uncond = self.dit(
|
||||
z,
|
||||
sigma,
|
||||
self.fastercache_counter,
|
||||
**sample_null)
|
||||
|
||||
if self.fastercache_counter >= self.fastercache_start_step + 1:
|
||||
(bb, cc, tt, hh, ww) = out_cond.shape
|
||||
cond = rearrange(out_cond.float(), "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
|
||||
uncond = rearrange(out_uncond.float(), "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
|
||||
cond = rearrange(out_cond, "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
|
||||
lf_c, hf_c = fft(cond.float())
|
||||
if self.fastercache_counter <= self.fastercache_lf_step:
|
||||
self.delta_lf = self.delta_lf * 1.1
|
||||
if self.fastercache_counter >= self.fastercache_hf_step:
|
||||
self.delta_hf = self.delta_hf * 1.1
|
||||
|
||||
lf_c, hf_c = fft(cond)
|
||||
lf_uc, hf_uc = fft(uncond)
|
||||
new_hf_uc = self.delta_hf + hf_c
|
||||
new_lf_uc = self.delta_lf + lf_c
|
||||
|
||||
self.delta_lf = lf_uc - lf_c
|
||||
self.delta_hf = hf_uc - hf_c
|
||||
combine_uc = new_lf_uc + new_hf_uc
|
||||
combined_fft = torch.fft.ifftshift(combine_uc)
|
||||
recovered_uncond = torch.fft.ifft2(combined_fft).real
|
||||
recovered_uncond = rearrange(recovered_uncond.to(out_cond.dtype), "(B T) C H W -> B C T H W", B=bb, C=cc, T=tt, H=hh, W=ww)
|
||||
|
||||
return recovered_uncond + cfg_scale * (out_cond - recovered_uncond)
|
||||
else:
|
||||
out_cond = self.dit(
|
||||
z,
|
||||
sigma,
|
||||
self.fastercache_counter,
|
||||
**sample)
|
||||
|
||||
out_uncond = self.dit(
|
||||
z,
|
||||
sigma,
|
||||
self.fastercache_counter,
|
||||
**sample_null)
|
||||
|
||||
if self.fastercache_counter >= self.fastercache_start_step + 1:
|
||||
(bb, cc, tt, hh, ww) = out_cond.shape
|
||||
cond = rearrange(out_cond.float(), "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
|
||||
uncond = rearrange(out_uncond.float(), "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
|
||||
|
||||
lf_c, hf_c = fft(cond)
|
||||
lf_uc, hf_uc = fft(uncond)
|
||||
|
||||
self.delta_lf = lf_uc - lf_c
|
||||
self.delta_hf = hf_uc - hf_c
|
||||
|
||||
return out_uncond + cfg_scale * (out_cond - out_uncond)
|
||||
else: #handle cfg 1.0
|
||||
out_cond = self.dit(
|
||||
z,
|
||||
sigma,
|
||||
self.fastercache_counter,
|
||||
**sample)
|
||||
return out_cond
|
||||
|
||||
return out_uncond + cfg_scale * (out_cond - out_uncond)
|
||||
|
||||
comfy_pbar = ProgressBar(sample_steps)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user