This commit is contained in:
kijai 2024-11-05 20:36:32 +02:00
parent 70ad32621b
commit fbd2252dc4
2 changed files with 5 additions and 3 deletions

View File

@ -643,7 +643,7 @@ class AsymmDiTJoint(nn.Module):
y_mask: List[torch.Tensor],
rope_cos: torch.Tensor = None,
rope_sin: torch.Tensor = None,
fastercache: Optional[Dict[str, torch.Tensor]] = None,
fastercache: Optional[Dict] = None,
fastercache_counter: Optional[int]=0,
):
"""Forward pass of DiT.

View File

@ -224,12 +224,10 @@ class T2VSynthMochiModel:
sample = {
"y_mask": [args["positive_embeds"]["attention_mask"].to(self.device)],
"y_feat": [args["positive_embeds"]["embeds"].to(self.device)],
"fastercache": args["fastercache"] if args["fastercache"] is not None else None
}
sample_null = {
"y_mask": [args["negative_embeds"]["attention_mask"].to(self.device)],
"y_feat": [args["negative_embeds"]["embeds"].to(self.device)],
"fastercache": args["fastercache"] if args["fastercache"] is not None else None
}
print(args["fastercache"])
if args["fastercache"]:
@ -251,6 +249,7 @@ class T2VSynthMochiModel:
z,
sigma,
**sample,
fastercache = args["fastercache"],
fastercache_counter=self.fastercache_counter)
(bb, cc, tt, hh, ww) = out_cond.shape
@ -275,12 +274,14 @@ class T2VSynthMochiModel:
z,
sigma,
**sample,
fastercache = args["fastercache"],
fastercache_counter=self.fastercache_counter)
out_uncond = self.dit(
z,
sigma,
**sample_null,
fastercache = args["fastercache"],
fastercache_counter=self.fastercache_counter)
if self.fastercache_counter >= self.fastercache_start_step + 1:
@ -300,6 +301,7 @@ class T2VSynthMochiModel:
z,
sigma,
**sample,
fastercache = args["fastercache"],
fastercache_counter=self.fastercache_counter)
return out_cond