restore MochiEdit compatiblity
temporary
This commit is contained in:
parent
0dc011d1b6
commit
4a7458ffd6
@ -584,6 +584,7 @@ class AsymmDiTJoint(nn.Module):
|
||||
y_mask: List[torch.Tensor],
|
||||
rope_cos: torch.Tensor = None,
|
||||
rope_sin: torch.Tensor = None,
|
||||
packed_indices: Optional[dict] = None,
|
||||
):
|
||||
"""Forward pass of DiT.
|
||||
|
||||
|
||||
@ -176,6 +176,10 @@ class T2VSynthMochiModel:
|
||||
self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device)
|
||||
self.vae_std = torch.Tensor(vae_stats["std"]).to(self.device)
|
||||
|
||||
def get_packed_indices(self, y_mask, **latent_dims):
|
||||
# temporary dummy func for compatibility
|
||||
return []
|
||||
|
||||
def move_to_device_(self, sample):
|
||||
if isinstance(sample, dict):
|
||||
for key in sample.keys():
|
||||
|
||||
@ -6,8 +6,6 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
#from ..dit.joint_model.context_parallel import get_cp_rank_size
|
||||
#from ..vae.cp_conv import cp_pass_frames, gather_all_frames
|
||||
from .latent_dist import LatentDistribution
|
||||
|
||||
def cast_tuple(t, length=1):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user