This commit is contained in:
kijai 2024-11-05 20:21:47 +02:00
parent 70bd024456
commit 70ad32621b
2 changed files with 14 additions and 12 deletions

View File

@ -639,12 +639,12 @@ class AsymmDiTJoint(nn.Module):
self, self,
x: torch.Tensor, x: torch.Tensor,
sigma: torch.Tensor, sigma: torch.Tensor,
fastercache_counter: int,
y_feat: List[torch.Tensor], y_feat: List[torch.Tensor],
y_mask: List[torch.Tensor], y_mask: List[torch.Tensor],
rope_cos: torch.Tensor = None, rope_cos: torch.Tensor = None,
rope_sin: torch.Tensor = None, rope_sin: torch.Tensor = None,
fastercache: Optional[Dict[str, torch.Tensor]] = None, fastercache: Optional[Dict[str, torch.Tensor]] = None,
fastercache_counter: Optional[int]=0,
): ):
"""Forward pass of DiT. """Forward pass of DiT.
@ -671,6 +671,7 @@ class AsymmDiTJoint(nn.Module):
else: else:
fastercache_start_step = 1000 fastercache_start_step = 1000
fastercache_device = None fastercache_device = None
#print(fastercache_counter)
for i, block in enumerate(self.blocks): for i, block in enumerate(self.blocks):
x, y_feat = block( x, y_feat = block(

View File

@ -231,8 +231,9 @@ class T2VSynthMochiModel:
"y_feat": [args["negative_embeds"]["embeds"].to(self.device)], "y_feat": [args["negative_embeds"]["embeds"].to(self.device)],
"fastercache": args["fastercache"] if args["fastercache"] is not None else None "fastercache": args["fastercache"] if args["fastercache"] is not None else None
} }
print(args["fastercache"])
if args["fastercache"]: if args["fastercache"]:
print("Using fastercache")
self.fastercache_start_step = args["fastercache"]["start_step"] self.fastercache_start_step = args["fastercache"]["start_step"]
self.fastercache_lf_step = args["fastercache"]["lf_step"] self.fastercache_lf_step = args["fastercache"]["lf_step"]
self.fastercache_hf_step = args["fastercache"]["hf_step"] self.fastercache_hf_step = args["fastercache"]["hf_step"]
@ -249,8 +250,8 @@ class T2VSynthMochiModel:
out_cond = self.dit( out_cond = self.dit(
z, z,
sigma, sigma,
self.fastercache_counter, **sample,
**sample) fastercache_counter=self.fastercache_counter)
(bb, cc, tt, hh, ww) = out_cond.shape (bb, cc, tt, hh, ww) = out_cond.shape
cond = rearrange(out_cond, "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww) cond = rearrange(out_cond, "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
@ -273,14 +274,14 @@ class T2VSynthMochiModel:
out_cond = self.dit( out_cond = self.dit(
z, z,
sigma, sigma,
self.fastercache_counter, **sample,
**sample) fastercache_counter=self.fastercache_counter)
out_uncond = self.dit( out_uncond = self.dit(
z, z,
sigma, sigma,
self.fastercache_counter, **sample_null,
**sample_null) fastercache_counter=self.fastercache_counter)
if self.fastercache_counter >= self.fastercache_start_step + 1: if self.fastercache_counter >= self.fastercache_start_step + 1:
(bb, cc, tt, hh, ww) = out_cond.shape (bb, cc, tt, hh, ww) = out_cond.shape
@ -297,9 +298,9 @@ class T2VSynthMochiModel:
else: #handle cfg 1.0 else: #handle cfg 1.0
out_cond = self.dit( out_cond = self.dit(
z, z,
sigma, sigma,
self.fastercache_counter, **sample,
**sample) fastercache_counter=self.fastercache_counter)
return out_cond return out_cond