reorder
This commit is contained in:
parent
70bd024456
commit
70ad32621b
@ -639,12 +639,12 @@ class AsymmDiTJoint(nn.Module):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
sigma: torch.Tensor,
|
||||
fastercache_counter: int,
|
||||
y_feat: List[torch.Tensor],
|
||||
y_mask: List[torch.Tensor],
|
||||
rope_cos: torch.Tensor = None,
|
||||
rope_sin: torch.Tensor = None,
|
||||
fastercache: Optional[Dict[str, torch.Tensor]] = None,
|
||||
fastercache_counter: Optional[int]=0,
|
||||
):
|
||||
"""Forward pass of DiT.
|
||||
|
||||
@ -671,6 +671,7 @@ class AsymmDiTJoint(nn.Module):
|
||||
else:
|
||||
fastercache_start_step = 1000
|
||||
fastercache_device = None
|
||||
#print(fastercache_counter)
|
||||
|
||||
for i, block in enumerate(self.blocks):
|
||||
x, y_feat = block(
|
||||
|
||||
@ -231,8 +231,9 @@ class T2VSynthMochiModel:
|
||||
"y_feat": [args["negative_embeds"]["embeds"].to(self.device)],
|
||||
"fastercache": args["fastercache"] if args["fastercache"] is not None else None
|
||||
}
|
||||
|
||||
print(args["fastercache"])
|
||||
if args["fastercache"]:
|
||||
print("Using fastercache")
|
||||
self.fastercache_start_step = args["fastercache"]["start_step"]
|
||||
self.fastercache_lf_step = args["fastercache"]["lf_step"]
|
||||
self.fastercache_hf_step = args["fastercache"]["hf_step"]
|
||||
@ -249,8 +250,8 @@ class T2VSynthMochiModel:
|
||||
out_cond = self.dit(
|
||||
z,
|
||||
sigma,
|
||||
self.fastercache_counter,
|
||||
**sample)
|
||||
**sample,
|
||||
fastercache_counter=self.fastercache_counter)
|
||||
|
||||
(bb, cc, tt, hh, ww) = out_cond.shape
|
||||
cond = rearrange(out_cond, "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
|
||||
@ -273,14 +274,14 @@ class T2VSynthMochiModel:
|
||||
out_cond = self.dit(
|
||||
z,
|
||||
sigma,
|
||||
self.fastercache_counter,
|
||||
**sample)
|
||||
**sample,
|
||||
fastercache_counter=self.fastercache_counter)
|
||||
|
||||
out_uncond = self.dit(
|
||||
z,
|
||||
sigma,
|
||||
self.fastercache_counter,
|
||||
**sample_null)
|
||||
sigma,
|
||||
**sample_null,
|
||||
fastercache_counter=self.fastercache_counter)
|
||||
|
||||
if self.fastercache_counter >= self.fastercache_start_step + 1:
|
||||
(bb, cc, tt, hh, ww) = out_cond.shape
|
||||
@ -297,9 +298,9 @@ class T2VSynthMochiModel:
|
||||
else: #handle cfg 1.0
|
||||
out_cond = self.dit(
|
||||
z,
|
||||
sigma,
|
||||
self.fastercache_counter,
|
||||
**sample)
|
||||
sigma,
|
||||
**sample,
|
||||
fastercache_counter=self.fastercache_counter)
|
||||
return out_cond
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user