restore MochiEdit compatiblity

temporary
This commit is contained in:
kijai 2024-11-03 19:18:29 +02:00
parent 0dc011d1b6
commit 4a7458ffd6
3 changed files with 5 additions and 2 deletions

View File

@ -584,6 +584,7 @@ class AsymmDiTJoint(nn.Module):
y_mask: List[torch.Tensor],
rope_cos: torch.Tensor = None,
rope_sin: torch.Tensor = None,
packed_indices: Optional[dict] = None,
):
"""Forward pass of DiT.

View File

@ -176,6 +176,10 @@ class T2VSynthMochiModel:
self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device)
self.vae_std = torch.Tensor(vae_stats["std"]).to(self.device)
def get_packed_indices(self, y_mask, **latent_dims):
# temporary dummy func for compatibility
return []
def move_to_device_(self, sample):
if isinstance(sample, dict):
for key in sample.keys():

View File

@ -6,8 +6,6 @@ import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
#from ..dit.joint_model.context_parallel import get_cp_rank_size
#from ..vae.cp_conv import cp_pass_frames, gather_all_frames
from .latent_dist import LatentDistribution
def cast_tuple(t, length=1):