From 4a7458ffd6cb72e7c29ca4e45f4ac099453ed262 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sun, 3 Nov 2024 19:18:29 +0200 Subject: [PATCH] restore MochiEdit compatiblity temporary --- mochi_preview/dit/joint_model/asymm_models_joint.py | 1 + mochi_preview/t2v_synth_mochi.py | 4 ++++ mochi_preview/vae/model.py | 2 -- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/mochi_preview/dit/joint_model/asymm_models_joint.py b/mochi_preview/dit/joint_model/asymm_models_joint.py index 502a2df..9650f40 100644 --- a/mochi_preview/dit/joint_model/asymm_models_joint.py +++ b/mochi_preview/dit/joint_model/asymm_models_joint.py @@ -584,6 +584,7 @@ class AsymmDiTJoint(nn.Module): y_mask: List[torch.Tensor], rope_cos: torch.Tensor = None, rope_sin: torch.Tensor = None, + packed_indices: Optional[dict] = None, ): """Forward pass of DiT. diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index d6f2c60..ccec165 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -176,6 +176,10 @@ class T2VSynthMochiModel: self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device) self.vae_std = torch.Tensor(vae_stats["std"]).to(self.device) + def get_packed_indices(self, y_mask, **latent_dims): + # temporary dummy func for compatibility + return [] + def move_to_device_(self, sample): if isinstance(sample, dict): for key in sample.keys(): diff --git a/mochi_preview/vae/model.py b/mochi_preview/vae/model.py index 6cfeeae..56f937e 100644 --- a/mochi_preview/vae/model.py +++ b/mochi_preview/vae/model.py @@ -6,8 +6,6 @@ import torch.nn as nn import torch.nn.functional as F from einops import rearrange -#from ..dit.joint_model.context_parallel import get_cp_rank_size -#from ..vae.cp_conv import cp_pass_frames, gather_all_frames from .latent_dist import LatentDistribution def cast_tuple(t, length=1):