cleanup
This commit is contained in:
parent
ab3b18a153
commit
7391389276
@ -146,7 +146,7 @@ class AsymmetricAttention(nn.Module):
|
|||||||
self.qkv_y = nn.Linear(dim_y, 3 * dim_x, bias=qkv_bias, device=device)
|
self.qkv_y = nn.Linear(dim_y, 3 * dim_x, bias=qkv_bias, device=device)
|
||||||
|
|
||||||
# Query and key normalization for stability.
|
# Query and key normalization for stability.
|
||||||
assert qk_norm
|
#assert qk_norm
|
||||||
if rms_norm_func == "flash_attn_triton": #use the same rms_norm_func
|
if rms_norm_func == "flash_attn_triton": #use the same rms_norm_func
|
||||||
try:
|
try:
|
||||||
from flash_attn.ops.triton.layer_norm import RMSNorm as FlashTritonRMSNorm #slightly faster
|
from flash_attn.ops.triton.layer_norm import RMSNorm as FlashTritonRMSNorm #slightly faster
|
||||||
@ -338,7 +338,7 @@ class AsymmetricJointBlock(nn.Module):
|
|||||||
|
|
||||||
# MLP.
|
# MLP.
|
||||||
mlp_hidden_dim_x = int(hidden_size_x * mlp_ratio_x)
|
mlp_hidden_dim_x = int(hidden_size_x * mlp_ratio_x)
|
||||||
assert mlp_hidden_dim_x == int(1536 * 8)
|
#assert mlp_hidden_dim_x == int(1536 * 8)
|
||||||
self.mlp_x = FeedForward(
|
self.mlp_x = FeedForward(
|
||||||
in_features=hidden_size_x,
|
in_features=hidden_size_x,
|
||||||
hidden_size=mlp_hidden_dim_x,
|
hidden_size=mlp_hidden_dim_x,
|
||||||
@ -422,7 +422,7 @@ class AsymmetricJointBlock(nn.Module):
|
|||||||
self.cached_x_attention[-1].copy_(x_attn.to(fastercache_device))
|
self.cached_x_attention[-1].copy_(x_attn.to(fastercache_device))
|
||||||
self.cached_y_attention[-1].copy_(y_attn.to(fastercache_device))
|
self.cached_y_attention[-1].copy_(y_attn.to(fastercache_device))
|
||||||
|
|
||||||
assert x_attn.size(1) == N
|
#assert x_attn.size(1) == N
|
||||||
x = residual_tanh_gated_rmsnorm(x, x_attn, gate_msa_x)
|
x = residual_tanh_gated_rmsnorm(x, x_attn, gate_msa_x)
|
||||||
if self.update_y:
|
if self.update_y:
|
||||||
y = residual_tanh_gated_rmsnorm(y, y_attn, gate_msa_y)
|
y = residual_tanh_gated_rmsnorm(y, y_attn, gate_msa_y)
|
||||||
@ -606,7 +606,7 @@ class AsymmDiTJoint(nn.Module):
|
|||||||
T, H, W = x.shape[-3:]
|
T, H, W = x.shape[-3:]
|
||||||
pH, pW = H // self.patch_size, W // self.patch_size
|
pH, pW = H // self.patch_size, W // self.patch_size
|
||||||
x = self.embed_x(x) # (B, N, D), where N = T * H * W / patch_size ** 2
|
x = self.embed_x(x) # (B, N, D), where N = T * H * W / patch_size ** 2
|
||||||
assert x.ndim == 3
|
#assert x.ndim == 3
|
||||||
|
|
||||||
# Construct position array of size [N, 3].
|
# Construct position array of size [N, 3].
|
||||||
# pos[:, 0] is the frame index for each location,
|
# pos[:, 0] is the frame index for each location,
|
||||||
@ -614,7 +614,7 @@ class AsymmDiTJoint(nn.Module):
|
|||||||
# pos[:, 2] is the column index for each location.
|
# pos[:, 2] is the column index for each location.
|
||||||
pH, pW = H // self.patch_size, W // self.patch_size
|
pH, pW = H // self.patch_size, W // self.patch_size
|
||||||
N = T * pH * pW
|
N = T * pH * pW
|
||||||
assert x.size(1) == N
|
#assert x.size(1) == N
|
||||||
pos = create_position_matrix(T, pH=pH, pW=pW, device=x.device, dtype=torch.float32) # (N, 3)
|
pos = create_position_matrix(T, pH=pH, pW=pW, device=x.device, dtype=torch.float32) # (N, 3)
|
||||||
rope_cos, rope_sin = compute_mixed_rotation(freqs=self.pos_frequencies, pos=pos) # Each are (N, num_heads, dim // 2)
|
rope_cos, rope_sin = compute_mixed_rotation(freqs=self.pos_frequencies, pos=pos) # Each are (N, num_heads, dim // 2)
|
||||||
|
|
||||||
|
|||||||
@ -72,25 +72,6 @@ def fft(tensor):
|
|||||||
high_freq_fft = tensor_fft_shifted * high_freq_mask
|
high_freq_fft = tensor_fft_shifted * high_freq_mask
|
||||||
|
|
||||||
return low_freq_fft, high_freq_fft
|
return low_freq_fft, high_freq_fft
|
||||||
def unnormalize_latents(
|
|
||||||
z: torch.Tensor,
|
|
||||||
mean: torch.Tensor,
|
|
||||||
std: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Unnormalize latents. Useful for decoding DiT samples.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
z (torch.Tensor): [B, C_z, T_z, H_z, W_z], float
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: [B, C_z, T_z, H_z, W_z], float
|
|
||||||
"""
|
|
||||||
mean = mean[:, None, None, None]
|
|
||||||
std = std[:, None, None, None]
|
|
||||||
|
|
||||||
assert z.ndim == 5
|
|
||||||
assert z.size(1) == mean.size(0) == std.size(0)
|
|
||||||
return z * std.to(z) + mean.to(z)
|
|
||||||
|
|
||||||
class T2VSynthMochiModel:
|
class T2VSynthMochiModel:
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -188,10 +169,6 @@ class T2VSynthMochiModel:
|
|||||||
|
|
||||||
self.dit = model
|
self.dit = model
|
||||||
|
|
||||||
def get_packed_indices(self, y_mask, **latent_dims):
|
|
||||||
# temporary dummy func for compatibility
|
|
||||||
return []
|
|
||||||
|
|
||||||
def move_to_device_(self, sample):
|
def move_to_device_(self, sample):
|
||||||
if isinstance(sample, dict):
|
if isinstance(sample, dict):
|
||||||
for key in sample.keys():
|
for key in sample.keys():
|
||||||
@ -265,57 +242,66 @@ class T2VSynthMochiModel:
|
|||||||
|
|
||||||
def model_fn(*, z, sigma, cfg_scale):
|
def model_fn(*, z, sigma, cfg_scale):
|
||||||
nonlocal sample, sample_null
|
nonlocal sample, sample_null
|
||||||
if args["fastercache"]:
|
if cfg_scale != 1.0:
|
||||||
self.fastercache_counter+=1
|
if args["fastercache"]:
|
||||||
if self.fastercache_counter >= self.fastercache_start_step + 3 and self.fastercache_counter % 5 !=0:
|
self.fastercache_counter+=1
|
||||||
out_cond = self.dit(
|
if self.fastercache_counter >= self.fastercache_start_step + 3 and self.fastercache_counter % 5 !=0:
|
||||||
z,
|
out_cond = self.dit(
|
||||||
sigma,
|
z,
|
||||||
self.fastercache_counter,
|
sigma,
|
||||||
**sample)
|
self.fastercache_counter,
|
||||||
|
**sample)
|
||||||
(bb, cc, tt, hh, ww) = out_cond.shape
|
|
||||||
cond = rearrange(out_cond, "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
|
|
||||||
lf_c, hf_c = fft(cond.float())
|
|
||||||
if self.fastercache_counter <= self.fastercache_lf_step:
|
|
||||||
self.delta_lf = self.delta_lf * 1.1
|
|
||||||
if self.fastercache_counter >= self.fastercache_hf_step:
|
|
||||||
self.delta_hf = self.delta_hf * 1.1
|
|
||||||
|
|
||||||
new_hf_uc = self.delta_hf + hf_c
|
|
||||||
new_lf_uc = self.delta_lf + lf_c
|
|
||||||
|
|
||||||
combine_uc = new_lf_uc + new_hf_uc
|
|
||||||
combined_fft = torch.fft.ifftshift(combine_uc)
|
|
||||||
recovered_uncond = torch.fft.ifft2(combined_fft).real
|
|
||||||
recovered_uncond = rearrange(recovered_uncond.to(out_cond.dtype), "(B T) C H W -> B C T H W", B=bb, C=cc, T=tt, H=hh, W=ww)
|
|
||||||
|
|
||||||
return recovered_uncond + cfg_scale * (out_cond - recovered_uncond)
|
|
||||||
else:
|
|
||||||
out_cond = self.dit(
|
|
||||||
z,
|
|
||||||
sigma,
|
|
||||||
self.fastercache_counter,
|
|
||||||
**sample)
|
|
||||||
|
|
||||||
out_uncond = self.dit(
|
|
||||||
z,
|
|
||||||
sigma,
|
|
||||||
self.fastercache_counter,
|
|
||||||
**sample_null)
|
|
||||||
|
|
||||||
if self.fastercache_counter >= self.fastercache_start_step + 1:
|
|
||||||
(bb, cc, tt, hh, ww) = out_cond.shape
|
|
||||||
cond = rearrange(out_cond.float(), "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
|
|
||||||
uncond = rearrange(out_uncond.float(), "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
|
|
||||||
|
|
||||||
lf_c, hf_c = fft(cond)
|
|
||||||
lf_uc, hf_uc = fft(uncond)
|
|
||||||
|
|
||||||
self.delta_lf = lf_uc - lf_c
|
|
||||||
self.delta_hf = hf_uc - hf_c
|
|
||||||
|
|
||||||
return out_uncond + cfg_scale * (out_cond - out_uncond)
|
(bb, cc, tt, hh, ww) = out_cond.shape
|
||||||
|
cond = rearrange(out_cond, "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
|
||||||
|
lf_c, hf_c = fft(cond.float())
|
||||||
|
if self.fastercache_counter <= self.fastercache_lf_step:
|
||||||
|
self.delta_lf = self.delta_lf * 1.1
|
||||||
|
if self.fastercache_counter >= self.fastercache_hf_step:
|
||||||
|
self.delta_hf = self.delta_hf * 1.1
|
||||||
|
|
||||||
|
new_hf_uc = self.delta_hf + hf_c
|
||||||
|
new_lf_uc = self.delta_lf + lf_c
|
||||||
|
|
||||||
|
combine_uc = new_lf_uc + new_hf_uc
|
||||||
|
combined_fft = torch.fft.ifftshift(combine_uc)
|
||||||
|
recovered_uncond = torch.fft.ifft2(combined_fft).real
|
||||||
|
recovered_uncond = rearrange(recovered_uncond.to(out_cond.dtype), "(B T) C H W -> B C T H W", B=bb, C=cc, T=tt, H=hh, W=ww)
|
||||||
|
|
||||||
|
return recovered_uncond + cfg_scale * (out_cond - recovered_uncond)
|
||||||
|
else:
|
||||||
|
out_cond = self.dit(
|
||||||
|
z,
|
||||||
|
sigma,
|
||||||
|
self.fastercache_counter,
|
||||||
|
**sample)
|
||||||
|
|
||||||
|
out_uncond = self.dit(
|
||||||
|
z,
|
||||||
|
sigma,
|
||||||
|
self.fastercache_counter,
|
||||||
|
**sample_null)
|
||||||
|
|
||||||
|
if self.fastercache_counter >= self.fastercache_start_step + 1:
|
||||||
|
(bb, cc, tt, hh, ww) = out_cond.shape
|
||||||
|
cond = rearrange(out_cond.float(), "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
|
||||||
|
uncond = rearrange(out_uncond.float(), "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
|
||||||
|
|
||||||
|
lf_c, hf_c = fft(cond)
|
||||||
|
lf_uc, hf_uc = fft(uncond)
|
||||||
|
|
||||||
|
self.delta_lf = lf_uc - lf_c
|
||||||
|
self.delta_hf = hf_uc - hf_c
|
||||||
|
|
||||||
|
return out_uncond + cfg_scale * (out_cond - out_uncond)
|
||||||
|
else: #handle cfg 1.0
|
||||||
|
out_cond = self.dit(
|
||||||
|
z,
|
||||||
|
sigma,
|
||||||
|
self.fastercache_counter,
|
||||||
|
**sample)
|
||||||
|
return out_cond
|
||||||
|
|
||||||
|
|
||||||
comfy_pbar = ProgressBar(sample_steps)
|
comfy_pbar = ProgressBar(sample_steps)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user