This commit is contained in:
kijai 2024-11-03 00:32:51 +02:00
parent 5ec01cbff4
commit a6e545531c
2 changed files with 129 additions and 41 deletions

View File

@ -389,6 +389,9 @@ class AsymmetricJointBlock(nn.Module):
self.mod_y = nn.Linear(hidden_size_x, 4 * hidden_size_y, device=device)
else:
self.mod_y = nn.Linear(hidden_size_x, hidden_size_y, device=device)
self.cached_x_attention = [None, None]
self.cached_y_attention = [None, None]
# Self-attention:
self.attn = AsymmetricAttention(
@ -428,6 +431,8 @@ class AsymmetricJointBlock(nn.Module):
x: torch.Tensor,
c: torch.Tensor,
y: torch.Tensor,
fastercache_counter: Optional[int] = 0,
fastercache_start_step: Optional[int] = 15,
**attn_kwargs,
):
"""Forward pass of a block.
@ -453,15 +458,36 @@ class AsymmetricJointBlock(nn.Module):
scale_msa_y, gate_msa_y, scale_mlp_y, gate_mlp_y = mod_y.chunk(4, dim=1)
else:
scale_msa_y = mod_y
# Self-attention block.
x_attn, y_attn = self.attn(
x,
y,
scale_x=scale_msa_x,
scale_y=scale_msa_y,
**attn_kwargs,
)
#fastercache
B = x.shape[0]
#print("x", x.shape) #([1, 9540, 3072])
if fastercache_counter >= fastercache_start_step + 3 and fastercache_counter%3!=0 and self.cached_x_attention[-1].shape[0] >= B:
x_attn = (
self.cached_x_attention[1][:B] +
(self.cached_x_attention[1][:B] - self.cached_x_attention[0][:B])
* 0.3
).to(x.device, non_blocking=True)
y_attn = (
self.cached_y_attention[1][:B] +
(self.cached_y_attention[1][:B] - self.cached_y_attention[0][:B])
* 0.3
).to(x.device, non_blocking=True)
else:
# Self-attention block.
x_attn, y_attn = self.attn(
x,
y,
scale_x=scale_msa_x,
scale_y=scale_msa_y,
**attn_kwargs,
)
if fastercache_counter == fastercache_start_step:
self.cached_x_attention = [x_attn, x_attn]
self.cached_y_attention = [y_attn, y_attn]
elif fastercache_counter > fastercache_start_step:
self.cached_x_attention[-1].copy_(x_attn)
self.cached_y_attention[-1].copy_(y_attn)
assert x_attn.size(1) == N
x = residual_tanh_gated_rmsnorm(x, x_attn, gate_msa_x)
@ -674,6 +700,8 @@ class AsymmDiTJoint(nn.Module):
self,
x: torch.Tensor,
sigma: torch.Tensor,
fastercache_counter: int,
fastercache_start_step: int,
y_feat: List[torch.Tensor],
y_mask: List[torch.Tensor],
packed_indices: Dict[str, torch.Tensor] = None,
@ -707,7 +735,9 @@ class AsymmDiTJoint(nn.Module):
rope_cos=rope_cos,
rope_sin=rope_sin,
packed_indices=packed_indices,
) # (B, M, D), (B, L, D)
fastercache_counter = fastercache_counter,
fastercache_start_step = fastercache_start_step,
) # (B, M, D), (B, L, D)
del y_feat # Final layers don't use dense text features.
x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels)
@ -720,6 +750,6 @@ class AsymmDiTJoint(nn.Module):
p1=self.patch_size,
p2=self.patch_size,
c=self.out_channels,
)
)
return x

View File

@ -59,6 +59,22 @@ log = logging.getLogger(__name__)
MAX_T5_TOKEN_LENGTH = 256
def fft(tensor):
tensor_fft = torch.fft.fft2(tensor)
tensor_fft_shifted = torch.fft.fftshift(tensor_fft)
B, C, H, W = tensor.size()
radius = min(H, W) // 5
Y, X = torch.meshgrid(torch.arange(H), torch.arange(W))
center_x, center_y = W // 2, H // 2
mask = (X - center_x) ** 2 + (Y - center_y) ** 2 <= radius ** 2
low_freq_mask = mask.unsqueeze(0).unsqueeze(0).to(tensor.device)
high_freq_mask = ~low_freq_mask
low_freq_fft = tensor_fft_shifted * low_freq_mask
high_freq_fft = tensor_fft_shifted * high_freq_mask
return low_freq_fft, high_freq_fft
def unnormalize_latents(
z: torch.Tensor,
mean: torch.Tensor,
@ -285,42 +301,84 @@ class T2VSynthMochiModel:
sample_null["packed_indices"] = self.get_packed_indices(
sample_null["y_mask"], **latent_dims
)
self.use_fastercache = True
self.fastercache_counter = 0
self.fastercache_start_step = 15
self.fastercache_lf_step = 40
self.fastercache_hf_step = 30
def model_fn(*, z, sigma, cfg_scale):
self.dit.to(self.device)
if hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul:
autocast_dtype = torch.float16
else:
autocast_dtype = torch.bfloat16
def model_fn(*, z, sigma, cfg_scale):
nonlocal sample, sample_null
with torch.autocast(mm.get_autocast_device(self.device), dtype=autocast_dtype):
if cfg_scale > 1.0:
out_cond = self.dit(z, sigma, **sample)
out_uncond = self.dit(z, sigma, **sample_null)
else:
out_cond = self.dit(z, sigma, **sample)
return out_cond
if self.use_fastercache:
self.fastercache_counter+=1
if self.fastercache_counter >= self.fastercache_start_step + 3 and self.fastercache_counter % 5 !=0:
out_cond = self.dit(z, sigma,self.fastercache_counter, self.fastercache_start_step, **sample)
(bb, cc, tt, hh, ww) = out_cond.shape
cond = rearrange(out_cond, "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
lf_c, hf_c = fft(cond.float())
#lf_step = 40
#hf_step = 30
if self.fastercache_counter <= self.fastercache_lf_step:
self.delta_lf = self.delta_lf * 1.1
if self.fastercache_counter >= self.fastercache_hf_step:
self.delta_hf = self.delta_hf * 1.1
return out_uncond + cfg_scale * (out_cond - out_uncond)
new_hf_uc = self.delta_hf + hf_c
new_lf_uc = self.delta_lf + lf_c
combine_uc = new_lf_uc + new_hf_uc
combined_fft = torch.fft.ifftshift(combine_uc)
recovered_uncond = torch.fft.ifft2(combined_fft).real
recovered_uncond = rearrange(recovered_uncond.to(out_cond.dtype), "(B T) C H W -> B C T H W", B=bb, C=cc, T=tt, H=hh, W=ww)
return recovered_uncond + cfg_scale * (out_cond - recovered_uncond)
else:
out_cond = self.dit(z, sigma, self.fastercache_counter, self.fastercache_start_step,**sample)
out_uncond = self.dit(z, sigma, self.fastercache_counter, self.fastercache_start_step,**sample_null)
#print("out_cond.shape",out_cond.shape) #([1, 12, 3, 60, 106])
if self.fastercache_counter >= self.fastercache_start_step + 1:
(bb, cc, tt, hh, ww) = out_cond.shape
cond = rearrange(out_cond.float(), "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
uncond = rearrange(out_uncond.float(), "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
lf_c, hf_c = fft(cond)
lf_uc, hf_uc = fft(uncond)
self.delta_lf = lf_uc - lf_c
self.delta_hf = hf_uc - hf_c
return out_uncond + cfg_scale * (out_cond - out_uncond)
comfy_pbar = ProgressBar(sample_steps)
for i in tqdm(range(0, sample_steps), desc="Processing Samples", total=sample_steps):
sigma = sigma_schedule[i]
dsigma = sigma - sigma_schedule[i + 1]
# `pred` estimates `z_0 - eps`.
pred = model_fn(
z=z,
sigma=torch.full([B], sigma, device=z.device),
cfg_scale=cfg_schedule[i],
)
pred = pred.to(z)
z = z + dsigma * pred
if callback is not None:
callback(i, z.detach()[0].permute(1,0,2,3), None, sample_steps)
else:
comfy_pbar.update(1)
if hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul:
autocast_dtype = torch.float16
else:
autocast_dtype = torch.bfloat16
self.dit.to(self.device)
with torch.autocast(mm.get_autocast_device(self.device), dtype=autocast_dtype):
for i in tqdm(range(0, sample_steps), desc="Processing Samples", total=sample_steps):
sigma = sigma_schedule[i]
dsigma = sigma - sigma_schedule[i + 1]
# `pred` estimates `z_0 - eps`.
pred = model_fn(
z=z,
sigma=torch.full([B], sigma, device=z.device),
cfg_scale=cfg_schedule[i],
)
pred = pred.to(z)
z = z + dsigma * pred
if callback is not None:
callback(i, z.detach()[0].permute(1,0,2,3), None, sample_steps)
else:
comfy_pbar.update(1)
self.dit.to(self.offload_device)
logging.info(f"samples shape: {z.shape}")