test
This commit is contained in:
parent
5ec01cbff4
commit
a6e545531c
@ -389,6 +389,9 @@ class AsymmetricJointBlock(nn.Module):
|
||||
self.mod_y = nn.Linear(hidden_size_x, 4 * hidden_size_y, device=device)
|
||||
else:
|
||||
self.mod_y = nn.Linear(hidden_size_x, hidden_size_y, device=device)
|
||||
|
||||
self.cached_x_attention = [None, None]
|
||||
self.cached_y_attention = [None, None]
|
||||
|
||||
# Self-attention:
|
||||
self.attn = AsymmetricAttention(
|
||||
@ -428,6 +431,8 @@ class AsymmetricJointBlock(nn.Module):
|
||||
x: torch.Tensor,
|
||||
c: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
fastercache_counter: Optional[int] = 0,
|
||||
fastercache_start_step: Optional[int] = 15,
|
||||
**attn_kwargs,
|
||||
):
|
||||
"""Forward pass of a block.
|
||||
@ -453,15 +458,36 @@ class AsymmetricJointBlock(nn.Module):
|
||||
scale_msa_y, gate_msa_y, scale_mlp_y, gate_mlp_y = mod_y.chunk(4, dim=1)
|
||||
else:
|
||||
scale_msa_y = mod_y
|
||||
|
||||
# Self-attention block.
|
||||
x_attn, y_attn = self.attn(
|
||||
x,
|
||||
y,
|
||||
scale_x=scale_msa_x,
|
||||
scale_y=scale_msa_y,
|
||||
**attn_kwargs,
|
||||
)
|
||||
|
||||
#fastercache
|
||||
B = x.shape[0]
|
||||
#print("x", x.shape) #([1, 9540, 3072])
|
||||
if fastercache_counter >= fastercache_start_step + 3 and fastercache_counter%3!=0 and self.cached_x_attention[-1].shape[0] >= B:
|
||||
x_attn = (
|
||||
self.cached_x_attention[1][:B] +
|
||||
(self.cached_x_attention[1][:B] - self.cached_x_attention[0][:B])
|
||||
* 0.3
|
||||
).to(x.device, non_blocking=True)
|
||||
y_attn = (
|
||||
self.cached_y_attention[1][:B] +
|
||||
(self.cached_y_attention[1][:B] - self.cached_y_attention[0][:B])
|
||||
* 0.3
|
||||
).to(x.device, non_blocking=True)
|
||||
else:
|
||||
# Self-attention block.
|
||||
x_attn, y_attn = self.attn(
|
||||
x,
|
||||
y,
|
||||
scale_x=scale_msa_x,
|
||||
scale_y=scale_msa_y,
|
||||
**attn_kwargs,
|
||||
)
|
||||
if fastercache_counter == fastercache_start_step:
|
||||
self.cached_x_attention = [x_attn, x_attn]
|
||||
self.cached_y_attention = [y_attn, y_attn]
|
||||
elif fastercache_counter > fastercache_start_step:
|
||||
self.cached_x_attention[-1].copy_(x_attn)
|
||||
self.cached_y_attention[-1].copy_(y_attn)
|
||||
|
||||
assert x_attn.size(1) == N
|
||||
x = residual_tanh_gated_rmsnorm(x, x_attn, gate_msa_x)
|
||||
@ -674,6 +700,8 @@ class AsymmDiTJoint(nn.Module):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
sigma: torch.Tensor,
|
||||
fastercache_counter: int,
|
||||
fastercache_start_step: int,
|
||||
y_feat: List[torch.Tensor],
|
||||
y_mask: List[torch.Tensor],
|
||||
packed_indices: Dict[str, torch.Tensor] = None,
|
||||
@ -707,7 +735,9 @@ class AsymmDiTJoint(nn.Module):
|
||||
rope_cos=rope_cos,
|
||||
rope_sin=rope_sin,
|
||||
packed_indices=packed_indices,
|
||||
) # (B, M, D), (B, L, D)
|
||||
fastercache_counter = fastercache_counter,
|
||||
fastercache_start_step = fastercache_start_step,
|
||||
) # (B, M, D), (B, L, D)
|
||||
del y_feat # Final layers don't use dense text features.
|
||||
|
||||
x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels)
|
||||
@ -720,6 +750,6 @@ class AsymmDiTJoint(nn.Module):
|
||||
p1=self.patch_size,
|
||||
p2=self.patch_size,
|
||||
c=self.out_channels,
|
||||
)
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
@ -59,6 +59,22 @@ log = logging.getLogger(__name__)
|
||||
|
||||
MAX_T5_TOKEN_LENGTH = 256
|
||||
|
||||
def fft(tensor):
|
||||
tensor_fft = torch.fft.fft2(tensor)
|
||||
tensor_fft_shifted = torch.fft.fftshift(tensor_fft)
|
||||
B, C, H, W = tensor.size()
|
||||
radius = min(H, W) // 5
|
||||
|
||||
Y, X = torch.meshgrid(torch.arange(H), torch.arange(W))
|
||||
center_x, center_y = W // 2, H // 2
|
||||
mask = (X - center_x) ** 2 + (Y - center_y) ** 2 <= radius ** 2
|
||||
low_freq_mask = mask.unsqueeze(0).unsqueeze(0).to(tensor.device)
|
||||
high_freq_mask = ~low_freq_mask
|
||||
|
||||
low_freq_fft = tensor_fft_shifted * low_freq_mask
|
||||
high_freq_fft = tensor_fft_shifted * high_freq_mask
|
||||
|
||||
return low_freq_fft, high_freq_fft
|
||||
def unnormalize_latents(
|
||||
z: torch.Tensor,
|
||||
mean: torch.Tensor,
|
||||
@ -285,42 +301,84 @@ class T2VSynthMochiModel:
|
||||
sample_null["packed_indices"] = self.get_packed_indices(
|
||||
sample_null["y_mask"], **latent_dims
|
||||
)
|
||||
self.use_fastercache = True
|
||||
self.fastercache_counter = 0
|
||||
self.fastercache_start_step = 15
|
||||
self.fastercache_lf_step = 40
|
||||
self.fastercache_hf_step = 30
|
||||
|
||||
def model_fn(*, z, sigma, cfg_scale):
|
||||
self.dit.to(self.device)
|
||||
if hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul:
|
||||
autocast_dtype = torch.float16
|
||||
else:
|
||||
autocast_dtype = torch.bfloat16
|
||||
|
||||
def model_fn(*, z, sigma, cfg_scale):
|
||||
nonlocal sample, sample_null
|
||||
with torch.autocast(mm.get_autocast_device(self.device), dtype=autocast_dtype):
|
||||
if cfg_scale > 1.0:
|
||||
out_cond = self.dit(z, sigma, **sample)
|
||||
out_uncond = self.dit(z, sigma, **sample_null)
|
||||
else:
|
||||
out_cond = self.dit(z, sigma, **sample)
|
||||
return out_cond
|
||||
if self.use_fastercache:
|
||||
self.fastercache_counter+=1
|
||||
if self.fastercache_counter >= self.fastercache_start_step + 3 and self.fastercache_counter % 5 !=0:
|
||||
out_cond = self.dit(z, sigma,self.fastercache_counter, self.fastercache_start_step, **sample)
|
||||
|
||||
(bb, cc, tt, hh, ww) = out_cond.shape
|
||||
cond = rearrange(out_cond, "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
|
||||
lf_c, hf_c = fft(cond.float())
|
||||
#lf_step = 40
|
||||
#hf_step = 30
|
||||
if self.fastercache_counter <= self.fastercache_lf_step:
|
||||
self.delta_lf = self.delta_lf * 1.1
|
||||
if self.fastercache_counter >= self.fastercache_hf_step:
|
||||
self.delta_hf = self.delta_hf * 1.1
|
||||
|
||||
return out_uncond + cfg_scale * (out_cond - out_uncond)
|
||||
new_hf_uc = self.delta_hf + hf_c
|
||||
new_lf_uc = self.delta_lf + lf_c
|
||||
|
||||
combine_uc = new_lf_uc + new_hf_uc
|
||||
combined_fft = torch.fft.ifftshift(combine_uc)
|
||||
recovered_uncond = torch.fft.ifft2(combined_fft).real
|
||||
recovered_uncond = rearrange(recovered_uncond.to(out_cond.dtype), "(B T) C H W -> B C T H W", B=bb, C=cc, T=tt, H=hh, W=ww)
|
||||
|
||||
return recovered_uncond + cfg_scale * (out_cond - recovered_uncond)
|
||||
else:
|
||||
out_cond = self.dit(z, sigma, self.fastercache_counter, self.fastercache_start_step,**sample)
|
||||
out_uncond = self.dit(z, sigma, self.fastercache_counter, self.fastercache_start_step,**sample_null)
|
||||
#print("out_cond.shape",out_cond.shape) #([1, 12, 3, 60, 106])
|
||||
|
||||
if self.fastercache_counter >= self.fastercache_start_step + 1:
|
||||
|
||||
(bb, cc, tt, hh, ww) = out_cond.shape
|
||||
cond = rearrange(out_cond.float(), "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
|
||||
uncond = rearrange(out_uncond.float(), "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
|
||||
|
||||
lf_c, hf_c = fft(cond)
|
||||
lf_uc, hf_uc = fft(uncond)
|
||||
|
||||
self.delta_lf = lf_uc - lf_c
|
||||
self.delta_hf = hf_uc - hf_c
|
||||
|
||||
return out_uncond + cfg_scale * (out_cond - out_uncond)
|
||||
|
||||
|
||||
comfy_pbar = ProgressBar(sample_steps)
|
||||
for i in tqdm(range(0, sample_steps), desc="Processing Samples", total=sample_steps):
|
||||
sigma = sigma_schedule[i]
|
||||
dsigma = sigma - sigma_schedule[i + 1]
|
||||
|
||||
# `pred` estimates `z_0 - eps`.
|
||||
pred = model_fn(
|
||||
z=z,
|
||||
sigma=torch.full([B], sigma, device=z.device),
|
||||
cfg_scale=cfg_schedule[i],
|
||||
)
|
||||
pred = pred.to(z)
|
||||
z = z + dsigma * pred
|
||||
if callback is not None:
|
||||
callback(i, z.detach()[0].permute(1,0,2,3), None, sample_steps)
|
||||
else:
|
||||
comfy_pbar.update(1)
|
||||
if hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul:
|
||||
autocast_dtype = torch.float16
|
||||
else:
|
||||
autocast_dtype = torch.bfloat16
|
||||
|
||||
self.dit.to(self.device)
|
||||
|
||||
with torch.autocast(mm.get_autocast_device(self.device), dtype=autocast_dtype):
|
||||
for i in tqdm(range(0, sample_steps), desc="Processing Samples", total=sample_steps):
|
||||
sigma = sigma_schedule[i]
|
||||
dsigma = sigma - sigma_schedule[i + 1]
|
||||
|
||||
# `pred` estimates `z_0 - eps`.
|
||||
pred = model_fn(
|
||||
z=z,
|
||||
sigma=torch.full([B], sigma, device=z.device),
|
||||
cfg_scale=cfg_schedule[i],
|
||||
)
|
||||
pred = pred.to(z)
|
||||
z = z + dsigma * pred
|
||||
if callback is not None:
|
||||
callback(i, z.detach()[0].permute(1,0,2,3), None, sample_steps)
|
||||
else:
|
||||
comfy_pbar.update(1)
|
||||
|
||||
self.dit.to(self.offload_device)
|
||||
logging.info(f"samples shape: {z.shape}")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user