fix
This commit is contained in:
parent
70ad32621b
commit
fbd2252dc4
@ -643,7 +643,7 @@ class AsymmDiTJoint(nn.Module):
|
|||||||
y_mask: List[torch.Tensor],
|
y_mask: List[torch.Tensor],
|
||||||
rope_cos: torch.Tensor = None,
|
rope_cos: torch.Tensor = None,
|
||||||
rope_sin: torch.Tensor = None,
|
rope_sin: torch.Tensor = None,
|
||||||
fastercache: Optional[Dict[str, torch.Tensor]] = None,
|
fastercache: Optional[Dict] = None,
|
||||||
fastercache_counter: Optional[int]=0,
|
fastercache_counter: Optional[int]=0,
|
||||||
):
|
):
|
||||||
"""Forward pass of DiT.
|
"""Forward pass of DiT.
|
||||||
|
|||||||
@ -224,12 +224,10 @@ class T2VSynthMochiModel:
|
|||||||
sample = {
|
sample = {
|
||||||
"y_mask": [args["positive_embeds"]["attention_mask"].to(self.device)],
|
"y_mask": [args["positive_embeds"]["attention_mask"].to(self.device)],
|
||||||
"y_feat": [args["positive_embeds"]["embeds"].to(self.device)],
|
"y_feat": [args["positive_embeds"]["embeds"].to(self.device)],
|
||||||
"fastercache": args["fastercache"] if args["fastercache"] is not None else None
|
|
||||||
}
|
}
|
||||||
sample_null = {
|
sample_null = {
|
||||||
"y_mask": [args["negative_embeds"]["attention_mask"].to(self.device)],
|
"y_mask": [args["negative_embeds"]["attention_mask"].to(self.device)],
|
||||||
"y_feat": [args["negative_embeds"]["embeds"].to(self.device)],
|
"y_feat": [args["negative_embeds"]["embeds"].to(self.device)],
|
||||||
"fastercache": args["fastercache"] if args["fastercache"] is not None else None
|
|
||||||
}
|
}
|
||||||
print(args["fastercache"])
|
print(args["fastercache"])
|
||||||
if args["fastercache"]:
|
if args["fastercache"]:
|
||||||
@ -251,6 +249,7 @@ class T2VSynthMochiModel:
|
|||||||
z,
|
z,
|
||||||
sigma,
|
sigma,
|
||||||
**sample,
|
**sample,
|
||||||
|
fastercache = args["fastercache"],
|
||||||
fastercache_counter=self.fastercache_counter)
|
fastercache_counter=self.fastercache_counter)
|
||||||
|
|
||||||
(bb, cc, tt, hh, ww) = out_cond.shape
|
(bb, cc, tt, hh, ww) = out_cond.shape
|
||||||
@ -275,12 +274,14 @@ class T2VSynthMochiModel:
|
|||||||
z,
|
z,
|
||||||
sigma,
|
sigma,
|
||||||
**sample,
|
**sample,
|
||||||
|
fastercache = args["fastercache"],
|
||||||
fastercache_counter=self.fastercache_counter)
|
fastercache_counter=self.fastercache_counter)
|
||||||
|
|
||||||
out_uncond = self.dit(
|
out_uncond = self.dit(
|
||||||
z,
|
z,
|
||||||
sigma,
|
sigma,
|
||||||
**sample_null,
|
**sample_null,
|
||||||
|
fastercache = args["fastercache"],
|
||||||
fastercache_counter=self.fastercache_counter)
|
fastercache_counter=self.fastercache_counter)
|
||||||
|
|
||||||
if self.fastercache_counter >= self.fastercache_start_step + 1:
|
if self.fastercache_counter >= self.fastercache_start_step + 1:
|
||||||
@ -300,6 +301,7 @@ class T2VSynthMochiModel:
|
|||||||
z,
|
z,
|
||||||
sigma,
|
sigma,
|
||||||
**sample,
|
**sample,
|
||||||
|
fastercache = args["fastercache"],
|
||||||
fastercache_counter=self.fastercache_counter)
|
fastercache_counter=self.fastercache_counter)
|
||||||
return out_cond
|
return out_cond
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user