From b80cb4a691144646e34ae2916815a839d1b2a759 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Wed, 23 Oct 2024 15:34:22 +0300 Subject: [PATCH] initial commit --- __init__.py | 3 + __pycache__/__init__.cpython-312.pyc | Bin 0 -> 283 bytes __pycache__/nodes.cpython-312.pyc | Bin 0 -> 16678 bytes configs/vae_stats.json | 4 + infer.py | 213 +++++ mochi_preview/__init__.py | 0 .../__pycache__/__init__.cpython-311.pyc | Bin 0 -> 158 bytes .../__pycache__/__init__.cpython-312.pyc | Bin 0 -> 142 bytes .../t2v_synth_mochi.cpython-311.pyc | Bin 0 -> 28320 bytes .../t2v_synth_mochi.cpython-312.pyc | Bin 0 -> 18397 bytes .../__pycache__/utils.cpython-311.pyc | Bin 0 -> 2837 bytes .../__pycache__/utils.cpython-312.pyc | Bin 0 -> 2567 bytes mochi_preview/dit/joint_model/__init__.py | 0 .../__pycache__/__init__.cpython-311.pyc | Bin 0 -> 174 bytes .../__pycache__/__init__.cpython-312.pyc | Bin 0 -> 144 bytes .../asymm_models_joint.cpython-311.pyc | Bin 0 -> 26766 bytes .../asymm_models_joint.cpython-312.pyc | Bin 0 -> 24501 bytes .../context_parallel.cpython-311.pyc | Bin 0 -> 8755 bytes .../context_parallel.cpython-312.pyc | Bin 0 -> 8103 bytes .../__pycache__/layers.cpython-311.pyc | Bin 0 -> 11080 bytes .../__pycache__/layers.cpython-312.pyc | Bin 0 -> 9941 bytes .../__pycache__/mod_rmsnorm.cpython-311.pyc | Bin 0 -> 1542 bytes .../__pycache__/mod_rmsnorm.cpython-312.pyc | Bin 0 -> 1371 bytes ...esidual_tanh_gated_rmsnorm.cpython-311.pyc | Bin 0 -> 1610 bytes ...esidual_tanh_gated_rmsnorm.cpython-312.pyc | Bin 0 -> 1449 bytes .../__pycache__/rope_mixed.cpython-311.pyc | Bin 0 -> 3917 bytes .../__pycache__/rope_mixed.cpython-312.pyc | Bin 0 -> 3643 bytes .../__pycache__/temporal_rope.cpython-311.pyc | Bin 0 -> 1742 bytes .../__pycache__/temporal_rope.cpython-312.pyc | Bin 0 -> 1740 bytes .../__pycache__/utils.cpython-311.pyc | Bin 0 -> 10014 bytes .../__pycache__/utils.cpython-312.pyc | Bin 0 -> 9399 bytes .../dit/joint_model/asymm_models_joint.py | 675 +++++++++++++++ .../dit/joint_model/context_parallel.py | 163 ++++ mochi_preview/dit/joint_model/layers.py | 178 ++++ mochi_preview/dit/joint_model/mod_rmsnorm.py | 23 + .../residual_tanh_gated_rmsnorm.py | 27 + mochi_preview/dit/joint_model/rope_mixed.py | 88 ++ .../dit/joint_model/temporal_rope.py | 34 + mochi_preview/dit/joint_model/utils.py | 189 ++++ mochi_preview/t2v_synth_mochi.py | 445 ++++++++++ mochi_preview/utils.py | 33 + mochi_preview/vae/__init__.py | 0 .../vae/__pycache__/__init__.cpython-311.pyc | Bin 0 -> 162 bytes .../vae/__pycache__/__init__.cpython-312.pyc | Bin 0 -> 132 bytes .../vae/__pycache__/cp_conv.cpython-311.pyc | Bin 0 -> 8457 bytes .../vae/__pycache__/cp_conv.cpython-312.pyc | Bin 0 -> 7287 bytes .../vae/__pycache__/model.cpython-311.pyc | Bin 0 -> 38113 bytes .../vae/__pycache__/model.cpython-312.pyc | Bin 0 -> 34153 bytes mochi_preview/vae/cp_conv.py | 152 ++++ mochi_preview/vae/model.py | 815 ++++++++++++++++++ nodes.py | 356 ++++++++ readme.md | 18 + 52 files changed, 3416 insertions(+) create mode 100644 __init__.py create mode 100644 __pycache__/__init__.cpython-312.pyc create mode 100644 __pycache__/nodes.cpython-312.pyc create mode 100644 configs/vae_stats.json create mode 100644 infer.py create mode 100644 mochi_preview/__init__.py create mode 100644 mochi_preview/__pycache__/__init__.cpython-311.pyc create mode 100644 mochi_preview/__pycache__/__init__.cpython-312.pyc create mode 100644 mochi_preview/__pycache__/t2v_synth_mochi.cpython-311.pyc create mode 100644 mochi_preview/__pycache__/t2v_synth_mochi.cpython-312.pyc create mode 100644 mochi_preview/__pycache__/utils.cpython-311.pyc create mode 100644 mochi_preview/__pycache__/utils.cpython-312.pyc create mode 100644 mochi_preview/dit/joint_model/__init__.py create mode 100644 mochi_preview/dit/joint_model/__pycache__/__init__.cpython-311.pyc create mode 100644 mochi_preview/dit/joint_model/__pycache__/__init__.cpython-312.pyc create mode 100644 mochi_preview/dit/joint_model/__pycache__/asymm_models_joint.cpython-311.pyc create mode 100644 mochi_preview/dit/joint_model/__pycache__/asymm_models_joint.cpython-312.pyc create mode 100644 mochi_preview/dit/joint_model/__pycache__/context_parallel.cpython-311.pyc create mode 100644 mochi_preview/dit/joint_model/__pycache__/context_parallel.cpython-312.pyc create mode 100644 mochi_preview/dit/joint_model/__pycache__/layers.cpython-311.pyc create mode 100644 mochi_preview/dit/joint_model/__pycache__/layers.cpython-312.pyc create mode 100644 mochi_preview/dit/joint_model/__pycache__/mod_rmsnorm.cpython-311.pyc create mode 100644 mochi_preview/dit/joint_model/__pycache__/mod_rmsnorm.cpython-312.pyc create mode 100644 mochi_preview/dit/joint_model/__pycache__/residual_tanh_gated_rmsnorm.cpython-311.pyc create mode 100644 mochi_preview/dit/joint_model/__pycache__/residual_tanh_gated_rmsnorm.cpython-312.pyc create mode 100644 mochi_preview/dit/joint_model/__pycache__/rope_mixed.cpython-311.pyc create mode 100644 mochi_preview/dit/joint_model/__pycache__/rope_mixed.cpython-312.pyc create mode 100644 mochi_preview/dit/joint_model/__pycache__/temporal_rope.cpython-311.pyc create mode 100644 mochi_preview/dit/joint_model/__pycache__/temporal_rope.cpython-312.pyc create mode 100644 mochi_preview/dit/joint_model/__pycache__/utils.cpython-311.pyc create mode 100644 mochi_preview/dit/joint_model/__pycache__/utils.cpython-312.pyc create mode 100644 mochi_preview/dit/joint_model/asymm_models_joint.py create mode 100644 mochi_preview/dit/joint_model/context_parallel.py create mode 100644 mochi_preview/dit/joint_model/layers.py create mode 100644 mochi_preview/dit/joint_model/mod_rmsnorm.py create mode 100644 mochi_preview/dit/joint_model/residual_tanh_gated_rmsnorm.py create mode 100644 mochi_preview/dit/joint_model/rope_mixed.py create mode 100644 mochi_preview/dit/joint_model/temporal_rope.py create mode 100644 mochi_preview/dit/joint_model/utils.py create mode 100644 mochi_preview/t2v_synth_mochi.py create mode 100644 mochi_preview/utils.py create mode 100644 mochi_preview/vae/__init__.py create mode 100644 mochi_preview/vae/__pycache__/__init__.cpython-311.pyc create mode 100644 mochi_preview/vae/__pycache__/__init__.cpython-312.pyc create mode 100644 mochi_preview/vae/__pycache__/cp_conv.cpython-311.pyc create mode 100644 mochi_preview/vae/__pycache__/cp_conv.cpython-312.pyc create mode 100644 mochi_preview/vae/__pycache__/model.cpython-311.pyc create mode 100644 mochi_preview/vae/__pycache__/model.cpython-312.pyc create mode 100644 mochi_preview/vae/cp_conv.py create mode 100644 mochi_preview/vae/model.py create mode 100644 nodes.py create mode 100644 readme.md diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..2e96bd6 --- /dev/null +++ b/__init__.py @@ -0,0 +1,3 @@ +from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS + +__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] \ No newline at end of file diff --git a/__pycache__/__init__.cpython-312.pyc b/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6ee936528336fc0284c62adee6dde7b73bc6115 GIT binary patch literal 283 zcmX@j%ge<81ndXCq_zU-#~=<2FhLogWq^$73@HpLj5!Rsj8TlaOi@gX3@J=0%;`)~ z%#|#ftS=dXN;H{n3H$lGxW+sCI0gsD`#J^$c>1{q-;x3gx_AZ$_&7$!`#Ji$!sUvX zf%=MAfP|kW%PrQt{FKyUgxD?i`1r(}ocQ>a44;8Y7=Bs%S;aVd#yIEard5V|#w3>( zm*nThgLT0Kb$#=bGcv=A5(^4ai(=yAGxIV_;^XxSDt~d<fOG1tq%+T=tFZMI1nJ0A=z?KmY&$ literal 0 HcmV?d00001 diff --git a/__pycache__/nodes.cpython-312.pyc b/__pycache__/nodes.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fba7b548ef4e2e0113cf63bbfc8c480330a35798 GIT binary patch literal 16678 zcmd6OYj9L&me{@heyPym)sk9|7DAYpF<5|wBtS-phXe)zZnyg0mRj^f+}i@A zx$=-LyGqnLL3m~t;b9kfb~hu&)Qqx8rPyRsVJ16SW|OMiYsu5?J93t(@nk|JsffWT z*Gz3~a?W@AcDIDUW`AtuLwxtU=R4o~-1GWQ|EGci3k50hKg)-Aw^P(F@kLGgOyS8A zO;I-}hGOU_HK;u4L7Kd4V%kA1t-y4HIs((j41-1j(?m@%^Po9q8MKghZPXgG4cbUq z7cGd{2koS+j~2!pgAP)rqt3x1(ogYVF?p5@mS`xpWVA({SndhX*4&#~VhneUA8C{~ znj&r6)Vkc0E&mACg=x^um>6@^6RQ}kfH{~LOSCdpHCRPcEHw(dvNnzGT??U}h1szV zX01A|`W&od+>G@ddazzY9iSN7I}}r(&VSHHV0M5NW?>Bk<^Y&83u`2>B7hZVVNI-$ zZDgC+2BzeeZm^kgF{MlyQ%>Ie6SUv+2&dQQ*T~kRTw<7Gd46AzlMADXAQMO?xX?%- z9Eq}iS~hcRkmG{!VSKU1$D`3uBA#R~C*f7!6AKQrlm3S0ARkIbVyvI{ZuB-ZN7;*P zG#<>p$Dn&a^z$&Do>a<)aDs~klaKHqpr@jNEhqaY;>nT2iO@*oaDri@05x{8At-Vo zZQf=Ypp5TNP7>l2G*|Qj5a=Wv`;;b`IarDkH0maIQP;E~%`g?xoCA`d(h6FEo&*vX zG?P|}x~_T8rlY2G0M`h#Kw-@!;z6%~fhB{WhdKj{VS+lWhjvp&z$QH~w4f*bsTfd8 z+L;8SpvkqGGUv}upwz?4n_n@RpHa}}S_vjWOU8sH3+744Va6_)u?1$V6D$Z}G{ZDL zr>ue%db4F#YXpdm(Fiu24YV>tE9;a^u=c}V(C~jsd+F(t?GokJ_C3O{=-2Qc(#l>u zYlQ|1WOIBx7T}ZY7%vwlM}Trh5>X})PegcDwnZaxAeT%v9g$rsATk^a2Kdkj%Zx`^ z*}bMAX%Ju{;jmnIVLZri!DJ+q>BXJ{gc596wr08{FM1$GJ}H;yaQM|D+1I?&m-zx% zk~cs7?v=h>XFGb%b|zxsi4#3%L*slh5rZW$EUyAL62g3g3yzJk+*wl9HZ~!K{I1mztN2GxfX=_9(id;FAws~fw^FgV4lVp48 zO7|BAs;F#6m^;1X-1J1R^%{QR+%(;>qNU7bv$eAqZZ=A$+POmuwI3dlOj{S&MboQk z=cW}srLS6{0DHwmm2MC>c1y1A2gMKclI!r5zE7LBFS=g-@a?(Unb2&<0J;I}76s6thJ|nB9qOYD8VDfGpq|kV8W`Q6kv+V9l zT#6%i>xeVG@W}C*P?lvtU+~E&%%zzO#?bg_g=t2M`7p)kzDqND#=sbFX~G)DbW0nd z`zgPGb}I^;bD|J)E!!;c^>!C%T+@zTYLJBo)`!WoIcv!KiTlit#8XiYk_wB z5y-AkD+a2d3^I2oy`t5)N>{WxPtl5al@kx6D$%Q$@Ki~CN|%S1-&+PZi>J@SL!r4QsdzM%1hJ;1Q4kbsQ(moY zZ5@^0*8bdD$J4b8eM={3-B2-QNEWH}qa|u7_Z%(DmDN%bnJY=79<`KvG8#tvQATZ; zGQN24Ai8N)Ku?(%9SHB|&#)>twptQQstUqyO5%4x(rB$($~_sqVEPXAM_H8xbWz|j zQD>}Z3iM|UC0GPAV-o1^fR6e{z>5SUW4>$oD0BL7g3q7Zi|?K_Z}+S(-n(E0^(#Ze zQ?{>YzuwN-zV^=91l!lxIn6V74jCv(%s+$%QA{YU3mCgfZ-VJ=p<14zlN+))4#o*G zr0BaSNs38cLUI>Kl~RC}!Lyt(gH7OOJa;QTLQMq{sI?oF5~zULotY{SG^3kTcpqc)sp)4k8V{Vx##o0c=cq~YsplpwKVm10V>F04amppKEICh!2m>D zaSd2%giLmz$~TTeBtTd#k0J(m?WFsBWHcDrL^x?yl!9mrPfS*-{1-SXFj>IPu-ar| zmzP7L18vpG$Ae=a?2-W{Bll&xWj{z$Bgc*<0ue^G2cyx%rNCG)nPj;*FPo!@P%sMO znv+Yj#lSev2Kb2>XnE&(4h_yp&(p($Ao1ehvJte=aAcUDw4*HNlffk4Hp(aBAUe$q zr~@IEjGV2+(MRNBR1rXH44ofKMB+&Te1xp)(G^^bjzgx*A$NU*AN(Wa<$i4r%I#MG zh(XXu=GPOrk{!gV9?TH;5%O`rem^F(@5Z8$P%z2LIzAE|AMh8*w!}EjC>W2kQC=<> z3c@aAUyC!PKn%2l2()E6UM^0uv9SaP8^~TB3&!z8dAW#(^?_EB53{5707gS#2sZgV zFI%E8GAx>j#N>jJARoZ&Q9lM&Wp6TNunGmLsPO%U8&r zkez|O9O09^tc@myWnCm5PRNBLIn5z3GCstiPlB^TCOh~L7a7Bi13e68HNoMja!4T= zGal)Wg?d=Jj=dO+%9>=tZ--N6qhSu25+P^~S&MAQtk}M02lPyGVvLoG$uzUyl0vu2 z+)0_+gWcieWG7O$`q3$fCIpc!urP{2AK)1t$u-{!^D4R{w*%ml_3Bsll)B!Q)yuj7 zk^g&QTz&2=PC0#&)3@ksoHi`GD^l)Fl6%vld-Jq?*qz(JTu9M=az+&}VqN!rpRd(Zz>u=mRd;P3fw|!xJ(Y1Tp@~~**wDI#o=e0{$ zFI{{4>f5uUONIVrTj8~XR}YG1TNhqmXct|tE!jFgv+exUR+TQVN!QiSb&qp+LS-=$WmcbcHqqm%?lpU)xKogy=vgjx$X1xoMl#6J+QMrU0SwW zQFVLA%^kPfZ?-QrzjDtncAr?RIGOggd|}i%EwDirs<1R&Rh_P`m3=SCzKwEyt6bBT zuBei$JJMCn>B`!4Rejpmk}j+GqQYKipVod+O_esxomzCYORilxV8twg4UU`K_`k z5I~<(ARt=;$TPsv0A*Vs0LD4-M!}mk6bBMVyk4|`Z^AG z_sgdJC;B=EdX9j3*V!@9ec;IP)3UXzyT9{z&rt*uTIJfH8!`!E`*5h^S9;UMFF{u5xpFXFSOrHxU#y5s0xTi*bFg$#fOFr#I3jGh?0hFcoOh%Qsc!dMyG zEgcxap!hJl2o0q|sCC>zf$!`CzN_o(?KvtN`Uj5p^c|3?0l$XBEs@Pca3K2z^Gg?s zM#f~r7???8Nm;+Y_ejS8hjvlM_e4>{@B+#A+ONM+@`D4vn0#mN9i7sVHU|7u8w*|r zo(BFX*|_h>k>2i(K2i&Pn}G4N@nI0L1>kK2;hSM@cQjl*^eCGWV<>2XQRN6xm&!9@ z33*Ibtl6x5ll5I}fKrv1tgh;GSbyjLT zCsw|huG=slxmO}?+#}ZR1uVDyvDt*is>y62rw-P%t@B7Y{M}nXyT22O|<7_v_B{=X;g5S_zz;VzS_xqT= zhuI%s_J@#x?h8v&b{m}fIklThVy!``;XcF{IzaYSHp5(%eSX@DuSnlpfIeBamT6}g zCJxJ=0tbo!w3VohJ0$mxMfc8W`@_xKr*+rtSM4cNrDUpH zGPR{!wylOVrFVBs>r$pF$yAjx)k~)OCDX2TVOw5S@lbm)+0-^($>f~_m1kqx>AF#U zy*lM=kem%E=c|(Q)qDHJfm5RM)kWv&X~RQbbIP|<^6gCd_Da6JDc^wP8xT*NS@fL+ zC;hZxhMw88VxVkA*V?bP&(c@-&h~$7^)44xiZzEGy#AnFe1jFk(Nvg|!kow_rSQ1O zUKFb?Ef!rCZI{y~`wS>oo@v9U*7CH?k+RiFw%U}+Z!X!kLDywf>FQQ-<8G;X zPpZ08s_tB@Mj7X_Oq<~%!0E}?#WX~M58SbE-+&$G1k9!|dmA%UySQ&+b_KJ4g4yq3 zhG(}{n?f1OVfGIE=TSswv?+5|n}X6xDOJ*d+LRBqsq&0A^@Yx(DpJ_P`ic~4k1wD| z8Q?|%)hLs!c!m3ev9T!JPy$h96G{mpg`-eFDNJ-IqD7%w5A`RaIq5))TKx_>715{E z`a%X=Mnv-}Vv3m(a*eCjm5~cw_bn}GUY=WKauHF{M>HEh&7oU|y9t@B?dcnc9EJIC z2jKM)1+wB1$_s7=vdBZ2_#-@#N2ohQD3JZcQSb=&MR6o3Y8!VQ0sj=Tj6zoQ8u!m~ zZ+l+jt|0>I7*Q1a|M#29ze4m?px&o+@II9MW@3KIrm+MMO3g)9HpSUta1}sFe<{L% z!e$sD>_)bryCBSgS`V%v<umoD#^;fiGS>`gHy9{w z_~)sa@4r`ke?aUXkZMne?vv>X?|kT9{e1>++3HS;o;T7pzJI-c8t0-LvJ?Jwq>8 z7k#K8Z-AM}MxGEAuY3?h$;&DsP>+X4$P25}rt?gEEpTDMY36_?-eU-!>;Ex2a< zi$y!YU6SXhdFb@a9uOTqv9M7zH7X`P_pfoaZI}_}gf}YOJvIB6P{@}L1zGt3rIl)` zqz2_fbH03dqAS-oDScr#A^4w7D#<ckD3Kg>(gn|NbGFh=e1d0R0IiAQ4N zF)tvpUWMhidNUW#sCSP9qhWHls?>o`d$lCtkYKqM$RhK=-X=G$>X3+%_4^}LiZuJP zfA^!Z#{Qq|CDM#2F2jF!=kdpfhmP(gBJ$S?OZn;Efyvz+}V2ld>GK5@niSM18n z$W`!j#=|y#yhPCmWu|*0&sl{eqObSd%ZfPnLP;8 z#Lb7H(3H4=n?$~wbzK~awycn8zNq{hn(_F}5m8xEn=Y#d-^D!_oMf-!AZ_c`z)Kxm zzZ-Sn?66VI8$n(UiN{Wefw!P0P6yDBy6G)MYYJ-8jT=D19v6=eh;Ie}6r;~#ecL&# zZ+uge-mvAq;Q=q6c~cyTN!#LLOCsH}`5ymZtN6yC$V8;AqoV(Ox~6%7fd&WZ1K4`Q zL2O;qqj_vKfG@?mO1J_u*7qFlIM59RLjBKvC;d)?tA+RBcEiIJ$TVTwWHbu-uce67Pf!WB!-|qGnKeei|HA z|ID`bb=qdtkz}?cQ0oOEz-WM>dM04TNHZj@YFVm6bI+$*7N~Wg4y|_uZdOMD)h)Ys zqbT&&r1g9vOkrM|x-|{p^{zuzw6=qrM9}gU2r_`{T6N4R&5R~~X8qki3+nw;!IXWf zaLO^|T($1iF2GrD^b!&sS_-x6_3I+FWl-yQdgfE*!_<^J*`U@iMS?>pQiBb; zsB<9Fr#ygDu?oj26zAdm3E)%$PE`&^C=|-^Yf?`GdKv+&Y61u9o2n7Yg=+GpRWn8B zn?ayxn4&l>)R4GSFV-?nEE60BS_>Qx0>t0V7)OsO3 zN_|*QcxWA=*E)5+Q*}bsXp_1txu=lDpsDwOJqs1ogUc# zF4+HK*a55HhIc1Y4DO0hN%A_OL_KZ5ugu^(0N*B5z**_Y*~i$|-R1Xn)K|2uXRNuFAYFmxm~3=4aJTjOt0gyQ z;B$Bjz9xN?d7sh%9hV^WF)kAM4e#p(-av?biON-hXKz-!gVtp#e7&FJsfQ?^D!oIw;davJD));1&k&GsIHJ2E`-1DvBQU zo?`g_M7PF?0N)rrK0D(g`4BmwQhUH&l_TM3s=?Pv3Pc*WXm{+5+=NBjI^US_BMb~U=Cm}`rp0*6q z-$PUWJCK33T~R%I;r6AQm*y_qoSfhM!yR{aENs5BYe877=$h_bc6nyof97hQq0{xg z?;iU;{r%1J!hGyrP-@%zW2e}#Ppa>{f8oLA2ZD6yjCA0PRDTw9?~j8+A2aO7!z1G8 z`Hy39F&LKuiFAo;)+Uws9&0sk(M69bjjw2ie!uqpV;|77n?DH7Z2Z{ah2ZZ}KlpP? zT+;`MWxR@LOfcd9Z-{0s>AVQhsHF4ti6+-#Ezul=O;Lie{s)wvA`}jFyY)$B8nmZ{ATY%M2OuWJ!#r6&DEZ<{~I#dl3N>u%fZUk06yZ^ z%sCQeyXs%a_#z6_cg2}doG(gL=R?sdpn~Y&Ae=H9i9s-9rl-OjB!fb>J#W`#9iZ?W zD#_pkORD?*nijnM@Iyz>kWo5LO88 z)&O=XTXHY(TQhQjz&FA$$gjUTDR+sV1O8;u)7<>G0OwBt_Xp&KK;MwIxaP~L(E$06f2#AyTD=T)QA|3NvC4s*ahj- z1;ncVbl*{_VV8LFTgiNLjZ{!GXSmD$ zFn%X47Sw!Fu=~?u->06c+2QZBzHfZE|L|YOeiV}&oimo%-uae=@I7|9);BkhYIsd* zcq`0dNWMMyYwt6u{fq>E`xwc`E_%Z=1eBU1sro%q{T{JnK&l@Q9o}?ZW2&xQs%sbb9+T>h z%@h)pE~&l?DnWLZduDXYZV>8ojq`0kbMKt#NP8-0I+v^J<~Gl<^XGn6_3A90uBn~X zE!Q^AO^VyQrKawm)$X6|NN?IHw)aY#dd14tbjz-L$E21$Aly5CeDO!MVt+_#VP+4d zAqKDGKK&QnQu96t#gm#(%=V<~TV!u7{H47c=6gi%_O!Qlu47)8YVMTa&(|q>JF}pD z&`|R26TSOZ%(~hx`mxsFtDfz8Y^QupbICv1Iop-4^xi&n^U%ERlgf=NZmRA8z2c$V zmA5T7Eq}WGdpkeeIX|#)VX1!SC*`|VYN+b=6))v--#B#r(46*@65mP#<=M5;3{jIa z*5&G^dEFl;uJ51enBmj4^)s(8*SF61|MNXJjkEMDUMp8t&0hFm`=|9S^V=5ag^qur zzw_$Cg*$uhZG-JztUm#hsi;&jw$5Jw&1f8Sqola|$YT9bs4cHcmHVY~pvdoE`19I7 zKPJA?Bkt_M{Lq;PM<5e7osnA3eo}sJrI7M%S#eMpqGDrx-F@$XH1II-Zl?_6r? z{G_C7#YNS$eNj$TcyIe}`sWTU99#12m@$3{o_qVA6^5qFHbTa4FPW-g;>DgD8?SHt z)~0D=y4W+@e|`5{r&QDgZ>BBzg*-j$>f zhS@KXfxQG-<}UrIyOL_WE5WR=EPI!Jj@G-C%M$DY^qg9FDEBX<<9Fe8va}fkcl{X6 zk@*?aVf@|MWC7u#-pqyUq;hxeRea8q4y7^|k?5!U{4E@d{l5j7Y$8$DZSlCQiO0DS z1eKB>1GFh|$+2J@JiG9-x1{V?3n3Z;#C_s#2inF#Q1Bf3c$GkD0~?7a#&{xmWebGT zU~En_GQ@$)10>SXo?Z@J1F{|jDE!(}6Caax7|<^pAx0a2(j{BrA|Mj#Bys<;uBUJR z5!npE<-JhNa!UBadF-JoM(zi?G7_8C?AE2{2S2r2qGW&I^(PgBKLj47&EqKapmZnxfS6*mojLY+w)%B~Ej z3}up`Y}T}BsJ){76}9^_s`X(eu6ghEy>nZ>x8uVd-`o4)-i22d>)RKLcZ-zciLQ|L zd|7vzrmG%Pn16Xu(~4!tzI>BzqnjR6kbU_U&C|5|F&NFXdnK-=>RwHkc!9RlWsPt( fS*8)GV)$iY-L7;=eYzN~IQ#&uHb|6*kox}v^5x`c literal 0 HcmV?d00001 diff --git a/configs/vae_stats.json b/configs/vae_stats.json new file mode 100644 index 0000000..e3278af --- /dev/null +++ b/configs/vae_stats.json @@ -0,0 +1,4 @@ +{ + "mean": [-0.06730895953510081, -0.038011381506090416, -0.07477820912866141, -0.05565264470995561, 0.012767231469026969, -0.04703542746246419, 0.043896967884726704, -0.09346305707025976, -0.09918314763016893, -0.008729793427399178, -0.011931556316503654, -0.0321993391887285], + "std": [0.9263795028493863, 0.9248894543193766, 0.9393059390890617, 0.959253732819592, 0.8244560132752793, 0.917259975397747, 0.9294154431013696, 1.3720942357788521, 0.881393668867029, 0.9168315692124348, 0.9185249279345552, 0.9274757570805041] +} diff --git a/infer.py b/infer.py new file mode 100644 index 0000000..da83061 --- /dev/null +++ b/infer.py @@ -0,0 +1,213 @@ +import json +import os +import tempfile +import time + +import click +import numpy as np +#import ray +from einops import rearrange +from PIL import Image +from tqdm import tqdm + +from mochi_preview.t2v_synth_mochi import T2VSynthMochiModel + +model = None +model_path = "weights" +def noexcept(f): + try: + return f() + except: + pass +# class MochiWrapper: +# def __init__(self, *, num_workers, **actor_kwargs): +# super().__init__() +# RemoteClass = ray.remote(T2VSynthMochiModel) +# self.workers = [ +# RemoteClass.options(num_gpus=1).remote( +# device_id=0, world_size=num_workers, local_rank=i, **actor_kwargs +# ) +# for i in range(num_workers) +# ] +# # Ensure the __init__ method has finished on all workers +# for worker in self.workers: +# ray.get(worker.__ray_ready__.remote()) +# self.is_loaded = True + +# def __call__(self, args): +# work_refs = [ +# worker.run.remote(args, i == 0) for i, worker in enumerate(self.workers) +# ] + +# try: +# for result in work_refs[0]: +# yield ray.get(result) + +# # Handle the (very unlikely) edge-case where a worker that's not the 1st one +# # fails (don't want an uncaught error) +# for result in work_refs[1:]: +# ray.get(result) +# except Exception as e: +# # Get exception from other workers +# for ref in work_refs[1:]: +# noexcept(lambda: ray.get(ref)) +# raise e + +def set_model_path(path): + global model_path + model_path = path + + +def load_model(): + global model, model_path + if model is None: + #ray.init() + MOCHI_DIR = model_path + VAE_CHECKPOINT_PATH = f"{MOCHI_DIR}/mochi_preview_vae_bf16.safetensors" + MODEL_CONFIG_PATH = f"{MOCHI_DIR}/dit-config.yaml" + MODEL_CHECKPOINT_PATH = f"{MOCHI_DIR}/mochi_preview_dit_fp8_e4m3fn.safetensors" + + model = T2VSynthMochiModel( + device_id=0, + world_size=1, + local_rank=0, + vae_stats_path=f"{MOCHI_DIR}/vae_stats.json", + vae_checkpoint_path=VAE_CHECKPOINT_PATH, + dit_config_path=MODEL_CONFIG_PATH, + dit_checkpoint_path=MODEL_CHECKPOINT_PATH, + ) + +def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None): + if linear_steps is None: + linear_steps = num_steps // 2 + linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)] + threshold_noise_step_diff = linear_steps - threshold_noise * num_steps + quadratic_steps = num_steps - linear_steps + quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps ** 2) + linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps ** 2) + const = quadratic_coef * (linear_steps ** 2) + quadratic_sigma_schedule = [ + quadratic_coef * (i ** 2) + linear_coef * i + const + for i in range(linear_steps, num_steps) + ] + sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0] + sigma_schedule = [1.0 - x for x in sigma_schedule] + return sigma_schedule + +def generate_video( + prompt, + negative_prompt, + width, + height, + num_frames, + seed, + cfg_scale, + num_inference_steps, +): + load_model() + + # sigma_schedule should be a list of floats of length (num_inference_steps + 1), + # such that sigma_schedule[0] == 1.0 and sigma_schedule[-1] == 0.0 and monotonically decreasing. + sigma_schedule = linear_quadratic_schedule(num_inference_steps, 0.025) + + # cfg_schedule should be a list of floats of length num_inference_steps. + # For simplicity, we just use the same cfg scale at all timesteps, + # but more optimal schedules may use varying cfg, e.g: + # [5.0] * (num_inference_steps // 2) + [4.5] * (num_inference_steps // 2) + cfg_schedule = [cfg_scale] * num_inference_steps + + args = { + "height": height, + "width": width, + "num_frames": num_frames, + "mochi_args": { + "sigma_schedule": sigma_schedule, + "cfg_schedule": cfg_schedule, + "num_inference_steps": num_inference_steps, + "batch_cfg": False, + }, + "prompt": [prompt], + "negative_prompt": [negative_prompt], + "seed": seed, + } + + final_frames = None + for cur_progress, frames, finished in tqdm(model.run(args, stream_results=True), total=num_inference_steps + 1): + final_frames = frames + + assert isinstance(final_frames, np.ndarray) + assert final_frames.dtype == np.float32 + + final_frames = rearrange(final_frames, "t b h w c -> b t h w c") + final_frames = final_frames[0] + + os.makedirs("outputs", exist_ok=True) + output_path = os.path.join("outputs", f"output_{int(time.time())}.mp4") + + with tempfile.TemporaryDirectory() as tmpdir: + frame_paths = [] + for i, frame in enumerate(final_frames): + frame = (frame * 255).astype(np.uint8) + frame_img = Image.fromarray(frame) + frame_path = os.path.join(tmpdir, f"frame_{i:04d}.png") + frame_img.save(frame_path) + frame_paths.append(frame_path) + + frame_pattern = os.path.join(tmpdir, "frame_%04d.png") + ffmpeg_cmd = f"ffmpeg -y -r 30 -i {frame_pattern} -vcodec libx264 -pix_fmt yuv420p {output_path}" + os.system(ffmpeg_cmd) + + json_path = os.path.splitext(output_path)[0] + ".json" + with open(json_path, "w") as f: + json.dump(args, f, indent=4) + + return output_path + + +@click.command() +@click.option("--prompt", default=""" + a high-motion drone POV flying at high speed through a vast desert environment, with dynamic camera movements capturing sweeping sand dunes, + rocky terrain, and the occasional dry brush. The camera smoothly glides over the rugged landscape, weaving between towering rock formations and + diving low across the sand. As the drone zooms forward, the motion gradually slows down, shifting into a close-up, hyper-detailed shot of a spider + resting on a sunlit rock. The scene emphasizes cinematic motion, natural lighting, and intricate texture details on both the rock and the spider’s body, + with a shallow depth of field to focus on the fine details of the spider’s legs and the rough surface beneath it. The atmosphere should feel immersive and alive, + with the wind subtly blowing sand grains across the frame.""" + , required=False, help="Prompt for video generation.") +@click.option( + "--negative_prompt", default="", help="Negative prompt for video generation." +) +@click.option("--width", default=848, type=int, help="Width of the video.") +@click.option("--height", default=480, type=int, help="Height of the video.") +@click.option("--num_frames", default=163, type=int, help="Number of frames.") +@click.option("--seed", default=12345, type=int, help="Random seed.") +@click.option("--cfg_scale", default=4.5, type=float, help="CFG Scale.") +@click.option( + "--num_steps", default=64, type=int, help="Number of inference steps." +) +@click.option("--model_dir", required=True, help="Path to the model directory.") +def generate_cli( + prompt, + negative_prompt, + width, + height, + num_frames, + seed, + cfg_scale, + num_steps, + model_dir, +): + set_model_path(model_dir) + output = generate_video( + prompt, + negative_prompt, + width, + height, + num_frames, + seed, + cfg_scale, + num_steps, + ) + click.echo(f"Video generated at: {output}") + +if __name__ == "__main__": + generate_cli() diff --git a/mochi_preview/__init__.py b/mochi_preview/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mochi_preview/__pycache__/__init__.cpython-311.pyc b/mochi_preview/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8e3adf059de0c622a3d879574df7b40da568dfa GIT binary patch literal 158 zcmZ3^%ge<81PM>X(?RrO5CH>>P{wCAAY(d13PUi1CZpdTZlX-=wL5i8IHkR8SRK;i>4 PBO~Jn1{hJq3={(ZUp^#5 literal 0 HcmV?d00001 diff --git a/mochi_preview/__pycache__/__init__.cpython-312.pyc b/mochi_preview/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0838524a56e259215e6c8b36bc41c81f94614141 GIT binary patch literal 142 zcmX@j%ge<81PM>X(?RrO5P=Rpvj9b=GgLBYGWxA#C}INgK7-W!()F{7arBJI%}>tA zj4vokEz3+TN8`oEXXa&=#K-FuRQ}?y$<0qG%}KQ@Vg>4E1ma>4<0CU8BV!RWkOcs0 C)ge>> literal 0 HcmV?d00001 diff --git a/mochi_preview/__pycache__/t2v_synth_mochi.cpython-311.pyc b/mochi_preview/__pycache__/t2v_synth_mochi.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d208bdd05d59ac1a4f074904dd395159de051550 GIT binary patch literal 28320 zcmd6Q32+?OnO^tYColsHuE909Z-5{PfCNd9I7sn?M2k8MNgfT+gM*wIfIR~qP>;>H z=+#h#iXazVg0ysb>~LksqSj1Su5hciQuNy1$Xn}5H_~2Cx(W$vv&^nisT3H~iQ=S^ zeDC#K4S+V~+O0}Yrr|6biuO>^B>7e=CY(6ziQSo z%)1Q7aBPT~P@n7sOW$=9Iy`kDeb_KzU`d!hWDJ`oOkwkcnZ6rBmauigO5JS8Hesjn z91{-ubcPEi3e#b(uzSLtb{B=+lRB0O7f*Q7J|*cGr4yxyX$+Nx%O}di6%!RS+!XSL zD<>+c+Z=LaQ?i7r!qpSi;hKq>aP34b4Y$J2H&I93worYzVWJ`2IMGPo?V+Y{^F(vF zWuk??J3_7Dwuv^lom@evJ=`(T5$>Gm40la*g}W!ZsefUpC%j`~2X(tbz2Th`JE_|p z+7;eCu^aB9i9W7)qF={6!*HH!3|I02TJ$6Qsn^5+R~mGj+mm_c$~14~nzxGlwj zU-00RfvZEvAzFue)S+Qh9fmpfd8V!LGZJaMjaAI^0X{Yx2({@H!`SR(Ofj6Cjm8jQ z4b25O|I}|8`~JwJos*~s+sF+LCrPG3RX@*@kO&=ndGHQbJVm}F(}CCwDMl#dpW*}IU=*IlGqWff$!aV}%du#?6r7!&iA8(P z^nJ!4-^t9EBhL(({mO@zcbOnVx-PmcJE7-v*O&;4UyvxM$!p>#gF@lx8IQF(qOPB)V zCn?a2V^Bl=_z|?7DFhOmywYaiix8v~PVsYL|2!YW$DfS^ImLojIQ8t<;ggCrTT4VU zM1!HJD9PIQn>{_@NUSH)GkmNkJU2NrOQK(#4PM&x7TbBzAH5QZ&G;!+_xu&b;`h%+ zW@CQ;Vo82ucW1nJApK}Hz*T0wqEW8c{iB!f1;vV!QpHKR;$*_GRFJ^$S3qFKCr#v$ zzQ~YG+v^YDc$X23e2rj?WkT^hPpKhVd>n&$zW*O$5dXIR17LiV#r`@gFv)ByXT9MN zSV5N`ThI#zPJi3*0m|}Gwmfl@U_yHv11} z#e$as=CXvM(g%q_uK55*dV7Afj2GyPjsV&SF@7O38Hi!5@U@?id4?1{tU{*#BtpE+}sZvZ*8nfXRM6qEWT z`8K@p9Rzd|&`!c!nO?;C=N4l4yiqf}>6i5R_$GpaUV*VMv`w#2y>|gAdLG~^v*BUf zl^c4U^dWp$~_wp4xdhLfpkf5e!pY>x>LE#(_UjLRolYWWuA_aQruaS_EMqnO{X*!<|q z+(L-B9e)3Xg+M4BV)pyFxkn}m&9dg610YV4e=#_v`};m;A-}_p)(jQ8w>#s z)_98DVDg};)act_o&eMgNU{pMs2IY*Kt$0;V;q@L35`x>$3LJL@kd99F?33d5G;q^ zH!vHJ-;^Q83dGUJpBSbSm~%)-y*OQZRM8!0o;+r7`VWxfN7?a2V7DJzID-~%Qg=C> zK3^f|o-R(hJcbC%GHRI+GcylnicOyaqs?s*?Sqm%9WGKJU6CX6G#OyVX3zMdf$)4Nh|z=&iQ3_C(bpQA<0ogj&(Ki<{*VjoDmD^WLy2!(FbyKt{jxT^o5sOYp zTtBd~`#S^g4e=_H(f-^I%c=Ah$3X%)Fgln3@7lO2f{ySrj+rRzcqW z%*3k{h6!^{7@`b!zJ<+ppG|Z01WRS<-J#hqxFRwKsC;i!v2tLkko7F56oHAF%nZEI z7`jN1h*hj(4JVitWe?Cr|pHKRW!SF`l%dVx%2IG1I=Y zYbQ@Wt>PFN9zF3AaQx4G>9As$T;KwVo#Zvc&k?2`JiL)IvHS_3^P>bD1&A5|&;imt zpb&ii!n}Vf%FPpYZ<2oz!2yCd4**8a;YwMY>)v+R+qYicoGNcvuc*0MO~HoQP+@z- z06ZpO!^U{qDds@RTazlPOjXvV%Bwai3!J7$48UUot{(n%4P&%jJAU8bT{CzkLzQf( z5)D-;yX)#nO2!I14(=6P6rI7dM2+GwgpepvzY3zoGyrz+JgE=E7w0qC0;?S!Hx2ce zGV&*kSe9(_I|GtnFo?j&5qDgNVL2}j$zQ!VqlS#Lo#wVl%iPEHu{^>)mt{SX zZnmpr(j-qU!E@%@7HxV@6Dw{I49RTA%1@o!)LNKGo(_*8)6RKX%P5^hw{J;JU9m4@=`Xto^<=2O}wJnw;J!Z?L#R1!I6ig5ZY&VHO zX%GAiBu3}Dkw7%a`M`hA#A5T&!JeM!+1Sj&+3v}?aL-s^u50+%$-vp3sZby~)0O2u zd(MXD&h`xS_YLgso$L(;&+h4;+}j@v4)g|g?cCcN;QIIW@0pxBJGDEwJJ74f_9J$r z=WLLVVO*fxFQl%9ll?ooPVU*U-Ml*vv2-Df2t+WwI2#2=w}t7?cBDaHD<+G_=h*~yuVAOjsi6k>GH7nzG?#FsP|uw@3H$KZ$ol*%T-=pYz}$(ZjF zn#*@ScqK}i3Wyrf8Q*Y*OV3H%WK8PvT?9MA`6<_(NydNv;-GI~o`^TCBffTD>$tB2 z_2_KNl$=zaL{saoRwM}7CNdZ4N;hO%R*)mvJmH!}TR>V-E8&XS>)8(67 z@FNHUPLvi5Me6YNY6>&C5MBT?F?%tXO(ZQ4MNzn?$%Du|Y9?p||JeXEM>I1CSyTt1 zN?Ehvh4AJaG8xb~YSAfcM46;BL=A(cS_ppje{n%L(i3jUNasYNUa$nBeoO~M$ED$_ zoSI52rQ4TJDc%0-wK)_$+Dl=ND0+MkaQO>i{yFqV#fS>++NqfFrJ#nIBU2phLHt1y zYC=N-Rna3m#V{KSh9R=Ogb@B^5^Nv>p5n@Uj7&o*#oLrZvE^!@6mL`fC>eOubfNNw zrRd)f_=4z*`9#fM0OA7vqRkKrz#cX;4rk)RTYZVXZ#}|lGi*A&xTbxS{v7~9W_eY06 zIU(;pELx68mLsy|$fu@9dR-C8s>pQpIFhgIxbN*=^L9(#9kO@FeQ)2Iw@>m8$ld|b zP=;`CW2(LnH#h@_2~r^iMw2?N8>6$+;ecX{&dg55G}AhnOv&G`8UcuOC(x-S^W@k? z_`?rF{UU~P49~j28e0E(k`l(H*(iEiED+(Ke)JL6OVyXe=7{RmHv=UjhDb{H8Tlno z%Au+=BIN{BrJpYlqfvZ(Fgh1nAgY9DCmlYZETmEpK@QJhM4W=oEp#OtK`lrqcrh5t zt3QINQ zfN#gPvKS2=d~%~;0D5Idr?2<5?1)Hn9ZZ*EDliFd(>HX$*Q*`)&^L~d%xMvaFyg7s zXu3EUoHw=era?9r8A*5gW;vpziCytcLA^!VI6YQTS#AKI8zT^p2OkB95)mcCvD#@7 zpoIo8@uVYun_$!dsGZTh+y@EiaPshbr&fohs;)amsj5%1_lt&pRZ0S8_9{rbewBdZ zP0(IW&m4rF8Gm~G8G_I@qVkYJ`*&T`EG?Xxwg2a{;SB7=#uLoaRS9Do)b z6SSVovL>AKcEKi=wD8&{CS}5&U@o&04p6sCDLS+B#pj9k^%POjQjI<3;}zSbIX=Xx zI(VC^btg)BrQl*92tg(gi(^QW+05qlF(^x61_NwomKlu^-Fo`cezmw&u>SrGXDnMDwK;6h zs&YCl7V1VsMS5Xb5kD#iImSw(-`m$J7K}&4XY0;J-di(Un zf!e1T&h$MN&`hC)LhJWf&IUg_o({6Kh+pAt*9R!eN7?eAY}r)pr1 z?VN`z2^EJuaR+F(BZqyiG~*WR+H~X2SWhO5VPZtLo4(FvWnjUnjiY^w7YGGOZzfj$ zlPlF?qD*<>>$%Dd?zyV0D@$d9V~U++go0b_KV{>EU#$JU7&fK&GHppoYYTU6N?Z15yWl#P?cq<+%jz?EXwwl~{%k5s<6kh%Cyb+- zG%s7fU=NKwlPM`iYQmJ#E0*6sH|hJeHHsH$dO|jyP{dVcyWJ)o{`T#MZ=0R#w0VGL znzJR#?|1pRRs93z){@Chuq0bE?);~g%1G}3FH=3P6j0L%s)&F<9aZ#`^9zcFdV)*& zjPMtqT?N3ZHevrhBi!C+q}HFAh_7B1MX$lle_i#U?ks3CQhybEgP{FE2KO^U9&N@W z%$Ltm}bM3PixlM-}H;D0_0R#pwV)1p3%4A_w`XQVcjEB*|(hHR<>kdjf+n(c!*k!J-{|U3Zar zk$D%4%95234!;waJe36WimrQyVulI{%rA!$U_CRvj&qgD>c6q>pAJUwHTa>$?ya12 z)=kPl7iDHQe70RDXv z-*6=m4k^}XU@GV*Nht=B3QrWX(73>`DHep9W^SG*oRwmP0E~WNjh>&Kr`8`tWfW6f z8Yl(oIt)IN|9mhwuUOSWL7L*ppr%-{#t*$^)bEd$;N7?N zI;2t}C4ui?~UIaPZ(rN#RF$S;;3#6ASG+>tU?F6S!W| zAeZd=(NU@7fLwAQaU|vSeW&C7j`zB5cHQ^3uX)=)?z=rCc?V_h;C=6bHSdADLCHHN zd&d%EOD8srxrEf@pm0K>Py=N%$j@x6`}5B~-|Ct6u`mUNiPMS2>#r@nw!sux4^XgP zS)Zz|N!8S)8h0d5CQqgsyOPI}$5IWQ$-T+FsoL(9udRG7Rau>?85ZfCs%}kH)upN% zQjIW(sKg>uRRI>6stT~kR8@d#oWwGdw?S;^mnsJ2ih;!Ol*6@j^1h>W&C$B*y*nm4 zS|!J@>=+gu!|9;vHAl7NsFk5abJVXn>i?+k){x}bDLZ!Fcl50}`tAfJ$9~zdUv%u> zXv&nF7XJ~c$m48na`oEeq79>osUozhjq)|gSt~nhHyFFMh=O%@$?{w>D!DskcgF@} za~7>$*wE?R2z>(hwS}pvzxfqtKB~Lr>h5LRdVTXBU6ES`?%Jf5QMqMQsvnc<$Ce9I zuG;&qhBa41^0?&ckzGBat7pBT_12zSLvq8ul&21=#zvA)V^OMUckX4`xZy@QXkjRL zM2hv8fSliUE?W0(u=JVA(i%S^b$^^~fLv2-a?#W@F;#?e+LB^nqvULoolS)3MrfxP zVIMf$_Z@X>j=Dc;6FUz|^@rs8Ls=3aM4v%aBN3t>p+v9s&rdSnohJHik+no`u^!xD z08(XDAk?FT`i{PR>aA1PpIv%3@ocKL;pXN0wR_fT_eiyaa_!&-V{{?{z&rM3!*VcH z_Z5-e%LVIY-u%hwZa3Z}?=cy$?3i43Y=hCe52osxZ+X98{k`gs7v!E1scuxR8(ltp zV8c$bG-nD*OZ$iv`7r_X^@)JC!&WQ|btM%yj(wX;4oNjT z<(i$6XP4~Rm9VatSKo9bzr1=;Zt9oH2judB`{jdc<%4&3|JgpN{296Y8H{_*^|GoP zLh{&ZP;T5UmG#MGefP`utd;G#>-n=vsqC;^b{JTW2Nvtw&bOS`3zrHL^cN)zMlb zFM3%(rI;*lz``UN^G9GgKr9ENSrs^oGLosCD%$1&wMw?l>aRFfP6Zt`Fd{J--X!SH zQ&+S|H64PPg&wAndO`Px%(o3>2K$I0((CJN{wgiZF^D0p>9+X;1F-O?UU37&<~-97 z&IMuJa0}S^YLi)(itk3j$XRlN23bkvthx7OHq}h(E zE1p>}PqT5$Gy{cVp5pO$r==AEA^KalH<5t}Y*w9_s}1Ud5^L*_Q&Mj)fNHtSp`h3$q4ZT|g>+M2q3qVM9kJClc z#m;TIDb_jW7D5~`Q9T=sT?z&xz8yZSHTiyH>hw)xODwK!7F6Hr0}}lembE zPO(zMU>E~Mm4d9uM~z#Ke8^JyPL&X`uEbv@v0ed~CAxR4(3-ALO!AxpZ>~y21^iip zM|es82LyBjC0~d@1U^es| zK|T%8W>XDg&FO)P{{=i1c0?)6Si}0sriSULVpav4sN$In%u}P=Y*qb`6)Q5n+FB!s z|0T)nuK}V&%*gT6>e{Qi-(U{5d?Zn{%VNVeEH&MNkcaetNOBcA)oiJ7*?j%v(#Zs- zeW&Pd6di52*C8E;q_W4Pw!YxRZ10^Oqwps1Cq{IaS-7d_%6?C0eRdjU8g=e!20$ z{l<~C#u2G;Ol}+#Ep;hJ@q@w!xv*_D@XOvYvFq?!;bAdzrwUu;!tOi0zwAFI?mWI$ zcwEff59~Eb_nN(4wAViv+$TDGVtt=jhr1uQu0CR7;!McuG4#hs$N zGhs@(ixZ{?7RQqP?Si)omSZ>GkSqt|(Z`2xpOV@J<+i~`W`mWOlv!<9LI$ugR%bf(1<6t^TdGA% zb*ivr`Rvl7Tv(eN{$9(+Fx&QhbOK+?M)(UGO$*T8yMb91f4g|MB zxneL;00W5omZ~*N)yh{SOQ&q<6fK>pLP7)D&xzb}cH;j9FPqt@e5+UO=ipy`lJ&H9 z`h({DSX`$i0wHfBT0YZ~r?=I7dcKM5L&;5JNupZYa)B+#G9=sV4nhiqjafQMRAuS? zL1*}X3$He_V#C%lX#WB*UZPWkVj$L6>iB*SD7m=-y46N_Ky65^yA?e_B;5gYKuLiA zTf8jRZ}ZJJk=-ux!M_G%n2g#>mZ!eWNp_e?H$Vcho@P^J)$hLX%{ShO--ttGa-Jp< z8Q`*hy|{X1D*61)^H50amx}ku#e1-Wz>OI7t*gbiI=zxyyZ$|$ zGXI+cPjcwZn+mWHn9(GB*Lp)9$f zB(8_KWH$9oE>BUq>?$p-hyF<|!Wq(Rmk;~2jQMNpsL2y!K}6AMdl`0PXz)2xF2}fG zQ@Z9{n6|cj0aD2e%}*r)>y0?}4}yVb^F zZDQ21YHuK_+B|9~&iHL9<)*mXmePv!z$S8@MU07@DV?qh^)3<8VcSvB^+CSrI`aep z$uP|d#%a~E4C1SrCa)7rpbgXQQtkZk*8a%8-X%>?HezI5bGrGf<=AC-EmipXsFYdP$;+r_w9168FjpftB4nCRw%PpG!PJr z5z8&Ori}SxiB=%bS+CsoZc-e0(CY0v263~3;Pw6atDbb-zlbOf;9o^%u~^nB05$P& z)m+W(+D#?)M7lSnx$g;S?o-p;gni%%*auYX#2_~7>`xqkU6+ybU9-eXxH`>To*e?utO-%8^J1s3gV+uUe2HN%XlS}MJ0kqLs3~OygnG!3#Mr{0tF9OL+Bp#Umrv$i=IAM(`JgC%{qDigka ze)%KBmYnxx;*xKY%-Xc%d^hKa)UwtqaQiZt+V3Ct#YQtB_)1-Z?_7Q69qugc!=0_W zc8r2$LicsX-|!+ByO+YZH-gHN^F9!E4Q=l8=WvX5({Z zgyy!yYkNYxYsr+KBhB8dAu9A0~*{-=>a!PaeL4P<@i!mhU2R+>y`>#bWX-BoNH5|C(qp{ zv})-EY4l{%C`B4&LN|OUhpmaXY2HP=fhW|zFP|PVZlRq}h|uv3UHJz@y0JnSdEV(7kKYl{oey(qG-1r{NY|=06$(v2~k^c!j67z!vtzeS+_4 z=0k$V4Uj&=?IE)Rjn|u_sho@3%MED9Xw3Y1I`bS_ zocLbwOjd3{*vsw9a7$U8)c$dD@=Op%0;wwzba8~3_$}5QzC0hCjA378@G^FsU=z(; z1Uv8O!f1L`1%nr~n-QFMLw3@o7fltI>|C1Vuyw&SLk|y9Y{bY;J!A-Hh0p_S=+cB$ z-8DBkHLVmAl-US&f0O-uSoH|bM->M`{88AtkrOc%Yx3WG^hlDK0QT>t&oSZu4Jk|k z-E)$Qq;7r5zHbxl4@h0>V`LkT8j3|EZWcQ=XU~SPoOH$4uT~ZfMAle7vj9LR_Ye>M zKN0YE1iY#GlEoawN(|j_PRtz76ZNXPj`A5?i)Ckh-QC^3t|4D5>>ROFjklXS~AwX|7rJH50Tj)Zwp7gV)?`fLSWJ`dJ$Jo(x+o&BC>G*1h z8zV1jV*Nwb!A$Fbh4pP~Xr-aXFIGdrkXxP>j5&8&Lqa!YV_P#8Gz4RYXCmW?V3^Wj zQNc9MKaaN$S(TQq!L6R7fn~?ZnT5#tDE}(_u%z&(Bo(eI9E+4Q{>)$sp!?_#0Y1mhF%#o- z!wd)OR4ay?cFEBwJFsZjm?|uXrHgYECUsuy>^$lfB8H=pDkn`c?J0Ji4LJAW_Um1pRR~VlIhIE^0=jiAfP2wshWl)B^r0YRBg*j?B?ZFcC}Be-G}?m5%sxzG*#D@ z^xP_6ty_ITtQ*07=e+t{J`UrkYzk^RSPO&?mhm?EtD#$=Y8oE=;adHVYYoE<<{vk+fPdap zJL=N^#AO)un152n0xE1kVb3ZK47L#tmf7Tp8=i=znsosyJm@&+0V!-+Gh%XU0r@Kd z0{aC2O_~g>6+vpqT*4-OJovijpp&O#T85h@{U3ZYOjQIuo%!i|J_kjNdjuWf2Qobx zT3m{i4>8JZiJ6UuGSN}sx5nF#VlSN*$nl6+eEd4z=;GstdhExZjnICs80i^^Fnh40 z>3D2PVt6Nj!d{7zkrS&(>N1jIRzJbT#g^aIIn#(sNG7_DG_amjS+jEC=B4C?n~STx zw+HU@-X6RoNR?x9XD~$s z@SdCe#HJCsVf5aGPkTQVq!X`5&%7cvyej%%7XxS0Hy0GudwNDx?{nwHa3p;PBC>xj zzD*^Ib?vh+S(dEN%<0(lTRTnVHWSqIk zEdBo(F>WTwFHVf<6!iHqpN1H7@wCMF4K$LvDFlbGl5HcK$?x*^Ykjj(Y<juLa#aGIlzeP9!9w+TGIAwzJ z6x7g3F{^WXbovE){p{fM`|pwK=Le?|0axUZ>bSaih_Wg?SVV@W-=B;~*o+XIOrvgO z;xOqr2wyBc>znjpKfmu1VT{2NQ|IKFwmuv)2P3YzFi&=X;JDU#m8r3Z0}&jfg3~v^ zdRU*yGTW5xIZLyIQ3U@taH(tSWC-LU)NYFZ&%~GXVE!I}qQ`y)9!CQ+O7Zwy>=-#e zEsPl`$Q|Z+@b3!*ku;KGz=><>77(7Ow6R}=o}=cEV#iTT<-MLK{&s9Uie15RgAiNz zKPJ(KiMNsNx8Q$FJpP<|(5)17h^CyK;>9KooH%tc$cJ!Zth&eFpBA7hwhL2btSTZn zr8h09B3Ak=5_-H^DUF!s&M+lp5-i#5uO_CVr&HU4p69Qk>tT_=<1A%{a+mzSBlXf_ z@0MZ>qEti#g5?@EWrXMbyqVPeDy=!f=C8C3sb&oZDhU>VLk;3oFzm#_F0Lp~%*3dj z9%RQLj^Ls~Cr{=Yg}tP(GyG49(~+(il@|H`MdJK#;-B5<#m!xcs5p5$!ty*V=_-Na zO}j2-T{DV3?T5W!QN@nN%WOeYIcFQGcrgK`017)x4?|FFYFEWprx?uk=!;^&L1f^P z)BC~*Gp(3$a4$U-fH$Yri#S$<2lHu&F1+(NcH|Z0ly}7EXYtwnnawLclFa0 z=krsUuU@3_R1sbL+$FzyI$AoV!kkYUh0;|%Ss8$W?L>eIdrmRLaB8r^;+z*g7{(d5 z7>F{`z8xkmI4X%IK&S$zGn3s&Bno5=yZX8PG&Ni9g)b*fVN%ICCwA}_(%3lfn~4sg z30^0x%kOm^!lg-t3etEwZIWw6r|*RNQ!Wb@%OkqGv$z z49K2=gf(Sx-?vn+S*lmKWdF^OWa*GC9ipWJqJqEasNzz@1%6-w0!8a`cA6SEmN!q71b*txuWBKMbBDAkGSKIR52`948!cq zUO~)@US}V=ikHVF9Qug8GQypJ)H1N<8W8E7D)is!#B9EEL@e~<{zb#55$X7A^6}Ty zcjV*VBU97|F8A{AJH{3Fb&a)K-Kl$lPwRdh6zfjmmI`ZO&YZ!VJ5%m52zkfB?7CX+wTsVvX>H_7;>efop1ylpeEG|;(5%@7?z4uNR@M|G#uMYP z&@A=dFKJ&ZX(c2?=qp~-eIPsvm2K&-dHMOFplJLCjJBP)ZUa7eE&ahP6m#|>>mRP@A zs@{DkAnrRQ^`Ba9O;{e5dRJyv_uRFJrNdI`uv|Kvu%}AP6L#!xTRM?)7A)IWUQS-V zQ!Y9NQ^jRB_N7W|*Ne(-biu^dNla{=#KhKVORzApb(D}DJ_V_sBdamF=Sa#`vvOH> zwPI(|jsggrI|@>r1G$&PxLDaCIXXe-Xt)cach}uv(KRHwhGf@JVmLAUpk*i-lUs&T zuF_?s<3rR|l1^(us<9{ck~kt3H$WQzhP`WyT^*AK_D7k9$md~r}x8! z)#*D6_fG%xWwcxw5uD1f?F#_$Dv|8Xo(;RH4qN&F9ut7>ofup003mqF-yQhoz)J5s zgEs~fV<~Ubce>y2UUf?tZ@e9#)WT|LN}Rh>rDdpDY4fUKH7J(u5?$##g^qY+<;**$ zBv&0FgomJ$$}YLGOR|(F%*$md()CztDZ#5>?p?0GVMKd4oNtf6HNG;i$|lAo2Moj8 zR|jOrE(}?)e73x8f6KmnIJbQhCu-U5TWZ!UHIl_ATW}<^kKj4pE_|zys+Y3kcu`X_ z@crrUO@DXx)+`JitM3_p>X3>~$wj9UW(@L)*_E=J)vM;)&O4Vr`Ko#<7mnw8(B4CK zfVQOCy2-OS)rrlavZd`|bK7dmcSE;AMN414@2p*O)+S9jm@FbXvD5T`?8J`Eoz{IS zFfpU-arUQOlH*|F_=+QW?oRXFriazOq(`pXe~-O4CXY=@hbE;uPO1*d)xpI0dR1-G zAo~XI2JSV=Bd<$CuS>pwRCQLaI-5APUf-1DDUh||>&&lcw;Q|lqSZAvY;3C zrV**6PcG@h-b*LQ8~_S~(sEd%RzTB`Y+9{W_l#~Zh3)|gmWQ$Tba`yOx;{B5_6|!8 z*j_mzSC4F9RBESSg-vzziF=Mp9Y@jqQ_cN2bFF#sUhk*upZDMEzZRo&s^?CxSlcJn_Q`;re%aG6(mUnyylelaePtB-3XG0WS>S%)txB9& z_ckQQ-s?dJ#Q-sk0gk|TOwfAlN>$fl!)Yxnci;Q!&981SX7>cb34G_7<>BQR`gVE6 zO7AxpmKU%aV&zgYwmN)={qQg}n!UHa3Lb)d#?x}`X%xJyJRdO%{mq3B`))1X*)8rp zA=RFgYfl2nQ*+2RPt(;grEf)1iMLlv+Phu1k*mBcK z@y&|l=&fU`(c6oXV^DSsVjkNFvf182`pF{#9_M<@lU!6os)?!gFen;}OB3b?MHTOM zeY5MGo*O;bN>c1ydH$UP$x*4O1#gx@${|Mau`;90L$`NjKy2JExev(h1NYs-Ywlsm zJu16L@4Jt!xsOTi6SDgR_H7U*hqX4Qib|F*;%IAZVR2|~Xs@k}6eU^=zWA?ekDFiL zseiE3@Or=b!Cn@y&B^Z~k^@;R)ahcp)f;Z=EF!+e)XDr-ct+P@+<8yG!QinOOt2%>H!%7!8@9oHuz znoZBr<_i#)A14V^ZXcnW{0|+Tn0N9ng67EkH#$-EsE#7yU!U*aO;?c6WPPfO2&LH5 zMmmUs^ykUqUh?dLF)v0AdWN*m2-QH-vQ{U*w4u2LCS!@RBfc z7=@Y-Jr0DP8bUYxssEVa9zh%?;6(yR=i^@`fNl>mlS482z!e_4PP7*d5Zo35 zngD)d*^eq=A$*9B0S`A5-3g*W?`PQ*W4}s&DaLx0{!)zdD*dIHYEgS{^-&QXFfF3? zPBBGd=1wt=tMr#*e4_SFF(sn*PBFFO6YdmKCT8xuVmFG+Huu(aQjAB`-uU+xtPb7w z3jmvrAU_G1^d$6m0}I3UiOlY;cZ%s3wf6&NTKrviifI?McZw+&wReiC5w-UNd)ZY> zsA(Zy1{|D5>U*rG) literal 0 HcmV?d00001 diff --git a/mochi_preview/__pycache__/t2v_synth_mochi.cpython-312.pyc b/mochi_preview/__pycache__/t2v_synth_mochi.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b28b7e858f447c64926bb0f6758ff4d8d667955 GIT binary patch literal 18397 zcmb_^dr(~0ndiOze$fpy(7b5`$g~7VNCL^o)*D$b@I%;^89SX;-wQPA2fnwF1a2H9 z<5@ShGeR?y5t=w_yh)W&S+iEEvSnwtCXur{6<6*2(Je^mT+0=z)~Q6b|DZ@&JN{?) z_nq6f8&F&EBzu5!&*MAqbH4Mv4*#>L$i~4f|E_ZQ)*g=gUlh@tB^$VVSjTafIe`=S zFgK`v`9Yq=b%Q#5bzyzPFlgW@Paif$OoOI~dC<(_hOi}K9kjAAAGQtJSvkj`gMEu4 z&Oui;uQ*aNSdtBwMoNM@o{P8#%d#ou*%B3l6)0&8dm@#Cm658!Dwb~wS4V0FYgpJE zcI0ZZgli*pgLRSm!TLzUU<1pyBF#J4$ilYphDg(3Q>1yYnZ@nlmdM7zjgf~2A7XJw zxHaM%^dVd%IK!JFn+G?uuq)gaX&-E7Fva1HNXK9Y3zvkqM79oYMYwdZQ*aM%({Yb+ zLfI>vQ2s7vk=)>Rp+aPTGy>)kHtr1=cT;ZM zk52Hu=Gz3?@8cEoSU`%00%4y{F&qj7vYEeEa55=U2DBYKfg3f`i-e@o;h*BTCuwozR8i<_{qoGMrqHzy&9SFyQXKwSD z2*itr#kfB>=0~Mx{4#)$W*i8ias&+V7ljDE_9vv+uq4Xzfq*n=J12&QN8<99fv(#e zY6+^dulm_d#@($){4ytUbWm|f`9Zy)dxaY`2u8v13O{I^(ECh^W1!2Q>)hnnvl!-> zD0u@?Y&kBkowW5T?L13}Rn*T&$*(7gj)Hk%i2zceiOun@%}G+=);5D{DDxa8~c z_9N(x4Xr91k9lc*J6M?@&4K0&2G2?-{XgRPbk*h7-{NdF+3zw zC%j)g(ckm@{-e(yj75efo<90~a9oa~jwnVeYl-cBvEWE3JAuzfC_;^#4T7y_5Mgi)$Y81skpOe-pM9Rm)}r+pWg`+NF#vlV=cPjI8T z>B}bKli(A&e3^tkVG#7!4DV*!o-!p&Cpn)n@C>FZw-MHE3)&iEnDw!^H@hzeMnrEY zIyN5n#>MmSV!0inZVW0xZY&_skzHL(OTq?Y$9QdlxHKLO2I4sI-qzExSQvZoa5ylG z#)J^1qpou!VtfRfoElZ3vANM7w)pbe>aa`3rKmqHM&+0!Yu$K2jcNgmiWOK}StM}Y z9~Ps-@eyxpC>lrCXc~~}&8E^pDbSHgcV8YRg1zTGEqW&{*>l)2;4>03WfP^?@{K6A z{A4K3wIfl=v++P2U8U*rM*{L0U@v#y{!@P1l8+zn?f0MP?SE|Gh~&jK@R_AXd=!(q zlF~**By1{9@{#ute8EDcgKs=vzj1OK5ngsuF3#pcK16vo38 zcKH1-jt9co9JAjq#Dadm)P!PE2bJYm6>LCK4W*Jt(;yRyd4;>9GZ-zucksqO{+_AA z=vBc9NZ}_HLqrTj6}=o6=uV>z7F>Poi~wT?xdJB&2bmY}^EbfY^fP3ji+Bo0vOua1 z2N=2QouFffUw4fAvLi0&-z7u`m*5Mb0KuT;2X*IJZBltc_h3lYrjgp@Ikjb!ps}rU z2~Mkf61iyq!7wOI=w|Z6r=w9^!GSPtOm7&NAR3oDyid#G(0G`g3_-*lOy~VjXuvB6 zB4c3@hnr!jwb*IW+ZvBa!I6#uc4%=d|LQ=S_n?2W%{xHe5%Qk(PrlIR9m3TWSL=9+ zE!?WRZ*?7ivI;s9{kLhPQW<%4!-0x1082%1^I%MhE*|sgr7p_tCT}M^nf3*Wsz=E^ zo}U7Lb_ygA@!SaleGT3PZl#)Y6#pGRvtzbwHl8YOO52++9Lm^=UJJ~W&BSL1QqIP- zZNr6w%XY`b?&<)V*1h99dGTtvU9HUyFF=pOUlr~PC!t*$mwgxAKT`? zM_~1JpWxye95-q}<~kQAoWdOLCA#&P8)bZq7E`|u&b`LJn>}_3ecYjy9Cd1;{8!N7 z5b6ZoHT}C*HFb(xRd;b7QVWf`wNU=6-LpKGpe=0Z$Aa5I1P3C1AwDrCDh^2$#)G03 zJ3T%$1PYVo4frfdiJT>1{o2*1xFSU35}lzDx%W}4C8x85F+r)6aY;B70qH?lcsM{; zp{!U1P%3CTBq*hzRDv2Vl;b#Avf>(-MSn~R4f|PtWuHwk96ovI38@IfQ*4KyK5@c- z^2q+D4oNhRijkcK#mtV&_H7a^wcaKKk z*R9HpnaZYHRrPPxU8&1{?wJi$wiO#!y@`SCOjqyBRM%%pYce$rnabLGHO?Z_h2DGh zoYD5`v2Pw*GE}Dx)r*GOjJ^263AT|Hz8}OQC>M5X=C(CDz|rPp+Zvo7VQ&yzXAWT~ z&OAMv;5Cf51~UV@tGGR$-IiRz?GJJcb%L!UA}k9zr2}rw9SiD)>ICDeZCc=fTr+8u z^OQcG-P&k1cNnyKfR}Znu6#@jXK@8^+NG_BID}+=qXW3a7tlF5Tj0yc{dA^($(&lH>qow4( zS{?{!W5NWY+V-X%Xu%%%S8N=1O&tixqTmH3I}(qN$vs=P42R+)6vz2ybR<}mrW4&;+fVG;TIkB`2WK}PK>_0J&W2>ATgA*0K(O6t#1hNtKXbOnJ3dB? zLhAwVCU0xMcQeM(=F@skgQsHZ*wulE;1Z&-XnS@deQYJEi3hk4<^puTItZ^UzBn#M zgW_7|q1_uC_ah6eg3LOGE^YR9uIWs0JTeYCAaqvDRgz^bWOVMUv_ag1+6X4We>wm` z3Tq7P3S;9{ekBwckF0M&Yk-weyUwT@M&As8pEhEu#>LzJ7Y=NOo(LTyyW5CIw*+KA z&^n2H)(}KRYX+@U6z-o=QF!%y4n>doQuxD)9_s;;dpsib;Rq^53~2i{#f+7LR3=7O z6+1!FBa~~xL;_UNqdCP8ii;62KLf~-o~CRAu}Vs@wiwz(DeeNjP;B`TDDDE?zn?>@ z!1)I2B;qWt-qcbQZI9zm{u~}CyN#TqDEZ=RUH@=$w(9S{HuvJq^4650HDzzj6j#kQ zUUFPG`YDLGuIa9ss@W&!_}Rnr&5QOep#D5Hm(O21KX>4o_}-bTXJ*bXczRNXvRl;+ zZ?#`(&sZE69n+52opa^)^tu|;Z%swt9RBOjijgx`u2?un!;&3DWYgUK1^dQ%-+LWb zJ3c5`*tGkDKzh^OoA$ja!`_sA@2!&3%jK8KZ&`|{u2gx;eCLmP-#`At<4Zew(>r<> zEr&lc-PP+#O&5-Vd8^sHRNaxT?pUbas$SX))jLxL4|1xTGaI__d&g{Wm{vH0(WG9# zMmaPb2`EN+Bs3IXbNv!(roX$*@Q8UzunR-`3j8F}k%r4~5@ZPwt}bvByw6Fh3794% z;~_K*J_qO~kGq0W3^RNdwc2_(UYPk4tQ z`k^__Vjsp>?Z9#ar^N7)NH8W7YoXDScGd+5U9p)Y2^s&9AR6JR$E#$f3!F^vT z`TgnI_Up!UZCBdfoicQ*{0ty-PLsvmJ&UAQvG?+#R}rFD5mGbJFl0({Lp%hqjzObf z6!jqEOpunF1T#dlX2AmAqT@vCD0P0F6Rg*4?`GL_!Tt)T5fFp+BzK-4bYL$#l+uB1 z&zvMtBNIn50W_nh73bN22*x20mw{K}BZ^x93lJO;gJ;HKK!z+aCyu2PMC1%RNs|VS z`wsuA4Ln+|N7MEM_iYejS|4~WPRuNGTH2H$?TAg=OcFvPZKLv=b1q)?6=?BbwF z0jWzmr70s?(Po(|#HWpZMNTXM5je!)n03J-Sn*>M?1BS-^aDMQ;1Y_jmAsoJ8K+Fp zH{nMk_IQrf<*|g~-2nj_gb11T0!9rY*X^R=1X~){2?Q5#tcVfN?N^o_nht9Fu8u5$%zPZ&5WTj$j)9Vek4 zbEBKIQ2rZl*TM;%R;TuzvaZ_wBs5&-5>}yt*gXxt&>Ug)3Y;6Q1olR5Km5tIJb!Up zQ~`5f;kvYN{;TaDo}1vBxK2*y&q06t9L@&M&G2vh(|P@K+&R8LPNHzm8$VIgeU<%6 zQ4kj=gxSHkA6SbKh@csZ_RABI2&8+*2!Y++ioVoM*>k?^0H!n@eFiFQbe)5c;>ACbSNMz7KqE$P_f#u z2<}QyQ%d4c2Fh`9%um%baAz(_hSXeC?C~!D5OkxMo`~W?MCD^twSr=kVq>Bo97a5# zqJ5@d&EtN>!lWsf+C7BCA$jbAB$7I4i~CtepvzQj=ksCLSWNa0Nh+ff66M^*L32%) z21)Vn(#iV{e;gZtO<@V`YXOsV$y z@3_w|aH;B}Kkk@6H@kCAcysTMJO0QdO@Hky;qwAlUk!-Nai~O9;Xf(lR0hJN3e1!Z zn6-{kc7rnERY!m}C5g|l(teg9O;9P4z9>e}fY<<5`Alex>HicLTMH&;!{PsA0<_iAqXb;U#B96#;*QAnp{Qax&ElQiieXfS z-rGPeC z9w|;$Aczg;KO>4`iZw@ONdH7&R;YWxc*}mjOvf)z$?a2Dlg;a7%>&l)zo3}>KIoo5 zUf}LIIjeKpoVHXhTB>h3N?v~~$=|A|xg5R}o_lzqqCIJN!+xijbGnk>_(f5}A9T8s z(t9SZ+4s(_w|Bj>=j}a9O}o=gyEA2tnd;h1gEv#}g^q}Ry-nN7Owe`yPUhZbjJ9Zm2-OUY(+yUSMx85T9zFpOOD30qwxp6 zRNLN#4f_@y`#&>s&W(30)S1a|zmW0NCJ()FB;zQ)cx3v>#lGpjHy+P4G`)5H%K4>+ zUFn8hzijBaWS=q2h)UzvW}LS?)eE#c;VpWOW*RrXQ~h@JJ9Tf@T^nE8av;6sz(V7} zncnXlTV9>>-pf~CUTEAy>38g$clT#z&RKJ>f-A4OeDu=MH-))93-#L;%C;x1w<_!2 za$IrDJvVQ=s%uReEzGhZrC?}JWtvHkgg9zcJVqf>?AA$?xmyz1mw5?1h3 zM4-WeS9SMKRy`)hnMnrdJKT6+-ANuof(ckRWY0enB2KmyzzmQY#Kyw{8AF^FfP#v&Qv;wukwlm z!a+Zz>A^GTm7))hjZ3Gf=*#e~>J-xigQPE!UMlrcEcua(%aJZqT+(FM=FsRdD7hqQ zGvymd1Ww|s07whE#HH_1vI&ro{3<50!iow?2}6Y%$w{g>1$0L|0F@a@F{OSgWK`v* zHq|=AtW6tTn+ksb)L_|tI#7zDwSGCcq9owpfoNOv^t21M_RX}&Cx<3FrljhDfV}AVvZgfjgs@ON*IPac6 zl5%z}TDq4VMX8eJMTc*mhexI$&+*Ge4}Z}3f%}8Qse{j?4m_K(pCWsZ=lG1n_bIe_ z?&?gLmlS%QntN7bjVG0}IWEQkKSy!@cZkT5|c)F5i4$vGY)>z4xZ8H)ZKfxq36M)+JX* z+SPHrbFuqqYTL1!u45_7v6SoBvb}z;(Q_^P z`MaF11(q$15X`#UQYCFkQ>Mh7G%Z7hYoE4XbWS^G;+J2#^wOfGDPwWcRh_LMQ{0@| zxECVWBIj?N4QXe?JEpfy-?zM1dbRXg?|Xe$`xbmXP?K71_iUiJl0z56)8R!+4OS*O zH1p(iXwgzhf?}Bvpjk(A{~ z%5~(H%RSpYzj4vknX+`MQnV~iU^hG6s_cfpfB8{iz49qk{&wL|fxWG1(C^cpJ zQ%coUIMf7rQ?{ikGq7S#UWJCT4){JtCnU_kQTl%ACFH3br=Fc{CWeRNJ$4o>Rm`8& zoBJ#(#Y{&;(S--__bmR7NEfL10HuObUM zU+-S1>Pb4Wo=cY6w54|TYm1h)jH@nXsav}~sGp-*W9xGhfwdG^-fBaeXX)9(<~-Uh z30c!YtfFr9Dkhkv@ha(1?q;FpRj1W2^KUegb>f%NPgvywx7HeiQIsZ?&l|nOiZYYY z0qMUX!Dm)%up5ThApljJ^fM$W2GS*}nD-judBhDnokjvMsb)rso&afTx zG8t*lP{IKSu9djKE8=jasJiGnxj^8nIl`X&(j%K!$q=z#QZ@R|d zj1>QDO;2I2U*%q{f3;!t$`&%eC8M;pbVlAFcNEpFcZE#F3e;iN$`z7(JrJI5q;+n( z;OfIzID5vY^ibO8j2`lZ?{7!Y!*Bsodm1tD5ib~72pB~$v5*-dGlc{Ti(7JWD~nsT zcpK*B+BP~_}8RXpX2=P34sGvP`UUvq0_8B--f*;MH&iD+4VZM9IskSImD z;)H9+Sa_d+rp@tmg{c2)-|)DoWw88#UIi_LNOq+w2;sMzRatg9o)#T0>@P(>Zy+)$lu+KdSccD8OSeb4hxP(n@ z`Pv38g>+FRVAwb1VnYhov2JyQC|qVO}P{0qr{M{b9r-U2&tzH zTwc$kVJz0CnyK1^1u`VJ069-J)#e^0LH0VAs1+Ko2604)q^@75smU!?BJ$Fz_(&r@~zb1AeW>P8EkFQoQp&o7ND~<{0ON_ z9cbx|{q&?p6ns2mb{S2r7|(?SXj@Dp>_wAeBTdT?NxEgcWMFS-7|&`|Es=sl!-|_g zLQz=8l8qGP)#8|}IFRF)q0t}%{>l2)F)w>CF%p1Hl&TK#O7ElRij!&m^MzF7gIxUn zVz=o!P^A$Y;$%&sf zXp3T!G>Bw{lqwK{fUjUW7g_qRlxoUGe!@Vac*?XIB*&6VOUMu6k2yL^)LptG@16#| z^Bw(VZrc1+DKkJ4AGXe|+Jz6K(_mDYzc~xgW%=h?q~iwGE7-3DYX0 zrqzVO*nh$&&m*X6BMaGQ2O;2=<|+4I3DTk(`Nt4etm@V!ogPyOl3+}+k#Z2WEeK1u z2+FSV8U83Zc|k43LWe_8TqrltwT{=2U1n%h4&p03$>-2`vO zlBsK%9eC^1l~Z$p`SNt#p6i?8&Dfw7Dcyt0%W5z0xwL0)aH*v$-O_c_-Ia25rQBW1 z?z%ba{HYIY>GlI@@4>YD;Ehu$cVEiU_n939Qzegn!SqNUI(!FO^gnAc9_TPnmOdfH zf_NKA!o!p5Ll0OO8t^zTNT%rVK$vmy+G`4$Xz30-EY))q{ug+PS)EhaSC2|9|Ui8q7R$8u+von<{(cU3&_Gb&`H8lp@-r|y6cDb+w${%p;q z4B`^$sUD*cwye@E%1 zGL+GHqKo%W=b|c%;)XTkSPZXZS0`H3*C+nCqJ_@oLL-Vjn+xT%Ed2+noo+7a57dMW z#*RZ1{;|^mN#R4xP75+lhMN@wUMxT{slFuhv!h_nK9p4`K+1XPbE=9?g>(;|?1lGr zO{`MzyKkQ=ZPDWlOjZ;=8=Z6#&H@L1ae@03f-}guW)96(TspC6X}o2zU9?VHRpHv~ z;Y*1{OEctROO~p%rE0c$_Ux6mMax5gH;1I@(EAIsjSe=J^xzdp@Z6ANm_7 z7WO@v+WS<>ev;&+18B8?r1Mwys!xmEGyMz2O>>Fs zf%HQ=Zx-)FeMMjAmrFXYx7{d7Z#$6cJb0t~-Pt2C-Uxiu_@Q{?*;HfSf~5~KsFgCV9^$C#4EZ&gvhrE}Qg5ohGwtrY zzCZ2mN&=PCr#9?J*X_6-NbT-RclXV-CM};Rnl3%IkDr-}1d{;For?N2*5E;-xMPVBb!Pjl)2 zu42ngJ##yjntIYrJ$O1fvn%bWy`$GP+cMtG?3*-UebN>$@U?FrKW|uS-J5R3i%ZX} z1-fp4_-vV$t?Zk$r^=esmX`lxDYo|W_nNti#<|*s@{XkKR!R9kh_m}=6AKk>sqK$_ zlvpZxKDj?rwqdR_UDlH9&6HQqer@%O zmv+u}{(VpKP^P-&t&S@l^Cb(_o0G>OD=#U9bZp>KU%I#vJ#kblIqFbyHt-9F7h3+( z7x;UvTy5JzO*;$?lI9uDuk2+RY~~F!oij!%P;{|>x_@@(JfG}eaBQ02nRaZ)QH4C* zGHqeF0v68IpO&`F1>PBcd-(ey2!CME^XB)8t`=QC z_x{U2d>MAp8CONp$z*lDB8eCTsgt~I@Dx4il<;gcYmAbn40e$$P>9*JJM)w=-RhDD zh)gNU+F%4ldUuOk??-@_k;AE|P>NG|@>s@c1U`R-ThToNi*MZ2o)t5NEL>IXij_h( zuDgT*kF8;UvY4z>I+XV8LQI9s@eFWvH6bfxX$Rg ztLIED_gIjPku~buBV9$4iY_Kgf24qw8S+9yCQC%fLM=8ihPz!Nnp6TY#^EVk08iQ2 z%_w)U$7ZVCj1kPi*o5>20kZihVVNBQUYwn7|yX{xO9tl~6jV&6N$S(>oZo zhgs`VLzr07?GdM-fRkBk%!EF(XFO!&L;n_o(InDaR5K~WRpQ0Yu5fiCRGOxN@`$fi zA4fA`IZ-X_t)5hgohRYHws5K{#Z0#LQeZ;DzmMUV6}Gw`cTj<5@;1Qx!tw|@4(@J` zv<-0hQ&E-T z+@Be|y!Q?VZ{;vw%5Th=9j_+7nOM;yylit_Y@2RdF(QmJgIMi!?TVRV7S3dU_4#i; zf7eQ}O(lHEM?3Cv_^h;UXS!!akMOd+_^uJ(&)SX}c+c-S^4Y}wKcq1JxBvhE literal 0 HcmV?d00001 diff --git a/mochi_preview/__pycache__/utils.cpython-311.pyc b/mochi_preview/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d3564b3f52bf5a00c69aa88e8513f44177bd12f GIT binary patch literal 2837 zcma)8OKcNY6n!(EvB$=7AU_j6jtS6U6P$#UN?0f=Dyku&rY)!yQLCxznWP4P!kaN4 zV^>+Qa3vO0s!AlNltq;Wg$1g*$f66DEIrXku|{ImcEK(vtXTElH}PPDp?%{y?|$#yGlknOR=3z?is z(_E8uB;bN~vxf=KkTQ{-X@#6854dE6vqT+sY7aO|(A4aVh7!UrVdWtvKTxsI%zwV3SkIq|CKi2`t1XOCDgB$p&DsTV|9LOum%L5krp#aKWc)2VE>Sx37s!K^j{ zFi)OJ9koDcK>&cCAq|+)KusE`cXqFL4y|<#Eep$v(K&8*j@P7d>W4ioE<;Jo3Lk=I ztyU)>h@5b7Z6Vz$)tnX)?MU+J1wECO@tV=?xQx>+&zyat zh87HCv14o-3rtzlQ@U2#VPB=DAq*JpApk^L@6Lsu`@7crqig+9qkq`!A2xcAn7v1; zzJD+X53C0cuLTY-CmvmV>@@zLq;%S1|u7U_nzTw z{2@_2b!T6FG`{-T?}0~wJ69IsKYqX5ZS)P-`i38$tR|lNchtIK>s`mzx{f_kR!^>8 zGrAIHSHke0HvOk-?6)Dl5@?4xkakKVkT=6H^scn02;fHXWxs8RHn%nH#3@KZ@3Gr6 z%EXqr=7a}vL*&N6kh7uZ+#-w?>V{1bmn9a8x=L+F`M?iUXP%3=m^0}YgKwqjq{CiVr|kxSbl4xNp1M6*-?;~>Gxg9QROhPjyqOjp#Xf+Yt!d%aWo*56 z&FYn#)Dv_FL(320HzTJDMaX=V4uVw5k{H@&D-yl33^zihknZ-sORud(yI`jl1c1m4 z4Vs}5BQ$1)#!#nvcd$BTDTmDgdxzNINu@lI{VhXGvMqs zA=s>)0|37ScbmaOMlfmyqt(;(zR+UNuLDNkXf;uf#vVq^XnZ|7u@;?J)&IQlWS<%ChB?K`@bTWAn4`8pa(TamoYfxfBT#x0 z{DU87idlheJLI}r9FFb+qth56OyC148yU+pSDGcQgA4i30^@qOUc0Ld+4E^Wf? z9U?MqkI&ioi?__$l#m4cF*yJgBSl-;zLZYbmY<;Gumo?hHVLrd;yA8Ow#~C=!@ZY7 P8tw5p8-MYZSw`buKiX4C literal 0 HcmV?d00001 diff --git a/mochi_preview/__pycache__/utils.cpython-312.pyc b/mochi_preview/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4017fe18fbe017806c32a24ae8809dc47a7926f7 GIT binary patch literal 2567 zcma)8O>7fa5Pth-uZ`nCesB_kV?t=yg@zD{)LJMKZKObGO{*Xf5!Kf9Zc+!w4sX|h zja_lzU?fn(r3pe6?Wu`!DC#Lal_OFwF{ZV(&8=LxwaEdAQ)k}VmT^=|u^;IT+_qH_r{MR+n!bp966g;m}WnBp>=BEG;) zB<37Byl!Cx_KUwb7($e%bFHX#w|_D`k4zuwY>zUOb@V|q|zBEhXD8NUTLLrJ7Mm(`-BGHJ)*~{?O8~}EcY{)GO?j>>I>bgJpOzx|+ zwwGH6i>-sp;<8?9J+>|%qdpLPEgPuLp!Nc!MyLR+F(e>{G>hSDqSflY7L+_^xoaXi zH;aIpNauAy7gwu+tLd>|Zg=O7CUw}>?i?eZ4-;|;wDNpj$Z@)Kd5jP$=U|-$J>8WB zkCSko7aP2CZ|7Bx2d!?ybR#@^wEnQANegnvGHe_)72EwO zRJ@V2C}Me`1IzX_Dh`cIGZP=+;(+3^1v5QIF=V#uYBZ5HOzLNw=@gxdTDBCoj5$-0 zoYL@GDJz=Lus`Td6yq%0J@4FTVs=7EPSNaOm8O|i)H1WXoIYqc!eTJ9z@T(GyOug0 z_T29&cdNy2wbXsM)cJ0~`&@2W_Yai)hl>6~%j1tOKK7LS6VK!`mB60G?7eI`Fjx!> z{t-}aj~B)k_Em<%YoGn>U%0XqzW4QV`_tWrAD=9YZ}@htw+)rsju+dGKhoDuu3am& zjsNaD^^ZilRHq=pg*w%GZ7T>zZ<;w7QITGCJD|2{rC4KIVkTIX4>6-Tu5so--icPb z7RDs(>Q$H_)CDSvU`wgAWl$&NUWh2QSPm3{rG@STDfZIDdukb=Si0H%|8=rHl{!Jr z3;=_v)Ez90-I=I#^}^$HCD>mWyL*I?sjckYnMi7IWk!XCt6TWwhzi^ zT8pL9Fp?sFRK{l?(G%GW2*dAQ5 z8De`Z8jTrWMq`(4*+}ZFSZ%kFh|ZgaZelo~NP+5~7vnnnT)web9U1IQ^E@z^8-YFL zz`S$3NU9*18{E~UHucUtb zTWe;c>RCNFsQZ+dx8*On}JE^v-GN=6{$N8m}} zdqU_@WLOJy7qY#`m_7U_)Nluu97Tpd69a>Ha8lQ%ioS&kUOd< s44?V|cra!9y46?m2?y;cJq9A!ekNl0Ti`hEPqO2sYcGceku!t;0x$Uoxc~qF literal 0 HcmV?d00001 diff --git a/mochi_preview/dit/joint_model/__init__.py b/mochi_preview/dit/joint_model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mochi_preview/dit/joint_model/__pycache__/__init__.cpython-311.pyc b/mochi_preview/dit/joint_model/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4270e08b11257c258f32f355a0bca68d3bd23fe8 GIT binary patch literal 174 zcmZ3^%ge<81PM>X(?RrO5CH>>P{wCAAY(d13PUi1CZpdX(?RrO5P=Rpvj9b=GgLBYGWxA#C}INgK7-W!()Y89arBJI%}>tA zj4vokEz3+Tk4ecaiOI^(%qxk{%}+_qiHVQT%*!l^kJl@x{Ka7d6f4b1wJTx;>SqMv QVi4maGb1Bo5i^hl0K0V}761SM literal 0 HcmV?d00001 diff --git a/mochi_preview/dit/joint_model/__pycache__/asymm_models_joint.cpython-311.pyc b/mochi_preview/dit/joint_model/__pycache__/asymm_models_joint.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15bbc006877b610431cc770492989c2b0b209b17 GIT binary patch literal 26766 zcmd6Q3ve4}e%~(MZ-4{|fG-iGUW6XhTe2m~GG&R9EQyjWo*nb-Lm)0FLIOZ)0ZJkT zo;?|74sv)gr7{tG;q`Dcj>C&J4clHjo=iGXuGh>ZlbHf4CdO>5R%zm9JZ(DA*ls=7 zX42pPTVSyZlCphCr(G`o`@Q$uZ}vWOzR7WaNs-!Nss=v& zWgW-ewK74o^+Sqaj|hiqVKDMv|8(5*&+1Qsd!BI-Lw(i-_@PGJN&QwaLin zm8cLNn@WzNj6EexM#5L3Vlo>4gkD+5D78i6@o;(~OqbD#csx3qM)XQFnUVslWQ|87 zLQ1kx&ZS6tJSs|_OVMX;b`I_MlFoOarE0ckWNRFRQ~+LL|n4eIwW}}BSJWm6vC;=cq|>hF)cZ#lCiPr za4L~vygl>#W5WgP$dNn!gkryAeQ z(Pnj?I;<`?<`DYcGL3k%+zo!D1&Cma_+ z4f!%HS(8(VuslsIzyRTku}DfXa*oDhli@K;t1xaa`6n?&Fy%%g3{8a- z@kzNw-&86ZP6Iil#&G|r7>*_{#gfrccdh@p5`87S{;3J@H~70;TA?Q#r+rm6Gf6*I-$7WR>@rF^hpUMwiETwo2_}v!=YS)(7o3i?{Eu z

-mzj;Wc!JRR&$$faHH<%iqn2HO*i_xGEyB15uF2aUM0R7NqTBP?WS*|i{gH(%*cPLa%1G=fEO-rUpapE%N zof3%9rfVtFA&+E8MdM=)-DEcxH&egr-6RgR-dFa(BJqsk4g$m};Z&Gq z^iNKU2XKd(&Q&mXe{JyT0aCYs$-a1#`@EO4x^9{0%pX|-4=jPArM+ZnFId{kt?gOE zoUL|g>qX2}4s_+)ih(T%JIl4CAdB_v&KgRVb>%K(qJ^^EGyina-kIkg*n0~0o`=ny zWzX4MYr%6C=lzg;eXQ4cU7v9P6;9`LRcu^a54HBN)xYp+ZYm#H+W5oR^7#F0h1Q{B z>rkn62)VY8pO{-QaUOrxg}>_$1KoMUd(P#tLf~*QaJUpW%o_iDr+#kY{A~-dyfHtu z{M7xyg70+Ece>;|jn?gS`#?pnb9R6C=-jFKXzm+Jfe$wo90NtiK*=$Hz{gGAg^t@@ z3tbg6XLZimv-tZIZ-g6Fuo)tR@o7GNDj@Yq6zMhO&+2ADRW2*ENv0{YhO~k)8|j#M z8|4|5z&$vTVOyjq*u0Kw^9m)RlC?tsg;U*+yIiMGT%GCDQixe&#+X;ARjnnuf;QRS_IZL3Jt5fSyQq#!{wFn zsFkH_hSyThSyRTe`Y0h?VJp=w2zqS`YnIb#%VE?(r|V<=KHl#K$~$KLP-w8`ehu{g zHN5#*^Q>jo3hu&^v5x71@r^^`Zd4L!v>ERIgucg+Tcjo#mKKRX%Q)~bLXzovGMUzL?Kze3AN;CFUo{Z(*{`YQyGmZBclQL1=?nzV{jYy}#gp6OZ*7b@8h zws0xPI=DBvhaFps9otJC+l!V!);#Yo``Q)?l|2E};9qw)c_*2F zaml>Awb*y~!<$vaaiZ+&x;t@a0vK=W4<|k}{DHH$>G=KhPiKn0mrA~uvL_z4v@V>? zy_h#IwG>g}@&XA$+Z-nHGa<@WvH;P(#ZBkvwwJiN5= z_qP1*mgN_JwEf=p51T$b{j-)o>neI)DS2MW8iCAz;P{>+|KhvuMfWXdR-e5vzrO5l zo_{{)&CV7qo#ILMZf=Md=xS#CH;83506GdTK8i(qsv1hgl9KmYBpxIF9V>W{dI79@ zYP5!D=u4&f&+yvO6?lz!;f5ItEVDYty%!dZ2kCQF8@oTb9+Vm@0i zr<58EJ+P?P0)(X+^1<8Z72H?5KV35obC=CI>#S|ojv5Mv($uhOXFFymw>rMe*cwXO zGM4-rSW;V(^};dhoOPjX)fm+Kk#T5i%sSNkW5$s&W=st@+JUxK#d8^FLt9QwTh3M5 za;9qn3K~O4#+h+ttXSvRGv@a(&L1daJ?qZ6fwgopp9(ZZ+_m@&09K;uX zmn;{@;uDeduKglCuZ+lO(u>a#7$EQxfoBN>2pl2M2@vuy{3O0c2{a4D*9m-!K$yTU z6NnI?8;ie0;8g-6020vD4MI}@#G&ee!B*q{f-=UkgXx=$h*3gxGu^A@)Zl$Bx8bIx19@4%yf6l+|BPt7n2X%+X}Vwp(j*!FkA&3 z)no&XYO-aGpW7N5D+fA%;QF2`f3+A0l>(s}o~U>@xA&dI+lk!sMOU!o3TE|XSF<(& z_^xf;Dk!G0+(Xxg%=~dK(m$MtaiXJ<;94q zT4mm37VC2rmt?K3&KaJP@lZoPa)2LzA?mn{c$!&9hR@fqC^}O=ghpP(nO1RFbuLcO z3Ho~txN9HBb7|GQB|a!)5{$UJGRv^e#{E@1hAZP*1A94vIne+5ntawJkS>DLrt)0U zzMk8~r3}|~H}tP^*Lj|M6+J;;lvOomz6+d*4)b20;V&zg?R~tB4-`zzIA8ayC*xjy zY~IH(eV_~#Yt1pHS5twE=PTCKXBum|*BAxrv*!IZ@W5_051Mf|;72zwslt~kT}NPL z$#h3OU;CtDuLg&D8t&mqch_rXJc1cttK*)fp`RY~gx0Dnj?9YoGVY9LOb1+Sm2t6T z7@ZO#gOF~Uqsa)w1cKO$WV3FR_eFY#gV8;D5!koYoPtum)da!AmeG;}0- zJX4)&q)0SXXHcsl8S6WV5oJmZ`rbZ_CVR; zosZl)o#o4J@7%ZZ=AwJ^(k?a9S97%GmhSw9_qG>Xc4Vz(kN2+SjwRQ0+rD7W8b5Y6 z&wVR5xOBYW++B3;E;)BsIE!^V0}q4Gp>(-pW4WWh+}>4Q*IQb*vtl>(xE^x=pApDD zUvY6>-`&1DeYt)4P_cPisd*a`n`(`)mL5}^K-JZT>$hkKeeuN?4Jj=8Ik0;$uPcgH7U%b!X-}BR_i|(@}_t{$p492dq+w;z~x3As0F?S<3kh_`>6*{(Q#^OuzA32=w z48A=$ziU2{>&Of5C5n!{CCA<@|Io1^|60AIsNwO9Di$QZYawu%LIef z_0ZFvlPU1PLr)tP1_kfdrRyJtARTDJU;z-(b!7)YD4k8}R?zkNo@_d6Pmwt5k2dW* z+N=Lbui@w>^G`xNAf}e9s#>e^*y6Y8Lo|ue^pu!X7;YjNgglAy`Lqw;%w4dhml2T< z)iRK8NF_#<*iY11d4N#pfS3&#QyvI^<^;nAu}!{8)jcfpCVU_?=e1QS#7bZk51 zaFKXWnZErtTWy*Th?i0al^CVoI2i0Zx-B@g4f0dyOF}_tI#Ls25PFFqx^7$TB(KJq@y>|a+M(IL8h?U)>PGgAqRt`%?)0UrN@I;qSGmM-$7E^1_x?- zMoo1R?2d^MC`Jzkho+z~6=?|Mlwd3wg!UEv7NQgV!J*(`CI_x&ZefwsB$S_|DP`K+ z`pCS*zR|BZIJl!`Y=V)oG&-{C=xcia^~UWhHCt!-CEqH`-0+oaA#)w+uenr!>Pq1~ zQAhIC>7l}YQd_1cQz0mbjZC^C($tmoP=kxJ2!$MLuUQz9aI$54O*gH&$qm$&=^gRP zDWaCOubHkgy)xr-B}Mv7$7+fm0 z4-O3t4G)G|NuQC#3<-xNTbKm#Q*jDA!r`l!&ea?}1oqI0kP6?(R-z)wK^W5_D+@?q zFVc!sG7U47Aiq%y856|f(3+TqSc>>4_I^b|fs!74j!2xa_#KKms|tqEC=4UQP`eBw zHATq#IriGbAc72RqKvD+ZC|{Jw+`P5ix2?Q+ zTcvxq(N^IA9=V$(J`zlTfxww|Likm_n%59=h zGljeqYT-JAl~xM*xIi#(TzI46r&xe%YRk1Qc9uMwDs7a|PTBf}*-8h+I=S9Wl`aac z<9asMd&j!i2*&wlV^@V!K;J&2_mSN*_e{lz5FW(Qgs_>yHm=#f&{?rl7>__uNWsjPcx_ z0~L2am8gO0^HQ|-ycDiIFNJH*OJNt+)|I=E-}v6PLRWt&utRRGnX|dGQ*)WzHx^$j zxI-oDmWr23D|cBl#AXuv2JwB)<1-|lTsdlMjZYVmq&7MuLgq7E_mgtl}{xwBeZiZC_VxmJDft52+{5j`*-5d5tvV;1y_ z`i$j0{@3aGl}|^rgs-Zh8W!0Q0jWy~4GI{w5JDpokOsm4I8^Pd8%Zklo?8C`$N1JV@XWfMf)f z#T;RmX@*TC1NiTx_-j;zun6)n<){cQ(=~Re-quDcxXOHz92aRGhp&LWxs;MTRK5BT zjA37qGA&ct20rOcncy-lyh`&_M-+-09;UinDh>cr3?c67YWRSxUazUut5;b;KD!iI z-d^lK@-tJR^<1&_T&eXO>Bq}<_uSzMr?;NxAGUQYw&f2m_bexhyPyBrV4>|ovF$>s z?Lzk0+zZ(kKCWhOM+r^vZD=W#b|CNLhlcjTf55Mwu+tN903RC z?x;9xQQX;zC3$CeZg%NrvHM`L<4~#NP{9%?ySsmTPyWUa55IT#fqQ46c0OzkmIE8h zfxdE}qujloHEr_QAVlzi6$7A00jYC%=wU2aPIvd>Ne&g)KU?fRQtCcZaCAOw+feQd zl{>eUJG;y4H<$fg5Ja?-AflZF5$!JdF|#jJjA($hB2%H)=sV#@nM_~7qex?>U~IK~ z9=H#yWj;@-D~`q$E>z2^m1$jg+o%M>3UjSBY{OcZDRUJZtL^~EA%CwyQT7^2$_iZ) zIp>MNTyd=E3yCR-g=B}=m9gi)-u;*wY1W<8W$d#K9ZcBjVq?$?{TOg=#vvFMp)?~5 z%GTJP*(L~GtM}Ai8pkwa)L==)ybHomKI6$)v@~Uw3=-;3)09@1*JqlbiozSJn{SiB z#2U&q;9w}&w57GJ`%zb+2I{8()o?*ntE~$WHjM_U*dss2bYCG=2r|U;)rALYA%JFe ztXTn%%8~=%E_M0Q2`FVr)EE_3ml|jm&gz^$ zYGv6f;veH(|A?3LG^V=x0zYIInH>5YvSBwCi%*H~QP4O#K9#&8{x-rg+x9+%$d*sC zjg5t6a&to!o(?su7Vr0{^aUzyCkj_3OyV>pv-a(>NEf zx&kwNvZQ52r&ZOal?;rkA{r#3%R5x%N);1JQh z^8k)Q~!H$8Co7Ha21$L1w|Y5(%UqT|_;<5>ud<&N>z=lO@8txE$-*Oq6Bo}(qt z(V8Im+}r0Mu)ewULeX)c2%L*U!MzJlDTNDNvs!g524^uYXIz+4#u%g7Y-zk?32Pakrv z;}j+m%%+fJVqY#8wi_fup$jCLDZ=II@9Li+u2zeyzbj=fP2j8gloUfS=JE98d-AUo zJ=;s3?a+Z&KaYJyPfy9ylQqgOX5aFrqT}h3<7v{Fc(RuN?}kvkkEf{(;u6A|K_u+J z%n_=CI4v2cH8N+B18+!T^UPZVla36`{t)|;0@aeg2!r z6HyqF86qlh0CyU*V1xw2!Xy}0upws2HW-5|*O|z4RGjHy6ml5S{4q)l#wnFV z{l{X%{VZ=tmN%%TeBEWjfcN$fA#YG4l)7OmNfBTiY5>2Bq0l%R)P{!B)KvPYc2ib5@^wIH15-xVtdkvyY9h<6`r{4Nkpe#y%h2|m6Tw0X6R$r7pZ0}s$ zR(AN8jxVS04?>;Ufn^&2)R`T$Y;$A%*y3AwI`_?`o~0{4+HrqV!F#IcJyr6aLaybZ zFR&0pbRYIC{m5ru`fNElpF|+B7bw zXyK)Zv#6Qc(YP#HUByVB;fh5U+MCtG%Gz3M3rmx=7GDi^I;1|;ew4I;S;LCBp)N|2 zpQd3*{)}iLBK{KqXrG7!6_*JV2>eq5?F9ZU0ruq>RpggAOwuC7#?rVDSD2#FWn@>y z%{8mOlPV8)hdz@rlm(&Hux$R0yEpFK$n&>vF5LXc-}k`Zx74)E7yNxi|L&51_bo@( zn4Nm)Xqg|r8@>~U9-WylEVV8hi;n#z$Nm)>^*ULWNujNQ97*PISeO`vO9^$okq|K& zv?}HKQz}g$LyvM50X5|@=!_QlBmgvXTly<@7t{z){T-}Q;@WC{|h1n=H;|C8* zxFH*aR{VVAn^`9n@T@ezG1wr$i&AJup5?OgN!Lq_`4~azfGFJtpGJ$yGqXz_8#4&& z$fpvH%$BRM{ANnx`(N9a?>;9VMe%c%0G+zOvwG9dtEW$f$BB}QM5#m2^_QdpA* zf#0KXSw|lU4%fXFgBRgQeG!ZHs*vCmoUT*plxzY3ui+^3l}te`4=?73z-$#j$S+3Y zG6k101ub7=Z-Q`svJ;_CJOJWd0y%)e7>5M$2M7)iT>|8|?7x04)FKuUVzOO@tSGWn zSlUJ>)8h>)3dzKLF84|H^uD@>rW+C2EF4Ccm(t^sJr3_JQhHoUM98I?Xc&8xsw-%^ z>XWVbaxy0?$)%(-ge7_DX0U{S+bH}$;b9$3j!G@^mHenuQSllmFdU;L8`FHlZXq3! zHLnlxXySh)@Sh0$1%dxeU@L*YB=A=RJ|^&A2r$J9rX@RQVMExL2@Kx`P5OsmgkbZm zR=kj^g{U1nG7iz0f%uPDk|Y`7sMIi6g{|#FNu4l6kWOqB6B&0wTq_7bpyJxUgF(=|zk_Id-Ji*FsF3N9IJ%ObZl6c(S^+ zDQYy=%B45gLa0GvHV^1AMy4*(Luq;=?0_&(QJ0AO()2#72b3QM)NfhEf7MqjlipDa z!B41$jyXIJG?1uiUcrnictlf^C@(wGqz$MZmpf}ALF0FT9xLOIHw@kacd}~lq(z^2 zUd2`(wdqS_hwRODiJSGM57*Mrno(Q(829?VS%1b4!zkMf>Rw?*X9MXYsyfWbYfR6~ zs{;5cS_J!|G1E%o+}XB_FCVOzt%ovg+B@;w+#>9fS$Ga5%9)S=6hCZl#Xrqd~e?(dTp1=bFgjZ$Eso+!D zhli7=B`yilLz<)Rk;EWPFMJDnu_!NY@kSi&}l51l8kJl54J*;OxUyn=vK0wWfm`N ze-IOsW~LZNCE3}UCCpZmp$?hhWRJ_9zW%uXfjYF2Kz#rkaPlL{{R;wLj+n(CQMu0n zB>p0M6Oxg2OJr-H4rIs@rPZK%VdT?LX=QEnLmr&sA)#cZJwjV`%$&j^|3|V`jY%4j| z!GWZ!d+`*?ORI_*GY@n zPI`h)_5=Z_LUGDG|7-N2z5$^4nIOzowM6Sb-39>VBbUc@<p*7o)x!Br{H;^kSY^uzX(i{x&;(x{S zL6>qI7bX%;VR~_jv7wBsZ_Ko6CuB7lOAIbWL>r0($lzPvdE@Oja#xG)o|3z#VCj*E zdI9-Z+%hiHTxCcl3M3+_E2CgHl!I6zHvAP38VDAY3@`zh`SP*45TQ|bNv<_;T%p$3 zL94c3Rqe^bEU%5TrdczrOjLx8k$r6VQ73hjGHaa$tFVSvcNO74T^uMF)H+mXvlS+= z)-$>a4XnYe4F*^E)2aua;JBk|>`O<=q$T-mq`vAH7&8_ojk6Mjqtsf@W7ggyV^@!Y zMl*~S!NT%&)=Tr&6_yI_F}P!$F@;m4zy8p8Uv-Aq0j3S_>;w9a*}?SQaJc-xIHnEm zplnRnaLI(bkB-l&Ws`z+++8 zOb)hVMv#!EnJwpK2UJKA!1D&N!B|kS7QjwSbp)R3mn@9;3{OTs`P;wyyT7aM96+CZ z%_}^Xl6hl;q?sp9Gj5i}I9;mWAM(natOfk-7_*;ti_y^uQQn4*OdrIz z&{3KH{yBBo0F#J>tiq>R&3Y=CnXUGo-69bf@iBq_Mxf4){|!a|I|1U@aa)CmiT{gI z|2KiZ1(0~`!gf;PUlFnUA}5lnOE$ZorV?ZZrgAONkA~BdRe2oo9yI{PJ6YDHh|MR$ z*>Esftz@qH2RBm7YO)WnC-p8!5}4N(&gGr^)<@2*51d<}XLcSaIS)X=ZEa`ZkU$X=ZX59!NoSlm%;4K3|jHvHIK+M}aL50$Y|2ET1R_o+$;M$)0-TX}jyX z;{vD0RQdTAA9ywu*jaWD=3Xqg2XW?O?@h=dKx%O_+z8}Sl_pgmS$`ih_?f!zPYko+ zX>6qW&_2%TnGfAMm*pS93%~15SMGWdH*eaUH9hkB)hih0 zc(l{L6<~T4hL+>fV3KP`oxs%@*Sha{ogLK@r7D}UDoaY%DH*-MlNdujAbV5Av~n@l zOCGyg=!vBUvyDG3316Fe{kB2P5NeqO#=5z%(*(b|sbJ$iKExB+1nIl3bwi_Q~ z1yh}139x@&eQedvj$jmw88@J5C3}zob9RM|wm6QS(A!#ZU|&C1rin4fO$1QtbVTj+hpRYu&K2XpexnaejlV$UE>{Fw0)Uv)%ee^Sbk}C zALIoqn6=h>HtqL1tKHd<7wpwt0~BVCEd+_F)uyQ8>O3`bchB)Y=Jp55{H|(?oe=|G zY@_KZGx7Kel7m&n22}@7nO#f;`(lX*v{FI1v(mP?$#f9xBb_*jtQkK6`n4`OXAb)gH} z+iQq%=QDwqWyfP|g+hYfXi(P81^Z@#AvoEl(kr%m#5P^DI~PjmLj4evhcTVRUncNv z0xtm6c5yfxJTLF!(03wuHWWM$G0zmTCI0+O|IuLIaBzGwczrSy?3>&boZKB84aucP zgTui-WKHfu)^%KWhhXmFAX=sypOkdh@tfR*-)_bmgYRbSJSk6P?LUdR*+#Rprz3V> zwp}=O{wQrk4Ux6?_~78F=T99U8V(N(594y`+>qEz95^8gc@1!uLK|5lQ5ZKz;k^}Q zQe~0s(?ll$B3kkaf@Jw2UnB2Nq*wG3V8c=$r7KkN1_62^5u2~$)&z*p7l#Rag8*9^ zBq;PUfW*hdFX(DydQv9zYFGnQqbjqF2BBHia$|^=0aF0X`nIio(TqKPm~ZZU)BFG= z#;&&Sw=M=sZCgui`%7(D9GosWPm?|Z)j zY~PbTUG{h1edEp>`KwEtmrobB9J$|7^q(#H&t^}qo%3;^2itONEe7Co+0R^77`Sy3 zV>$&7-5g9=o7m|i0XEL=?eM~B!|DRyGXl3xq3pHnH3&tW-bW74M~+Sm0d|k$i;mus zqZdP?S(m{U);E9k*4eqUw7Es=e5BB}Yx$M?#)4z8=olIRSPga8}>T(uLMlfbvv)0DNatx*2u)`9K%KiX~>*k%5s zy*yw@4~ObdYlCHv_~-Qb&k^`F0$@cs@w)^B0^13YVpaTg0&FF3uKG+%-r5f6XU}2} z^zhL$=bk?;&vRm=m=Cphn?m0qaECynX=<4g349OJGj#+r25O($Uqxpi10e#~t(m#5s^Qf~3n>4!?DyWHJdZV8loHkNxgkyG`S{_=)UdE=IH=Xx?~-5%q~v`7Vg z*uAUbX9)q$*H&p`!FH~_tJ1-O>$r~YN;eAz$#uGi1$!yDo&`5x8`XU4qPwz@MK^Kl zdMle*u#f9rUkS0`7PuBywzA+hl}*#i{;2yVhhNrPjJ@QUjW5&KyR`A%HWp|EEu6{w z<~QDYz0ke8(6PH<+*35}DH-=vbo-1u72oTfm}{QaxeWPU?}TQC`CD5lT=Pq&F#M`z zfJtiFTWh2;Ltf4_wDA9KI>D2@6~SubETPsqi_pes^WRu_t>o#gV3~^YFy~TI>@3C6 z==SVn$u;k97OA;^vk1GTeB&oWSsy`PRA{fHgXGJEZR)@o$X_G-F4-pzc2@qxp{AuB z?D~^Qrn;59wP-*6Lj|I3*lNgD4>R*W$rOzxCtzqwBnr98zhz+iB1z3k8_KRFB9TtX z16d!r_<(N^@$VI)ae2Eu2X%&#JVqc%_KLkc$wWBCye&myP}#1_vJTn~6Q1QuEj*i z?|zjEDpP#$)X>10@R@;EkG~}LV6I9gSuZUHsb-Lf>u*x%PYKZQ|46oMXQcAymS@mo z?A!3HxECdHa;Yyc*DHFS=gXY+P4-vj`U_unKH?m2vcEFdTF{`$m+Oy0Z zC}__zx22#xA91sVuXa{+MqURq#9s(FIDTuxohv#WHMJHn=_&WifraDu7Pvj?^Y7JD z7wyP=nQgB6deOJ1ZQ7x*1Ten*MlQPEBDJkZbTFt>a0>BlVZGX{ahbUW*{CJ|k|(RONcC75qs jXnRcQpAo3HMkNp4)KjI^;r!8sQ;%8RXKR#TNb3IpU_(tr literal 0 HcmV?d00001 diff --git a/mochi_preview/dit/joint_model/__pycache__/asymm_models_joint.cpython-312.pyc b/mochi_preview/dit/joint_model/__pycache__/asymm_models_joint.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63e495ac1902a884b27ce4adea0d4eb51607406f GIT binary patch literal 24501 zcmc(H3ve9gec$fA-)}q)&jSd8I1&U2iXtUaqy)a7peVwY1=BtsZx6(QH~`!`kOXuP z-Hv8_h~)~f;sC6b;jwlqsMOeSnkICc#`JiSIPOe~6X@`=VJ4hLiKC{OIf-=W(KO@! z{@>p2Eue!ar|D#Y`}TY9x8Hu>|NZ~|o6~9IaMk=x?fK#tIPR|~p}vq?sp?@6imaOq_^Lj^!59a{(gTl&>u)v^;fYn^Kf;troSdx+h3au_6L)7 z{dFwgGF+c*=x<u_VTvA>bUZNp8;=Kkhns6WKg_TiRfYkw-*Oy z+xy#-8~Qi!oR!A!9PUVl`@^W^>fb21`#bU5tm6)Ig6CCE@Ww4eXO+n*-*4*Wc>fkw z(}$Y=$JE@)Y6eiV>M=FDS-DF21Ra6X{#H`!~MCy>{+0cSZjacbN~@ zzE6Yf4I3oGfkZ4V8BQcpX~}YGES(q~i4IG~zVWf)c$k;0VmvB}(UJ3%u&0EvXyig% z9ElIVk43|FMjf$6}q9ECxkEUaT2b1UGf@FT-Ys1IvpRO`p=J!4oengK$2%HDnz0qLL@adoJdEmOi0f0k;K46B$YrM*N!AsMCzHwMP_s)u243-!1S$%bWcF_xWs% zd*a^8qaDz(Z@n(Q#y9<>Rs#;9?Nw91Kg(U=`vYgVaJ%I1Nlhe^aqOv>x@USNqmW2O zuE@6u$vQHgjL6`_f~^oamx!h$Gjj-eW03)D_Xr*@)r?{5V#~#%3^YcP z!((!f)#IsnBn_aI8o=}8VkACtJ~0xHxi#y;gg9Y?2M@6+zroLOX|-_Lo#DO)jH-N8 zp3CqVE`oaUmC+$p!XzA*(IaJ0%lM3b5;c*h;AAdiLW-Wpu8bKevzoFZWl>W$q^xR+ z`a(N_zp#_zL~puEDG_uT{+pbj&*;C&{ifk2fqPv4n$H-}-Y%y&E=}uzU6nRVh8GMO zeiEZZEBiafo27R^4=$zMV|vi7*@HUM8DmO!3?oaoD1BzQAx6h2Y59&3m~NJ47$fj1 zrN3f3-#@XP|DV_{ptgI=Tw?`UR?V{dH8W6fH)&!Gm6C!1@!GD3&984w_+RRt*2lE1szms`)qW0Jbq9V zM@6xfGIS#&lIa9?ZB&$u1XBQP=r$qJ+K?<48EiqWMTRmIE9C*Wl=3Dd)0j9qL}f<> zLicGUXCmy845|3=z^Y-ghl}fIT;*Z5Q&Jr(dtq@SHzMvVC|MzvYT1@G+_5wj8c|5i?Vg#Z z^Y(@u|E|5|p0AVIwh5EQxRWtX}>`U|G zx#+_B8;Qli+n4TC^`gY~-kK#Jm^e>O)|I_{ueLd7c*}Xixj1mAb`R_BpOo{^#MK1n z60?b%F*m;W)a^res*j^LJ3XlOp`LTPUq3p1bS6IknT6V0?ROkKMLlPAe$9?{zQA1l zZ2cRJMKjBJfT`tTB@jk&j(|KvF@z9+U`k-Rj9AVhe@Zt61Tv&;#~Kr2p>|V-w2DZS zjVyi#EmVY3s*an;KrPBtgi>| z*`d4D8}ii~?pAl^t2+xG@67hOy|a6B+uz!AW6xsS?`{0vMu4;$SN3ob5TuDZ%$;5^ zFK)_r?78*wGJwe8LUrTZ=l*zB?Si*KCF9=z|VnzPN?=AXR2>)Nhd^qYGY)_-^7+Zz{8 z-`sr5d+Yd*1OKD(Pa2m!FJ_H^)7Kr>9J$lqbWc08`s|sRbp^L?X8*iDJGEqK5D&0V zaYZ~ucQd1@ff7p*U`tTuaR>-frC2I4BKa>xhZ97TLs|fYE`a<_#R!4~(V0HmCLy4|bsa-60ZpcyYL5Znj zv6KNYfgd#=)clZ2uPUvu;nj`FHkEbCHf2Xml`_@Tv?^N|+m5RQ@G>^Fh7Ipq5zHwy z*wY+S&M6oAQbCvIX%404l%q^FW*iw~#-xHV2ijHuMj5Bt)~RXhT+!BFymOZDbkJ zyhyQUDA+?mF9o|OpjQ=lQ$Pa`dl&!_KSvp~ZpD`<=%?WG6p(aJ`~n5^V6l&a&rtAL z1QOuL6@nuOh;q~e#8jfoLNZ9PgUlL>igAKOlg*WYil)Li`cM5M0#H!3oU>^@nyXoM zgtL6n!g=K#+x>i_5iPWiuB&oZJ6;aymJ`f7unv>I*KPCItW#Z2H*#Us2_rt38)mII=jn z?0R-}^Oa?wYupV5hx=py%@5obw=JtL1~`vDYhy4)jMH56`y~E^#GjOGvGGVMei8Dp zl;mgJJ4*n3M9z_5P$iCNZxe~a;0J|T_@NLvj5s1hW|awmDe%S>2@`9HM^<6Nn)QgB zvSTH4l^_n05>t*0KU7-f7@v9xW#sQntHhOfGOa?NN{~m;3HoQCfvFM0R9M zf)Qh>K)iq#R#3hxsbq+a3(-T6F`MHfQLx{_&y)_s25|^+ z$yVN2l9gm?Bu5)#O!Q>?Aref7m^_U{+o41%6dM;saK6J6Au)b&JR!z~uF2|-!=ZiQ z(23B2(B{3NeTX3AKscoovlAc_1=tR$%BaxoIzTU{uGm)Op_8dLY zv+u+~@k>-{iH@g1)I3L9)EAYkOxgz(KwPB7EZQm9O97*7 zi7XIlUyIKn2v^Ivg`k&so`OLN5){xJix(&$+#>N2@ie>liZ9UJDwD1xeFTn6LmQ*V z{7cUGl`$w+Ad8dVHcZ>He0FES;h%|4AJ6gyAdpw4U&)#C?hOmwDmq$83l{>-xwf}9 z-`JcFY{^;+9{-$W)-vyX!=5$%wY~M8w|cH=wkc;?H2$e~$Gx@#g+NWAIb3MoR%mJ| zw6@=E-JWmV{?Kk(@0-yVU0hAw+~DjW*!#qC_0}2vBQNLMx#ZpPFu?iSm)d)meW#Wj zryzvzgmU_&^-tYuxpn$BfBT6)e)@+`FS}1p8!-Ovg4^@@rRhswy)xf3e=!$cs_!fh zeeI(A4(IEKrVq_@&qU|zbHZE68_8wI_AGzT(Uv>C?AU?Vfv{AjthC|E7h@x55vBqFvb@pjfB3+yJARY00Lu_D5%U`K3Ne(GvUN6Ppii z*Z<*m<>|s|)B~&7x#VoBLlNe8riz6y6O*oLKBU-T!42+h}=x7zLzBH^Izh*g6 zsYHWTq`3j5^L0!(f-0G6|Tigee(aT#9S^B$^RaxeYCtnhJLl ziAmKSREm;e8GZ^QLgUSR9*Q%(+^ zxMFTX(E$MD%Co;>^uOA?eX4HDbiY(xfjf;{xD+;*f$W<60<=gfnTj@2bs3?mlq&UQ z`ef7qe0&9KSsXwh>7geVzk*oUu{PVnw!f3{+KTd8^N1T~|}y)l_h~=|50U z|5dGbtJ?Eb?YYV2s%@)Mp{}JZhx4H$MK|Yee!!_rLEx^tCGTz_#zF21Va@v7Xuf7g zL208#(Ad+8#vYpt=ZiiLA{GWy8D|l8yl=r9^c@mM>DG6sx3{YJ9P|{hF_s8O}mWV(LvJbf_fN$KgnQq0ErYl?P?IXGW+(Md5E7YG*J6!TE2mtsDO`6(9Q8bZY?idA#9p`3B{YO#h= zwVXFNUv<6VT0`E`UJO!J9aYfJP8I7Z)xfp37aJ+o#I>ydISuW@LEdLasJx5&e_hQn^G9Bobpo2$2G0H ze)ihg2YyP8@SMs1l`sCQFWxoQ=Z*Etc&_dlW8l8i_xjoCvqdA~m=}k;Xr-8qVs?r- zDCVS?i>s=eJ288r=%$p1YicQaDW*KzNd2jfCnVZ%Gw87`ohoT z$Az9NMLf^-^}Y?WBGToHWZ(nojV3W6>kS#N;2cZ{Q0g%%8?{y>CuCvMx6r%zZ3@;? zK+R%`u#xUWv@pWUhyA(A2moy7DkA_=WH zEBBALG17dfPq%4$hiIWf@5C;$Yl55Lpdx>tK7>eug0Nm*Q|d}ea9cUo@Ao&IcMo0%9DP<<&B3qRZqHd!HG0eoC@LQDp_Y~})fCLPZ z5eN?}o!urFa*+&R6Gy~vP!*Cbp$r>`s^B)EP$T8OTBu@$T_HKn(Qb@f0CjObC3&cM z=_MGo{IQ$~8Q~GvqajniEaN_U3$1ddyxmnrG)zCa6zoEoBDz+qh_00t%!*>UbXSzc zQw!0>&H1ipe{8x_b&7Pw1-tw8J-@c+Uax$8f0B8Dee!LT+o&rdD9 zoNwN>T)%tCQd@90fBT8tm2d5N*S&SgvUSP5^qIX91Zt^ZH0z#p`o+T&|Fxzp-|I^Hgzr&UuDCc34uuDUcml8}!U)2d4hX6&ExNlaxj{AXMX{Ob^T<#R` zNIw7=mvIP&H=zU~$i2eEw#(UMN zwjk}l3R{(4N#9Ir7q!+}{rb_ms@zJ6W%O5FSKVrB2zlwfNf#Xb9W0ZNG1eir{Xaeb ztmT7OFo-Ldz?jTD1A=Jl$3~$}B=KZisLVL8;2pIs%N+MgPWxYuRo0Gf#XOWaS~jgK z=H%17T*bz+`hEdl_-$U&)12zcJZ{)7l8jasl|mCgJT86*X~`HH93QzLewWG^T3MtR z(Rz|?U?3s`{VSUIMA)a{V2O)lkVmo;o++UR@fFHu{o7?_axxVK(y@`rc0`UV1q>UN z3MkHo0%LWu>tt0YG|eW|0>WNO0gIciFnepsG{NpEGW63Ro>EjsKg6JE&7d^sr^F&d zY{#XV5r9Q*=bSECTKLtYkOeBDLrN4p0j7bM^?AFW`@!z-@4kKU$I)B6mv=v(w>*`# z&-CZ&7l!Vys3#+&!m|f{@bdRx{&Dxx3uj;u@N8e+@l5vkd|mGI-`jY?7G?$#ws z>yo?mo@2uTzp!I**Ro?TMqR>)o=poq3zrrrmp%Iw3FfKkQxK%Syl`yUu~V%`Z|R;rEf&YIU{CLEQph!N4Anz6i!5o|aro?!{%#W+-t>U-rpmPfOM)&-0Uu?aPj* zF@0`N*77elL*h>`SY?iW8}Su$bc!wZmL(TF?tO>KSAOfgwZLNNzdU0-xk%tbLb z#bn_K#e7=P$Iq#=;xNBW!}cSNZWWmZr66fw5{?Iu!UP}&F&Mr7;6NagYBby!e(D7R zh+Pj(;I@)~@7KeN$p|+OL)=)2U?4g9fR6F#dbnwr1p~Ng3*)A(K;tIKb|?X7pcB!F zxH#FuFial=!UL2U8m3&579U9Tb+Iy;^;NeVI;UO2Hu8IjtYUqkm0rpEmN#o9vqNHN zVbW&+=~sH=70j^u)Ji6ZS}SzepqGFE+SB;hrFpi>U@B#dOoA~XV|w*0G_2@FWvRi; znR*YM|JvXGr`)4ITHG7fOTM_QlDwRNvp{qJGZ&Zi@iEz~Sv-VC$ZVfQ!rlEXr>*6!{Ab2Y@twGbT*=6un`YL#)Z2QI5%^f7P z2%w^EhN?AUOoh4Q;Drmtu;vBrK~ANnybPC z%?Q-mGU^E%dgZTAPZ4h+fLh@=#s4z}5(WQ-0@|YDpHaX*xKV?hh;Aa~c4B~Z+(hf7 z=-nAs($FYA&F3vqCjWuHE+IGIp%tVEEEKQIUdi!qynMH&BVW_8;9caGYC4u{woN;- z#_ae#M_{IJE;1W|hO=tJLe-*i*|7siN~`EBV=Nha*j7Ps$sB=xJqFu-8trP>&A@g= zU`jL;pjGS-erXl+ z2P{G}p5v5#2!-}v`FWwZt8&k5o?1S!v^DOA4_f7MDIRDQONF*ITDgSvuNwQk49WR` zMQ$uh0ehj^Bfq=5<`0wF0!d~ zO4iuI4>Zoa%u>*z!Y?%{Fx4`U=DGNA30pGgPpnqr%fM*wZ$<2V4+1eq!3_in4r%di z#QK0%aO1e_kDi0o;-4eNq=71o|1IP&Ngk{t<74T;RibIho_@0I)8IH;%LRtK9DGK(3BNj9c>gxWYAl_|2H(ZK$e zg7+!-UljZu1^=6Zzo#IKAch^tE}8)XiUe{WdxG-VfY zFX-8x@w~$;rgx~RN|4gcYE2lfWt=PI_B1c-O0QF5kX*w#54{;QB#{bf(i@a|Lmf)2 zd?$lP7zTPE^)8cBs*Fd|wzJ$uiKU5!Vb{=BC04!*a6PO@gjF;`6fA6>J!Mljg?q!y z5P)FUd2Y(5F^@AUbq;V}=mNU-XZ$etuubB7Tp;uvNRuR+U6>yjqYUVokw_rp5$tao zGG3C%!Q9X{^t4j1d@q~S!E|n`fz!;zu5v9UmhmZL!FWlx#;)|UN(|2@shUb?D)}HP zHSdhJA*BRFq@y=%l=rv=5`Ws`Dak5tPKabgnsl~lAe>=wwx4E^TT&lLHqd`?V`q+V zw7dF5`ed?$hDvnGvYvlO%~~|%=U-FJdlV3qk>Q96U1T>6wsqN_X1kkhVp^Y)?TQRT z;D9gELX>Rjaz-s>I01T?n>@fHDiN(K6SWp9AU`_E`h1iuQeXnYhy%$OPmaN*&J73y zTNT(B0?s7jDapuAC1CuUWFmjI5gZP1l4@_Evx9z0n zp+a*P-5Uz^?X0`|j%vA%`=N!aUtg$cDL`gay{_PGkrfXmvfszmZ-R-EgF%sp_w%xZ3a z`}Muo_AYGy?w+^z-0I11+H-s0j)Q#sLaeg}m>c*4SraS8rgraXtq`n$KKe1IVuEOpal8AO<7>mn?LX(mOTiI1*iY@lhY^X*ZD|6>eY&4B8 zt(u`D&RAf=Ne|i(K#Rgv(gBp+RX~;SlfT2r5Lzm0TAA(N)4i=T)Yln0(iwVjO(F58 zGs@b(=ZUR5(Od7D{m3McasUPd04j|m(=#;GFfV?HR|H(@!Ho&gQUJ)DYFZ=%u`82R zE0=jGJBYd|lE*<}3MpjZ^{dlY=Pxe1Tb3*>Y<|J#B+e(JlDS09ut+kRx)3v>m2wfv zX45^0TrKxVD)CAjq#{y;?R*Itauq>#WhIe<0Lb>fa{pvSQO7acO5>Di%1nevB^1d} z;Tq_ZxoKLb$p7fE{ADGjl%P>8VU%8FvdVYCq`b+L4OSoc`+u4j)gyrB!PnptF-)`xgnL{Qbq;>5X5e(^%zcQ!cM7BjuS1sl}CvZ-)|ak zmb$_@7u--wxnCtl1f}wYr#w*SRvM$alQK3oo?48@pRuuLdn&y(N&?zvLvFij2VQkw z^-N+AP$6?eDhr^D3go2f0zBmd$?47bt~#%}28K1|p` z@fA&^7>SC5lyqZ9@Ox3_C?lu>FA^isf+1IWYdZkBnL4|o41Ci}> zzr;oj>UDrwcDluQY*dtwX(Ao8OiA-t`;q}_>7=YiB~ZiONHQ}6(kHfw#DIyvpx^-o zjKqFOv7aLd`^#8Wgm(`o>M0;|eu?js_#+bkqG+QmI|b*d-76H}WHpUHqoj{T((q*= zI-E_`@1hQ=Jjt?%VPC@Jt1c!_gtLtNfmfsNQg$1m@6RCsjp_l7nvKnL&+VMudB;+J z52ln~oc`in=cc@K)55WzI(Oc$u9-VOdw%}hd~9ZT`f#=r?P^xm&$4U%hcrkL*@g-qDmjGBfz7hQ`yhay*BB z>fCV zc+uTq*4^p({qvV{iMK{?j4rL;u{gTqK13ZI=KsvoPH(*vhP_VDOnCZK7WJ#<8fP2l zFE4vHWKB%@JM-E3&*iQxxwkA?wmkYgk6Fq)cYgE;Urvr-%AU+Sw}12q6S8y3x#OeM zHZ0ZO?OA_}*Z)Y@^XxH~{tr8UA%0==VMmXL<$gG`|63zWA5g~9HC0A__1XQ108X}?s zw?CwPPN_vD60nk*PC@A<;l}0Ho(Yqc}I4>gxy(?5M709*~<5A8msay{r3rN$n<`aGF zi9Yp-uu1*YbamIWUf zh2}Vv!eNyHTyjEmMjbnma;0sWjJ$V+Nbc1^4rp*Clk-RfsO_(GjNK}ivaS%OBy{G$ z=T4N=7b7ufJw!N#o!GN;r?#&>+qQ0hYURLXxT<}4I7EIi@hJSqWcN4boQQFQ2)sO) zhz*vU;6gYKAv6NQt8^sa%6@-!*S4p~=gMBnq6&6KO)n!bCu=_| zyXN6s0TM38L$X>V)G-+f!woH!{@7_{IBtggZc#%Q?t;|y^H_`WX%Sh9k&IPwfIb;| zUUq=)I2<||4m}Uyh=^i|e}1xSU#O!mG&mNzJQfahjCF^`wuNG0xppkn7wSOKST~9; zQ+fo^Ii+*KD*F!<0XhEw}SBNWb|zN zd`IrcBR%IjLFaZj{qL-Qk?6nLYtB2@x}VhP&T&G=R{uev!ZLT%U`+b^gZr zh3Ji;+}Wk>(@QVBv~;?EssHSf{R^PCP^W;(Fe38Loc=XN3p0=E`NO&Bwc`ukWzQxE zNBuSIXraS-_#9t+*S~eiv31G6wNTe|z3W=n!r^?~6WQa1n&!EyvsZH$7d9*&&u@J8 zcKveA$?TCevfrz1fg|^(8=L;LwrlzbUU>Y^;Bn^hvj1w5v*N#b^FyaO=$by7?U}v= zfpTyQwDnH^eTV0+qap8Tz_~>HvZEC*t|8%J16MtB@vA4vnYwBwx)khQeDSvNj^hx^ zV1|s(kpCYn2oKDinmv_w<0v9++)t82@;Kve<%)phjm@)Udl za*W4+eRH9?wa~I2zVU^PU4=H-`mZCqwe?@*VGmKwRR@bTELzLeH5P*`TF2El7wcKH zfn4?)S+og9=H;rcxr@y#8RD8+i!ChL3h%b!Iu>nPVPGLYM0Ejyh|QPRme&mxJ|@4+BnLB7#I0C3I>)*K+Mjhz`1E;rI%uEJ0hv* z*uj>vZq5$N2IXXI*$Os5kCip}?EGt49s#P9^mfp~jpA`Q=^967kTOd$j;9kaFRP_) ziSv+8#i*wuTlw^A`KXRG<=SyX$EcV-NC$9A{{5#;K7aJY!N{Sby*(!)Cwg8w_(D1J zs5cq91g?iF`q@8F+H{RBZ# z&-47BbJoA)x|X@Fzv3K!!Bzbw*Sy3v{{`3Z*WAXRaZ^9#rXJ~xyzT*q;O7pG-}G~x zj@Lah>5;T>{PuEb(PiM<{=%@k(A2tQsK)}XX;|X?h3a~Q4^2ruADC~w9=;ZS$Wi>^ zm(1fl@15Uv{po8@KjbL>P}JA)-WlIqa5ngmL-gUUi+W`2n>#vt^yeJKe^LCrgR}TQ my;r|82RtG`= literal 0 HcmV?d00001 diff --git a/mochi_preview/dit/joint_model/__pycache__/context_parallel.cpython-311.pyc b/mochi_preview/dit/joint_model/__pycache__/context_parallel.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..964d807010efe1e34151f93820aa62f7eb32347f GIT binary patch literal 8755 zcmb_hU2GFsmcC_I*_D(NlaPN1R1y*}gk<<>wh2Ge2_zYMAkaye8KryHYg|R*;Mk$6 z3?woucEn0<)Y_(7tjNqVGoBXeHVdoaVMbzCE$uw)!>qIqWmE{Oq)14+ti)SbXsy&Q zyXV}p%T;#hLE61>ow{}Jxqs)LbG~zLzN)LMC6NB&KU!z`W=lZ zh{VLnjP+&~JTqPf^PV`nz|C-wXCgj${P6Jb)JSa-E0cNJKlUVXvLN}KCFgyT|2j7l zkofClrcSB>S})ZC4N3yg4IVO1q`-9|)j_L<(zNCqOSvz-gxrV6N;4a!dS~wyW3f{3 zIy=(@9I#F+Wm`%)=le_exNjFo8?F;~$xJJqR|Cwe@vY`nDTj^g$e{`5)=v9vhJIV# zs^5+Z`MPDxO0Cf6CW)CMq4pQJo{11+)+lmBQ6h<1`Gp6HOy-K|m0}B~Pm#5hl89nq zDY1}2>AqS7a+_$Sa(JtrG?^lrEqhWSqmQ7UtDdXgG|4)9w0lT<)64_rOQ^eV*QY)0 zc0J751JhM}P!jT*t>I&*C#FY#G95lMJUM*w7G6~3OELL! z#g`P*2IrHpgce>%N^*QKnoMZ&6)n6NQ6llU93NO*GKF|D8i|M1xribuK^Q=QM;!uk zjjRgYYlPt%3c(E-{(0;6d~3*P4dvOUYvY+87lgp4M{gWm*`8H1NAtoiL)euYH-!B< zyZ;rmCQ*m52@eHyYo@7bI=AMpg3fJ`m@k3jK0H?X8IvXvqM=lZY2NNPP29W@*2QPH zsp9jc!R$4hQ!%BPtmCUxN=Vw{9^nGre6{_3a3g8Y-w_Ga0x0Y33%{|lW4%JL7cOE+<&ML{&q8N#T*+@%bED_UU zkvQm%G_V|;1f>%fV(LOfi_Q&+%PiE3%Wa~RlvOd2)Wo^SC0Sfds?aEz5KTs0ZXT!o zq4qgj@?zH!F)k-|e9MGFwMq@F#|%uS5?X9Q9#xd2VzN?9)l44M9d+GoskVs8;fWub z^~iBKsl+AA95z?;QOgNWotPOlJKt1`eNCBMEFsC@AYf)-ggn>}6^F^;6jW1~jVI56 zuG25nA}}STp3Ysh%#yMzA{5~ksjEP)k#7Qo->Um}uGV+xn}+iBhm86|I)A9JrB7!! zt#)?Z`DyOmeCK|n6I8!rwP(B0bKv2Xe9wDE&wDy67Pf_Sw(A=oso9$C{c6`2yK)!u z-9twA&|`i`r@w;#+{z~1e-3`RUhDnjMA*szed#4NhYZU3T-g2(`wD#BORuNj_nH9t zFX9RK)@<_3f9lEBL%RPI{Qhv(dINpQ0T=K)S!p9>1FSF#kDX8~_`Es{#O*_`{Ru;L z8a4HY%WbwoqINlyG322-qcp)F_XwRuaU+WU(b|X}sNMmO+6e^Qk$;Qs@1xuBPQLzt zQGY<^4}g!-+3xCvprz7<;27?ONF)UybB{rXjjF43BSBFVClOQ8M@OPs?2@eXLHPhY zY8?>j6LZ^iVL)dGzGKku0k3;drCzg z(o{@S!9azQy^v8i0ihyVBVMj)HQ4)P`!PLuOy`e13vSG4)pISm$5$iJ@-~sl3?q>l zkL;BgnVm1yf}@jt^QFj&oM$EPb=K{GW_XE(*>mQm;dp#JqRq+5iG-G%rnokiN`M%X z35shbNB1@ASUmB;=yacEC^Do0sKZcXC4gNLwy^-bv_pg?9thMbd?>WR9%<5WgNv!W zsI=GZ*RQhBH_oIgxtDsC(GVhwY2L<+G`4@Mxh;cWXPGoJZ_i8eNaU{_p~M%=B!PN6 z423*GGEI{Aj^}TftDI|9HnOC-G?Vr|VBKLRL(&{>&jeL11??&1F*!+FT9i$n1aU)D zzQ-uLu)NL;D6*7_%3)ef3r&{BDYfSiHK>zTFd$XhfS9!yQp9Fc$&_kxT2hJ5nVh_^ zs4bZsL==m%vKv{^e9RUIIl{A4J;D%@B*RwV6Y`nfsCLDQd(dpGL2Id|SeW2TrnJSB zrs4*H%Z6mxssDUZxeQ^;a?f`t-awHFBN`RMbHc0o&44hGtQCU^T5z&=HP+#d#c48TlgYc-MK(5(;1+HPuBXBJ%+?~K0fy|@* z#-WomS0(^ezQ*exQO0tq9LI}B@mBTEy_VZ8IU zdvPL?+bE`p(hI|dILpV<$#Idz#rTpXIsf^1GNSF<3kY!x_*9&5QR`~UdaR5?@d>Qo z1H_J(TMD>&&xO9HLjPl-zrff3EqAN#$<_ma+hrW)e3nnPT@`d;%OVFKx4{i$;a#2F zPD!nWV3Bc`Vc49lZU`lAKHApfE`wmzF{gQ3`e~R?+Vj9$DajB!u`HfQfQ5?I_KFKB zRTIz2V%a$iEO(uzM*JNt${$kb`5 zIBnZIG*V_)8|5T0Ov0o7Igo2)%|`+aPlYXyg)MnOGz3v+kK?cVc%D6OKt>P?0x)cS zENsmSU53!5vt#({8OyU{24sY;RUruF*}h!w!;~%@$qPpe;fT&2sRrBS%w)2jo2 z5^r6nDO)8f&qRQo-ZZ=HLmV!iUOev;P(UmNEe7P|g7I0Z*Eg4HmEqOyEA6 z0ZsoWOsAm6L*UCZbBXh3&SnQPs^Q8((91*`lmPyUG3`^p|^>C(K_(c3HPtp)ny3a zWbu?3Q$>j7aNu*YA|}sU4U6mQw-+)Y9uXm0rbon`>ov1ZT~+NI(|{hz%NrcDyL-b7 z=+p=73DCG^xqAf0-T@L7$Hh5O3W>XqLhcmgN33Eg^a6Dl{Kjc`nasG!%u$!K++D$k zO|c_r2sI(v9Yb5eSr_9Z_{iovxs916!{7i8*wLO$hPti(w; zI>W<-S!V@0;pkin&a0L24#n){W|zH}qMHw)x%zt`;M6vg<_;Ryv~|I&Znd$+Xzb0( z`Nn>uu|MN41lv~rEFbL7Mji*Z>!n|zqbGYg-!T9!g3s$)3xQxE@V&x@hC+RFp}ygh zRYh5-RW!$eCUUAlD6H zCXf%w3_R*tAlK|a8p3Sh&N3^#cly@|=5m@b_@0&hg<$G*!i2ex$VfDrTKvK}#m(Ln6@Dr4t9<(t` zGFiN-H2rWmSOCvScS{*XCeAtpphQh(#AHsHOjyMUp|31PU+s=SAZj5W#03a_N(4Bs z!lPaQQVhm8ZU=z;#-`hSpZ9&XWQCw7BZk z(BTymsIj;D+?p-LOC>h?b{h(u$L)W!5K=xFcpwR z`NAI+efeKquy3iJ3-^oYWW!T6>|YrYi@`8}YD9ye8E;5jp#K&yy)NRr=B=R9HF=Z* zhA3r}tN=#kN63r;IoLf#MAToaA3T53=)S`GEd*?G+Cm8cKZYjiZ6E-EO(g*A*!1hZ zU-V@gejd6LLSQSjw%xw;`K6V}XG^!1pde6r(q9cU-;C((`yMXqfsuS*#0ZQ$4V-!$ zIF%2aHUg*V>ECGR%68=&wiylEGCoAP1KB;fQJvqL=l2@?UdVuRSp}eL(Yp@igNKda zVVys0L7alOs+Gs?D3&3~z_Z@5yMa4m9lIfun>KbIT_cby6C0*nr4~kEzpmppI(zCR zHd8O*bnNQ+H^I}JV5r)AtbafUn1a#n!(*j4_E&K)_eMU}-5=hnRf3Z9N+64>UX|Ri zR%PqX8?&;ll&%aug6dW8c2s!1*^#^SPK`eF3QImn@BuwsJ<=BSWSW);@+Eq~a$4?C(Pp%yhv<-lvN zc^QT&kia$iS0JM9{uRh(-Tiyku|+4Xh4xOJv=p}XKy5+XrjyP!AIH42Mu6OW=hk7G z$wukxVejuCjx5zWavK_%35XXRbf=f;fs6SP2{6nc{Aaa9c$R5eBQ~k8V_McoHOcng z9Wc5NzQn56<$BS__->|drPl~#a@+&BpMy*#X(S&1^({}C*2he1o@qCj_B9X7fGYt4 z4#d-8m(+_4M!W0No1?cTjHX>LY4vL-(-^}uV@$UB?&dXuxks>IeE*^MUqDKjTZgu6 qWWbSC6A<-LW8==3SoOMG-_J2aS?(^Z5Hb&&e+NTB=8c4E{Qm-(Rwe2H literal 0 HcmV?d00001 diff --git a/mochi_preview/dit/joint_model/__pycache__/context_parallel.cpython-312.pyc b/mochi_preview/dit/joint_model/__pycache__/context_parallel.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dce93f93093d129a96599dee7dc4232ab7893c6b GIT binary patch literal 8103 zcmcIpU2GFsmcG}nvMVV&CL#YoQb`EG!GxyirkeoS4nIjV4WT;$^(altZCpj-;Mk$6 z3?womW`s0Bvzi9=Y7^8ui`ji32v);9HKTo)-ly%+K5TiMQLNgugtTZS-U5lVQorn; zbIVnBf?;=Ov{$O@Q}>?xQ|F#@zH_es#qX~qP`-H9FecOy^6&U!7jJ=Beal729Fd7k z<7CwOrIW5vH$~bN=O+149%vfzz~hBSfTu!kklhNIsI{jn{>q#p%6$303m(~XgCF(D z-Wz1pFAGpt$rVrs@#@U-h-S;pD0t z+-NO~>MqV_Qxm&ws`&e`K+1s|#HL3Z*seFhuB-oqT{pf~jx}$T<7U`-6I;0!R<8RK zR&IXn%C9>;POgVlw@lMu<8ze6P>`AxsuEGvNMcMe=~dG$$0kjWs^}>-5e3C^3I`M( z)as9+nj^Z?p}Epzgy@AsB#v0m!RTwQYwk3e;2ma%NW0VY5q(%#A?<2{o)Hps!H!g9 z54z2U@X^ylXO92)O!(ve;r^2+kDm+=4xc{zvAJnX(ZkWHFl_o_Sc}alrcYNAT2c*b zQ}LJ{y|W5KNYs%Wya@0xDQ~CirTkN| zC}EjCgkCH^qiGT$x+C8SkJVmUr-{uS;Yv7}V=3h$q=eIx23MiuZdggviE5iJ{-#~_ z%olfp0P21O?k4T}n#gdrP;uk455e+;Y!MU>Nj7nv7|a5j{HIE6s4w zFU>Sca#GQxL{gW=BbOCvDyhMsWI{5jG*dUo=0k-k9Vtj{ho!iZ*!cqu1}jyRkLepu zCG^;&a$HrDs>#VQO*aM9P*hK|{#7L&?jQQdtU`_}Ni{B88oc#YKBny9krSiG&26vi z#kodIK9-ObuwB?0*o^=x&%|N!xCPA=$KuKJps?%>wF|Ug^|QTemL)NEwH2E1IJ9q| zx=z-7L})d<-OE+Y#+Ke()j>lzm~Yu(a9ftQwcY*v&gTp7=eB_wH!pW|Ep_bAcI8SMIi#6JN4-#2N>nT8Vj7^L z%pMT7azU&G2W5#BVG!(tXw6VvC(GUz!@GkC_1#?6enZ&5EVdh5`>Q)g=g4+m14QZ0 zWs;&J*+QqXN}svc;y5v%pvQ!B6ytGD6NFaiO;>1|mK zbS`!sH3CNs;po%A=2tgoySddN7&1o`GD>AK>QdY?Rk(>V4o!9`o(Xi&8_ESvcHiJ^ zH!>;48ppqJS2NH%Mu7`;a6^$@GHE7%NB;rPqFf*et*8SVitPRy*DvWczT_KGNj8); zpQdT|BhC&EDM|BSj{FeQGqnrYgD#Vo_30_a^vD3-QRM+uF$@g)O`odBsi+cW-K^2% zSfEgO9x;HqTon~hZGg(GM3j$>rIIPlwt-MGYYM3 zrY2#C5|Uvn@(6lNcT~S>1s)g)Sc_INjj=quCz;ZxQo4pJ0}dF9Wvt!{N%aau5;Gld zP{5u7Q!fnG{u-+5UL)9b}rQA z>h{c5KK1#(JbClvFNS96a&`UOmfKtMzWVuUdq~#T4!*gjGT*u5{-^goUHD{SIP=*H z*7E*=-+6!Q{f_>t@O$AUw#<6av-_8Xj;zp;8Cd9leD;rk41u~?KeO4P2Nk9e4ksd$ zN;qs*hB1;##gX=f!slJQ>mVy$vq^2=HmR)q)097Ijh`faVj}j}sKaQ8<`q;*KRTlodmHq3W;r zIsa|{V(Wf@Rtr@NsOh*L80vieFX06o8=WYX4YXmg-HhgiV2#>u7?l2x+^!;bF|eKB zta(f8G_0L=J#rW26y%0xq!S6SCCNHAX)>kh(s@NHnsLue+i7NQZ(vao6d_*nBU(BR zz|y^H6e(CY{%EnpSX=>ALoLHt3sz?4t)ajiJFa2O>&B=b0mCpn8Xn1Yvf?4W>Lsxy zE4D0(lEED_+K=U+7Nxuh?X6j{bx~|HxB;VMAP2SBwk!sg#I~&1mf5k;`8c&G9yYkc zuOg#!7l4cXImieZr2siTSAvce(>!=L|Is9*KT-Wm@**Ik>ArOFvgP6`5~<1XI7Hpr zI$*>dG3#juy?HY*mLw&P&A{kFDJB^JcJ4Gc13WTiDT3=qm;hlL*doBEqdmIljHUom zzyTUr#tH#m3gGu=fvA(X0H_ty83ZubBi`0qN=kOc%O2VN$ZPijunCb_5Y(`O>4O?O zpVAeH2~~dYw0Hfph7VE~ON8IwW0Grlq7RU5><4 zN{=*xNY*QTb|fU73Q7B8Vz%K> zxv8!|RHaC#q?jf_fQ1X6S5zr^!5Ucju8{K~L(*Xh;#Zb^b(b4vB{Z*$E7=CjQ2e|- zdz|TBvjevE9%lnA$e3wA0&DLCiAsagxFiRq-F-lv0{V#6EC-*Xj)QB!h}5KmCLL#9 zWu|=t9}dNdU?B7a*>seRf;19?7WFh%#gL?M$H4Zb;vy;r%N7(=ve+aXk6l!lhY5Pr z_plQ;Q8F3k*rwl#I>OQM6r_S1gAazPGj;21+YU88gW=j8sK9S+C3Vd#$Y^YXAOCVq z{ZdV5wx%d0!y!dn>=G zI$u?nud4nsc{8~tdIFXI=_6v@&F)`>R{W&Gx3-BiwJbI6%r@>^*qm$JJ?mc(?verUOfecnq!REsgU3Z1% z_{49aXI(_m2mT2Uc|)-=@*MQxRT_A}@|S6@G~_^O@O8F7y9WMn!bXLXS!K7($-GQu zkIe~JcaCSAaQtR*2t+c@3tZ6@b_pKKJEVRBJEkh)GYGG_0XhMrT%tZ<P(%HSUkvzn{12+e5?Y$M5ubw}VQFFfTSqj&5UryhghL-;M zOPRLJCpo6O&%{Oxq316xwjX$MASWI-xZ|%vZUH#R|4%aNE12Ym2&|Tf&`YelD|AA% z?s5=I%PG{lY|{x~A-F4}e4tD)FP8s9X6lT&=GOhzXc$icf=hC_iGbZx{IyjbGFIoN zv0Tnsyd}KG>F}3`RWOfT54Jy{CF^sb?uD;K#hv;6I-IpsX&ulO;K0FyS8?`guT%&y z0mULZ1fqCPunuMa7r0m!V%Gz85yeVDnpp=L2s$ighck^uY&m8j)i_|)5$PCY3>X$# zz|@FP#+s3b0UWj*3kF3CI?!CzGfT) zS9dF7H0^slv*bwUi||7`=*8dUk^MOc)a)Fi6``Tp5J<&T>8g~Y^XmMIARD# zo{0@h;@l1Az3xMHR>rw@3JYDyjV%4M~NkpQ<`N& z3>%AAr1<}r6|Q`WR25cuAFsA!)rl2K*AyZtJvpXEWK(c1Cxh6Fi9cNZCnu+3aYZ$~ zuzKV?P-vYN-Xd070!8_3Q0*UtL)JA-pPGbT^!!>Y)GKaE>Aw-*_eA=hY<=3?VvvS> z(>8)p9Sw-2olATv?+^SJYy&~HMk*z$5y4IQE{yWOtt$lSCmq)Bx1IkS`X|^Vwyo7~ribA>`fm3vZSKx)?p`C%uoB{FZ-&3`y9f74 SKt8VfX6tXZt|7^E;lBY~-gb2W literal 0 HcmV?d00001 diff --git a/mochi_preview/dit/joint_model/__pycache__/layers.cpython-311.pyc b/mochi_preview/dit/joint_model/__pycache__/layers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6c2ed5725eee08636394485ebe4ad78a375b3f6 GIT binary patch literal 11080 zcmcIqZEO=sn(p>z`z!v8NkRY<2*eyD5awfdm2_G340yCH4 zR@?3F*vZXFb7lAI>Z-S@x~txL-lwYkcOH+Ef)KZ~TzY>UMg0M9s=-=`ynL6Ys9O|E zv2=kakZ?=eMpX%~Z=$lu=Xg zimqB-jxvtvSxTQPdi0t%X=6<*OWvgiqjtra6p|T1jAs&@;!0#wDK3E>Cn)Cl#e~8P zh#XHk)GUa+VihOGQyh;k1-(sADs)D`cZyv9)Bc`xM(oM-^bho;vx(7UY@FvlPI6aQ z+_6cqXDpk{h_Q5*&tmL!kE@dnN&w7OKPZ@-s+6i{l6k44WJAGgT4dh?rCL*T`Ww+VRxRI zq=!9V;XcJTluUDi$c-OKU*uSpxB&xFke{r7U@5>OhxN<%jAH1NO zX;7jPj5;(WR&f^cYxB;KYN?jZnR0Z_oU`7i??Bo;E#084)n{zUpre(9`Wn_~MAsY@ zv+16JirKY03*>3T%2H>kyQaarCbb8#^P&#LDol)Xyu!p{$xKp=#S~K}qd3o!z8916 zlwv!X%mCIYmb1x|=M-}~HO_B>qGSYMu3H2yHNtPi6q0Nl9mNrgB~o!gh{XhatOyC$ z$yDICLc(1(QY{T;f@V#tm(5vNs#aneEPlpn=oRcc5)7%DkIq?*&uh@q{MJgXOD5KG z-)i(uGd40V&_DbehFY{I$k`0(?&_1}Y}L1CZND_>d%*Z?tF&qZAhZwiK-wrC5j$#u zc_zkrQ89BFRG)&}+D~R##WEUCjX-GtTS2kJ`AdT0(C}6_Wh>q*ph35$G`*#62G=X?gm#7h0&$&s z<_*lnZ%*bd%f9A$cVW-c)^E9Q(vSOp|8CKDPV$|TedqGl665*8`MLA!!}ISh3`?Od zIn-4Qb}hXo1$!iBx6JH*%Dn!Bd0k@Ol$ke+%$p^zZ_ZCXhv(MKPCxZ-dg9$Ac{j`6 z&4nXNTmSZio<;vTmRAfjkmb*z@v-B(0~*gkV6A` zN6FPR=fC+TcwO3DtnJn&3wsB7h&%ST%?Gxf4)dcTZ*zq%`(e zL|3bMh5ehL6t+?qawel4&zN)OvlNf>B5F~c>oJI{kRDZ2cVI$42*oDG&|tt&(KieD zE{N+Df-{okuf}pG>r~0*m0cT)t{o+xf9|jx+AjIF=le_k=BNIy zC;qM_|I%T3=YZrtF8hz?Z6#mm7V{Oe5SDxq*%tvb@rLql(z#KZM}Gl7Xll_WbZD7) zngfUCj3F4{ZYf*w#zt+~R@L~i|&5Pc$2k)9Ff9pc?V!Q0ywxW_+ zb!u!YWeb!|PHStKs)`q8NV{=CO>MDSo6FQ%f^;`|u>`!67OL?5atM<2Jd~x8g)Ra- z1N&52s0COkFwim05n&gpVT2)!ley6Z3shJz4m$|5bON9{mFCkhR2+LKPS#;ZD)e$_ zkne!PN-&|O$3|e`Ho~+&Z)}7OPdD#yV*0dRTU^AyhQNr%S-41$QO(@@Ejy zqYD95o!RmV4|xRRMF2nl4m zC@x@aViNW_V%ZTTFfx)saEetDlN_%uWi zmY)ao+Bu7=Z8T-jGb5@@!-KNsX+Tqm1NsrVf_U@4q92m;%5zv?*=#i~exSV89|lxI z>m>o&pxMSHHFQKAi`rBw{v7+0e+|S@;xVF(4%%>bD9*|b1bU86R7P-9Ojq~7)f=^| zv-Km)ARu@FV}F5y@CCvEuw{a827n9c>_daK4?$J917ZzW`##*duzBI~uX~pIrM6vC zxLXc)=UvZS8;Y*za@&UD#@!G0NgMm5w*7M3{``qjcyrxW{f{x5;lcR3K6k1m`5UiH!4O%Ep?pDzZ6q`;6I7=m1THFy2P*B^(Afzwjpv>Z5% zIeGZhgCLFsvC(fJ^20%}F_A@{01IHSpxTz^1-P4`!P2M##Oi~njWc*vq@g6N1l@+# ztM12IDvPu|P>}Bh5w-Gh_))Bh(TU7u{vyWdqp-`7Oie`1HMmwRm*OHP;M>D!tBjAJ zc2Gak6|7EymVmi{)q(J>?O$zQ*!<;=`5jLK9Zv!sumtKT20Em`E;+F4rae#RUn{wN zUmW}V*w_8@M;G>dIWRv^ut=e;lDkuOcjjqW>{MKmO4u7CJq}JIM!N>m4`UG&AHXHt z2BAm#CTd4_*%Ppyq01CX!cWPq3i2e%Al^Gj)JbjVESy_>uk0eX2Gl1hq!zWV2C~a+ z_2YKv>O#eSeBK0HT0p0W;UyxI5&IpmZ&UlQe&C^I*bOzq4yYN9CXAlI#o6E>!p3^w zYJ$ARJET6sSH0jB9IOw#0(c4733UPL1Qh4#xR@9vJoP#Hcs>T=IZm0up`(g(9COv( zP6j+iVl)n0}KO8KwK2L4B5L=yzCS#VUmehGJPpF9%ohep`(Ws z5;-z8BmWQXqi%xfpgv~>;I2(YJKSl2tv^esk$*x>ai`OkV}SAk=-<O zIQ1Ewv+y1h1y0Ef`X(b_WCDgZv>ri}y$6>Xpw{Yg1@^dLsLX?~L*u&4@fj|q4u}no zCX%pst+7{F;UUOJT{WMfIK*tMS3SPKDbA0hxPk(mhhjUF&3xPodo1I8b_|X^U~gMY zic^XO$2ICy!Gsh{W0?$!EDBN&X_-`3b1wMQs&fIUC(yx~rD^U(0A=AfAke*p8?)x+ zwk?I@Qd^JQ2HPXE-qN~_@Z&CR>P8OT{(UGsKTvXqO77-TOGhcV0T$+^U~Ac9Z+5<* zK)gZ$B^+fR=-guk``f&L7VopPO0=7u~xh_iov}8zR3u_}z(rKK0$HvYB#t zW*vF>)Vmb~bz-}}0no?2)iFp=!v(OV8Pkk;#xi4t-94(d!z#ql@@o64EL|gY(_l|5 zeVMorq5U|}o}BfsfwZiq4|$f`R-?^;-C=v08)9{&v95kw&Q@2yPQP*enJb3+bhi?Ki5cSN8^_US>7-3$%`b%RXW;&dwQHzi$oX%?@+la-T60 zKR7C-3h>FQ@I+jg$nwuTzjy70F_+jyYs8YCOM(@g`J z@d+_9$|Wz2ijl7A&S*p!%}%7)$VD!~CO?J~Q8+@5T%3v!)Q{j97pR${-Gq=%d(`q* zldL!j=`9B%hay9fcOxG}cJ)QN4n~HekN~%MM0!x6ho(cPCIm53@d}abgb2-dPq*n* zvE;!>7pVnG!~Tg?7>&ceF98*TI~=^2Vnzpo{mtVNn}ClJ9L$OXX;`nM0kb&v2+wDE z9{nk72kbyC;?8mqQ%tc#{0v6ucNO{rJaT1Yqw3F90lQKIEFDV<=Y`TUX657u1v`RA>3 zabW!C3jKwNV(Tu+*Dd?H^Znnu;e5K-x)-*d`-|>A$=xTr`|>obuRZ=R20kBvE%lSL zCrN@YG@HY->b~VbXK~xxQs5mq@J`VU$7H06*B|;H#*6L)lKX({K9HxM0oT-3P$QSH zJVM|lq=ml>vB9CCqk~a*!+izx@jSYHb%)^xZ!RRU2oVqweg{HNIc?PX4W*XWQfoVG zF_c=`p(jg`PTXz?uE*^L5YIyE%2vFwQNeK8jxmOUGzZ3%QEeN`evAcl+Y(y!q>Yv0t!8AUYYCN_;(HmWDS9;wFSmivQ&ZMTXr%du zV;Ht2fJRyfjkMz8Rk6Q+>g*s;hJSxC`iK0tfBd(;!GuPzpdOsM#9jUeM1GWgOkBHb zj!x7;aczxNzyYrd28W=P18QdmVh1(WR(3{khW50T9Du(@`ScCIF7k@MvEZMS2zlct zK#@>2{ca_kwT%nvc77-1D~{7}pa%Q+aN2F<(&IqHY*z_fD*>Lnl!SjQ@W7=jDzFYD zI7im`pJVkk7KX2mpr$8azRDPB?CUNl2^T?tJufq!24~l$U>K+qvm9POKl(J>^(5T& z56$=2Kllii-qF)i^t22zd`1qR0WDXFaX)1?Jz+K#Y)cl2*(Ec(iq-EL3<^yG8B_s& z>%$#LsyL2s6t8gBpJU%>>xZ)*mGZpCO0WD5yS zq(DnEB1~(1a=p?sgGpF*%laYsx0d}f$Bo013Lx!mXpI8;c8Y7)Hzh-JN2ajE#)%?Nyk`7&3T{HJ*H_ySCn{AynMjZ=ocy18{5p zz^!#(t^2Zdz7@{0f~~iYES^}h{nq*S&Ifz$dhUCkcD?na>#c{qkHlivTT<6ax$ER( z@%NKb*9TJjhjROeQt-n(Q}VSeytcSc@@;$S?Rnzuc|c3vUfJ7QWO|7mz#otD(Bxp$ zwi2heLh>4m|6gG?iY2(Je~Y?Ksh_gRV(l(ds|mN+x);~#s|ytnvHEC>;??~ptQ-Ck z#=EfXpMp>vNtlKr9P{FbQBVx#Xf%6*-DN>p>v_$^VJimQC< z>n}4l8a$LH0+c6|KQ`BY^Z4xX>yEOmlZHdSdNIe(PrbnOSGCO3bdZKU_&0BheDnpTzp7=vu@`mg#av*%^#!KCs%55VCk?h$FY?^%=xp`{=Dcc< IO-Aj108}6M(EtDd literal 0 HcmV?d00001 diff --git a/mochi_preview/dit/joint_model/__pycache__/layers.cpython-312.pyc b/mochi_preview/dit/joint_model/__pycache__/layers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9672afbabd6e9178fbac7039a07998ffac5dd90a GIT binary patch literal 9941 zcmb7KYit`=cAg<;$l*(TiWVh5DETECw&F*!Y3;&Py@n`mZ%?KUwK0?Lr2(4H|VNi(45tG6l ziltJqizz|jP7lh-Y$le1e48l5L@}1>7vhy+B2Ab;;$S@Xf%pkbSApuIdHOmxO6Tcw zRK(Q1O5FllepJ5;;&13lO4cjOQ!-XY!T_Vis?Px|TNCb7N{B71>ftmsEnjlT4lOND zY4d!Js;0`WYW^6lm8)vrrn@6n#h#Rs8A*<1;)3FcXHzL5jsqkq=Geu!!ga}lNCwU< z$)du_LxU+n#L= zn$Ge<>Ov|uB8XDQ;E2Kw#zaZT1)~|#K*uYoY&@2d_IHr<253Rr3*y^r)RLodJa)6? z)`sZ~i=kbm(5?l??tj^PKlUH~g?*=%{7qBdPyeEDddVI5?9Hjpsmq^Un0L1oxwgkp znz{|07QGM=alS~$-wlg6%ZjzqvKWLn95=}g0@_ze!OXMpvj_}t6I+`0i=y(TJe@b^*}L?e ziZL1V3HVlZF{ZU@(5hD}SoP{!H7LweQQbhHZ7OQj>d!Oqrq?w>-8OaKHmPHT!xOP7 ztTZ$zhzb{tCNoJn8dXf0jAB1WMpsV8Qi|ndG6V3VFz1pd&nxD1YEWDcMadY0Su&E4 z>J!&t3Q1l){)#Ocji+Lg6pc#QsW1uG$XXURLP9Q7yA(RKo?uN<3J;(h`58hjFzKKb%kG=kJS0!)!(#*blJ@>aPIJ+KO znSi~({FDP5iFDEhW`WyWUcbG-{3q8!>_uNH&?04nA6Y7Pu%byS53rhp*r>URbpvQ$ zHGwzf&9a_W4@vWNm(yLdwopcyJTri;*Murs4mfMF;5Bd3(9%KW%{jnZ?R@~@p(bd{ zji&Ori&G!MzTRg-Gaf+DsR^2Pp5ZCf$$?>vvW9U$T&f9%+&VjBp4ED1&_g_{301v* z6VKdbwVh$qLgoXOy!eu#hC2PoTMYS4HPypw-lBa!Z~4%qj)B&J^Eh=w`s-9C)1B`la07wN6L!vgkp)2eO$5jrLr-3 z&u*fxFzKOG6l6?JDCShQUtxsd!KlI#CB+<%K^yTbY;~{&D)yWpW~FEupi|!32_l9|Q>&%haP5*GB zMz(|+J0`E+w@QF;`#&kFD?YS3bwMNVch%C%S%n` zZrP@7WpDG%Ba3Udm)35d{r*2*o@<-iHP?1;sI=qN`L+9tp4aZbv0^rb8VX&@7RnPU za!o2880{NjoZ!LVLIm)D8h~eQ_@NsBx^51!8O*~vY62-!hjZzx9RPhm2lQ&;s?Cp@ z^X7Auh^P@^)V=Qr#8o_QS5qUH@EQokB1dsyz=FiBQ^J-Jlq>Y2FDqV&iF~f*_4R8wYV3>9u#PWWPeedV4`v|o)+XpmKVvmVDA-IWn+p9sZ=z2 zn-;f0s)WXRK;4}X>qHdzlOzz|UZ<8hDzbgq>S1@lHWF+uZ|f*Wc6{xI_lJ#5%Pgi_ zsEu1R4Yz;UfoV=E)H>5R?Ot}_TLa|{O*P(HJH58#+42n*Ufx7md|#XFZ0NBWV#^qN z(!{cjWNb(;;?My-XdL44+mK{ra)ia1ZvtG_a<~B`r!@h;tdAqijl13et{XU0y*9$r z>;QIlrSS}mt-v3O5ntqpPS(U$D5?14YI?K}oR=X{hyvWHOn{I8oCAGRMPtHV6CLd(kBcUb0(~w(ltE8bWbDlA9vubU1#z@=#o>Qx-wpIf- zGw?4l3>}Qv#5Q=4^(mgm7@<#_Ft!=R78GCxlt`@CB1SM?)zpVqU9`yk(1gQAUbVjN z)am5vAe51Q00POC$3U*i8^4F#C|qvYg3qQ}vK5)#H`nt37}eS5ldVsD;>dpIia1b65ro0&uzgH-^WCr|6I0 zFLKQ){y4-FkSCr*aSBB@iqj|vwycH~(x7-2bKd}A#El+IMA3;Ehe052I1w!XDqL>X z{%~D*nT01nMKH8%#h9H!sKA(W6*!>!lQx=FDJrQp4mIjQTi}mzDvdw@E*;0y00JiX znc>F>jKHG7gKNhKmhXfE)seGYcL5y0DU1OQY>MNkAn-@k9YS#=Gn(N`WC1GwM z{O)MBPx1BjWe^aemBgeVsveo{sX$*aP8Nd`7QvtIk?~-|xTW$YBh?)><(VHgRDCG$ zeTbH}lpx=XR+_^>Fy{a{APxvaHfOY7DvEs}!C>um5C-XIr3Vand8|g>p$7fyo>K#O zhb~?(19ydu~*z;#9o3RTQJphWf%N*N35#7^kOXp zWZuEp1r%h9+y*fNAv6L=z%bc5Y~X$)je)3zjW2_tsn)4W(=W_+mYR3W2RjOmhmLhc zM`WpaU2*-+d;3c3_s=)KRya`(w$5C-ccQfAwfn85Er%XBN}JD>f<2(#ycygx_mVF# z8JUPoodDwS&g~6zLl53u@by5BwVKm1_tJylg6~X$`NaA(!l1d5MfLX(c|O9_<7Bvv zf=^c2elEasl!oO)4Pf0hLDr{jKnMD&N)^d!7NdGZ$5i}@7odc=6GVg+--U<5#uG!C zOJWS;^ilA#lc}MIxr%-Y(;t%s3BMD7Uu7%=1pUTot3Y20w4_5I0Q!8v$*mJxr&>SV zw&-gs`PzV~w-tSD^S&Kp)&gC4vF!Aq3;kK=DkwvMk{KK3k4xD`e!h)JRxf~So{?bN!> zGv}vIFFVL<4TNIalXf%fRml{#6NjH3G{5ojFeI(UP#YWsRR8gg??hA!9F|$&u*?F- zWXv;;$vl-~GFPt+j>xE9KwCHza{$?LD)uumIi4USOR*1PTs`aJ0BGZh7@TyfP8_1D zV$Wp7bd(%SqAwYfWg$b(brd&00`xc;k0#Uo(ZLw6qOMKynm|K+5YR_BYO*fM2!EU|nhA<{P)3W3_U}F{Jgz~x>0iK(VQ$M2fj0o2R zfahk=G-+lDnmRU*P=>a&3Dr6UPmRT`E)y|#gA?z#B#0Ryr4F_QPLGptAgNO<@C*t{hqO#8!**4Mhu;Zit_CAW`>Sh4$7pm_qz^!R3)C)Xpms~m8)uHsH-i)X zQ+IjIdU!a?8#<8Xt^G0(M6bH+43wRX<)*f>e;s&mWq)YJWo^V`GuyI<@-!FStwpXC z zz>33`^;KQ%BTp!lq5f&8-;%dHqkdB2sru{JPI~|RZzJ6088u-4M!^!l`wO1`ea~n^ zAA?cud7B^bB%q;rYagRo?I>M6$^;H%1r%WJavFnzn=5KSyQioQJFFF*#i2%nA*&LF zYbO2hlDBEvKU&?Iy-sWXr|VkP8s4Niy2*)OuAwV4I3$M?Lb5+0hud?TBVj3#9ZK=x zi$a)Bz6V!SaCaBJI1(oK9mcx~P%}k32#3qL)bdx7yqtjari0-_;hylZ@LS;>`@`)A z!#xp5fLFY;I;hY+xxlF*Ne)-=J)9kqq4|znvq2S09t^jWTA(yI?W~lD!9gH_5P~v1 ze4%2-tpG7r972JRrTDwEa#tD%dm0c*;E#x6Rur#54tS^$)IxVj96?o6^pKdt2z^YU z-@YwT3tn)!X)O#`IlLL- zU-?_WBMz)tX5k5H{@}6|V;lu(HjLS+#&uI~mjYXs9r)^`e8FWG#v0IRc4N##HLqXx zV$8S7ea5<<>}}MjANDE^KRw`XgyHcP5Ju|9n=5zewq9_dfv7QrsIh*JrvY}g6M$+afMwncgsM`% z28BR*MHTzCvhD!tzX`HJlf+-krHE8HWp zLV6IWj^zr0NW~`#{Ykj!6@iIXRKUX_*h1XVpJ4S>o`Y)+fy790JeASb`OR%ml8%A^ zezU~6v|GWlKL|X9TMD*JCMFV#!S+(H{chvuEw@|ly#wB2U$A9)>#mkTkj&Uy3fJ7W(}@r5B$-I^^zbAf~x^X4tt>`GTjO*N|*7% zYcP8SuB(JtMkF)%5hy5Dy#I>EB)EgnHxI=!9PJz2vs>YYqsV>*e)WZedT{5~rns_6 zd=Aw~cv1nIgx5E@W@62!A-Fp5hi)FZbz=I&tmSk2ZTr1lcU+6@drR$m=XT$hi|u>o z+fP1_e>?n};Zpls^K0KO`QI*ZWlz)8i?{Yo?}L-}MfVGEgLIFcckeE8yUDaiO#NRa zz9{_jw_on>j#z48a3g+p7mB(EeIH^DidpzcNLm2szA-Uu#}lWS-J^P&m>+3U+=f3` z>43ji5z(l{mq94DB$$T`=j;;sa}~Is85xAXW)KPV!o*cZo+I~%9hr>kWB9e#j=rG` z`9m1Es<&6;B9C=jgk&ZQriuVV{$fN#=RmQXBKIleUV)q`66ae5ZSqx(sE;?dP^30M%GaEQ_I~A~ zT!Hb?p^4B61<{ohZKwAXgd2&m#EtA&_8SUbNk9J&oh9Ld literal 0 HcmV?d00001 diff --git a/mochi_preview/dit/joint_model/__pycache__/mod_rmsnorm.cpython-311.pyc b/mochi_preview/dit/joint_model/__pycache__/mod_rmsnorm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5812523d84ffed781107add83920296e207b965b GIT binary patch literal 1542 zcma)6&u<$=6rS~Z*N&ZF;HA-X5c-)91|s#1^P6C_5Jp(2q7n+BTL_1C5f%xm2m|K|tnSHy zh|nRtJt>J;#?r1(3S;38%0|xkbW!W0{%7EOo`GQt5mXW|DusxM!w^R}fZ4;qA9zII zk=U!ZOPE5TO5Whuh9qENX>=WD@>h+yW1c+X+8?+VGJvAX#WTLmvp+%ZW(5{`K!}CQ z`Wa6%8)4E~)l?lOTIL1|>qIq}XE%9{1w4Hom~dUl*MDnOu8SYI|CDO9&vw5}L%_yn_fUC{?EVe_tj zGmw`b<#fZz8MzNW&grJMUREqh9+k<)g%K`0xr$jf9L_|lfMOnm2UxT=S!B(m8!E+H zuU+iD#XiDoZ+`=72X$qXoNA@_(|cE%R~qu^<>}Vy{_5Vn=DkJ?u=LedWxwK0?&(ea zB(-pyT6mItcK0Z?&`uRPsX`J@-o6}d^uy26c z37%7FvKyL^GF>!KFGM6KcR#x@+oSWD_{Oxs-vI34s$r(d>MV; zk4#qRI{@%oWv_zjih>}VqVXO7>Bw>V&2JM&6rb5IZ@h_29I$C9Q4_$*3MnLTiy&%+RMZ1L0KJ%?WwkR&ym-CK?3y^y z1|bn4UnMv6EUW4}ym|9Je(&v@ z_x87mi6mh3_^<5xD*=E%`6VW~FLdrPp#cy;umlP~0CxZqb^#JGuvyF~+0v*6fCI6o zD~L!&(k?7WTToYyxiW7JyZ8M-XNGAT0D}TVpderoi3sk(gT60|7j%i9x$B}z$qX>M zclUnYC5|yIdF-!L1p4UB>^3) zl>Mg_jJGjO1X~_9OkYn@(F&>*R;lb;W(iXfdp=dF#+rBII!$(UJz3+31_Lk&i!YX+ z8N~6aS~be-6)~b()o|js_4$Z?HK31Q=a*N%`EJ#5t&O7L5&W!(H%D%$7_8cEu^ey| zTw+$l6L9baZ;K|@T(W5rw0(YbfEM}|{+9kr#$JH;3P?{ivi0oV7d7SR+;rn%{o&rD zT9WCr7aMlnZcgqw?aX|Xncq*peQ=nWZ)KKi+PieNakYN+^^^VE(ag10da;&Z5of1g z--)I!)zqWROmntrw{y!;Zuys`R_<0bcPq-Q)U-cR)9ut;l$vWUwNh6O)hiuV2FSmG z>p%ScbMwQmZ&!4gs)kWE9c&mhY4ke9<#EGc9WC`O<9=WUMa#j#hKtCT4CD#I6LPL20ww(qmwEx*#TuUe!EP-wO9KRq_vpWNBxJi63h_UH z%^u9_LN_&)P0uTB=^{z9a5BXiuf*rM*o${7rIUkEh#!&Um_5TCvy8m}@3aeTZ8p+o zSx>*;zunRn52eLk>HJ5c%Y?53Rn1D^t`iec_1j9>3W{#I*9HzpE^#*K&qwuUH!WTv rIy3OC@VUo(5kmMo7(Y=`F!y7n@x1=Lon45s3nzdnPA)^Z)CKtr{Kz(^ literal 0 HcmV?d00001 diff --git a/mochi_preview/dit/joint_model/__pycache__/residual_tanh_gated_rmsnorm.cpython-311.pyc b/mochi_preview/dit/joint_model/__pycache__/residual_tanh_gated_rmsnorm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13cd0961d5cbcc852ca0681c3b92412a5de06ee5 GIT binary patch literal 1610 zcmaJ=&1)M+6ra_8=wt2Hj-_=;LSo}M@q#Mh6nY9Og_6*g8tS-E1&6; z(}KS%c@ZmE-V^c?7VaWFdcv0%3O#E$^nFjS!?cbF$_p6fBSgd!oFg2Vhp^rn*cyZMhrA+K{82gtTWesiDqi%=XPw6tim&+64rs%?)##A- zAR}KK_G>6TnxO&9!(P>xFZMIos}^%aF%bP5248d}ZN;r*rZT z$dWFw=of@IxL7{n*P+Llyi&4FkBN@G$)qwdEhc+5Evzw_x*ODEk_S)5;uXu?s1UM4 zSk&8gh+(?A%EW@V#e^*;ZyA)hOj?DPWg>B0mIzmj73Zy+Oy#TwY&R)kvRg1qgvB|l z5&Flr3=1lOV5Zm=E>3Xx5n$@8Ub+48bh_TBfhWxKFeG#pC4ERxN0GhFnt>vqxd zjIxbMDGMQmHD?UY!Uiq7{IHC(&7vzd-83oQnR-=mW_lj{A>Ho)c2P$`+NJvP{pD}( z?cb{lh=-unWJQ;xe+Ai+sXM_9Eg)w8gei(+l&Vj^R*Z(T$uRg z)8o|aQEIkf2C2DrYAzU`t0}*{o~n(WLUd^1fk1hz&~P!su*@12W@CC4HaUf*d2UN1dZ&5Uq_h`;HF=OgUa_`wXeSYjGO*Ev!?~&oep{hgopnN3Yc_F literal 0 HcmV?d00001 diff --git a/mochi_preview/dit/joint_model/__pycache__/residual_tanh_gated_rmsnorm.cpython-312.pyc b/mochi_preview/dit/joint_model/__pycache__/residual_tanh_gated_rmsnorm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39f734c62b34b22c8156d680465222525ff1e51c GIT binary patch literal 1449 zcmaJ=&2Jk;6rb7MS+Bi`+SK+IQfN|zh^z>fTBIHjGD1RB6k$PHB;*DyE6*gcv3EDK zWAdSO)k6*z5)u&yhZ8xv3Wr{h`WJBFQp*RUp$8MXRym|9` zzj^cKy&oqhGeF3vznx!2DgeLAO(JMy43~t_0T^Id2W6muYXB=d0ILMllmt9pYueaO ztOLO`~Obs{4i7+w^f4j`ZmF(@lU#VUq7@L+@!@rteS=axuV z#Z9+<)AiP_xs2eY_ivPZ+NdPCOT5fS@(!n=Sf7=IxrHKBOR_ix!th`eRXfc+ zLSO}}-)RSmA zryDNC+q3`GYlRU)wng}bP#Oyj^eq!@l+Mz*_8Ci`}LB zxd+x_-zv7vhiCI0bJyH^uRGU!;nn`xLK_|FAV1x?zI%O-^>Z%`a)n;5&|UB6F1OQy zEq8Gr^|JHbbT50kZ3xG?>{lNT^7Fm?{Jz`IFZ8As+WIeN^6kl?$Oe?|!qqQ-|9=0E zAMO=x#0|&s+y-$Ro^fOenssSUIF6`#eRMJrGM7~=4Z_xZOywxv1$0I%^hK%UK*dTL zVNR$W@DrMPqD&!cr~>`mV`&`+NJvGhM08=q(|AP~9l#Pl__wKxfaqXF^#6!KbSW&^ zN_;s-ZV=SBY>j3G2bF2iyi}G@p$w3X5ogRV{Le4b1p%FwG0zKi2RtoKi4;l547U(b7lp^MpTX=H{2%kSE-A+@lMmLu&VDJmqrrOd8tln_SAUexYkAb tia3xH70wCuNQDso3?`2BEVTBT-`@V__F(2}Z|3R|5RRio2;YoJ{sM4DPiz1H literal 0 HcmV?d00001 diff --git a/mochi_preview/dit/joint_model/__pycache__/rope_mixed.cpython-311.pyc b/mochi_preview/dit/joint_model/__pycache__/rope_mixed.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3435b9e5cdfc3a955a1a1606596793564725cfdc GIT binary patch literal 3917 zcma)9Z)hCH6`%ck|5hi-mh9Rj*7{$4m&#IPV-w^Muw^HXE3S<#YNDpSu5aei4o7W?9z-|yf)h=qWm?FWB5*D*Eu(!QC~u5@QR zbXGg}-kbN{yf-uNx1&Gx^u!UgU;T4`;XgqAgE2u%ZWGTQ0r3D~goP5CZvP6?0_VkP z5%0kwk@6yz?kdx=hzOcd0vA}`P*w$m*5Pe8tm4S$0*(?D$KZ{_n;_y$Bpla(i-uee zc-C<8uDS}M^+1fKW4IUec5hkoW2Pr(wFexuUA z1lhofIV1YDqER8^dJ2*mPPfB;!nW_B=qo&2U&VaZ%#;5@dm_SE8pe1+el*yJv`X1w z@MXufE4~EYC_5mO9Y$XvxIi371)8sk{@uiA*>Xp%(eodUmhD{8%vLB<0--Y`UW} zY(cYS8$B6N(^6^^2N$!c(3zB+bIu#VOS|_DZ@&60o>~}kVE?c3)R_L;gbsz;h6TEq<7xSW6+K^Wv{`D}2Phi&h_#z&+ z_2OQ_(+Uz+Hq@?F>G672LNE;GVZT_xkyq%(cIyPRCwtOHv|Ag;*wZN$kd-)!#*ib= zinpbY(X1e#k6|ALm*e$-^;MLGkv7K@K=TrJydgU_ft`F;{e*rOq z+;kx_y==JDyv>c%Li^OSuZb_3Gf*;7)57GoX%#42Q4mOw(-U@jDBJE#I!@Wq^COm> zEl>mds=*Q1m%(olAmkQxc}@G#GI5FpYGVEY%H8y31`@vL+UX=6WDfVSZa?eTPVW?; zFSuZnGfKo4W__XPtISQ--Tn>uSzm&r9QLtpBU8}RJ`nl)a!}060HH@!q93KOI^cYN5T#81Me2^)F4<`d5zEwNyh(HMP{D@H~n72k*yfKU&+f zcH`mx#=)^h->GKbsYT@*ZSV5c<%!zy&wjje__Lp^TwOc<#q{RkFMhdswV}P!)ZSSX z{tuPsDmrrFkB2se&9P16%VYJ!A2bes(1bpEXi;5OYbq%84lLcQon1M*cD8=-)FZK; zJl9B`YbMVvN?XbP-xu!{muYS6Q+vs73-b5!wKHq`7W0kd+s)+Li&E=EkRea=DNFd< zFzcQCyQmj0UUW&?>|03dUAW@xfz`Y}(K|6JeLX5qoQ`~bR)F4D@|4_i=rD}tLbwAa zgw!M71F3M05I3aGB4xnpp_$MZmTz*}5wD=%3LOv=!W~vTVI>44m_2M>7pMXcR=a>w z?C=;~dpxlVjHM2}9Z>Anli>ykPiIcw13Yl$%8LL#5#tZAP`qG&O6?hvb9G|miu&o3 zP?|e+4Nzq36+%xIiGhbhiYtv>)Y)cKiIp?StM6ERvj90x59@FN$TBBRo8%)5LQ2G- z7Qs5TU4uhO;PVdLxNNGs!!R^Etmo`fT4w?tCi4_Vxr1^XrZU~&1h*q@{i57J(Q6vMsn&Ibld`2QUh&0D-;E!K=)DIakHT`|Qd= zEaOIrVEP6ew7EnE594C~ z9BGA)fPj~eAKQ+tS`J927D&OIolXva0ro+9RRBMJ-^vCXM#n*I0$%6u&;j=84Eqk% z)HS6Z8)?Kwnz4}?wg7GBY9HPE%_4uk88~o%a^=lU>2K;^)XnOjd;Z$Eo96`mnQ>eup5?4&UgoAD;Km9^bEL4v{ww#a(ZCaIxz6Qk+6ZdFDzk|;}} zGNkGs&`40^V6`Z09~RM}T{x$8U>{uc&^<2NV=p#rv<5dQ7T84r-_$yQfn3@*q$n$j zx7#5&eDBSBZ{C}k_i_H((h@?@e(~k0^dV4RaK>xV8gcLfh^L4kf^%rn{o+aN@WP}( zI*35Ult9EM(xgv76lEQso*VZ|d>hg_uF*2_{|=J?l}QlZ5WHb3WWD(s3`qp!Z4H7N zNm^jW2#G$C*K3l=AZZ2N&SP}jj+du=r2UC+GCYT4C-%8n+?EVeWoA-paN-AAg!kYl z&^<+FEd_+KLW5R~w2*}@f#7-IEz$CA2Em7C?=~@ z;5erhjG~&PikV@0DVsF_UQcGUM3KP+bf$?TnmLu_i_#meOYw5Na%J`Wnz*Ju_ivUy{qb&O{PE4b zNc-|z%a<1KRJtk;E39&QJJPqRZioA~eEkPtjTk?HF>L$u>a3y7Qw=RUd>9T(nwklf zQPsv1wotqU?|mmZH%(qk^;+P|SFeSC0QYsZ;1%W-gd3~I2Y#9wRE=)~uh;5nK{ok3 zPRsI@Zfcff2^OA*ta$CL^SGcb_)Xy4W(i?r`MCxWxDM#Q#;x;Bg|-lAu)$3>_%$0` zw+S{NK*aUa*jus{v;suj@U7P}or2jGLJQ%ANQ31uP-lg#uoX#*MB0$+Ytdr0Wcgue zYD6Ob*K|%aDFAEvEO8@nq(y@K@pz1^mV4+eG=dB~BVhCa>=H(gKsGCWfM#$kG{5VV z;Ol@hF{;4fRfd!anlI{1%_&67Ln&%{p$-nkTb|M_s--jLk&2?G0Lql2YC?si=8i1x zdlvUKVrIVDz_EA(6pFers{C|9xu_Ix@$UYha`|TvBgi!qB2)6J$+TI=IELM`&Hqrg zpk<*j0$PF4S*?&}K@fsKgq-&I8DYG*Xz&WTsQG zb|6oUOqyxL`H>{2Z6AjKTQK!llyz~3CwX^@cl=}@m6?rAut};pYU3FjXKb0fNtm^- z$IsXzBxUf^<6@52GyadXPbQ(3O`7l$Pq>M*AsGAyyv9GF133963bihsEuZ~u@0XGG zrSbB3rG53nZe(B)zl@^xuBXA3VCC(#jnRZ;ouLpPt`6_xLBGNcWzWx3QzMdvv;5#SHLDJ7KzV>AH26B^j%5k!9vq4d zn#aM^k*1Hx{PA*vjNsyu5{>}}@sT610fb-uJxt~l)sL8-rAbqvYBHl-9)j{Ysy%=U zT`5vFl%XmatW%EC$a@OkloBl@HG2FF$JsQf?#qJ;{C*SMiQ$sY2ru&-Rat=&h3Tg1 zl)LA14(>ZTQ_Mjani*7*dM>7LffFY8RFBfhg4dD@CGlnefxC%gWa`uO2w`ocO6tn_`7y(ubaHTQoA-N7<6jtGXI5^JS_9aV1{T~1M(}MaS zM+rLv$80_vCxaxD8hR(5NyTG+7v2JD$<&joVcNa|?4T`ExJto|=R8>0BE)BlFh@*c z!ys_V$PU(GE9E_st9eo|=T{F0qC9+p^hy9w&W$JGZ5ZRXoH2Nf1?T{u6pl|l<(`VX zCjBiqya(VgU!JeL|LaFv02z;7be?`XzB0agc2oS5{0Din^v9M@N;{o57lVA{_R8&w zw$nM#%pC?%=v&ngigYb&TjAanP*}8B;ov1TZALu0Ph@pZP~rAI+;F4tb_aDko_uf9WH{U9DRq~ z0G3y;@ey=Y5o7!rYWoa@zd(_H$x%G8eB}Var`m^vuE(eUjytz;XXX7}+*@r!-M#;A z?Z>Tq{!mqf@8^MVRf2C-M!`tc^vK?2fOgPem^XkWTkzE%x$E{Zx%Irax45|paI gWfQmyD|_L#rR(MERT1dVgVCikt3%#Ffq5uE@ literal 0 HcmV?d00001 diff --git a/mochi_preview/dit/joint_model/__pycache__/temporal_rope.cpython-311.pyc b/mochi_preview/dit/joint_model/__pycache__/temporal_rope.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c9187db654069bbaf96d87cfac912ea3770ee72 GIT binary patch literal 1742 zcmb7EO>7%Q6rS~Zf3j(97x9xIDh)q^I#{-<2&A}(sx^X8qzFh4s6tvcp0T~odUrEB zZsKT*EL3hj_*Nf4%BcvILysIea*U%$jrJ4?F1}fcxby(L+0-jHN%ftMbGZJb@w-@(}?7_fUwM$xEI} zAZjLfJw`$52Kq({ApnvhGD&|c*HZI|`W-obqnz<39m=7!4PL`^m=e!zNY~CZqr^&x z?gOU>Nn-9Oa5ugKcVipdp!*;1s(i{vGd*zBe<0st@i*+zOnb1!YecgzQjFNZ4__D_VZ=@ zS(^d{%Oxd`;ZNwqZcN2Gn`IHNHY}f(@ulkpyjs8u1>x0fP>Tcw{6RkPCw!h<70bnI z6jPv08E&=%jv*kR0PbUHopuUBSo3%T2i&TzQ-UqWgERzQi{dqlR~tfM0Ta7bOq?Z+ ze=5XJ85P-Sb1(#hVY~EH8c&)5rGGT@3?ng_&*NE}O8a!i%lIbLa=ONG1hqK1MP$Tp&Jalj1)9u+(yS z$t_*`xYY8h4cqh?y=~LY6B4m`sp;7+H(MT|PKnc&?=j0Undj4@--+hpmIF}?h6#gV z0u4qjhly{D;28W0TNjOAL+7%+rSFyRGJCUc?oR%P+2{I;18I;RTppzROTDkV%LkbY z{nz^&eLi@#&-?kF+x=*)q6_nbrD5)koy~{y!?()AT=`M9vX`y=QW<8K_p{4zEjRZs z8h@y0HrIQ*zqK>B`{KgG)UL5OG#2-b#Ye`KJ>$yH*N4W{edB7kaxj}Y%@hv1H;y0{ zl$Tis9+7%HlX0X!BK9(zu{qct>+eCf+75l6T?A27KllQ+u`Ef_Av)KMe`EEcls%Zq m^j;aufDW_I_rkFP2!hSM09X@PN9yd|FUK7%Q6rS~Ze`2?`i}*=UD-AyiacUV=6-aRrCEXg%4!4Qel3)@H#fqW3Zt98QFi??>03^yGRm9GDa%Jd5F~;Ge^8PH@Br1PK8{H~ zNueGe%Z}cS1%&+AHT1ODt{ z7HI#gI@&tQ&Y15&umpy)?0I$zGnbpJg=wQi39+22kGYF2$7^!TspGrM$4zKn74sU! zRqG~ouxM_RI_1`(2HvuG%>{$9+wg3dYA)-W77!SO_T~@B;uSfo~ z1^ii)0tC|`Z@UbCLdSk1DCWs1^LV{xdbEHSR&sbfhZl1qs!^jR@pJgY>?EFudHk%J z4lYql0XAj0-t;+!ga87BkGKum$_b|A@*4KJS>B)oo3;yO@ZAQ*C6kwHLSYdTtC61s zn~eQ4DSpbRs7{lEBUlT|q0f}@xEoM(I;T9!DKKVH8C1 zElNhn@=T2B+Cba`u@{49#)CA=c66#Pf!!sEVbx3+d$uMxf*s5`JAFr$x zipw7t8*aH~86KlIExL7#BNi{#UCZG{!zI)%a@z1*X4(dGJ(~Ag;cPT+z>1|XU?~iM zz^G|6@i!6p{NLbeqv7l5OnP^zv-G3#Q|gD*pUL^X`21dS{z3YM?&Y4;i}x<|;@zdr zm+hox|?0zd#%g6*^blxXsDudbG@a0=FQt%cjx+V75bUN{q$--z54s=z4TSE z$;>|Q^CJ}{GrJc%7rWcHXZK!Syc_!~aru7YNSzh0^nOO{Rgke+xE|~8L5Z3+eV<(bRSz9)_1ggfSpBuW%X%?uSu z9Rb3^i_(IyQUh(x6h_OmhF7i)R15UMK=Tmvp~%BNkiq~G1~6ct?StQFI13c}(w_4V zIlPGSI_*E=;XnU>{`23@`Ocw!=k>ZMNWc8g&4d5iN>TrTFXiMYWnQ1BDe68YPy!vN zM(KWfl<8+^d}rd8ehc}wj@tTdr84`dEeiFcj(!$OEO)7Xr(hLq4`{(IT82q~C|PjS zys_1919riA*WT}Ds0)ITt>gzme$_ldN&uv65||rQIQSA}hZ&u{l2T&H z1RvLJH`A%O*cUY?1PnI#;QMO<$bCwMrB|68a<@Iuo5|dpk#y z(Vn5qY_Het)sj=qT3}=NlS?+gG`Wce&)jd9lYVf=CEd6s$ZnHUCa^Em&{zGenoF6 zo2lKf4Q1Ey)ySp=i(vURQ>Fu!%BjtnN1=LJDm7eQ2kNP{b_CP33Qw;30Gdq&TChE} zKZ3ST%=S5}SzVP8=%*hse< zUCFq?xF{0rAf>ypG-52jv=vJ5%F6!+WQHofLwUUOU9(*aZTC*joy=PQ-ubuAT-&#v z2c8x8o_zTaZuQ$R`Qhs0bA&v?f=hF6U3KqRaqq}od-z%2y`$hhq`42x^sKr43w!Q& zeAAH~TiUbGQ3!Tw!Oqp-v6bMlXKnf5u|n{)7DOchtUh49H#Gm%>{mH@>3HU=g7>KA zJvwuK%^jHeMDoHkqVO?O*+9L%4n=iS%e> zNaO`scO}zGDy@)Wy8y~HDn?O#m6NH;z{*m;rSDTp)y01VTu;iJRAZG;%R^Htqf!y` zk}GBN4^oEB3$M~(iO6+UtxuU+!wRj<@om%^6iuUc<_d1u&XPhU)u^pGP63sMm2WRA z4LrcLcN|OP6x>igj?2n*fME)s;`-9C{#CafRg4TuLrg?PE-oepm7$6PKpEzKl_m7V zMnPi`dmd9JQE$W&uy-|A78_zRcevcR(mh{37aTbPwVhUuL%HGlUJagdw_J}nz^CJi z33L)%%N1>8Hs?Zk>yB4#yutrFhjJ8v=`o|T57Js;d0Nr zpicf8$bZjJf80hnefJ!5j#YNc3cIDizN4}432-~gtK5#Rw+U3z}a3|K3A;%#o` zv{fv(JmD)v&U=r7f#(u0iPc*@oqd^F-Y z^PZL4;kR0cKBw)#i5P5M0?lEpgR>f9*8+xVhiU{?xYQSr{7!^)%RO;q<8%O1wcBzgH#^tCR6;s!f16rclS(oa2 zns(Ieyy@Dg4%Ml;VE^4uJ&)jQJTcD)Lb;sj_*8dLFFS`;e_eugq3cGAdmmWG2nFl* z9{L6i?B&nASp!dS9RNn6Lx*n?lm!a75%Ahrfg%KJEaLx($WS}|A8;f#TL8E-VMQYjfA)iIEizUC!is=3yrl#08cjr=P( zm;hN@D6en+{OtLKBlmjedNXv=Srxv5IUlTj%1wA zL(PllvfJo}+IGq%u{ zom~E?XkomrR}_$+kig(vxRhk<_b$&}UiEM*9xfLE$KBJedD>BClhKfREG!e zBm!vWAnr75H1(l@Gqz(RBnU|4zXk%JqWr)Fw(eXDA1bJm7->oVC_N=ywS&(C=>D zb(I%xq~bB-Ugia5VhR7$Q zD2|UO!EaWE#KzEDXC@|f2l22)S*N>ox=*Js6PH%Pb&6KQ6NH+0U|#<_oP&*{JZfEp z1zaGDYE^Af03i`PK_pc>MUY3iMme1`-J)_Ce%BjX05#08X`ltVI*y}M!GK|N2F57N zK%PUTD;ucV!R>;6pO}3EKw@$xWO~y`Hh?P@)lp~i8I|1tHd(<*q-((>tQwnZ2%IYO zfF5AoDlyw>x@P9)Dxix8ZZlmEYT?+~`RlMHW*la{V>dB*io423M?}H&e{Kx%DG}g> zl2qUTp8f^b&F$wp`nUt!bCxFAM=rP3avQGIVUN|^Y~%EhV?0O72Pzu9n_1W5qW zKk znFJr$#fcQ$b)5yA06;>D=vIst>mDQ079lzSLL-)_G9i6}BKwfkb%Oyc=?oCvZg{#f zo_kzck_>&1R$(ewQBj9W2w)h><(EJJ0kl(2&nmlph25SVEU@hw+n#6J*LLhg)0Jnp zW-N2Ij69cKh%Sn29^d?h*$WG2?_Hd`NX&85<~cPxw&LgV<-g~Fmc^}r0NQf=cYDA6 zVySJp`(Jv0xRnoFDg-WRflC?tTCjCBxO*kI8%$%cT?@9a1`n+S4}tIn`jv&*ga&n! zE8_qt;`Gn^XZ>KGkLS7z>^_a%muL5_c{j~}G5ZBTl(RYd+jF@i1@B(XyEnstm2PU* z{B5~ixv`wEOy~XY75wjM{`d0id&I2t0n9oNvz9I(`51`qhyYHDqMel$2!=#QMP(={ z7&~V-K)KY5++@8`(GXMnm@@1X!Y}D^nFx@P{{~3eq{8wt78_fnkhmy&u*h(>qc1#x zxywZx=ssc}9)YAM zrnDQ>V;Z44_MzMOR4P6hw!irdS7ITh2x}AA%m~~M;~qpNCM8_f-@r%afXq56K!F@t=pU7KJkV!^?(M)9*@bw%N{Y>ak4J2&U00H1cbsYi= za}@wsRJ}EX58!eW_~5yfhz(5MP$ZEbm8W))w{;*vHmVQnma)i$ZW%!`j$|w{iLXG$ zBa=i+L``%9b45c-HWt0~H<+epLjwc=S(P7Tro9{RGpem2zi$J6i)wGk z@2}yn02Qz5Q~iSw53e851~DL9v8rft8R7_5)l&x$np8lrf(_zkKx=;9Y7BMk4P-cK zMs3EB)FuHUFzfnj+5qxcAkyL3LE{An{Oy_H zBD_5_To{A78VhVZjGw^_7x3*Rn^3AD8vAo7X>4R1x@wH$j#gouM)NX)D2Qc9SVejl za(&Vbe8(FZ_DDA|j~J#Na1e*{I8;DvTmkGOeU40@VS#;+2WLf+PQXhzAbo)qgGh#u z#E>B3kVcTik<=j?46R7rND$FTXOVn_qz4G-xsh=h9UFpds!0{f2~}&-GS2xeG?F{u zXo4Yrw;|5iv+s8w|HH@MU3_@4z_wvbAQQ`OG0@LPShIBPyGsu*t;0dhS+?NeGW36e zDeqm7vmY%C=lHy*qu}Y#JRKSO`Ohu)ZJHa&9)UaKeYXDt>C+$`7UHEFM_QQ{lMpD?)~f`E!eRd>|63wYJK|*#OfA`kypO6q03DxK+s@(n$X&NF~jp;)7V(%+b z_fs{;d1fCCAbmYqXkOf*1v*|~)lZfBo3IKemv22g_WkJ}oByya|G_Qo*rzy$ll6q% nN5k=7Phdi4v`za_@NpDeVnV^^=x67?zl?&;!7{vwlC$)`wXcKo literal 0 HcmV?d00001 diff --git a/mochi_preview/dit/joint_model/__pycache__/utils.cpython-312.pyc b/mochi_preview/dit/joint_model/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..514a8912bd0f9bcc3cad1a17156e848e23cdd7b8 GIT binary patch literal 9399 zcmb_iU2GdycAgpj4gW<-vaQ&bNAaI1%XYku?L_MhR*Gy#i6YyvoivKp+%ab)k)lXy zW~fNo5nvQvl!UXD6m%gvEfBhSh+G?}^0W?GAbHHT1y+m%NFBIX1G_H<`k+K^vDlaP zoI5ilMUO4-wpWgabEZO&0g4Ejf1{VcTDZczPh!7e!djuxDvEzw}I zulx$Gr?j)y?SoFieZ$%BWvFwM;JHBwUeS@*QQ>En?%MQHzn_fs!AL)F+-3GxeiO}= zvRNuq*9ol1Sg*vob&!Frqy`xwaKqjoB69~}?z%r??l8=4xo5yV1l;QomPp9~#=D3{AB*ng%u>s_j9FSvu<1%FPlnTeGG6^`^+9c3-jn*mXngTkN-g?|6 zSV7+-B#h9SJDZZnvZ6R8Mr;~Ap*iH?_@DwmF+M5ExC)Nr{@BsiG0@pFFq)1H$D3Nd3-xT0S47DN+*o+z2xQl(HLq`$o;@>v ze(rp+bGhL_sqR3@dvIa@KOVhv^nS~7+v$?`^pgGb1`MXo-k~+uh$v=+_^5QsJkU$HLd(OMY&LrGbT}Dx$@3~Aq zQ(lS!ULZgyM%MgProwO9rrJP|SbiI9vIk^@yO2&NIVC+JrerRj;wHGP98V2#cuYq; zWJKTj%V#>c_e74D)p3x-?!x52N&^~y}^&^kB`j? zv!0T7=c>24-=UNM63;SkUm%~R_!zW7N6Zcz| z!Y7u)c#mOF00!lUXRhY27U_jkxvR_m(m3OGrFQEUf(Lz%yqf${*S5Zp&)+SJ;)g+Y$wNGxNYQM*rHV){^vu{NYnqe~~ z3M5HeQraeH8IZvf9Lg)6v_ zGTePJ6%)Cnm>N=sEB6?Al+VQE&>bIzi%*PaT$#k{A5X#gY`S?!h^c&vCUIqY-aLAa z9P{Q$D@Se;J~^vC#+;Vx76JoXv=m#zv#@R9uNh zr9IH0F{za1?2V^HUeX*&Iyy40+2LnQb3zUiMx3B2#N?oaAc&Ai9F~x)=8i@|W5{we zD&skm4Y*E}gR~#&Q`t6K-Ccj2C4ns%u+pa|==>U;`zahg?~iN@+yy$nDAS7Je> zvV*qrYMns6-GbUBcs7l}O|MB696EsOa8G+wm*q-aPhAI7p2G#SXuxVxWsYfA)u~#z z2&zkUs~$L8?>*mLvRvCEVz~ z2vdPjmk~mZ^U16z1)z^~+@U4(Q3=F)gvw+H9+CiMxk1CHqQpsX3sO1*>NA>&gr#jT zK(iw%)abLC4e%a;OQ>(*Of(ySLkSQ+C83j|t7Y(!2El;=tf09SiBHLyv@B}QZ0Z`I zNeMNXwu|^n03OmXiv$i&Oo^6FjOH?2EX`%`k=lSL;*}2J9BBIqzSHP#=@@czPDByT zX$%+udQ4-{L=gZ&k7)F`jCu%n$w)c?bPcAZ4|z$LYTT^njpma;Blmz{0vy~%`2#a& z^Jiy|-R#NH>rSe%>61{QySVQk=%uDesiAGP;b^Jh=yJococlpTHp$o?f_PL>DUo#oh&cP;=WPY*`x~b;vIXZW24Iu7Y`L_ymq5IRmTYU>}EbTn8>^qsG%N;lS zax^YBu=Mvn~;kc3hy;aL!yO$@Vi2CACh zu--8-nT$pODYWB(q2B+oZo_vY5oV~TO!@{ksMw;0PvR`dB6r{0ej>M4_NlMp6P5ZcWlxQ;s^h zz`9x-sOkjA#Ry*jI$;v%-PWWk3v7j5Y^tklff<#xL@0J}t7sFa!A`LW3ye}(F>o|? zYMQQk%NJ@E232%_z>XSh<})JLc_poYn4SI+ z*TuDQt$o}ft~b(N<)*X|0Yz0GsrU}4<_)Rp;NInvaf7ef$bo4?KPa2Z7_$;wafI76 zYkNcw!XP(mYYoixYoJFFD)bO?>3gnK!i!(7T%S!0AZAnX#w=fxm~=4tEwJ(A(meVf zNW{2HN1-he&_XglE=IwkMX$k|qP3IGRUBpbc?hxY{C0bTuj)_D>`?Gz_6 zaHBOA3^rJFDXQ5qn5p^na8i`q^oSd9fHEPyjV${xtCeQMa9Yrh#;L0-Jrf>Yz`KTpEJD9aE{{-eW97&eDAsg%J44X?en|lcCEXx#!|lEx*JO#EP1iy z!;&9M0eGL8pXGnH9>iK5rxz@ zn3QnO{u&+`k;8TBfwythyA#5Ag|4D&A+*rBFt)gV+55_p{T1?FF&I@&rt7+avikpE z|E+O>!`ik3POu)2*uhs`W4Qp>!LogDGdge)bl|&^iVsd+QY4Wdm8Y6X?FkGZ>)l5* z+gNl$vyEUjj@ejr604AnM<cv3+v~P>323SCdYS3WE4W85fEhre( zQ9}{fg2JXcYbb(M6jtcNuLjiM5X65g>#~fa;C@k7!BrEv2zJ$1Mgw&!Kvcm2F($~3 zPcl}B#kq;fRkc=&Po`>}01=Iz6Q_u?mVFe(K^gfP!4;!3AJa`(C^a@lW zA?X9`7{Y8Avp8l5B%~3{l9<&16^wRCU6>(IkvcJZ9kXu8K#Prx%jlI5z*3!6ftQAA zJ!!(1Ujh++8vt{%{!$Hm*>~{sH*UZ2*_$P{1%2{dys$%uG_Pa(!o|KylsK{nUVYmL>BmBs1R8QtppCO23{%!URn$-^7r?C zwg1cgU&O!I^L5MD7ncL)a<+%zrcYW5my0hK01CCPhL4uQM;DpJ&c(6i@T)o31K^j; zCkw}mbn(=3;84!C77Wi!=cfT-H9=&7JFy%*xx}6%HWR}6HLyqe32yUM%*X*pv4m() z`YC4USm~QQh$YNEhDZJovK!RD+gl$xLhH<4$Ek-9SAlTIx&z7pYhY1Wb774&Yi_K0 z%$gT#KC|Y>TEMIYu~uh*qwhmdGe8ZL^flp1HaV#XYKbXBA^cN7jq1E>IJ=i9=^79N zupuvt03+gh4ndZY-VjT&S(-nKMmuVlWl4hme}+d!IY(r(ebu|WMH=8U%47yU?&?t(w1qknW%Y1uDkU9&Hm%;Q9vRH0V#fbzcq(P- zWu!z5B9W8DV72Can?!v`6h&kCtdho0*&2J+U_o#lfn)kc;PPw4;g9TTJ#_S%gysqH zf&44T)@?LR|CW03ca-lxsK(z?dzPp@-#cmgM<4~A^jZ3QDolrGPt2d5JH3v@V)k?O zwn~cM?Eg*suiH1U^>N#eX*x99g&%DmQ&@g{mN`ftpKY9Po@-tWwU$Dyk11$azYNp8 zx_IUOiLXw7dHS2i-|YJJuB9JdDV=zCgMyC7>_Pez%=X&cYpeBbrTR7`Iz@=i(D%=L P)%#`dV;FaaCY%1hr dim_x) + self.qkv_y = nn.Linear(dim_y, 3 * dim_x, bias=qkv_bias, device=device) + + # Query and key normalization for stability. + assert qk_norm + self.q_norm_x = RMSNorm(self.head_dim, device=device) + self.k_norm_x = RMSNorm(self.head_dim, device=device) + self.q_norm_y = RMSNorm(self.head_dim, device=device) + self.k_norm_y = RMSNorm(self.head_dim, device=device) + + # Output layers. y features go back down from dim_x -> dim_y. + self.proj_x = nn.Linear(dim_x, dim_x, bias=out_bias, device=device) + self.proj_y = ( + nn.Linear(dim_x, dim_y, bias=out_bias, device=device) + if update_y + else nn.Identity() + ) + + def run_qkv_y(self, y): + cp_rank, cp_size = get_cp_rank_size() + local_heads = self.num_heads // cp_size + + if is_cp_active(): + # Only predict local heads. + assert not self.qkv_bias + W_qkv_y = self.qkv_y.weight.view( + 3, self.num_heads, self.head_dim, self.dim_y + ) + W_qkv_y = W_qkv_y.narrow(1, cp_rank * local_heads, local_heads) + W_qkv_y = W_qkv_y.reshape(3 * local_heads * self.head_dim, self.dim_y) + qkv_y = F.linear(y, W_qkv_y, None) # (B, L, 3 * local_h * head_dim) + else: + qkv_y = self.qkv_y(y) # (B, L, 3 * dim) + + qkv_y = qkv_y.view(qkv_y.size(0), qkv_y.size(1), 3, local_heads, self.head_dim) + q_y, k_y, v_y = qkv_y.unbind(2) + return q_y, k_y, v_y + + def prepare_qkv( + self, + x: torch.Tensor, # (B, N, dim_x) + y: torch.Tensor, # (B, L, dim_y) + *, + scale_x: torch.Tensor, + scale_y: torch.Tensor, + rope_cos: torch.Tensor, + rope_sin: torch.Tensor, + valid_token_indices: torch.Tensor, + ): + # Pre-norm for visual features + x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size + #print("x in attn", x.dtype, x.device) + + # Process visual features + qkv_x = self.qkv_x(x) # (B, M, 3 * dim_x) + assert qkv_x.dtype == torch.bfloat16 + qkv_x = all_to_all_collect_tokens( + qkv_x, self.num_heads + ) # (3, B, N, local_h, head_dim) + + # Process text features + y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y) + q_y, k_y, v_y = self.run_qkv_y(y) # (B, L, local_heads, head_dim) + #print("y in attn", y.dtype, y.device) + #print(q_y.dtype, q_y.device) + #print(self.q_norm_y.weight.dtype, self.q_norm_y.weight.device) + # self.q_norm_y.weight = self.q_norm_y.weight.to(q_y.dtype) + # self.q_norm_y.bias = self.q_norm_y.bias.to(q_y.dtype) + # self.k_norm_y.weight = self.k_norm_y.weight.to(k_y.dtype) + # self.k_norm_y.bias = self.k_norm_y.bias.to(k_y.dtype) + q_y = self.q_norm_y(q_y) + k_y = self.k_norm_y(k_y) + + # Split qkv_x into q, k, v + q_x, k_x, v_x = qkv_x.unbind(0) # (B, N, local_h, head_dim) + q_x = self.q_norm_x(q_x) + q_x = apply_rotary_emb_qk_real(q_x, rope_cos, rope_sin) + k_x = self.k_norm_x(k_x) + k_x = apply_rotary_emb_qk_real(k_x, rope_cos, rope_sin) + + # Unite streams + qkv = unify_streams( + q_x, + k_x, + v_x, + q_y, + k_y, + v_y, + valid_token_indices, + ) + + return qkv + + @torch.compiler.disable() + def run_attention( + self, + qkv: torch.Tensor, # (total <= B * (N + L), 3, local_heads, head_dim) + *, + B: int, + L: int, + M: int, + cu_seqlens: torch.Tensor, + max_seqlen_in_batch: int, + valid_token_indices: torch.Tensor, + ): + _, cp_size = get_cp_rank_size() + N = cp_size * M + assert self.num_heads % cp_size == 0 + local_heads = self.num_heads // cp_size + local_dim = local_heads * self.head_dim + total = qkv.size(0) + + if FLASH_ATTN_IS_AVAILABLE: + with torch.autocast("cuda", enabled=False): + out: torch.Tensor = flash_attn_varlen_qkvpacked_func( + qkv, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen_in_batch, + dropout_p=0.0, + softmax_scale=self.softmax_scale, + ) # (total, local_heads, head_dim) + out = out.view(total, local_dim) + else: + raise NotImplementedError("Flash attention is currently required.") + print("qkv: ",qkv.shape, qkv.dtype, qkv.device) + expected_size = 2 * 44520 * 3 * 24 * 128 + actual_size = qkv.numel() + print(f"Expected size: {expected_size}, Actual size: {actual_size}") + q, k, v = qkv.reshape(B, N, 3, local_heads, self.head_dim).permute(2, 0, 3, 1, 4) + with torch.autocast("cuda", enabled=False): + out = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + dropout_p=0.0, + is_causal=False + ) + out = out.transpose(1, 2).reshape(B, -1, local_heads * self.head_dim) + + x, y = pad_and_split_xy(out, valid_token_indices, B, N, L, qkv.dtype) + assert x.size() == (B, N, local_dim) + assert y.size() == (B, L, local_dim) + + x = x.view(B, N, local_heads, self.head_dim) + x = all_to_all_collect_heads(x) # (B, M, dim_x = num_heads * head_dim) + x = self.proj_x(x) # (B, M, dim_x) + + if is_cp_active(): + y = all_gather(y) # (cp_size * B, L, local_heads * head_dim) + y = rearrange( + y, "(G B) L D -> B L (G D)", G=cp_size, D=local_dim + ) # (B, L, dim_x) + y = self.proj_y(y) # (B, L, dim_y) + return x, y + + def forward( + self, + x: torch.Tensor, # (B, N, dim_x) + y: torch.Tensor, # (B, L, dim_y) + *, + scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm. + scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm. + packed_indices: Dict[str, torch.Tensor] = None, + **rope_rotation, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass of asymmetric multi-modal attention. + + Args: + x: (B, N, dim_x) tensor for visual tokens + y: (B, L, dim_y) tensor of text token features + packed_indices: Dict with keys for Flash Attention + num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens + + Returns: + x: (B, N, dim_x) tensor of visual tokens after multi-modal attention + y: (B, L, dim_y) tensor of text token features after multi-modal attention + """ + B, L, _ = y.shape + _, M, _ = x.shape + + # Predict a packed QKV tensor from visual and text features. + # Don't checkpoint the all_to_all. + qkv = self.prepare_qkv( + x=x, + y=y, + scale_x=scale_x, + scale_y=scale_y, + rope_cos=rope_rotation.get("rope_cos"), + rope_sin=rope_rotation.get("rope_sin"), + valid_token_indices=packed_indices["valid_token_indices_kv"], + ) # (total <= B * (N + L), 3, local_heads, head_dim) + + x, y = self.run_attention( + qkv, + B=B, + L=L, + M=M, + cu_seqlens=packed_indices["cu_seqlens_kv"], + max_seqlen_in_batch=packed_indices["max_seqlen_in_batch_kv"], + valid_token_indices=packed_indices["valid_token_indices_kv"], + ) + return x, y + +#@torch.compile(disable=not COMPILE_MMDIT_BLOCK) +class AsymmetricJointBlock(nn.Module): + def __init__( + self, + hidden_size_x: int, + hidden_size_y: int, + num_heads: int, + *, + mlp_ratio_x: float = 8.0, # Ratio of hidden size to d_model for MLP for visual tokens. + mlp_ratio_y: float = 4.0, # Ratio of hidden size to d_model for MLP for text tokens. + update_y: bool = True, # Whether to update text tokens in this block. + device: Optional[torch.device] = None, + **block_kwargs, + ): + super().__init__() + self.update_y = update_y + self.hidden_size_x = hidden_size_x + self.hidden_size_y = hidden_size_y + self.mod_x = nn.Linear(hidden_size_x, 4 * hidden_size_x, device=device) + if self.update_y: + self.mod_y = nn.Linear(hidden_size_x, 4 * hidden_size_y, device=device) + else: + self.mod_y = nn.Linear(hidden_size_x, hidden_size_y, device=device) + + # Self-attention: + self.attn = AsymmetricAttention( + hidden_size_x, + hidden_size_y, + num_heads=num_heads, + update_y=update_y, + device=device, + **block_kwargs, + ) + + # MLP. + mlp_hidden_dim_x = int(hidden_size_x * mlp_ratio_x) + assert mlp_hidden_dim_x == int(1536 * 8) + self.mlp_x = FeedForward( + in_features=hidden_size_x, + hidden_size=mlp_hidden_dim_x, + multiple_of=256, + ffn_dim_multiplier=None, + device=device, + ) + + # MLP for text not needed in last block. + if self.update_y: + mlp_hidden_dim_y = int(hidden_size_y * mlp_ratio_y) + self.mlp_y = FeedForward( + in_features=hidden_size_y, + hidden_size=mlp_hidden_dim_y, + multiple_of=256, + ffn_dim_multiplier=None, + device=device, + ) + + def forward( + self, + x: torch.Tensor, + c: torch.Tensor, + y: torch.Tensor, + **attn_kwargs, + ): + """Forward pass of a block. + + Args: + x: (B, N, dim) tensor of visual tokens + c: (B, dim) tensor of conditioned features + y: (B, L, dim) tensor of text tokens + num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens + + Returns: + x: (B, N, dim) tensor of visual tokens after block + y: (B, L, dim) tensor of text tokens after block + """ + N = x.size(1) + + c = F.silu(c) + mod_x = self.mod_x(c) + scale_msa_x, gate_msa_x, scale_mlp_x, gate_mlp_x = mod_x.chunk(4, dim=1) + + mod_y = self.mod_y(c) + if self.update_y: + scale_msa_y, gate_msa_y, scale_mlp_y, gate_mlp_y = mod_y.chunk(4, dim=1) + else: + scale_msa_y = mod_y + + # Self-attention block. + x_attn, y_attn = self.attn( + x, + y, + scale_x=scale_msa_x, + scale_y=scale_msa_y, + **attn_kwargs, + ) + + assert x_attn.size(1) == N + x = residual_tanh_gated_rmsnorm(x, x_attn, gate_msa_x) + if self.update_y: + y = residual_tanh_gated_rmsnorm(y, y_attn, gate_msa_y) + + # MLP block. + x = self.ff_block_x(x, scale_mlp_x, gate_mlp_x) + if self.update_y: + y = self.ff_block_y(y, scale_mlp_y, gate_mlp_y) + + return x, y + + def ff_block_x(self, x, scale_x, gate_x): + x_mod = modulated_rmsnorm(x, scale_x) + x_res = self.mlp_x(x_mod) + x = residual_tanh_gated_rmsnorm(x, x_res, gate_x) # Sandwich norm + return x + + def ff_block_y(self, y, scale_y, gate_y): + y_mod = modulated_rmsnorm(y, scale_y) + y_res = self.mlp_y(y_mod) + y = residual_tanh_gated_rmsnorm(y, y_res, gate_y) # Sandwich norm + return y + + +#@torch.compile(disable=not COMPILE_FINAL_LAYER) +class FinalLayer(nn.Module): + """ + The final layer of DiT. + """ + + def __init__( + self, + hidden_size, + patch_size, + out_channels, + device: Optional[torch.device] = None, + ): + super().__init__() + self.norm_final = nn.LayerNorm( + hidden_size, elementwise_affine=False, eps=1e-6, device=device + ) + self.mod = nn.Linear(hidden_size, 2 * hidden_size, device=device) + self.linear = nn.Linear( + hidden_size, patch_size * patch_size * out_channels, device=device + ) + + def forward(self, x, c): + c = F.silu(c) + shift, scale = self.mod(c).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class AsymmDiTJoint(nn.Module): + """ + Diffusion model with a Transformer backbone. + + Ingests text embeddings instead of a label. + """ + + def __init__( + self, + *, + patch_size=2, + in_channels=4, + hidden_size_x=1152, + hidden_size_y=1152, + depth=48, + num_heads=16, + mlp_ratio_x=8.0, + mlp_ratio_y=4.0, + t5_feat_dim: int = 4096, + t5_token_length: int = 256, + patch_embed_bias: bool = True, + timestep_mlp_bias: bool = True, + timestep_scale: Optional[float] = None, + use_extended_posenc: bool = False, + rope_theta: float = 10000.0, + device: Optional[torch.device] = None, + **block_kwargs, + ): + super().__init__() + + + self.in_channels = in_channels + self.out_channels = in_channels + self.patch_size = patch_size + self.num_heads = num_heads + self.hidden_size_x = hidden_size_x + self.hidden_size_y = hidden_size_y + self.head_dim = ( + hidden_size_x // num_heads + ) # Head dimension and count is determined by visual. + self.use_extended_posenc = use_extended_posenc + self.t5_token_length = t5_token_length + self.t5_feat_dim = t5_feat_dim + self.rope_theta = ( + rope_theta # Scaling factor for frequency computation for temporal RoPE. + ) + + self.x_embedder = PatchEmbed( + patch_size=patch_size, + in_chans=in_channels, + embed_dim=hidden_size_x, + bias=patch_embed_bias, + device=device, + ) + # Conditionings + # Timestep + self.t_embedder = TimestepEmbedder( + hidden_size_x, bias=timestep_mlp_bias, timestep_scale=timestep_scale + ) + + # Caption Pooling (T5) + self.t5_y_embedder = AttentionPool( + t5_feat_dim, num_heads=8, output_dim=hidden_size_x, device=device + ) + + # Dense Embedding Projection (T5) + self.t5_yproj = nn.Linear( + t5_feat_dim, hidden_size_y, bias=True, device=device + ) + + # Initialize pos_frequencies as an empty parameter. + self.pos_frequencies = nn.Parameter( + torch.empty(3, self.num_heads, self.head_dim // 2, device=device) + ) + + # for depth 48: + # b = 0: AsymmetricJointBlock, update_y=True + # b = 1: AsymmetricJointBlock, update_y=True + # ... + # b = 46: AsymmetricJointBlock, update_y=True + # b = 47: AsymmetricJointBlock, update_y=False. No need to update text features. + blocks = [] + for b in range(depth): + # Joint multi-modal block + update_y = b < depth - 1 + block = AsymmetricJointBlock( + hidden_size_x, + hidden_size_y, + num_heads, + mlp_ratio_x=mlp_ratio_x, + mlp_ratio_y=mlp_ratio_y, + update_y=update_y, + device=device, + **block_kwargs, + ) + + blocks.append(block) + self.blocks = nn.ModuleList(blocks) + + self.final_layer = FinalLayer( + hidden_size_x, patch_size, self.out_channels, device=device + ) + + def embed_x(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: (B, C=12, T, H, W) tensor of visual tokens + + Returns: + x: (B, C=3072, N) tensor of visual tokens with positional embedding. + """ + return self.x_embedder(x) # Convert BcTHW to BCN + + #@torch.compile(disable=not COMPILE_MMDIT_BLOCK) + def prepare( + self, + x: torch.Tensor, + sigma: torch.Tensor, + t5_feat: torch.Tensor, + t5_mask: torch.Tensor, + ): + """Prepare input and conditioning embeddings.""" + #("X", x.shape) + with torch.profiler.record_function("x_emb_pe"): + # Visual patch embeddings with positional encoding. + T, H, W = x.shape[-3:] + pH, pW = H // self.patch_size, W // self.patch_size + x = self.embed_x(x) # (B, N, D), where N = T * H * W / patch_size ** 2 + assert x.ndim == 3 + B = x.size(0) + + with torch.profiler.record_function("rope_cis"): + # Construct position array of size [N, 3]. + # pos[:, 0] is the frame index for each location, + # pos[:, 1] is the row index for each location, and + # pos[:, 2] is the column index for each location. + pH, pW = H // self.patch_size, W // self.patch_size + N = T * pH * pW + assert x.size(1) == N + pos = create_position_matrix( + T, pH=pH, pW=pW, device=x.device, dtype=torch.float32 + ) # (N, 3) + rope_cos, rope_sin = compute_mixed_rotation( + freqs=self.pos_frequencies, pos=pos + ) # Each are (N, num_heads, dim // 2) + + with torch.profiler.record_function("t_emb"): + # Global vector embedding for conditionings. + c_t = self.t_embedder(1 - sigma) # (B, D) + + with torch.profiler.record_function("t5_pool"): + # Pool T5 tokens using attention pooler + # Note y_feat[1] contains T5 token features. + # print("B", B) + # print("t5 feat shape",t5_feat.shape) + # print("t5 mask shape", t5_mask.shape) + assert ( + t5_feat.size(1) == self.t5_token_length + ), f"Expected L={self.t5_token_length}, got {t5_feat.shape} for y_feat." + t5_y_pool = self.t5_y_embedder(t5_feat, t5_mask) # (B, D) + assert ( + t5_y_pool.size(0) == B + ), f"Expected B={B}, got {t5_y_pool.shape} for t5_y_pool." + + c = c_t + t5_y_pool + + y_feat = self.t5_yproj(t5_feat) # (B, L, t5_feat_dim) --> (B, L, D) + + return x, c, y_feat, rope_cos, rope_sin + + def forward( + self, + x: torch.Tensor, + sigma: torch.Tensor, + y_feat: List[torch.Tensor], + y_mask: List[torch.Tensor], + packed_indices: Dict[str, torch.Tensor] = None, + rope_cos: torch.Tensor = None, + rope_sin: torch.Tensor = None, + ): + """Forward pass of DiT. + + Args: + x: (B, C, T, H, W) tensor of spatial inputs (images or latent representations of images) + sigma: (B,) tensor of noise standard deviations + y_feat: List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77, y_feat_dim=2048) + y_mask: List((B, L) boolean tensor indicating which tokens are not padding) + packed_indices: Dict with keys for Flash Attention. Result of compute_packed_indices. + """ + B, _, T, H, W = x.shape + + # Use EFFICIENT_ATTENTION backend for T5 pooling, since we have a mask. + # Have to call sdpa_kernel outside of a torch.compile region. + with sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION): + x, c, y_feat, rope_cos, rope_sin = self.prepare( + x, sigma, y_feat[0], y_mask[0] + ) + del y_mask + + cp_rank, cp_size = get_cp_rank_size() + N = x.size(1) + M = N // cp_size + assert ( + N % cp_size == 0 + ), f"Visual sequence length ({x.shape[1]}) must be divisible by cp_size ({cp_size})." + + if cp_size > 1: + x = x.narrow(1, cp_rank * M, M) + + assert self.num_heads % cp_size == 0 + local_heads = self.num_heads // cp_size + rope_cos = rope_cos.narrow(1, cp_rank * local_heads, local_heads) + rope_sin = rope_sin.narrow(1, cp_rank * local_heads, local_heads) + + for i, block in enumerate(self.blocks): + x, y_feat = block( + x, + c, + y_feat, + rope_cos=rope_cos, + rope_sin=rope_sin, + packed_indices=packed_indices, + ) # (B, M, D), (B, L, D) + del y_feat # Final layers don't use dense text features. + + x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels) + + patch = x.size(2) + x = all_gather(x) + x = rearrange(x, "(G B) M P -> B (G M) P", G=cp_size, P=patch) + x = rearrange( + x, + "B (T hp wp) (p1 p2 c) -> B c T (hp p1) (wp p2)", + T=T, + hp=H // self.patch_size, + wp=W // self.patch_size, + p1=self.patch_size, + p2=self.patch_size, + c=self.out_channels, + ) + + return x diff --git a/mochi_preview/dit/joint_model/context_parallel.py b/mochi_preview/dit/joint_model/context_parallel.py new file mode 100644 index 0000000..d93145d --- /dev/null +++ b/mochi_preview/dit/joint_model/context_parallel.py @@ -0,0 +1,163 @@ +import torch +import torch.distributed as dist +from einops import rearrange + +_CONTEXT_PARALLEL_GROUP = None +_CONTEXT_PARALLEL_RANK = None +_CONTEXT_PARALLEL_GROUP_SIZE = None +_CONTEXT_PARALLEL_GROUP_RANKS = None + + +def local_shard(x: torch.Tensor, dim: int = 2) -> torch.Tensor: + if not _CONTEXT_PARALLEL_GROUP: + return x + + cp_rank, cp_size = get_cp_rank_size() + return x.tensor_split(cp_size, dim=dim)[cp_rank] + + +def set_cp_group(cp_group, ranks, global_rank): + global \ + _CONTEXT_PARALLEL_GROUP, \ + _CONTEXT_PARALLEL_RANK, \ + _CONTEXT_PARALLEL_GROUP_SIZE, \ + _CONTEXT_PARALLEL_GROUP_RANKS + if _CONTEXT_PARALLEL_GROUP is not None: + raise RuntimeError("CP group already initialized.") + _CONTEXT_PARALLEL_GROUP = cp_group + _CONTEXT_PARALLEL_RANK = dist.get_rank(cp_group) + _CONTEXT_PARALLEL_GROUP_SIZE = dist.get_world_size(cp_group) + _CONTEXT_PARALLEL_GROUP_RANKS = ranks + + assert ( + _CONTEXT_PARALLEL_RANK == ranks.index(global_rank) + ), f"Rank mismatch: {global_rank} in {ranks} does not have position {_CONTEXT_PARALLEL_RANK} " + assert _CONTEXT_PARALLEL_GROUP_SIZE == len( + ranks + ), f"Group size mismatch: {_CONTEXT_PARALLEL_GROUP_SIZE} != len({ranks})" + + +def get_cp_group(): + if _CONTEXT_PARALLEL_GROUP is None: + raise RuntimeError("CP group not initialized") + return _CONTEXT_PARALLEL_GROUP + + +def is_cp_active(): + return _CONTEXT_PARALLEL_GROUP is not None + + +def get_cp_rank_size(): + if _CONTEXT_PARALLEL_GROUP: + return _CONTEXT_PARALLEL_RANK, _CONTEXT_PARALLEL_GROUP_SIZE + else: + return 0, 1 + + +class AllGatherIntoTensorFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor, reduce_dtype, group: dist.ProcessGroup): + ctx.reduce_dtype = reduce_dtype + ctx.group = group + ctx.batch_size = x.size(0) + group_size = dist.get_world_size(group) + + x = x.contiguous() + output = torch.empty( + group_size * x.size(0), *x.shape[1:], dtype=x.dtype, device=x.device + ) + dist.all_gather_into_tensor(output, x, group=group) + return output + + +def all_gather(tensor: torch.Tensor) -> torch.Tensor: + if not _CONTEXT_PARALLEL_GROUP: + return tensor + + return AllGatherIntoTensorFunction.apply( + tensor, torch.float32, _CONTEXT_PARALLEL_GROUP + ) + + +@torch.compiler.disable() +def _all_to_all_single(output, input, group): + # Disable compilation since torch compile changes contiguity. + assert input.is_contiguous(), "Input tensor must be contiguous." + assert output.is_contiguous(), "Output tensor must be contiguous." + return dist.all_to_all_single(output, input, group=group) + + +class CollectTokens(torch.autograd.Function): + @staticmethod + def forward(ctx, qkv: torch.Tensor, group: dist.ProcessGroup, num_heads: int): + """Redistribute heads and receive tokens. + + Args: + qkv: query, key or value. Shape: [B, M, 3 * num_heads * head_dim] + + Returns: + qkv: shape: [3, B, N, local_heads, head_dim] + + where M is the number of local tokens, + N = cp_size * M is the number of global tokens, + local_heads = num_heads // cp_size is the number of local heads. + """ + ctx.group = group + ctx.num_heads = num_heads + cp_size = dist.get_world_size(group) + assert num_heads % cp_size == 0 + ctx.local_heads = num_heads // cp_size + + qkv = rearrange( + qkv, + "B M (qkv G h d) -> G M h B (qkv d)", + qkv=3, + G=cp_size, + h=ctx.local_heads, + ).contiguous() + + output_chunks = torch.empty_like(qkv) + _all_to_all_single(output_chunks, qkv, group=group) + + return rearrange(output_chunks, "G M h B (qkv d) -> qkv B (G M) h d", qkv=3) + + +def all_to_all_collect_tokens(x: torch.Tensor, num_heads: int) -> torch.Tensor: + if not _CONTEXT_PARALLEL_GROUP: + # Move QKV dimension to the front. + # B M (3 H d) -> 3 B M H d + B, M, _ = x.size() + x = x.view(B, M, 3, num_heads, -1) + return x.permute(2, 0, 1, 3, 4) + + return CollectTokens.apply(x, _CONTEXT_PARALLEL_GROUP, num_heads) + + +class CollectHeads(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor, group: dist.ProcessGroup): + """Redistribute tokens and receive heads. + + Args: + x: Output of attention. Shape: [B, N, local_heads, head_dim] + + Returns: + Shape: [B, M, num_heads * head_dim] + """ + ctx.group = group + ctx.local_heads = x.size(2) + ctx.head_dim = x.size(3) + group_size = dist.get_world_size(group) + x = rearrange(x, "B (G M) h D -> G h M B D", G=group_size).contiguous() + output = torch.empty_like(x) + _all_to_all_single(output, x, group=group) + del x + return rearrange(output, "G h M B D -> B M (G h D)") + + +def all_to_all_collect_heads(x: torch.Tensor) -> torch.Tensor: + if not _CONTEXT_PARALLEL_GROUP: + # Merge heads. + return x.view(x.size(0), x.size(1), x.size(2) * x.size(3)) + + return CollectHeads.apply(x, _CONTEXT_PARALLEL_GROUP) diff --git a/mochi_preview/dit/joint_model/layers.py b/mochi_preview/dit/joint_model/layers.py new file mode 100644 index 0000000..aa40a67 --- /dev/null +++ b/mochi_preview/dit/joint_model/layers.py @@ -0,0 +1,178 @@ +import collections.abc +import math +from itertools import repeat +from typing import Callable, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + + return parse + + +to_2tuple = _ntuple(2) + + +class TimestepEmbedder(nn.Module): + def __init__( + self, + hidden_size: int, + frequency_embedding_size: int = 256, + *, + bias: bool = True, + timestep_scale: Optional[float] = None, + device: Optional[torch.device] = None, + ): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=bias, device=device), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=bias, device=device), + ) + self.frequency_embedding_size = frequency_embedding_size + self.timestep_scale = timestep_scale + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + half = dim // 2 + freqs = torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) + freqs.mul_(-math.log(max_period) / half).exp_() + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + return embedding + + def forward(self, t): + if self.timestep_scale is not None: + t = t * self.timestep_scale + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + + +class PooledCaptionEmbedder(nn.Module): + def __init__( + self, + caption_feature_dim: int, + hidden_size: int, + *, + bias: bool = True, + device: Optional[torch.device] = None, + ): + super().__init__() + self.caption_feature_dim = caption_feature_dim + self.hidden_size = hidden_size + self.mlp = nn.Sequential( + nn.Linear(caption_feature_dim, hidden_size, bias=bias, device=device), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=bias, device=device), + ) + + def forward(self, x): + return self.mlp(x) + + +class FeedForward(nn.Module): + def __init__( + self, + in_features: int, + hidden_size: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + device: Optional[torch.device] = None, + ): + super().__init__() + # keep parameter count and computation constant compared to standard FFN + hidden_size = int(2 * hidden_size / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_size = int(ffn_dim_multiplier * hidden_size) + hidden_size = multiple_of * ((hidden_size + multiple_of - 1) // multiple_of) + + self.hidden_dim = hidden_size + self.w1 = nn.Linear(in_features, 2 * hidden_size, bias=False, device=device) + self.w2 = nn.Linear(hidden_size, in_features, bias=False, device=device) + + def forward(self, x): + x, gate = self.w1(x).chunk(2, dim=-1) + x = self.w2(F.silu(x) * gate) + return x + + +class PatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten: bool = True, + bias: bool = True, + dynamic_img_pad: bool = False, + device: Optional[torch.device] = None, + ): + super().__init__() + self.patch_size = to_2tuple(patch_size) + self.flatten = flatten + self.dynamic_img_pad = dynamic_img_pad + + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=patch_size, + stride=patch_size, + bias=bias, + device=device, + ) + assert norm_layer is None + self.norm = ( + norm_layer(embed_dim, device=device) if norm_layer else nn.Identity() + ) + + def forward(self, x): + B, _C, T, H, W = x.shape + if not self.dynamic_img_pad: + assert H % self.patch_size[0] == 0, f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})." + assert W % self.patch_size[1] == 0, f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})." + else: + pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0] + pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1] + x = F.pad(x, (0, pad_w, 0, pad_h)) + + x = rearrange(x, "B C T H W -> (B T) C H W", B=B, T=T) + #print("x",x.dtype, x.device) + #print(self.proj.weight.dtype, self.proj.weight.device) + x = self.proj(x) + + # Flatten temporal and spatial dimensions. + if not self.flatten: + raise NotImplementedError("Must flatten output.") + x = rearrange(x, "(B T) C H W -> B (T H W) C", B=B, T=T) + + x = self.norm(x) + return x + + +class RMSNorm(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-5, device=None): + super().__init__() + self.eps = eps + self.weight = torch.nn.Parameter(torch.empty(hidden_size, device=device)) + self.register_parameter("bias", None) + + def forward(self, x): + x_fp32 = x.float() + x_normed = x_fp32 * torch.rsqrt(x_fp32.pow(2).mean(-1, keepdim=True) + self.eps) + return (x_normed * self.weight).type_as(x) diff --git a/mochi_preview/dit/joint_model/mod_rmsnorm.py b/mochi_preview/dit/joint_model/mod_rmsnorm.py new file mode 100644 index 0000000..ffbb4c8 --- /dev/null +++ b/mochi_preview/dit/joint_model/mod_rmsnorm.py @@ -0,0 +1,23 @@ +import torch + + +class ModulatedRMSNorm(torch.autograd.Function): + @staticmethod + def forward(ctx, x, scale, eps=1e-6): + # Convert to fp32 for precision + x_fp32 = x.float() + scale_fp32 = scale.float() + + # Compute RMS + mean_square = x_fp32.pow(2).mean(-1, keepdim=True) + inv_rms = torch.rsqrt(mean_square + eps) + + # Normalize and modulate + x_normed = x_fp32 * inv_rms + x_modulated = x_normed * (1 + scale_fp32.unsqueeze(1)) + + return x_modulated.type_as(x) + + +def modulated_rmsnorm(x, scale, eps=1e-6): + return ModulatedRMSNorm.apply(x, scale, eps) diff --git a/mochi_preview/dit/joint_model/residual_tanh_gated_rmsnorm.py b/mochi_preview/dit/joint_model/residual_tanh_gated_rmsnorm.py new file mode 100644 index 0000000..0bb96e2 --- /dev/null +++ b/mochi_preview/dit/joint_model/residual_tanh_gated_rmsnorm.py @@ -0,0 +1,27 @@ +import torch + + +class ResidualTanhGatedRMSNorm(torch.autograd.Function): + @staticmethod + def forward(ctx, x, x_res, gate, eps=1e-6): + # Convert to fp32 for precision + x_res_fp32 = x_res.float() + + # Compute RMS + mean_square = x_res_fp32.pow(2).mean(-1, keepdim=True) + scale = torch.rsqrt(mean_square + eps) + + # Apply tanh to gate + tanh_gate = torch.tanh(gate).unsqueeze(1) + + # Normalize and apply gated scaling + x_normed = x_res_fp32 * scale * tanh_gate + + # Apply residual connection + output = x + x_normed.type_as(x) + + return output + + +def residual_tanh_gated_rmsnorm(x, x_res, gate, eps=1e-6): + return ResidualTanhGatedRMSNorm.apply(x, x_res, gate, eps) diff --git a/mochi_preview/dit/joint_model/rope_mixed.py b/mochi_preview/dit/joint_model/rope_mixed.py new file mode 100644 index 0000000..f2952bd --- /dev/null +++ b/mochi_preview/dit/joint_model/rope_mixed.py @@ -0,0 +1,88 @@ +import functools +import math + +import torch + + +def centers(start: float, stop, num, dtype=None, device=None): + """linspace through bin centers. + + Args: + start (float): Start of the range. + stop (float): End of the range. + num (int): Number of points. + dtype (torch.dtype): Data type of the points. + device (torch.device): Device of the points. + + Returns: + centers (Tensor): Centers of the bins. Shape: (num,). + """ + edges = torch.linspace(start, stop, num + 1, dtype=dtype, device=device) + return (edges[:-1] + edges[1:]) / 2 + + +@functools.lru_cache(maxsize=1) +def create_position_matrix( + T: int, + pH: int, + pW: int, + device: torch.device, + dtype: torch.dtype, + *, + target_area: float = 36864, +): + """ + Args: + T: int - Temporal dimension + pH: int - Height dimension after patchify + pW: int - Width dimension after patchify + + Returns: + pos: [T * pH * pW, 3] - position matrix + """ + with torch.no_grad(): + # Create 1D tensors for each dimension + t = torch.arange(T, dtype=dtype) + + # Positionally interpolate to area 36864. + # (3072x3072 frame with 16x16 patches = 192x192 latents). + # This automatically scales rope positions when the resolution changes. + # We use a large target area so the model is more sensitive + # to changes in the learned pos_frequencies matrix. + scale = math.sqrt(target_area / (pW * pH)) + w = centers(-pW * scale / 2, pW * scale / 2, pW) + h = centers(-pH * scale / 2, pH * scale / 2, pH) + + # Use meshgrid to create 3D grids + grid_t, grid_h, grid_w = torch.meshgrid(t, h, w, indexing="ij") + + # Stack and reshape the grids. + pos = torch.stack([grid_t, grid_h, grid_w], dim=-1) # [T, pH, pW, 3] + pos = pos.view(-1, 3) # [T * pH * pW, 3] + pos = pos.to(dtype=dtype, device=device) + + return pos + + +def compute_mixed_rotation( + freqs: torch.Tensor, + pos: torch.Tensor, +): + """ + Project each 3-dim position into per-head, per-head-dim 1D frequencies. + + Args: + freqs: [3, num_heads, num_freqs] - learned rotation frequency (for t, row, col) for each head position + pos: [N, 3] - position of each token + num_heads: int + + Returns: + freqs_cos: [N, num_heads, num_freqs] - cosine components + freqs_sin: [N, num_heads, num_freqs] - sine components + """ + with torch.autocast("cuda", enabled=False): + assert freqs.ndim == 3 + freqs_sum = torch.einsum("Nd,dhf->Nhf", pos.to(freqs), freqs) + freqs_cos = torch.cos(freqs_sum) + freqs_sin = torch.sin(freqs_sum) + return freqs_cos, freqs_sin diff --git a/mochi_preview/dit/joint_model/temporal_rope.py b/mochi_preview/dit/joint_model/temporal_rope.py new file mode 100644 index 0000000..a8276db --- /dev/null +++ b/mochi_preview/dit/joint_model/temporal_rope.py @@ -0,0 +1,34 @@ +# Based on Llama3 Implementation. +import torch + + +def apply_rotary_emb_qk_real( + xqk: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, +) -> torch.Tensor: + """ + Apply rotary embeddings to input tensors using the given frequency tensor without complex numbers. + + Args: + xqk (torch.Tensor): Query and/or Key tensors to apply rotary embeddings. Shape: (B, S, *, num_heads, D) + Can be either just query or just key, or both stacked along some batch or * dim. + freqs_cos (torch.Tensor): Precomputed cosine frequency tensor. + freqs_sin (torch.Tensor): Precomputed sine frequency tensor. + + Returns: + torch.Tensor: The input tensor with rotary embeddings applied. + """ + assert xqk.dtype == torch.bfloat16 + # Split the last dimension into even and odd parts + xqk_even = xqk[..., 0::2] + xqk_odd = xqk[..., 1::2] + + # Apply rotation + cos_part = (xqk_even * freqs_cos - xqk_odd * freqs_sin).type_as(xqk) + sin_part = (xqk_even * freqs_sin + xqk_odd * freqs_cos).type_as(xqk) + + # Interleave the results back into the original shape + out = torch.stack([cos_part, sin_part], dim=-1).flatten(-2) + assert out.dtype == torch.bfloat16 + return out diff --git a/mochi_preview/dit/joint_model/utils.py b/mochi_preview/dit/joint_model/utils.py new file mode 100644 index 0000000..502e3ec --- /dev/null +++ b/mochi_preview/dit/joint_model/utils.py @@ -0,0 +1,189 @@ +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.Tensor: + """ + Pool tokens in x using mask. + + NOTE: We assume x does not require gradients. + + Args: + x: (B, L, D) tensor of tokens. + mask: (B, L) boolean tensor indicating which tokens are not padding. + + Returns: + pooled: (B, D) tensor of pooled tokens. + """ + assert x.size(1) == mask.size(1) # Expected mask to have same length as tokens. + assert x.size(0) == mask.size(0) # Expected mask to have same batch size as tokens. + mask = mask[:, :, None].to(dtype=x.dtype) + mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1) + pooled = (x * mask).sum(dim=1, keepdim=keepdim) + return pooled + + +class AttentionPool(nn.Module): + def __init__( + self, + embed_dim: int, + num_heads: int, + output_dim: int = None, + device: Optional[torch.device] = None, + ): + """ + Args: + spatial_dim (int): Number of tokens in sequence length. + embed_dim (int): Dimensionality of input tokens. + num_heads (int): Number of attention heads. + output_dim (int): Dimensionality of output tokens. Defaults to embed_dim. + """ + super().__init__() + self.num_heads = num_heads + self.to_kv = nn.Linear(embed_dim, 2 * embed_dim, device=device) + self.to_q = nn.Linear(embed_dim, embed_dim, device=device) + self.to_out = nn.Linear(embed_dim, output_dim or embed_dim, device=device) + + def forward(self, x, mask): + """ + Args: + x (torch.Tensor): (B, L, D) tensor of input tokens. + mask (torch.Tensor): (B, L) boolean tensor indicating which tokens are not padding. + + NOTE: We assume x does not require gradients. + + Returns: + x (torch.Tensor): (B, D) tensor of pooled tokens. + """ + D = x.size(2) + + # Construct attention mask, shape: (B, 1, num_queries=1, num_keys=1+L). + attn_mask = mask[:, None, None, :].bool() # (B, 1, 1, L). + attn_mask = F.pad(attn_mask, (1, 0), value=True) # (B, 1, 1, 1+L). + + # Average non-padding token features. These will be used as the query. + x_pool = pool_tokens(x, mask, keepdim=True) # (B, 1, D) + + # Concat pooled features to input sequence. + x = torch.cat([x_pool, x], dim=1) # (B, L+1, D) + + # Compute queries, keys, values. Only the mean token is used to create a query. + kv = self.to_kv(x) # (B, L+1, 2 * D) + q = self.to_q(x[:, 0]) # (B, D) + + # Extract heads. + head_dim = D // self.num_heads + kv = kv.unflatten(2, (2, self.num_heads, head_dim)) # (B, 1+L, 2, H, head_dim) + kv = kv.transpose(1, 3) # (B, H, 2, 1+L, head_dim) + k, v = kv.unbind(2) # (B, H, 1+L, head_dim) + q = q.unflatten(1, (self.num_heads, head_dim)) # (B, H, head_dim) + q = q.unsqueeze(2) # (B, H, 1, head_dim) + + # Compute attention. + x = F.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask, dropout_p=0.0 + ) # (B, H, 1, head_dim) + + # Concatenate heads and run output. + x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim) + x = self.to_out(x) + return x + + +class PadSplitXY(torch.autograd.Function): + """ + Merge heads, pad and extract visual and text tokens, + and split along the sequence length. + """ + + @staticmethod + def forward( + ctx, + xy: torch.Tensor, + indices: torch.Tensor, + B: int, + N: int, + L: int, + dtype: torch.dtype, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + xy: Packed tokens. Shape: (total <= B * (N + L), num_heads * head_dim). + indices: Valid token indices out of unpacked tensor. Shape: (total,) + + Returns: + x: Visual tokens. Shape: (B, N, num_heads * head_dim). + y: Text tokens. Shape: (B, L, num_heads * head_dim). + """ + ctx.save_for_backward(indices) + ctx.B, ctx.N, ctx.L = B, N, L + D = xy.size(1) + + # Pad sequences to (B, N + L, dim). + assert indices.ndim == 1 + output = torch.zeros(B * (N + L), D, device=xy.device, dtype=dtype) + indices = indices.unsqueeze(1).expand( + -1, D + ) # (total,) -> (total, num_heads * head_dim) + output.scatter_(0, indices, xy) + xy = output.view(B, N + L, D) + + # Split visual and text tokens along the sequence length. + return torch.tensor_split(xy, (N,), dim=1) + + +def pad_and_split_xy(xy, indices, B, N, L, dtype) -> Tuple[torch.Tensor, torch.Tensor]: + return PadSplitXY.apply(xy, indices, B, N, L, dtype) + + +class UnifyStreams(torch.autograd.Function): + """Unify visual and text streams.""" + + @staticmethod + def forward( + ctx, + q_x: torch.Tensor, + k_x: torch.Tensor, + v_x: torch.Tensor, + q_y: torch.Tensor, + k_y: torch.Tensor, + v_y: torch.Tensor, + indices: torch.Tensor, + ): + """ + Args: + q_x: (B, N, num_heads, head_dim) + k_x: (B, N, num_heads, head_dim) + v_x: (B, N, num_heads, head_dim) + q_y: (B, L, num_heads, head_dim) + k_y: (B, L, num_heads, head_dim) + v_y: (B, L, num_heads, head_dim) + indices: (total <= B * (N + L)) + + Returns: + qkv: (total <= B * (N + L), 3, num_heads, head_dim) + """ + ctx.save_for_backward(indices) + B, N, num_heads, head_dim = q_x.size() + ctx.B, ctx.N, ctx.L = B, N, q_y.size(1) + D = num_heads * head_dim + + q = torch.cat([q_x, q_y], dim=1) + k = torch.cat([k_x, k_y], dim=1) + v = torch.cat([v_x, v_y], dim=1) + qkv = torch.stack([q, k, v], dim=2).view(B * (N + ctx.L), 3, D) + + indices = indices[:, None, None].expand(-1, 3, D) + qkv = torch.gather(qkv, 0, indices) # (total, 3, num_heads * head_dim) + return qkv.unflatten(2, (num_heads, head_dim)) + + +def unify_streams(q_x, k_x, v_x, q_y, k_y, v_y, indices) -> torch.Tensor: + return UnifyStreams.apply(q_x, k_x, v_x, q_y, k_y, v_y, indices) diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py new file mode 100644 index 0000000..066998f --- /dev/null +++ b/mochi_preview/t2v_synth_mochi.py @@ -0,0 +1,445 @@ +import json +import os +import random +from functools import partial +from typing import Dict, List + +from safetensors.torch import load_file +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.data +import yaml +from einops import rearrange, repeat +from omegaconf import OmegaConf +from torch import nn +from torch.distributed.fsdp import ( + BackwardPrefetch, + MixedPrecision, + ShardingStrategy, +) +from torch.distributed.fsdp import ( + FullyShardedDataParallel as FSDP, +) +from torch.distributed.fsdp.wrap import ( + lambda_auto_wrap_policy, + transformer_auto_wrap_policy, +) +from transformers import T5EncoderModel, T5Tokenizer +from transformers.models.t5.modeling_t5 import T5Block + +from .dit.joint_model.context_parallel import get_cp_rank_size +from .utils import Timer +from tqdm import tqdm +from comfy.utils import ProgressBar + +T5_MODEL = "weights/T5" +MAX_T5_TOKEN_LENGTH = 256 + +class T5_Tokenizer: + """Wrapper around Hugging Face tokenizer for T5 + + Args: + model_name(str): Name of tokenizer to load. + """ + + def __init__(self): + self.tokenizer = T5Tokenizer.from_pretrained(T5_MODEL, legacy=False) + + def __call__(self, prompt, padding, truncation, return_tensors, max_length=None): + """ + Args: + prompt (str): The input text to tokenize. + padding (str): The padding strategy. + truncation (bool): Flag indicating whether to truncate the tokens. + return_tensors (str): Flag indicating whether to return tensors. + max_length (int): The max length of the tokens. + """ + assert ( + not max_length or max_length == MAX_T5_TOKEN_LENGTH + ), f"Max length must be {MAX_T5_TOKEN_LENGTH} for T5." + + tokenized_output = self.tokenizer( + prompt, + padding=padding, + max_length=MAX_T5_TOKEN_LENGTH, # Max token length for T5 is set here. + truncation=truncation, + return_tensors=return_tensors, + return_attention_mask=True, + ) + + return tokenized_output + + +def unnormalize_latents( + z: torch.Tensor, + mean: torch.Tensor, + std: torch.Tensor, +) -> torch.Tensor: + """Unnormalize latents. Useful for decoding DiT samples. + + Args: + z (torch.Tensor): [B, C_z, T_z, H_z, W_z], float + + Returns: + torch.Tensor: [B, C_z, T_z, H_z, W_z], float + """ + mean = mean[:, None, None, None] + std = std[:, None, None, None] + + assert z.ndim == 5 + assert z.size(1) == mean.size(0) == std.size(0) + return z * std.to(z) + mean.to(z) + + +def setup_fsdp_sync(model, device_id, *, param_dtype, auto_wrap_policy) -> FSDP: + model = FSDP( + model, + sharding_strategy=ShardingStrategy.FULL_SHARD, + mixed_precision=MixedPrecision( + param_dtype=param_dtype, + reduce_dtype=torch.float32, + buffer_dtype=torch.float32, + ), + auto_wrap_policy=auto_wrap_policy, + backward_prefetch=BackwardPrefetch.BACKWARD_PRE, + limit_all_gathers=True, + device_id=device_id, + sync_module_states=True, + use_orig_params=True, + ) + torch.cuda.synchronize() + return model + + +def compute_packed_indices( + N: int, + text_mask: List[torch.Tensor], +) -> Dict[str, torch.Tensor]: + """ + Based on https://github.com/Dao-AILab/flash-attention/blob/765741c1eeb86c96ee71a3291ad6968cfbf4e4a1/flash_attn/bert_padding.py#L60-L80 + + Args: + N: Number of visual tokens. + text_mask: (B, L) List of boolean tensor indicating which text tokens are not padding. + + Returns: + packed_indices: Dict with keys for Flash Attention: + - valid_token_indices_kv: up to (B * (N + L),) tensor of valid token indices (non-padding) + in the packed sequence. + - cu_seqlens_kv: (B + 1,) tensor of cumulative sequence lengths in the packed sequence. + - max_seqlen_in_batch_kv: int of the maximum sequence length in the batch. + """ + # Create an expanded token mask saying which tokens are valid across both visual and text tokens. + assert N > 0 and len(text_mask) == 1 + text_mask = text_mask[0] + + mask = F.pad(text_mask, (N, 0), value=True) # (B, N + L) + seqlens_in_batch = mask.sum(dim=-1, dtype=torch.int32) # (B,) + valid_token_indices = torch.nonzero( + mask.flatten(), as_tuple=False + ).flatten() # up to (B * (N + L),) + + assert valid_token_indices.size(0) >= text_mask.size(0) * N # At least (B * N,) + cu_seqlens = F.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0) + ) + max_seqlen_in_batch = seqlens_in_batch.max().item() + + return { + "cu_seqlens_kv": cu_seqlens, + "max_seqlen_in_batch_kv": max_seqlen_in_batch, + "valid_token_indices_kv": valid_token_indices, + } + + +def shift_sigma( + sigma: np.ndarray, + shift: float, +): + """Shift noise standard deviation toward higher values. + + Useful for training a model at high resolutions, + or sampling more finely at high noise levels. + + Equivalent to: + sigma_shift = shift / (shift + 1 / sigma - 1) + except for sigma = 0. + + Args: + sigma: noise standard deviation in [0, 1] + shift: shift factor >= 1. + For shift > 1, shifts sigma to higher values. + For shift = 1, identity function. + """ + return shift * sigma / (shift * sigma + 1 - sigma) + + +class T2VSynthMochiModel: + def __init__( + self, + *, + device_id: int, + vae_stats_path: str, + dit_checkpoint_path: str, + ): + super().__init__() + t = Timer() + self.device = torch.device(device_id) + + #self.t5_tokenizer = T5_Tokenizer() + + # with t("load_text_encs"): + # t5_enc = T5EncoderModel.from_pretrained(T5_MODEL) + # self.t5_enc = t5_enc.eval().to(torch.bfloat16).to("cpu") + + with t("construct_dit"): + from .dit.joint_model.asymm_models_joint import ( + AsymmDiTJoint, + ) + model: nn.Module = torch.nn.utils.skip_init( + AsymmDiTJoint, + depth=48, + patch_size=2, + num_heads=24, + hidden_size_x=3072, + hidden_size_y=1536, + mlp_ratio_x=4.0, + mlp_ratio_y=4.0, + in_channels=12, + qk_norm=True, + qkv_bias=False, + out_bias=True, + patch_embed_bias=True, + timestep_mlp_bias=True, + timestep_scale=1000.0, + t5_feat_dim=4096, + t5_token_length=256, + rope_theta=10000.0, + ) + with t("dit_load_checkpoint"): + + model.load_state_dict(load_file(dit_checkpoint_path)) + + with t("fsdp_dit"): + self.dit = model + self.dit.eval() + for name, param in self.dit.named_parameters(): + params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"} + if not any(keyword in name for keyword in params_to_keep): + param.data = param.data.to(torch.float8_e4m3fn) + else: + param.data = param.data.to(torch.bfloat16) + + + vae_stats = json.load(open(vae_stats_path)) + self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device) + self.vae_std = torch.Tensor(vae_stats["std"]).to(self.device) + + t.print_stats() + + def get_conditioning(self, prompts, *, zero_last_n_prompts: int): + B = len(prompts) + print(f"Getting conditioning for {B} prompts") + assert ( + 0 <= zero_last_n_prompts <= B + ), f"zero_last_n_prompts should be between 0 and {B}, got {zero_last_n_prompts}" + tokenize_kwargs = dict( + prompt=prompts, + padding="max_length", + return_tensors="pt", + truncation=True, + ) + + t5_toks = self.t5_tokenizer(**tokenize_kwargs, max_length=MAX_T5_TOKEN_LENGTH) + caption_input_ids_t5 = t5_toks["input_ids"] + caption_attention_mask_t5 = t5_toks["attention_mask"].bool() + del t5_toks + + assert caption_input_ids_t5.shape == (B, MAX_T5_TOKEN_LENGTH) + assert caption_attention_mask_t5.shape == (B, MAX_T5_TOKEN_LENGTH) + + if zero_last_n_prompts > 0: + # Zero the last N prompts + caption_input_ids_t5[-zero_last_n_prompts:] = 0 + caption_attention_mask_t5[-zero_last_n_prompts:] = False + + caption_input_ids_t5 = caption_input_ids_t5.to(self.device, non_blocking=True) + caption_attention_mask_t5 = caption_attention_mask_t5.to( + self.device, non_blocking=True + ) + + y_mask = [caption_attention_mask_t5] + y_feat = [] + + self.t5_enc.to(self.device) + y_feat.append( + self.t5_enc( + caption_input_ids_t5, caption_attention_mask_t5 + ).last_hidden_state.detach().to(torch.float32) + ) + print(y_feat.shape) + print(y_feat[0]) + self.t5_enc.to("cpu") + # Sometimes returns a tensor, othertimes a tuple, not sure why + # See: https://huggingface.co/genmo/mochi-1-preview/discussions/3 + assert tuple(y_feat[-1].shape) == (B, MAX_T5_TOKEN_LENGTH, 4096) + return dict(y_mask=y_mask, y_feat=y_feat) + + def get_packed_indices(self, y_mask, *, lT, lW, lH): + patch_size = 2 + N = lT * lH * lW // (patch_size**2) + assert len(y_mask) == 1 + packed_indices = compute_packed_indices(N, y_mask) + self.move_to_device_(packed_indices) + return packed_indices + + def move_to_device_(self, sample): + if isinstance(sample, dict): + for key in sample.keys(): + if isinstance(sample[key], torch.Tensor): + sample[key] = sample[key].to(self.device, non_blocking=True) + + @torch.inference_mode(mode=True) + def run(self, args, stream_results): + random.seed(args["seed"]) + np.random.seed(args["seed"]) + torch.manual_seed(args["seed"]) + + generator = torch.Generator(device=self.device) + generator.manual_seed(args["seed"]) + + # assert ( + # len(args["prompt"]) == 1 + # ), f"Expected exactly one prompt, got {len(args['prompt'])}" + #prompt = args["prompt"][0] + #neg_prompt = args["negative_prompt"][0] if len(args["negative_prompt"]) else "" + B = 1 + + w = args["width"] + h = args["height"] + t = args["num_frames"] + batch_cfg = args["mochi_args"]["batch_cfg"] + sample_steps = args["mochi_args"]["num_inference_steps"] + cfg_schedule = args["mochi_args"].get("cfg_schedule") + assert ( + len(cfg_schedule) == sample_steps + ), f"cfg_schedule must have length {sample_steps}, got {len(cfg_schedule)}" + sigma_schedule = args["mochi_args"].get("sigma_schedule") + if sigma_schedule: + assert ( + len(sigma_schedule) == sample_steps + 1 + ), f"sigma_schedule must have length {sample_steps + 1}, got {len(sigma_schedule)}" + assert (t - 1) % 6 == 0, f"t - 1 must be divisible by 6, got {t - 1}" + + # if batch_cfg: + # sample_batched = self.get_conditioning( + # [prompt] + [neg_prompt], zero_last_n_prompts=B if neg_prompt == "" else 0 + # ) + # else: + # sample = self.get_conditioning([prompt], zero_last_n_prompts=0) + # sample_null = self.get_conditioning([neg_prompt] * B, zero_last_n_prompts=B if neg_prompt == "" else 0) + + spatial_downsample = 8 + temporal_downsample = 6 + latent_t = (t - 1) // temporal_downsample + 1 + latent_w, latent_h = w // spatial_downsample, h // spatial_downsample + + latent_dims = dict(lT=latent_t, lW=latent_w, lH=latent_h) + in_channels = 12 + z = torch.randn( + (B, in_channels, latent_t, latent_h, latent_w), + device=self.device, + generator=generator, + dtype=torch.float32, + ) + + # if batch_cfg: + # sample_batched["packed_indices"] = self.get_packed_indices( + # sample_batched["y_mask"], **latent_dims + # ) + # z = repeat(z, "b ... -> (repeat b) ...", repeat=2) + # else: + + sample = { + "y_mask": [args["positive_embeds"]["attention_mask"].to(self.device)], + "y_feat": [args["positive_embeds"]["embeds"].to(self.device)] + } + sample_null = { + "y_mask": [args["negative_embeds"]["attention_mask"].to(self.device)], + "y_feat": [args["negative_embeds"]["embeds"].to(self.device)] + } + + # print(sample["y_mask"]) + # print(type(sample["y_mask"])) + # print(sample["y_mask"][0].shape) + + # print(sample["y_feat"]) + # print(type(sample["y_feat"])) + # print(sample["y_feat"][0].shape) + + print(sample_null["y_mask"]) + print(type(sample_null["y_mask"])) + print(sample_null["y_mask"][0].shape) + + print(sample_null["y_feat"]) + print(type(sample_null["y_feat"])) + print(sample_null["y_feat"][0].shape) + + sample["packed_indices"] = self.get_packed_indices( + sample["y_mask"], **latent_dims + ) + sample_null["packed_indices"] = self.get_packed_indices( + sample_null["y_mask"], **latent_dims + ) + + def model_fn(*, z, sigma, cfg_scale): + #print("z", z.dtype, z.device) + #print("sigma", sigma.dtype, sigma.device) + self.dit.to(self.device) + # if batch_cfg: + # with torch.autocast("cuda", dtype=torch.bfloat16): + # out = self.dit(z, sigma, **sample_batched) + # out_cond, out_uncond = torch.chunk(out, chunks=2, dim=0) + #else: + + nonlocal sample, sample_null + with torch.autocast("cuda", dtype=torch.bfloat16): + out_cond = self.dit(z, sigma, **sample) + out_uncond = self.dit(z, sigma, **sample_null) + assert out_cond.shape == out_uncond.shape + + return out_uncond + cfg_scale * (out_cond - out_uncond), out_cond + + comfy_pbar = ProgressBar(sample_steps) + for i in tqdm(range(0, sample_steps), desc="Processing Samples", total=sample_steps): + sigma = sigma_schedule[i] + dsigma = sigma - sigma_schedule[i + 1] + + # `pred` estimates `z_0 - eps`. + pred, output_cond = model_fn( + z=z, + sigma=torch.full( + [B] if not batch_cfg else [B * 2], sigma, device=z.device + ), + cfg_scale=cfg_schedule[i], + ) + pred = pred.to(z) + output_cond = output_cond.to(z) + + #if stream_results: + # yield i / sample_steps, None, False + z = z + dsigma * pred + comfy_pbar.update(1) + + cp_rank, cp_size = get_cp_rank_size() + if batch_cfg: + z = z[:B] + z = z.tensor_split(cp_size, dim=2)[cp_rank] # split along temporal dim + self.dit.to("cpu") + + samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std) + print("samples", samples.shape, samples.dtype, samples.device) + return samples diff --git a/mochi_preview/utils.py b/mochi_preview/utils.py new file mode 100644 index 0000000..8732472 --- /dev/null +++ b/mochi_preview/utils.py @@ -0,0 +1,33 @@ +import time + + +class Timer: + def __init__(self): + self.times = {} # Dictionary to store times per stage + + def __call__(self, name): + print(f"Timing {name}") + return self.TimerContextManager(self, name) + + def print_stats(self): + total_time = sum(self.times.values()) + # Print table header + print("{:<20} {:>10} {:>10}".format("Stage", "Time(s)", "Percent")) + for name, t in self.times.items(): + percent = (t / total_time) * 100 if total_time > 0 else 0 + print("{:<20} {:>10.2f} {:>9.2f}%".format(name, t, percent)) + + class TimerContextManager: + def __init__(self, outer, name): + self.outer = outer # Reference to the Timer instance + self.name = name + self.start_time = None + + def __enter__(self): + self.start_time = time.perf_counter() + return self + + def __exit__(self, exc_type, exc_value, traceback): + end_time = time.perf_counter() + elapsed = end_time - self.start_time + self.outer.times[self.name] = self.outer.times.get(self.name, 0) + elapsed diff --git a/mochi_preview/vae/__init__.py b/mochi_preview/vae/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mochi_preview/vae/__pycache__/__init__.cpython-311.pyc b/mochi_preview/vae/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c3e80b0e195309bf348e7e859b01f534e6af0e8 GIT binary patch literal 162 zcmZ3^%ge<81PM>X(?RrO5CH>>P{wCAAY(d13PUi1CZpdTZlX-=wL5i8IXkUhox TK;i>4BO~Jn1{hJq3={(Z5rQSM literal 0 HcmV?d00001 diff --git a/mochi_preview/vae/__pycache__/__init__.cpython-312.pyc b/mochi_preview/vae/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4aadf57b1dbf3d8b03cca90f7a9a4cf013470d2 GIT binary patch literal 132 zcmX@j%ge<81PM>X(?RrO5P=Rpvj9b=GgLBYGWxA#C}INgK7-W!QuecoarBJI%}>tA zj4vokEz3+Tk10z`jfs!X%*!l^kJl@x{Ka9Do1apelWJGQ3e>|0#Kj=SM`lJw#v*1Q F3jo?`9Tflo literal 0 HcmV?d00001 diff --git a/mochi_preview/vae/__pycache__/cp_conv.cpython-311.pyc b/mochi_preview/vae/__pycache__/cp_conv.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0254bf9b8505bd52311e96e11e048a9b1e03278e GIT binary patch literal 8457 zcmb_BZEO=~nlt_!f5mo;lR!4jf2^@}oNSq&=xW?oNBso#cr|!WtwJZK@fTI2j3-x5LWgd3X6m^^8DV|PJ zDSC=dnWjuMmQ6|Xl$pFODeIK2R%cH!QD{#&rkqe?<}FEA%01;yd8Ry2-p1Q*(o=1` zT`r?TJ;Nu<5H8C6SylI*8x8S*+faeCh_7*(jSQg#`H~~J$dz&!DgbF^!(u54b2(lBsoGx74a;Qm!uGY zMQX#hqcrjxLGkTUeY+~IR|+GAk=60lgKN9*9<5kjc?d|9Mqk(3Vj`Ckq==k}NJ5&| zY@#4%#k7GW$cTg_hhVip66{+`nQ3=$bQu*z7M_fro1^v^Yj%$Fv*rP^Cmry zgu6(|K(cx=_Rd@6`b7YQSv($+TCl7+6OOq6B)CG_#h>P=a?_}_Zt&vGU&5-tYOFj* zUmKT59`a@^Gx26cjjF(2#R`|CE^21IeU)D(K%<%6$3V$b=5H- z@2=UcRfStX49q*WT`Z)@s9#ZA2Vp51sc@)VwLnBm+&a3>4tjf{AW z6M1%ylO$H2;beA(OY=!VVxyT4hWtG@iMr0G|m>>#pl?KT` z<6>MoLUK6l^&{-0kd^=~6Jw-L=$`%XapWd7ECkz2a4 ze4{k6?tQ6J|3OpHPoD82Jf*hcFDjk=`1ALx-hunx1MA)cig#4?jutE%jCIPOBW14Q z9;mhlR{E=fZ7VNt+AW>VKTrVvh@jwlL-p z-zK-i64_xlWGolZPY)x_3Ba;S!j@&~8`fnj~n*y;xl_B?d z;%TF(tO=*lXYYXmId9?{PtFvx<;`+Ki-SlfDgo!tQ8}|wcwE$lhk1fr5prX~X7L5U z)#x`g^Bl*AEjr8Y$2u!iL?$%`UvOb031unroX&18U^X5#9Z*Dk8Je(@Gznmls@mO! zt6N;LZttqtyQ;omiT+#L-`Gl5KX=@5D85j^THO{bUM+=67s_;bv=Rsxm<_LQIa|C? zVia#k^@b`;sM?@@w4r5SmIm4X0k-^CEyjgWe-NEX=S|IoGGgS-OH{c*IvBl?jCi*w zz-uCR0}aec<27+GR_-f!@rS?^%p(u2J2$z6O0Gi*iwI$xp+rfUR!JN}*y&x`{ zw#%&Hw72v)!_)Ve)@vAcEc_$Ud8avf1_V0GC6jDTPsYSdij?QXOjM90P~?9uW&PY5 z7H+r9B@Jp8`8aGMs*Wf1MlswsDDhgYmm1xt2_` zjVpGKHZtj_rsXggKqg()BVU5ddJ$flbwLm_l4b?-4VjY4^jZ__ZaCpj4emoiF z5RjqsB6SzsBF&PJg%l9v81xa*IIVd>WS8P`sCdThy^VcY0Jyt8{Z|SVE_`z7)}^KQ zmSMlOO0prtUXUxEsy2ahM5|7CCVzIq^s`aRn?*s27g{FdXvpZ2uB= z9tMD?>p!Re)~^1lPi@~>9#Pui-FLrzc)fjCY2UB5?{DA`a!iOjpljH>Wl-u%LhY)r z)+Q5sTE99dGc+q0gmFRa!Mqb$Pv=6-T-T_YLjXHk5II>8cjNFbT|M_>540$WiON7m zcgXF8XF+GT)iM!%PR)v=Pl`mXJM~!vK5|_%C;01{O(x?==;Ztb>UTZqTHgQI%@BVA zct_wbad4A?SdV;E*MQO)Ry)IrWgD23?{0Xzmp)wnaPbTnZVu-s?pyAqw&k|PaVUbZ z7`qiKURX*jCl*g{1b3{wUp`z8uZ@3w=3h_!%ehMMyb?UG2G1`}0J7(&bmJENXYlve z{=N4QJ}qT|Uln(s>h4=S^$5oD z6|GBWm(M~sCIHrGR$+QnrsqC0u+9u9%u6cs5?NsTTXf@O|M}Fn{r@!ay9uaJ{O45v zxy7?ElGOpTo-Na>Fe@0*`&G++$Us9^s2o%}`_<0=O^UX@Lco%%U@ly$?$}kncxPxe zM4k()@T}R_I@au~`_vsT7o5cdOWuO_JC|?ymEu(SLdDguxcXIBf5p;Ixa+@x+?ZZN z@di$jWS>n$CZ5sX(GV&tzAkcL`(6=7wQv&^JpnvhLYM|@(tDn94>Cu&=#m-PrAL>9FhnEOZN7(p2~}v73e*(MQ&SL&5X^!_u)-gn z<_Jj4)MRa#o&_srsab>WDI0GQ>|3#on2EGlZ@Q-(1?oCI<%GkrY27EmixsZR=Q$B1 zfsj1O0U44|Ie&oHTp&KvLQ)qdHc1u}ybwh?))NW_{6G9R0I1fR)j2YR(Y^tPq@m)` z-hHgyg0ZFMJ)>L0n-_qA^BBu&iVo>9s9KxdL_OVt-!dO~J2+Gjzk-=9ycJ?Z_L!Ns z-2t2T3;M5V4C`V9W#W6hD_ao*P%0#Pum$iw0qtCgbR-Hf6(EtMxgjXgD2YGE5h4II zhEF6(oJF&PSHvgM@w%qMNFu}xvU7r{F_B0josc7uumgh$$(V>dC`J)%jb*z60iC#{ zIcJ1Kd`3n$OLJUH@bZkFjYOhJbn_z;G8Ep#-VCq*Pl4`uCs`9PXsunFzlqq!DG z6T-cXFAuO@8@O;y!Om7^&gi^}Bh4-pTpLTvRi^(8t3%7TmRi>ip27cf_rzds8lyfF zoP(LSx4K^75N{p*W3FtjcGk7fwJp~qq(EdiCYpkbufjwc1HnQ38X=<@p2~pOY!ZBg zLk|Gj!<_)mQ8dxXmz}s`p3=GM5_KB{(z1)EQHK%%G@%O)-JaKIso9$U!7&yW(!%vQ z@z}zlKXu2!W5l3vlM_?#_c5$R+6~};-~mIoZ4dmw%D-A1zcsIynxTFHQ5op=tw*1&c%a>Gm0RFjrQ6SG&3k1A|pvWk{FMgTQ>tD(p)kjYi9U(6K2F2>`MZs zXx@qgIEotxjCxCtm1eMtxeoF-3o&dl2}KFTbxj1;Bp`lJS$OdosPhr_->ciO8=(;; zbU+OqP?+9=z1Z`hvm4B~Ra)(Q3FKYpV8LB=wHLWtV}-G5s80>;UyHAYj#ffPi${t_ zzCEa(oGiS);q$Lpi(D<5HBwHljaRlG`ew4?JE8basJ;`0@dvK1(m0w9uAYrR*QbY9 z4&Ofh*>UhL0)1)#y{qqBZM7qj?|c9`rEQN2(6vW(?b)Qvj@{q+dP`T!dllc1>KiJI zKY5f#4CU!9af+K&-E4t|xaQL7<_r-Ms}G;-j~ZeRiz8Z|y?ym=ZXw z1`cEJt-bj2QhqsK_N;sNRqDU0x3gqdy=;YH2}6WksLImtsVfrEoDpPzY!dVCNaSjk zOX@8mibQZR(y~ZYEvlxPD>SLQLU>e~r>4)M0{AGAFs+{}UZmG_7xpstLof?}>9+tt zW7VHcQ_$+%WT=kr(t)zI99=!IW-g~x_F$#^pc)+AbYQcS@^=)YrO|Q#5z0Mk_fVw+ z2spgy!e;7eV6ej)Dh2QSXp@4>a{>L{;G$te$fNMOh_CJ7Zrzwe`+(2{Nu<-nPrfG% z;ztNEF`bnKFwB|?x|bNxk<#hGST-HS?>bykGo{m-bE`~lqQ>$Gd2luZ#3$D?Iat3* zb2Sr?eDu(7GpY~qLj-vQ_}zd6boI+>jU_XjtXZZrnIy5BH5*(mFdWI55NW0F7K2_t zFk?D1920Lr5l>M{0NgawG+m`$s66TSfO0RAr%JIE4zG zz=UW|QCLZCQjl3YNnYQa`rTV3(+bYeR=^%z8QY{FvuY- z3Dbay#=0qP9x#(Hldud}D{Zy}8;1UbeZT=NW`T)26RrVQ!ad-Idc9z|P7l-xR?$4% zVQf|V3O3Ozup9S9tb+YIGvIsVT*qVQI$>`8Bj-M9PeyRTyau5`aBobNP$xD5#Z5q+ z=elK}*+iYD1n+f9@WFU*?PNx2xYsDvzIV)ed7yRD6lnY$D*=;c=}(WwMa}Y3GL}mA zeQwfQT~Sd9g-1hCDU}}8yp@{7Cr3hZY)TB{Fe5Qgz@=CALN-S!aFKdaDWy8@PNgK7 zRO&LNQW09Erz{kd19Tsd;fTqxq^$7CFwmkPNzh0unl&yaqsows(>d3iMMNJ)zY_Bt9X^Ate=(#iXEFB~eLB$u;Lk5hmyF0V3i2 zIG3d13d4;$e(zK&S72<$lrgKcK#z0-Nfy;Cn6RXCNCDpH$(WI06hm34VbfRL4LfV&jFMo!Q@JUX9=QS(n9fme&?+S`0`mito~yJI>OU|=%$s1^ zu-lkh`&M_guA&b4ye_F$)dDm@tIFK97@|m1stL}20;n7ZDmM`1E5L`0EyHH)8Aryc zfcenVaHCPGeN~&vsMH)?Di`O3>PBlxyF+R)e05Yt<<>*G3p6)kRJLoLzg0 zRdk_ftxGVAikbxTU7Mi=8PEWbe3o~?6*xFZr&7{5FA3ZzFUy=V#4FqopA_Pv%!N~l zL^>G@^9m>^Hz+FOqL}1R6=g1xlDJqhB1$6YO)ps}C`IMNq=d^(9OnAPqzuDS5l$ZB zM@6ptg&yvB57*zro$29T2~>x!mGr}~uF}IL(}_WmOxC+OCAEqmaVin_)~ zVt7n51J?s?3FT7qVY3Y^3t83;h4GXW7l=KSP}VeOG@csdw!o;HV>4qp{>Eg!@1AFSk=SVuABB5dv0sqS>SJZ=F^|B2lRD)Z(J+4ZhFUd%QoZA(%C&(KIh8?mzvw= zMrKFK?z$YEJ670J;ES&Aa(%;G*KAk0p>gie?4gG?rp1wUK6FraH(7f}{@VPdTi;s{ zKVgq7xqLUz&YaC%D+G(p+vbD!UASbCjO0OXSrYt2sDF(;Mlm4v_-!c6QD6}$ z1Q+X*K;Nmf68p3Fo5!I=&6tGhPc=oX8M9Kg^26u}*Cmxw&4$*a-vg~T_Yv79q8>0y zPr?k1KBbvQc_F~)vWA{ivp_?rUt{qDk}k`*JPM00Cr*$O0UZ-?Ksp3nI7z+$*)&zQ zxw2P{%e`l7Eqj~t^gGsD*8J7KvzNU7tfk!8l)IYu=PwoL!hvE#Aj>X+IHYINxl4I= z(c>?&{&JPe;nnOz*F)I*BIK*xz2ORoYojx!8e@zeaCg*j)v>`kl~nBq1`Su8oY6}G zCN!=JoS@xkA4XTcE~!iuuiLe*8LCRvWe#uPp6;5eF2JIorjw+ajD1f~H?3h)lnR?( z2R3J{tEi|qM%TWA`SH7LOyOO|pkdbZ=6(jgfwR^8B@*;IfHYuu>~CArkWG7Z>Glr-(t$kI8K-V1qWso2m{u%4IQDUnlXbZ4Y|KLC!Kg}yc;3GIb9_K2qHzK7Hq&_O8 z!lEn#Hu(SY|Nm#F4R%N2<7+%O>UF?Gyr=F*5ur&fn4u8&NrlZ*YZh>}QHThpL@6apD2$o|SQ3g7_GnB@5fhr(uz;F_j0?qMBcf&( z0lbJ3uZYq%m}j{0E?fiMIKc%h6!aL!7Egu2clT>`l|raO&Q6?Q!A<)mCzK_$vzkMv zKomj~nmHy+0QF?WM|pLhK0bu&R2=(xSm!YOG^F7G1CYq2r_An=jpbZRWMS?-!l#?Ra;`0{_9#`$K|1TWI}+IsC6rZe4VB-b&~F*-JNHo_YDLSMzBc*a-`}y*bN` za{w*ahFmb0US!)Ju-zrLdy(Bf9fba)FVG)#{bS#v@BFX(7OwuJ^FjTVQvH^~p2d17 z?k?5u{-l2IhkeuMioWwO+F~agpDWPwu<>4KY4sPH{@BtrzDpt zURtm%G%wiZca}CE$~toUe&ks;1Iu6nV1#F8p2-aqF5P!_6`8KDzU-vj-=@F%4_D_` zU)E9e7ibx+^luIxJhh4Wt*hl!2lKm@?Wdk%G!QcYr$|Oj{>^p zWl`V&^%E*NT_8CWTn>T*PT&=uLrmNo0}`KyfY~L+#e|sDMNzY$0>lFr2|6%{)1!JKOcov(1oL&*fK^~b+eU$9#6-qe zV2nLWP0#}lfM!;$?KlJq;)L=dFM+Cx@#8$4IS!3u`vlE7B1%ayuA2_4tVl6I3?tKw zg!TddD?fqkZ|T+O2yfo-p8y)Gh)QKvwKEvF>t{o-DVV1~3Nuz{B>}h|)K^+%9P#A(xIIyg_Ad0Kjq?d7pS0K|^Ar>crsb&K-D#Vh}+W>IMLkmP$=}}SA*ia~z zj47c|z>c9;JR%L?^w%(}Eo2>m@HZ-Jjv+A?9a0e6YWDG%pbY87P$(S7peQ7xq@avq znkG@URNw)hKe^Za7-?yPgk*Oot32rl}zWv%6Af5AVS zEL<%$ch6rcHT7m4`Tr3ucK~Hx#OZS$}B~n51*#5gE&>myjxSuC0XcR{9SGI07P816Uav z#hC2iQK}$)#s)$PM5YJGX3Z=H%dkY9WQ&lVa8+NOL z=^720$IgM1(8CB})+IQf?Jis0*Qy}%Nm8u_l`7$J8tQzk(ph zpfFVAKIuD=4<8h z+A=cqzfhNobvtfd1FUd07kq2%*0rxltonw>N9&<3fqrfs`k~Hiu8Oq?;d!O~&?KQV zkwVyMtL&)7cpeV05|4AqD-FUs33viq09(kl+T z*u+kN4NJqrnIuA0MzKaK0;yt1lF%~s)c=AZZ{YkdA%n9o`S&dP_m$X=tS#66X-nIK zmM(ZXnWsxF+rfCW^k!XUXMK+U(UG#h^MQYF$-j3Yde47!_V7ph9~|#59q-ScSn~Sj zEVGszUx~=~6ygiP;-;rR?7#1QAshVE*_sdDbG9!vw9X!!eGZURLnlD7&zyCox*d7t zoohe4wph30zH`TC-j4j$LeHXiS2nnA@I#hzcjWm+7nh~sQDx@z%*A|g(bbWq@fqgW z?6Jbmg-f|(iwy_yDWE=g=tidCzUSFlWOo)lJIkJyyshNniY!MgSip&PJqd5)p^)YX z;Weh?SayX%SJQl4?~%~-06b1A5^*};g=z)s_v=_68Lqi2#yu^6vh#6T5qxE&hDdJ5Stt;=d^a2$1obFDMHhfYY6jb1(%S=PF*T`1OSqY*{`-TVTk6xg)bjmMJLCo5^Rv@`3Ap z*9x{QdpqfF{U(+vD9(5Os^@mk3KqY35oUK3I^XMgw+E+nd}tw`j~SfV@zC2#`=Ph{ R-R}2xy}Rqr6jaH9{s#s|Xea;x literal 0 HcmV?d00001 diff --git a/mochi_preview/vae/__pycache__/model.cpython-311.pyc b/mochi_preview/vae/__pycache__/model.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fcead8c04467812bca3c54e02e9f5b092c419f79 GIT binary patch literal 38113 zcmdVDdvF{_njcuz@Am`mHyR-EXb=DeKE5r^Hg_Ez@Axq0MV z?Cl@+{l2WO>Z(SA;*7Nyan;pdWo2b%W#!}hzWlx~GykE}Y2$F@oHdtwu5jG{MmO@| z%sU?23>^0n7vzHcBsaxhy%|w z$8%G*i+1LtpL9$)FFN5i1dWrfDfdP9l;@&{-J2%8Q)L&+rphmt<9;Y;d7r;nq2o?) zLF@Zm&=#^>sg$$!nV#c5$6q=wR6JhG5P6#)~a1W-VgY zJtbx<;??6@w#g|68<0vPQZWUakji$JN;6XNJtdVk7ODlITAvcCorT(lP}`ppsw3DI zY`@C~JCFxmu=9P>#ZDH_k9b{AiRTZV4ECtM1hISZ1zjxmPQ-rhDY193P<;rs>nWkS zgS*vl2<|~y?Ztne+IP$n-2cAeVo&gSwQul%=8j?($2y2uy-$sGs6?#8h_&-~j&%gF zo(mpDdM~gvn^E5{!mrQDp)H*X_J??NelKh8#&orEbK#JNy7e4GY6JLFP6tEHV~*f) z)WNQx?gHmO{WZn%8?%;yz~p4$(qt%WI5QEBWG%1EMkZ#a1Cv?f@cb;ijjv7P{;Z$R zT7^(R5CYSeLtpDCvM%en9EyyL&W^zI>PUEEA(XXG&Wr{oN5bO)A&3AjJe&=L!y{ut zU@8>OdM^he;~`-L@63CB%_FH%N>4t3s5n9kj}HTU#6>7T{)loBRmjia{#>~ibWv{F z80E(JD8FFjxC`7p{w%U!n+Q)#ha-XMQG82;J_(+DBx{-sO<#_ThbaM{?{9YPoSKg8 zoZi`gYUk9<==j9QtPr|35xQPm!3duO|54Ql^|h-f(Q7=gG^0`h?zvx9y0AK<$~+(kafUDO5n zi~680s7E(6gbY{Ya9q$BGL0Ex_}-xDeb+^EjC+T_XhGGQvz`;e%>3-x8DZ+eY#=fb zm|SSL`EV%SzQFA46&wtD;ui z{e=IBi}D~m`Qy9D;-Y@huxMN~Et(?~E`LP%Ag`^oMLwc@5XW6n!bbVekn-pBE_N(h zB39YWflfvBQ6p7yRHwz{gWTKu7OfF`A?#;xey-dmO^Q?Hi1JY@e!2zt{T)?*(&iy* zQHP?&xwK#Nb(Wx>@^#}sj(i*^24ZBBYG7ojOqw2c}5r=&)Py$vymGklM`3bxLMaksCfj+n#X2@ z>(mMDSwlE9IVKRjVl^o=Qv&>e;A7X5S^kZzX=Xk$J0HpNm$KH8k`q%xv_p0w# z|6%hln;&=|^gWn+c<_(k6A!(dK5#bWe|qkY|7lc*9~ z4k2*A=pIVBhtlpL(J~|uW8t@EEh8h-z)d3~S=-16y7v4exg8@TbMt}8{1fxYNN{F! zWJG91oUCa$G##E1vigbX2rG`PGhb$-A>g%wD9~xiza|vlF2uIu1wRcISE7L3GLyK5~jOGM&3+OY8x+t*rRTeAG3xBFc zNA!fkXx|Xj6Y5#iN4YBsf(jZ6gniMV`2s?}?1&l=6U={;f{hxZ`k>*yQHxp^O~@S^Xt0!3xeMyAvN8>qo#uj?P=Am6nNC|yKjM)J;T3kM zzfqHlSsiMNREXMcPQ{?giv_7Dp-QI0Uy%n$oS-|+A$O`a!3ddo(Xwb=v@P1BlqZ+` zPLKivmAq7QQ5&8W%0^$PxdoI=)UJ8TYoCr87af{<2E)Yqt2+a4FoqVC1=XnC~k+1fAbB^A-Ks8f^Y@@PeIo*8+*6?I}Lh)Diip>C{Tl76dYWW9>{IM9F7PR!H{pk8l0E}^EET=Td>H8@SZMf3`TCu zhO(w0xXaPd*Z&2Fzfz(rhFM`Im^DQhbrD*Sa@H^r2~7#R$!WhD5~f37t%=Px1tel` zV0T8wu4PTqtFspQ%~^B)#aX8inwy^xLgA6iLLiuRfZUIajt8cZU^r`JOoTv1CLAF^ zbu8>BKsD^I5;}0n8Yc)%XB|_asTtwM$mq=cG=>Ya0UItGid`Qq|AJPG&q6X^$`IO?$S1#i?n$>qwqk@qY5!%JDxu{mavMqg2&hoBITi2&=r5vrX{@7f6d%}zB>*$6NbC?WdB`P%G0jB%h^r!uGmYzb#Ybo%Woy?SC9YdrQaC;%pq33 zma2R$UHMvUD6?&Mjq8hpygSwC&vL8L<4u+rs$kUYdUSJ3^Sm&zi`F9E0R!Y>I;<)QTSVcIFw|GL%>v()RQN)5Zlhe(ru{O| z^qNA5AMqGoGZHKjV@g;!0N{5DL*zb2-~@q_1O^GT5*Pr;>Zd?~^izR%vbyQ%tVtTw zNTa&)von!XQxF(o$Px(!j|;+#;I}hkD4d}LD8Ewj>_rzSJ>&2Q$PdPbAc97Sodmc; z(?MA<#;8`&z}lx8Sl4dq9WoJm8L5YTXnq)2yUUl}P4q3jm#j#)^sTm}TlTEE_lWXI zj~5f?((dMD-;C^A9mdFFZ&Z?sj}_KcT;7%5U{Q} z$+KTyl769;f#k%hm;L%z@e_S}({=k|FRho?EI+^e!k6V;Yvo;#sMV*+kEF|wh>jy) zxmrX^3&T;6R}T4&MHn+{xHK~}$tw3AC!sPX(iow7BtW$qCS-b3`Q>!w#?~AMu(F+9 zwgggwl*}mp${{!U$II|ikFJ??*Gd>NWC$7Y$8StEbkIn4_0JT_5Hc~bpd~LBbY3(E z&5)rhoo!6Vo><;@T(q(Ayd5~%maRmykA&Wdyb88*axye2^@8C-+s>Ls1M^`Bg*~HE zNJw?1prdLOLmh(tkN*OI$ps;hr4o^jCZxi>UIfMQ>LOa$bXGNCipG2)Y$ClI(9xA2i)HMa`e`Kjujf zhLW?K{hFvo0rA|^OKvtSP{szJdSF)-hFdsIaTr=+c%y)uCh-8Rc7~qM5+IJ1vEwAB znwXx5jEu+_C_;~lQc@yjO;;iOyByBi#zPaA$0Mk&VAgtlA{ZGbr^F6ZO%?dyJP$k% z&xNLv?V(rBLJ(oXzXSk7k*naU8)Jqgd&c)%%#gM;u2=bz^(%o?)pO~p=VJDCZ+&7Y z*`M;Zr@ie!Dpip8El(%sQdONR=TeouF9I(?O1GixsLvy}w zE|Mki(`T6cihYIf4Cw^wTNHeMk}zY8(Z$N2)<6ltu#)s38_Y!b~`80ihtR zTLg+j18K|_dXB#{8yW>8<$K2$nIQ@9%$SdaA-IDsqxaI-xAR_j4&CLge?x4<}qQ56)%GB5@e8sAHk7JfzhkNA|=RV4Ah2Xq^cGx zG?qrGd9}$85G$MkULgrXOJAyGce-VF%F+}w#~W#A-H@*CSm80g2F|SRC82_Q=}?X{ zTA$0*Y>%DFRM*5##!hb3Y+vD1HGcBgmYhgc?~I*9Y={@!zT``5?wz81W*S@47>{dPU{(afA&dy%=ZDt@u~ZiH-Xo1jUMhRK-BL zVj#}P`HZ(F@iNeTL(iJGNA&irb^;W~Y*~3zY~1%?SgbgfsyLReI2PwO@)_iF!)}cdk^d%;9_8ohkP~+C3mz1{nSE>xPH@jv}U1AZCP#_yiNVt`Z^=Pfn$Md&K%Z>B_x1 zD?PSxP{EBR_9n|IK(aAizf-J2X+4*-(_;tct;{*e<>Km_Rmrlni@?5i;~+1ouE>?+ zejg~&ZTSQ7dD0=XX^*L&tC?y^TONX51kAL7F}nsZyGCM*uPAI^o((qP30P&QwXDEu z=Bycg^jhCLeG6S(fIhnOO%w)zzUvc_abLhUIWY}G2;XGjMo8#ooZVO`5SfQ`RT8Yq zcN&QX#@HOusKBr zm)(L}xDAjs5DO?=qkD#%sC_Zqq~`TJ`Hxn(yU!6SOwA50DWhV<67|VR(K5I?3?Q8k zURRuwV+$^eYKwT5$A2Gy4GZ`rT$>MYCOYj{9|i&nw^Lk7NyecRr?Q#SK>|}8h~2}g zqIl2&{SR7#e&HVtN!oKLA;aNb8&h#o$@QZ^4=Q)$^UguvDW<-Z#P+?u3#6n!=zHr} zw{M`^*Y&m!icq@g!u{K-q|S%XWTrQ#_6pO)`${F(QJh`}5~NJfAyMO__%bMH%~Lbi zLQqE(D576qpmb7epmO#3)HKR1H-$V^`tK1Ud>nxI{4$`$4^G`a1u{^Uax|nJ4M2Df z@2wL{Cq8)P_A810l%px_Xo~R}hm(c7lxUH|W!#>aMM8nLV)VzPn55x&cBB`h&;(ZS z58nh(A-#-Yk7(GFGw{aad~yJm1Guh+rRz69lf&qh5EdLc?souE!>25b;;$Uk@abTk zUW=weV2mLfl~9x(hzR4%_qol83dTJK4U+NK~* zvu3GDS?bc3I?+;>Y28Bu)aFdfF1j{}>|`}Js{TfLF)JIQx-gj~jWt}-@%PC1NF{X! z-%Vm|{n-U01Sqpd7R=qgFgfn&RA0!mFJ!`b9~1V$aOa(f5Eud5EG*P-`mDmk-{7@j z;?zm%_ewmHuI?4xui#vbr1!rpIYrAWTOeK{cQLp*tT0`G#PbL3N%b z;|F@CmYX4m%0P|=KHro&s?W6WPrfaLKius*+3kDdSup*1*_<+165H?W3`e@cU15K> zZ@4QQLAqUGq}&x&A^nBtj`=zVd@!5|kNd*c{RIm%G9c?5#uLA<^CX_)@&<(fanQlc zQx_h`|GGppsN&CpeX!tKXtiwImDNKa$aK+iH>IJH@FN1}0mAuiDm81ys4moOYR`Fk z@^2^tVGVRrkdrD*1=}As$LscHFXIOyp(5 zlvb!Z?d_tyeP#b@;FpJ24*%A|RaD34mapGAynI-6G~7WO+A$1h<>*X&$v zN!9EYYxdGvwm0qBM}q;_o@rxlJcVi%2-e@FqG%=X4Xk>c!Vvfq{KM}7sHj#>XEyHM zXlTwEamAPoE;G4c1y&kV@J41ASAeEEnR>yIN6$(pBO!8$TFEqsEzzoMrU9Bs$KIU17tRxdq#SF~S9*)OE+7evDa36mp)a@g~( z$Fp!~$**Kn3(+pn%3VXygb{-|WDHtDrjWTX`JkY^3R*)}OkXgLL2M3-TRSXnZJ4NF zFW47B&FKWGcCzt=3&r5dTF!^U$6yeBbzv`D6G0dh`YzGEFESp8_+Vii4hh#nVcdp% zd4gTAkPG(?pA|^YhMp%CWFQim9$^Ux)O?tn7=o6(9R~{}d`b_Lr%GK?H}&Ap>hhPs z`hN>V-zed0HGL=uC)Gkf8r0k`V!bu7XrRK5pzeN=np7=hLpm^uhIAg)2)kR;XP``< zE7a-#`!{}ue1ERw9V`$bRCYx5Fi+P{mubFIp;DhAOo5p~Jd${^gJ~JeP-d~p9fpw+ z8Nf_}^F!|v5JUySxb>T*zI2)%9V0MAfZ79;UlV~aYco(Gw2?QHGW>#EqZGh)0Y*Yl zb;4krb-AqJ!o-=^m~?@uGDl%rb1FyyA~#q&!`nJ2HcNtKLGqSBM+r{xEddH4bBaBWz31+we`yC<+s-zAd*z^Z6x%l zvAPPB3jlew<#uuY-D2JGL|3BgPjp|{|JeD36aMZPep|3@)Ja)K9fO06>tRvKCE?zD zlp5_yhU|;{dlXAy@YHdOV*y+73m3nA7D%dgpARVhKHe`i8r48FYXJ&mcri@hi{L6% zxJPLrdG>*r*zi7t!)VmORaC|-8HeYC)3;B@=We~c^fDWH#x1wbkfUP5UH-ujZvP-L zmvXnH-HKk=@7{X+d6lwgB*40Z>1AdNp4G#w?%Btt|A68X_)qwUiPWPx%fF1_1<~*V zG7qs_S8t}FBU4?UbGI9Zm~>2r8)cO_BlEP7yv@p-Hm1 zbGj(prgTW+Ca0NqzslQ^o&`Oq5k!zMP@0K5@VCYaj>rh`6en1M-H$|6`^=B5?hA@)-QcW1hNyM0|g zVpgfWhvAmny*#knB)*&L_Fe7vAyrkz!cdAue3&LOJ5Z9CCW}b zg37q*C~j#+(5!HfTs+fmVG@gBnuL9W@XK1~r^jG;jMB)O=BF=B!1l-rQNwh2b_U!g zDVnlYrFgPB=-v)fQfA>F;UfHF0>6?1$zPx4=Y&_ti@z2w5c52fUI{Li>SF{*dPJ!N z;UU6z627JxVoZAIa4$U{pSyh|#(o*Q>x1WSKOgVEb!h1jbVcJIcc<(fD}8JBZc#on zp7M_@A6gRLJN9LJl24xe#F6s!+&q=BIX~#R-IL%G{fQeZonrm&lxTh|@VrI+xw zbID_(qa)?$NIN=6_F2joLpO*?2X7C?!^<~fgAxNG$wW7KZrB|(FDkbG=Go0-7gABu z@tABv)MFPi8)BwI9?S>SL$3j|6gGM>VDw@LVw5uG$d6H+5GOi<&bvIDBCRmS7|W=2 zA%>tE6MGAK5O3DnFLx&4zs1{!hFNQ4Z6j+z^By5ik{MSsm(*TN0uyMT9VHqdA4E+F zdp`a@0lt+X2T2xbkjM@iNU#|eDrqC^8n_V!|89ZRNchF*N>pDwfJJPjq_&EU_ZmvX zQ!pQ6+x2Hz7e*(&hs;9L+=h$9NU&Z;ya6k^JiBp*sg+na%<5r7D0A1;)1`j-FUj?5 zB^uM|6gs4uyDn2|P+_l6P;aEh1bzhal&7)Rnc9Y=<>O{>sjf_Q6E2Rg$}5)-BrYUd zlB1t=tXjnK1Cntvo>|uO&$g|c6w6_>yfXY`t?D#_wgoi$`*!V&*F{IvnpDQHh4^6M99otjs=LZ+QmBkn7s3^UDwt+vi`vKZ zP<=W`^|{DT|J{hPJ_8sPmmDi<)us$7aI^f3-uvu1(y}_3R`vFx6ZvrBw`qnjCS2LJ zc;d#iuhMZmTYRgFCSR$()eAIJGKC~bayA{sldE`c>3e8%&sGo5SYH|zo=SdYT8 zs8q9BR*{&oZ_x4e^~;;MN=geae~MBV z+5tu7Aiuew3zoP@GKZIDphtY;zU%oJ=y3Uc5#InP6mfiKFUvJFoOf;fso#*bqJ*YL zrUGHi?+OYtv!F~PvzY1?9?6>o8^c)fAtW)W#ICZr>~h-jTz2+e$gHSOX8G|ff8Afs z)C$6n;Un>$$I0~@0z^m|3#sBq4Z(@gh-AIbnAisPT8u?!kRME(n;(s+_|H~KVH<(% z09oVc%xYDKr`=C}oamH3T2I`~aOB{ZB(Ey2~&fYE2MB$2{^p{GMgpa_`g9k`ZO>*IOZmPGgvVLd<+k8Id@nP;*tUps$^-<@CotRU#V|fRR>8ol#n)q-6OtTQ5NL6*E ztGaHTU{kSZ_EmwaM#S=pj}ClzAaOqFO_guMq%hI3jm(;=G&A)xuCfotZ;wOJa&_ry zOb-csdFP6LW%&Nv4+npJMl3s*Dm#}hJ9o=0P4@q)sxCukEy1#?<(^~(<_P4h#%ic_ z0G<$t*G zrX7t1i;BRlGfR+eZ9wXE^7hI2xmzzSy+i{MX;p>Qxd;6Zy`p0vVaDUyq`xpHpuqv$0;Cs$FMvmPOJmkc)So~6S% zBi-6Kn|rAtXO|xp0+n(}k}mqHQT&wyY$-b)pTR{v4h!;R!V#2BILOKjs)&%WAUlD9 zhw2VXL6IqkdYu-Rf&iG7NE&7RE4=5d!XwkWjrbEk_;?xM+ZaS?W>yp#24nc6$7|T3 z)uWH`;KFpog&_>(i5&&w4wGtt^9i0oC;=%9YbUTo0}+um$ZSB9Mu;s9J*u?@ql{q^ z)ZPXVlg2Pxn0c&M)F%e+oLN4DF^F}&qA@WlR`lUaz9n7BJBFWB--P@%g{??1-^3Fd zkboh9Yx4nr;8CsZ2S z5QA((Jn=mYj}QO+J3swSa_-~r-u6L_5-nZqYcU4xrMK)P@Eifw`JcJpXDJGS zkLb-5P2KOI!QRC>Ij-p1=)6{PL4#cyl<~$Lk|-LEmA3n|){Ak+GXSaPC?A@hDi4_+ zAHx%EPo;t3UG!D-*n-SN4~?%rug|qWp9|Ru8DKH}tqHP5Tef_FO?rB5HXI-;P^ptL ziQG@{ih`ys>mo@WR2hb|Sb}LSYK+vgLjy8}A5;r%8cH-b1@%#_go!yO)K^hje6!&#mljcaPnvFpzGUr(YKb5?JxAFG zFW-LoXPwDIV%-ZkSG%Qa%JE{_@uJAi^{SfXDdx%j-f|JeI*z>0K%}q&XsMM=05TG>tNa{!S(?>A)vv(8q`azl>U$3 z2O#B964%W~F@j|>7?jTDqhv_xiJOiRAz=tC87=Ci3jE&@N%$)Q|0}>|Bwb&s_VSqe z*Hn3X5Dl0b%o}6xZk=5^ORSwoGOCqWyW(4ggS%(m8g;i*rYQAIO91>ic@p>qWq?Ln zYJ~0Lwr$U}wC3CmFzDxGuu)l)Gt!NT^owS4fun24S;^(#8htq@xm;wTq>WifLvkH9 z5M%ZImZKG)e!FgfVX0piTra!ni>8)2?vp8!L6}JDq2)&nh!$Zr#2s_ z#Q{mpN5nwlWPs-NXNRGlzKbL!^|X#rIHtk{!}1IAr<@31%OQ$P-o7G#0{@1}_&6M> zPka5M*Z|v7#Z& zq#VY)>oKXopGM#EM#V-((@YwgX3n|?XRc2d(=(`)3P!|2h{n>;q!o;a6|Vx#(@-`- zX{=_Bp$xt)%0|NdtAr2G+1S!d=F&1xwM`AhAsP`ij$!7mbS2^sDf^AMQjnmM= zt*B^8%%P@@Q61xvx1yp7m4e*TOy5OEo{EABN(Wmd!x;Du3Q{rcFH~9Hd<lI6Qk@!AxT zS=9~0;I%-&kD>Gnwo*)Eb5^6W`S%oHn!t7ftk+Zb%rg#bGEjI6#vn)G9uZ(jnl5 z#Kpna#D>>1HCH1r-O#5%_HtS5_TYF&oDi zvvEwrY#bKrUc|YYzos1hX-B`v&aX7gM$8_wOIT0c(Fg(1|8Y0Kvoi%U@#Y0q+NDW$ zb#Vdb$@PDs@+I*d4XnyPc=t9I+qm`K(tDz%QKDG?5zjUeTQc|&2>eeJeGU#Sv2AF~ z)WO=vZ|uoolD`BPv85Y{*vj6804mq(Ja25G;nCJW${pq5Ya*6H9YEG3tvR9~Z1@$_ zZa$PXN>vAKD3Ex5{TTyPJQRu@vf52v0jKn+x7ZqH(WoAn6lSb51WoH48szlhmMI$t zG!3_EvVkdT&?vBqR=-iJBQ)<*&}-0u6?5>7pDS%DXpAWD1|n7vFAYN2N;ov9UKzNk zmzcoJeRFAbO4O(&D6~4Ayry1|)2yI})RW54MlJ)ok`XKHY&3bL!CTZKtPl zFoOhI2@DW8OyCGWR*&_TF-v>qdR8|(!9@RAm<-Hf=493@U9rGLBqU6R0@p%WJ?4_b z@{ZQ#8J!6O6aU|o&^RSz2uEgS35KACAc`}F#qGjECp|Mt?*EYj|671CRir{{B(w4g ztV#-Nm3d(T3sK8+H3cky&*@S4;W9YT@4b}?ayn<$@zN10X0 zE!UFkW28y}x+v(>Fh9NLhNxNX50}h@V?>OZF<3+m^WZmVL<{AB5na ze#6Nt_Fk#Tx6sWxUz{6-1*4HXLP` zwYP+6d^F-vIdf(M&^k?i2GVz!fb<8c>nYR!xz?v zdwzW&<$gKsep$4<%!u|qeR10_$zzx)Zv>hs?I~nO&12{c##$P1*Cwz`G%hO_*yRCc z!OBlcx*S&HS#WOO3?=@8-ajg}4GGEuy2oGOUV{oS^dcPzEMGArA2V&tuvKGnSJZdR?V+rD5(tNV@t3c zD9OrXzBV-HimZ7s#Ku!haej6fbY^059;R5Q@oCv|Wu&6av<+v=Y(F1>5O@r?lb|dxZ_Qn=oW$Cf3eNfpAC}g=Ne0gRYUO`N_xxE#NPN zCCQ_UnuVJH6IEa<7EbUIAY#T6R|F}5#b2m>z-~EJ7YN3>K|kxaiVM z2y$t;9%+Io-e8g6pb1^XwE@v7;ev(*L{amX5VpJ3#!!WI3-Rr)plX%mps9FGrOW)H z#^?uvZ&i40o*}gFTl7XfNrlIg^G;W_@sswg*!RMsPBrQXTJGwi&XOYiT~YTI>6aJY zrZTSb{J2QC#&epIEaJ;G&x)+TM|JPWYgEy?*h-1#%i1u_^BjqLG-(vol;&9x535Sw zrA}W{qlLFL6~1?|Ol_E@j=QObMikDW$S|%e8(_FjfF1q zhfyA-QS56fnXAP6*bl29uBIdj<58%`#Lp6kdKoH>x6M#j@V9g{rn{qC!hS8nBpx$BQ?q-Y!)jST z??MHV%TJ1u_{6`l9iGGnJ)^KWNp&iEZVV1 z($2)lSEjn8uJcz*EE^G3p`2byC8~P;H`7cet;_00$Jr~g_OS^8+x;-r9uwbXy_nL> zG{cf7=@S2E%7>AK$%2$C;hE--OkBZZejop$Lp>D00w0n@S6*&$7uDk_kkh|LMgF#R z3wd8tC#`BxpK0u(Yh7mBKDzocP0-|#zkKhmPh8X38_9DiH{>QlEo0{~z?5?!?Hq{d z)*F4vvX9@m`v&X^-eRe*yUXH7Q|{KZyEPe1w?F@&FV%i{&3#yuPo|2bL9rV{%`3^) zS*;*a&kk08&_1f9Ysib2P1cYBCj7 zabbBVr#E`bHd=SC>e8*dGxaSwGwxy61z_Q-x52_JfBQ%;%PcW`#gLW(SiQSR`1zAd z&~d&cR`p3kWYMzgD|Z=|Oi#Hx((aD`ZnwCzKegjnddD%*{U**e1|J5~LwTp>7fL_YS<}SYOxOvt?Z!`f!*4Isd?5G z6hTFsSkad8lwtPDt^|K)-?B|}lQUW?T^RcNdROLFF0J<68%sCu%{1*;Ik$3b)q8I! z-LxAO|A2?lTK^wiOK&@hMNn1)tLGoI{L$;_*28QPETzB^@NnhLIj?*hJC3D2o#mu5 zPGBn?TH#D&=SF#c>6F+H@aC$Tn2q-Bc`q4AS9eP4E?SrlX#!9sR(BRkb5swF3iT`+9wAa#u3&>E4y@RQ>*R{eEh=H6lA1%F~RD30V`v`uN`gkeK-j zuF?U5=UmjWWwwiKBNq83VljYMq(f$2vBagq;6f-@V1yxJCh%*l-}D)%czsSGq3|M82TO7w>rZ{BzbFP%>*_rg75ZWk&^Ll-}Sc< zh-`v`zPEb%x_y1}oH;^zB?s@Ya6S*DhXo@As2&mT8n)1wDfLFxm|b}ZWDUxDity?% z#{GT1BfidvugBL%gZ@2Q{Sm^S#Pe}Hp_M1!fcrXg7lJX$F8o6jJcfaPiHk(Jen2iF z`dJI6*I{km>B~YB-P-UWk%`OmGxK2>JTtignRN*?6RyNGXSa;*^HPH!5w?FH5nuyc z#Z}cUzeNTd7}o<}@oZ0J94j&2u`FBInv1j=;@pQw}Q}Ly{c?~Ot^Aa}y(%rM>?n$|Krm<9{Whd(> zST3NKK42$-g4q~#8ELxoRdW3i0kR2u%4&rU6oEjXk3N#-bAzJEKVb02CI}$+JfGRM zZ==DNX=u%C^V1CJ#SU4CgSeett)zs(g%<15oT9e!EekbQ-w=S-> zF>&FpHRoox9Ei4pMEFZurpWA{3^~~8o^RPGPR-JRn{e4}f@wpuLG$65Mti)bKsltTvor$m{ zWB*sUA62jlZU{eUuvZcun-4860JdUj2yJQ`30mGNxj8C)w6U=emi}i&nsVAcNtz=A zXu@=8SX*2sALQ>TQ?Ct0EIFoOAIAHDan-m}|6^Kl+Zgy0v>K*8#_|BP)Sk3@3d2Ns zgh9f`aVQZm-D3LkWGI3;!I&K@Z95i*SQ1-`P5J`wAZF;DQ7opGUw$PtIx$9jB2qb* zoHHy<*8nbI%he9E$s2xXsKs3PcWnO8AJZ0f2OUoW=DA3qM-; zVf6iId~VHLBbsa0U0Cq;wFhH=dgY&8`QyniCUGO3Dc9?17nZ?%9Y*=_J-1#;>`B?1 zZw_XhuGm0)`>oT^{FK#vwBy4aiNKxS<=&geW7~iDQpR2t+Xsb`5o^+lh9=h2vwT44 zMg}zf+6)iYuTLV6TCN1^n&F{zd38$ z#4?0?fR#*iC6OM(A>YMkF&uIXN=Z!bk6f3Tjj-?gf$y^=1ipU>S83xywmOpGPULG3y==6A=0UCaB}Awk31CZ+ZV3 zTS1Z*ku;xJ<$q-s%_nd^2z+r_ay`;>)_Q4~$qmfk8M-|LWB~o@&Ew#)oUV8!7BQ&3 z^$M_o%^ly1Sy~5Dw#J(S>(Cj0FL6A%dnNRs|BI8q>Ho76qH{Rq98NojZw_Jj?D8c| zD~6T8ABKJzx_9~h<&<+@+PUxM5K?tGv5fZ69UJx)!%ExjA74q?J4HjMgiOp>{#BR^ z%xbU+iCP&p`5wNLj6Rg{N{!l%vM~O@);s)VC~#nwfl?`mzi6@i(zbZTW+1TOL%&0b zN)~+l?-h&onMZEz*rZUsCkpzwfqco(LhX-5SinI@akNj!7hhc*-UMS6l!$z&zF)7g zEd(aR_n`y}%cew4%w%Fy`6ycwF=5&PB;9+@z?!h|pK&8uNlDNggH5%2)xkl2W>>uBN70+n zb$11(pJ6TcL{$=U3M^9InRa)ImQGACWV7*FZyj7Zn5bT}`$YN7JmpbMZDL#EMzT6- zN!&>Erfd8weJdAM_pf%Y`d9rc-%IZ}96NP$C{tU9g~u%{*pYOt6sQA%I#=FK?>H2@ ze)H5i87RcZeilleTDkF{HPv(^-E`!^b+PiqZ_bPMS5x*^)Amg=~mu}kE@Xi1K zY})zVHdWeAbpzl0RNKrqKQA=)&Cjqg_upER|N1{;OG{=F)GO>M*}{#15m@jD`|wcM z4}hr{BbR0(k(sH@t=c8EYGAVnanr(4g#TZtP5;&3Zkw*INt7ohlIFxaiHQ|2{lwa? zbTtXq!mDqro?Jb(dg{R({luLwrhEEh2Y}n_F!L-~lk7;=pa@V(-Z9J@m8Jk_T2M9+agT52qUsKNv6F-r3x+tR=7e)vnTMKw&V< zP8_6GnM*i)^M=FH;wOrcB=!!+Zk|C%`kAES1r%}^h0MwH(L)Q~hak+B-!71$_g_*y zkb(=ADRvJw_*|~FtP2~k<>luz zJzleXj6$(a;|n$2+6&%OUJt}hBgJBA7-$sRhMDnz$gI3EPS^!5Ss&GbH#T7Wirg1e z-Qb3ar)8#)3&nB3#cmNt#bF0cL|qC>Kei&W3ET9Tu;mX%#7Zvm7VRZ{u@q^37b{E2 z=enR-qN>6}^~fEY0J3;6MnuIFs5RTB&mtiK&`~Q43h~uNYbM)(F7jbm2yo7fPO8um zwm`<4HD6IEOY7#cRLekV6F(LkIYSwPl4|Yo7i~qMrFU%=+8(w2P#3k|1T9BB+mh7w zHyuHnb_64TIUBt=G#3AcsDs7_+VG1`jcq^HA&EM*)F1DLF-yt1Fl|v6E~%WMBF@)B z3FPdiFf2sA7D|Ni(3>~ah15pBxh~x38x)_)cXM61i|WGj)X<)&8+GBy*M)nFx+pRb zkiY+Ts0+ok6?w)wDVys;Z4?kh`$2DYP@lz^7k#-flEsFR&a;wzA-UfQAEH8V^wLQ#4dH9@gu?^B*>Jdy$enyeCS$eS|EOJq2ko^HLN~|W#|e>@5q8z zTA5Dq8b2bBZwQ%UJ&D=b$r}tu1||M*H_~EE9HV-IAEEqSVSwJv(DWd=juRlpQ5YiN zr#M{zS>4RFtZ6(jIX3b}*1$MQBjX&iZeVgrYbi~rU}{r^!IYw$%<9R=)|4m9q#~!$ zlID1FvC016EScq<1FRzz?!*QG*>5H!m4f!-JpdTw>KtWsfYQc`IrME z7;8oR-Q>ADKS)_<;WtaS$j*%U`S?}Q{JeD1zzPPg^z8SFfkoQ($9Khd$9FFs{^1}k z2nS9!_Fl>kTU%+h3CR$53ty`TE7UdIJeILL<1WxP?3`)eCd%iAxlI1bue{jTE;;u} z`$})hdm!ySK>KHU;wT0BB^nd^lE*$ilxP&O51zY&cGL92>{r?%xIMY+;~j~fM2}e3 zk@j@Pj2rH{#JO9VZ93b|uD&F;ot4g%ds~|A-ns6s%Q#$dL;Py;T&k*l#lHH+gXu?l zork7cdN9=zLr5|6CKG{QgypiReRD^dt&ylpg1sKr_Q%h?gMG} z0nu{c8;Zo`j_c!Z$FC)IcixS^joqQ}6FqHdS9?sK@s`DRCrr>Z+Y`=&lQw6@PNLX~ z+Ut**p1nw1we-FE_*=;xtKL-WuGJF{zK8F1)AzbY=*bnT8EO0*|>l_F{SdFM|%QBmD@yI0y*cir2O!d|2$iv%kUB=@Zx zO1bx^-TOt${;zBmV&xkvF9G2E%?|0Bvb~wMy(t>rWIPf~Eyw`F&kqYW^d&H2gqkG3 z-PjTyvSg^2h|CGZ>q%-C@DnfwCC~~0H58LePfHV;VXK(G?H!w+X7l;5eW$J9mX?Av zg(jwFAafy3D+@7sr1#1UgfYwO9g>Vz`r1n zCGh72z9R6S2rz?PW*N)0B|A`nOq8R-orGZ-mXfS($Pj~}ufMSz#x8R6lcA#mP0Gi~ zg=63Za(bTUGu#33+0G5lag+TrT&1WzGhDN{#hKw6MD4l3oe@i&8Sao+>dcvV9(#9d zfyKd}fHiMv2qTK?@L8VnUmC`B%hpHi*%Nh$OFBD`6*D%2`1$4UJR<)mg@+B>DE-YK(R$ba zi2R=v9-iU3J}kLox60+|;V~g?Gl*YWMis&TN#Wu5`5H_T*>)Q(W=Px0LK0-mC&#A&ouGavs@L= zcO~!|T#{&pxU5JQ=2n3+EUdmk^AwNs{C7|h?1EYA@wVmecoS6{_G!e7qnw^S*)Tia rkLIw&76k@nt<0GvH!Qobbh+fVaV~GpF1cN#uGb`S_PquR6DQQcYHsQ#>;r)QcG!&w8n8b?iMP3buEsA)jMbEB5CRu-cj zv5nf#+7Z_I^dpW@=UL~d>#U328%ErtMQ4jfi_aG0{*cf3Hh;E6!yV&%rnfnt*>4(d zQ>vD^zOR+TXUkYh3sPGBhGADGg%TR>R6?1nn!AGKwjsCuzan=f%k4mJr@ulivx>#K z5bH+WcV*tHgw&KZEM*Z=79-_8<$30+rmQn_zQeu}f4#5NFAVB@WpC@xHn7~~$X)T6 z+>OXviT3qmN>)O?DwJ7`G7Y{Ol(~tOS&K639#dvBOIeSU!edglu#^o*+4z{0n|)2b zO*eR7GwP%9wY+UO+sg88M&8!P3ahG+J)b?!p{dyzU^=8&UW~A6n^g8`G|Nb_t9nULYbYvbD6t8x6C~#bL;Q9 z%)Kb{3Ew`vcRzb~J?7s5#C4fD%=DwagMPkn`_Dfjeo#}1DUA{f&+;enmZ$KemOkyT zAGGK~ZsN92Y6U|`xWSx3ePyd(X=A+P8|0tX&Wc!RgY#>b!4`ZPW*0D;c08)k?qoi=^( z>HUU5T?7sG8QylBHAcA0{8`f(&SR8Z$HeieiQaK>^vr}e6!4Brw^#%?HCpgaOpFBl zK4HQycCf+dWC=boq3~NgS~oPMj3IwntVXOlGBgTC{Mg9+4Fdm&pXb6n5a5m?bOqlkmKUO zkYw?XPK2)Xj|47Yju5#?ufqGgt6Zw0 z@;&2qD|dUOd&SdvtNd2|l6$FZY4Y~IpS_jX+q>*}KGGBILdtz}ZzXLlt6Hw8hI8KRiBmal3DNoUO~==lhQC6MDx5%!-L|F(hD|gbAeA)}Gn2sgSaK1;RdzbqIjnYS7<53nZ znmo6mEMwtx8_J^fkmu56>C}3eP*#C-d1dj#Re1&BR@6Mhq0R;8216gl8Pkk;#xi3K zQ;iAy%-u>St_OBS?TY&xLKMv5*&uTg1{T z&vPG`)seP`4M0L+;1y+Rr!jl3!<_9TCWU0X` zBUv50(=%d6GRUu%Ov;-jWBSFCUGz^*1w?I(57 zc{CCtal8!ec*;Z?5J?{(Ff7?d{iEaJmHvV8sWFh(6J9aYKQ=Wg@fRdBOVJh|~yzrzT=zrMzZ-dT#m$%|C3v*}gFK<2P@; zx!kyax%|Mr^6q4L_j38++umgPv5z+=%TGmYDQkJGVcA-}QeG81u+WgK-m+ZYzA%|A ze-+Xc5=#QVd^~~br`_C;`?2DXOso(V8 zo7dl5*t<9c!lJ$>VXIj+a3vM<`{(w@Cg+}vpSn}rlCV9suz7LQPq%!yWwAEV`qZ-R zsYG#0s^q-F6O)Ma9Zd``?3DdZU4r0xAetG&%bTj9(AUwJm|?P>c*-k7Q2#FJEBI4 zYfM&kEb_^!F5K2NeZTwrljx1j3tRuV>yN#OO;083o{n14k3TrOSdy%Na_PwZtK60;Rcd}*o62IKCFIl}0k$a9r4d0pkCk?Sz;zNnr&Y$o<=}MIExZ~Qn z>gG1>SS{g7sv^%lD6Y6s74J(HZ(H2CT)a17+ba`%29c;Qk%U&_R`Q-8k0z41oxC0J zvMD#V+Q=hfI7ld$QEpnj>9>|m`hBY$JefSBL^JIb_^&bAe>>0AX7tWz%&$_=X?y);2F~7JNX&ItTJZgh6~VAaw?xD zCtU&RDxHqcN%zRNGAb#*Lb#Vpnb85e6`;<1S{1+OiLxOa3yMY~{=*DaR*3cr8`OM8 z4cNG_(Z_?`O4n}27&Z<&H&#ZWbAUC${LPqzaKlUUILYE7{Ps*pO`4amTqMVcL;!?w|xSvoo<%NoRe$>yC3%!n7&j+(gmg z*>J3T_ABv*JITZI`&@f$YPqO6a`@LpwXuQtr_5nctU09CiV2t? zE$|s3Vj_{0j#N7&SQ^y&3^xp6;}7_6^E`Kj^O$-+Aw-c`^S88em;nLip>WP)Algi& zF2!dkkJ?jNU}1z3>&0<;K0zMIY#8lH!jiyPAk^QlEdCHZ$}36DC>bt5!gM|;S%&<9 z^FtwYzE3h=3iv`p6qIQ{>h}!MmnQh85IJ2}aBw@7D)^8l_;q+7{($t$Ya_aET2sOk z5#3!=?Mj&tuU_yjmpu`&uDGjXJ@M{kcMBGA8H7o5W3;|o7fvNhJ0q4<umOr}63 zA`XH~%0=A0A+K26`eZ&^p#}=2&=Y1Eb2-@?$5NG=^&NJCx(&Hl#D4LEFzM8PiDh0hMZ(l-XM}L zpCMC^Stf>P`Af#Mrb@CSaA{l|@yWsgn}URxaKMZgK{o}wM9jLc|D1Q=f;dA3w!;e& zJ*>dRY&N*IFjP#7*S|tpZva;jN8Hf0+^{`qs*4z-wZtIS+$#t7yLo|6mbU@;m3I;! z?7X)B8~algn<6Jt*w@QFJ$+p z;(WX-9$2hMRPMW7a@W-}+kM|z7HjxMc%`It{^;CMlJCde@$UHKf@kqmqIUO^?{3NA zIX>mCh&>1RT+?yK-SJ`T#?%c9uOw=BE%n_kITYpZ=U1V)Vs1A$=&t4B=7g;okFV{Y z-5;HteKL0Hj5|7`b?Ct*FZrl+`+t$E$R+3RB7Wt$M;SDeBMhVVlvT&J$Mx}nh3$*l_-Intk*MC0EZwpx>9?rPB=Nvke?%Tf3t^JQ$PR0dU%oi2S{djWxLGE&w``XLVo z`6%libs+Ziq>CI@;b=NHE@^_$fHWW#Gz%U#oLZSV!U69O5=MYt3Q z4GCUhBrpaQvoPYl;ukv^Sv2VPhNd8glSOdKT@D8hu8ApgviImTyk>LuBr^g_}i&=!lZ@`muL@9|E>7K2B8h^I_vq>UXx_K%Q zL?i@L4%2#1Q!TOT_(;NZWU=p#>B!Pc%cdibpiF2ih;Uq^QDRgtKde%@h)O%)W{pbC z@)`B`EY#xP0+hNNJ~k(j)61p)LvrJ1rIk)^(#B_NjLidwHL5 zf+?D00c@vmhSZ$KVVX?XqLAw`f#bd29>5kZ<0 zvnK#3imggCLOV)J5Nm zsN6;G2T2BDn7rV%Sj=E^-sV1U*~273m$T$uNBk3^p}uh_ng{%SP*X}pkj+Dv3G)#D zree$aN}boGXoWgdQvR|T8kL9hcv6uMX52`U8` zXyMXg=-%$I5TP5?NGI?Zzmvk*+HGJQaY5G*pka9d<` zGNqDbFsP*}Sp8z}!1x-b|A*ffaPDgtj<*Xhe-WsBS~2?a6@<)NTZ5stU|Z1BF7&kp zLnyZ`h?3iaSpa(ai9` z4D;}?kOBJ30N+!#Cn1@<=WIzjTUeku>1sEu^65-cfTv^L<>E>Ar2zxgpwi}by zMyQ&sEepFBy+1mzZsJPHqm%QO<}Q8jK*ClZJ&cNw#}KAbN<-ei#4kv=i3n4TQNR6uP5r7K zSFjj0^`99L`m}`Czre2=sxs&WL3y>aa!BL^J?r2gXt*3n3oT5x^ob4Kxb(4f(@>g$ zuulY}0?NLDi)4YK6dG8Vwt)bou%d$n^H@RYG$NC9QR1{+pA6ZIV=#s59~={J;>q98 z(_(lSP!ngXiSJrGcl*t|)-ws+85tVmnd+?R(ArxFB=s_pC77v-oU&Ya#OqD4*=WqIO0&JJ1~t$#mKuJOrD{ z3)4Fh3ix3BB%GsrAvEL-39wrX`o)X>Aa4Ca8kc5_f`XlWy&^FjSQc`JdPAYHepY}; z?PJnJ@bYQZHcKS_fF7t%)v=^C&4|dE}F0I~5Vdd7&-KcvtAC9s@<6$CU(uqR~;Nz$DOoP3e7vP^R{0Na5RK1vX} z!e$NL)Ds!x$F_|M_}9fQPKBvpj~p|brq&eDiI za&e`!eEyBOH&$%L(c|%#?%LW{Y)mi^J95|7uwrXS)oe=DHm7QuQ{_Uc){`o&OqJKJ zJ56>6f%lT%HPc#cd6@;Tw>;d&dA7&eBD!xm|3}S7)}LAbg`L8WAb@Dqu$A)?0%Sf5 zdp52RSC-4S$tJ8b{9BZo^5?lsYJdxV`t0wYK#)gD5CCgGz{}-v&SI$~6JQVnOhKw2 zl4mIkUFx8wk!uy;24iBtSJ}9d(ugT#b6tC8_L=D9JI^s*Dr%ZNNr964&;@>F_A9Z; zWv8lA^Ee;PQC*_N;i6<3g5_pESXs$9d7+;WS`KSO0ERkz4Vy)1LKWm0U@+N= zb||U@N}uN-M8J@c10BxHEvHd(+63`;%ZGX}WeJTzo&(0@0s8nsttxP!8mp9gCV8el zucNlIC3#p^&9F#8U@qv;_|6i4lDV(4v(E>!6rvK&z@Oxg-FJO+PDSBpH49b8Y zIqdJeV?KeVn|B1Jb4ZpJ-Q-Guko$AwKPTIT?LwPCBpc1*Ai~OQR(OsLQlQCp;X=EB zQnM-+gw`P>z@U3vNI%XIGpXgb3wUGJ8whC0n2^}4EZd7GaZ9H86CyF&K%)ofVwQAc zBr`gK8_7I1HV8{Z^oC@Z8ao$&X@nVU@>p4OWIx{0Sb3e?;DY zFDFu3FY%M&^AyEj3}%pDnh7rj4M*eA4+z<>4uW_a>01d{{yhRr2xW6#+cUc-I{6K3 zEUAbbOj#Z5zWbg1kj)QWZ% zve>V6%yz{1Sa=S>7?Kxw{{g?Hjz$d56twzPNoke#3IQx?Wlpo8YH{VIA~*pxCv$r zdLKyh8z9OH%d$b>7nX7Bd^T*-%gDDKlDS)1J#1dFSs z0|dhrM^KZ&^Zx~b$CB10+ycCn9g)h+&^b%s1*I+5Pb52YkAYAmgJG~gtj(}v$gjfZ z$K^*2bU8Uyq8Ek`+FrI4vDj{*GEm87Xic_Si2Gz6lSbSpnHGR8CrJxaC55V_6)Uf= z|A|6>s%B$?`OLz)Dia_?%uW}n1EZ3a$7of(h=)X0y0Rs8GFh3btcjbh*Mr1!q{{1X zu{|g*o!>jRH+Cl85Ffa?dC_#YcrW5jD{6ex!tuMsJ3;=*g5LmEEGJe`1#94`EEGX< zswW%gJMh2QJ}O`khR4M>ASX)hbK90CPAMM}jHSt9YAAe@tyd%3|kWX2TUyIw7_Y=2S5>YYW>_5F#-E zb3UWb;#*!h0VjIO3uvHXH#%8eJ>0|mli?Zd1n=Hm0md=%F8uzN`BJ0Of78SGGT zdzLCJm!=}B$&Bx;L5JBC}R`&tym#5-z3pOd{kFLc(F703wZg&r^?+>r033f8^0gX7tcl z|EM>Joe@59d;*}qe*&8%g8gZ0L46Q=i~TY+M<9vqzEsLqc3-mhLKsqq1QI_a@s~Ws zqJSIm+vG7K?I?vlCXetGqqVXKH=Qpq5R$EM8MRl#UJF)i9QDDLlc|AF7NJHnU2G(e z=Afh>7@xS(Fa9xwjcGd4Qs?pDYX#a8!aO#9$}f(O#Fj;pN7-|z|QE6RUk z7q%wMo9E0RFi$VLgh+R)sBFGwxh>mi<;NX`f^Le3gZWmo4aOrMH}AP8?)T8Hp0N= zo~twI>V%jbJ5Rk2wHmi0Ro%j_v)${knfOLX)vro- zM0uDNV%KU{Z2J0>iPE;H=5hBc6_xiYb|ouzEs9Iqmc(SmsoSr?i|VNw4bURD&5b2o z&G%eelCCWaM}O(swrb_dcB1|+_qB_&7o*;HuEe_E>$%<&?|S#d9cOdG)SQ66K_#ZU zmTk2%(af~c??7I1A5xU#v&W;S-g%bDZh5!v;^b2IZTDT)L_Y1FL$3k;me^ zRx?}kTPyBl*$Gu+06*2~(xJ64BaqD|$jpIH(QuPV3o!`#Li|3Mh=TDrkA!sI+1v0P z>N;aGdzlyb*LZiYN&-@i^!O1uwC01CN9u3TYG>Imt5%92%~gUcTah@(gS5~PX@L~< z7UmS)w;_x0zKJJbdcozh!GoC#V6UP)V}$lyq%yGlhzH{T5}v zI)V$$KW@L({##gvY>_*e%E?(F3RS*#2v2fX2yqoRa3BR(XdQlf{0#URVHYCnvP{rr z>3n9g3t<{71LR81R+GyMo12l-avLNtD}M$mF|GD zY&A$m2%4>xIy66Tt$=*b)i|B!aOi|O@58E2D1A@_?HwZaE6;6MK?Yhu%De%wN(V(K z9aUJwXZ$wAnHhE&7T(!vr17>j~(GJ^TrXiMe zzHQLbTja%}4x;$)s0h_=MzYCe%e0F4*Yxl-d5l9RwU9=-P9$&)sN^?@v}ewb%~tm$s+rl!?s`$|~kZ=SJf% zCCb_XfvoqlO>*v<>Ohi|)}(CZDO*LNx+7`pNOT>#-JR?@nzZ$->9wV#?lA+Ll{Tm9 zyAnrVTGwhQE)BVIKs$jV`mK$?V^VokKi_456XQB4eV)rpp-$;cvYZq!-IL9mv~r*S z2c#4KhP=Oow-KyW7wYgdZ2ME{Zvs|;Z6GR`K<}MiA~#&J5gd^lxi>0pam)Bn28X7-!*61zwODk6OxPo|qCdEi0P#86< zW(wK3T4B{rAqN>NQ2M_;pi^i9Dz^tV>h>XTFpFycPl#5Hu<%-runmJVPa;6rMzm&5 z1PH6AON?Q6{){&3*cejkL$yJmQpUUhg4*6bXv=P(AX!_cVep%&m4IUZkmBk6;wMUq zAVEA@8ghsH6~&YHXVf8tF0OLF?%4Oey141C+jCoY^_d6O;*`k+i)WhwG}b~JC|ZQ} z9AVd0PNO$4cuxylb^3H@Z5INc2kO&ie|iSq4LRW5C^?UeUmD8??U1NvjQd1CaY#@k zWQ_aNs0{ee0rCVHSm*1|;4cE?N#c27H#$ewP#Fu!kTR-uN=<+>tRKXNE%{2=4Mqbr z>~I-ipM?wxp%0}K{$f`zpG^qJYaxr>Cty zuvnSN$4lKu1^?Ip?UeXju|n1kGaM^3HG0l3()J9t*;7!=KQbn@^j>1Wl_%MG5KqU; zD35Y8<>K`0Y>JjxZU96PO55nFNJ>Btq$kflp-+*XsByxecL3^^+^mWLshVn^0FzU) zSrF8y4H=nHOdpKNGMJ1App0<@6IJ)-$f_Suf-&-%$YYC?7FZgCGNbk2BQWGDrF%#o zTk~ra`T~ftA~Wucj#820T>=bb=S4`dH47=G2q}U$a#hDh?l2-W;cQEp9M|l#c4iB) znM@#Rw*tc@j=Y$tc_~%VmS{hic90x}M&sZCWu-S5$NFdT#{!$F@J5&GMg zcwR5t^RldxLEby9f;N0R@n6Nf2Pp8H+Y_A~kn#an16AOn8j zDhpX#A(lXIo+gE{O_)Mbg^)&2V!(QDSb9eKsz)~Cas|MWEQ=|egtMImuL(Y|bl zKkZyBQc`BF@@`%1o-`H`8;A7@2&WTll=E|ZT>7*eDpM8BXUZhYFev8JHKelz!Y28Y zfvk3sp<$lGafB)&D70%YY*K2Im5w;fJeNs#o|G?YO+Z5!dq1I9$nw?w2q`ktr?buv zaEd>rG)Ks5An!1F2go}JPtxMZL+qRzza(iU0-``s&^=CIPmSc3uh=&n@{1#W??u0) z#r{5+%g|{91LHx!zR#$jA--N?XpsKdIX)jo;;f}RBp{q_< zs|k!TmGnD~tKB323GGYLtP5U*V16DDEVwC!s5eUmxTJ*&Uk{QyF$y|a7o z+lsJFM4_0of~!UywkOl>H2_+hfzFjHK&&Zj>KYNtImDVFOR=pTG%#hh4$5rEukuN? z9=XGN(tD`bP>LKnUF6~AfGL8_3_hPkAddnAkjiDVNfeW%eI{Kr5k2X% zm{FQgnOa~+JCncU#>xyMEdZhpvU$adyhf=gv7RPU6LJ70DqF5-`$cILxHn6xq=J|@ zb}7WS5go|de-?pUBV~Y&G=RV()ge}Gr5Z^#O=*fU69?)tof6BMpZ;IFeLf-G2tip2 zv|(089*?2oKzRi}$L~3+ta7+4oMD%x!=KUJ-`dRj zUS6gYDAxc=SDhk;S^Ey6a#kg{3H(7)hCj&4ZqSJtOjZW#a4k;rh6If=!xCHxS=p%^ zES5d7<=xwn(!icLa2|)#Z}hqFbK_i&3by8pz@C|W#aQ!ZCVjJDc9lrRBYwsVFl}6K zA3%0sWC{ip&*0@!u}WsE8?b^>kwQRc#VIF@k$f12_v2&o#!Q&`O2sny4|~K+qS#x35!y2^u9{r)v*pfq%+0uNc z?g)3yF`zjQ3y(JdvFV6E5OyeIPV-H(j+=4LII_&qoeEJPA9sNAO=t;oS*cV)2C$@n zRfBv?V1@vL$kUavYAAG65lBvF`T)l)E5!5sj60*_hX9ftjH%QR0#htt93SW&R_44? zRa`SR3rfj2gV~fm+|cGTLKI@p6QQ<3-|s3|R zQ@B*A8wBFdV>He3L7td@Wm&C)NgdTmNvqDIvVvo;+P+{Ov3zVEu~7K;JCDk8=22$O z38-wI3%iusV}-hvhoNRAq|_A>*fd>$bZUuL01J5hW*nEG4e1r}@rSH~G(?gvia~fV zg*ZePjSTY$$bl64cl4M6K~<(M16O*JrzAr~Ne)IXsL)j`q4$)M$I4)k^Y4&JvJ<&J zHKAx(2%m{Q@|wtdK;B=G_apLt15dKZmh*HRyev+RP^^LzNG0H@knuZ#e6g5bn)mw0 zLv$tx!O0=O1(If9h`mU%4hBRVvd)wUOgp6yTKtk~qi1O(pc^%PJK}@H ztPx#Oz-3dB*qkiO%#>4$wxNTAx8Q+#TyM-4sMNL6_TB1KEp~bkJ=wI2u0pB~ssxHx z?w!@Ki>j{v*Yehs-)ud!Y(JdbdMF->_@ZwtI~o&P4@EQy`{9*ZAzt+E%kjzUud!lR zoJG+mmz|Aq-@TSS$(B7!UCEXMcbo?jrUMD*0jBC8j>pLuis5);{PfM1#lT%saA?T* z0hSXas)?UVmV18XZXU8qJ={wHT z3DfB$^Apa~4{C6DNhQt|r@h4I{V1c|X#?S4Vv*g^Y>dmDN@M1#e!sNoa z#jaa}$@-nCx-AQ*77i`CZ}lYWwqt0Pc-UZce|Rz3^d$CSFM1bGFExDlQnK*?wg5eJ zQVA}uxc;FVp~z93R%I_H9a60Mmo{Ow96j-{iYuv4;HZ|b;8kT65euDl{ua)pDsTOj z+e5uvcF$IuwAH42d2;S32E<*I^6XliOnUa-X0|3psgtje%?Y|#CwFlhv=%vCd^S<` z#9dRDqN(axF55zys+H=-__nzB=FWxoEKOH3Hi+2kEQqS*jip5DrvZih1a`qSG!q|2p zgiQwR!lgHmh%C{3!fPE}?LwEbrGbD>L1qhfb*BMvaA1-M)d=!k#F5A2h2EIOW2r9z zZ>YW}53corPwNs63aufbL+BzVZAY$j9O;ka`4FDa$xAOIe2IlK$z&AuL-Zi{guli` zCcDY>jXjb~*!_Ux<;Kp7w0jGmion@Of%8-2Q^B00HR|!Oz!os40&3G--q#4yPNHyF zo|SQBRr9ZX@3qL$U%Ozk^j*^|NtDYPu-I&_d)B6;wJCn_jnNx4|sad*^zX1EVILr7*mI>jvdr8 z(r{*(W6-LRH;-MQ(1+wrkw<91u-ei_c#pg;YSI6UbR;r1akX`j?(Eo;+P3R{jgYEo zOf`9E%UEqMuit~}kphziO&qkEW)`$?71gU&7PN77jq#@IuddqJt%Iwqjh(q}UUjlt z7gyI3_bqI{8A#SXvFc`VMOt^~W)Vn^ntS4-I~_48^u3sz7C zDp|0Kt7}-TX2BZTmd0xMw_nzj>YtGLU)I8>(J+_?T@D0gVdKBX{ea3s;YN=i-LL&A zyhn-ll?4hg0=1o3o6(@MLoLvY`A*4!V=uISpsM2NV6UK{CanE7`|OAJU!p=5c{#I>cgXvz88}F=({q3Q`tsFt_WR6GAo)27g{0NO=o2H@^>FD4Fv{RI8z5zNljRoY7s8H-Vok?)F#;j_Y!NeZJ|~WQee``EjD=idCJ>r#JgOW? zm+8B6e)Ltla1JU=I)`rBd_;l6%&-^&t$mmwlI11u$dsR*s4H&^)6oZl@bqas6*tk^ zXR>zzx4G(jsc6)D!QZdeiE1g*R<3_SaDwEVfP4aHa(sRItJD7=9G$#ltXOft^5o#s z;77wh8~zs~OJDkUZ=&+WWyebi(@U`IjPAg$svSvd{naBWyCZTqx@q6jX2RkN>FWJ8g3N%*)AYD8`ZJGWq0!A^2!-~YggtvUCP z(|eJk|98`S=`wn+Y^@grC|se9mr7fdtrm2iC+cht)( zL(t1$Q)K9pLIVVaFMGetHeP=D9Io>BU9gIRI8iYJFoY=rHp(e!Omj0a4uKo{{0l2M z7nXPnf)U1JDRZ9XB0zVpVA>#2Ey^116rQ6`$}p(Mj+?;ud9bblhy*WdKn7Sm9s-RE z_}URjf>8#HutWe~Xz{WNeF4zGasZIZ3tN8kbW5MS&q4n7O@WHaXHQP+SywX{fsnSY^dVQmTSgYW3+32_uTF~rYaCO`s4h4^VUWFr^XMB z3s;u(OWu#p|LpwIs|n*VoNM`G_C$BQbMop@P#Shev^3_PtDJovz{28; z?!;FG_9iX0R}Zg1b@rv%FU5|=w=ejYx<5Mpv*RCk|M{_hbL_6Y?`jViO@|OSEa(=z zKlb19zkhz&zUyia3bfgye6(lIlC)OiFvync!%1svLf6WO1<8nW*TfO;L=K!G0Vjhq z?Pw!uzte0)WyZFuuyQ5LpNBvlnrQS3Id(29Q5tzJvt1kpL9e3^WrHBTkU`KR($dE{ zQ8O#PxgsBpja96pd(tAyG1!I<%0W)2#L+$4lSyNMCLXPnUBOmqvrLizolu`N1PmLP z{x_?&>WT48V6+#7cHYv-I^O>aH}c%1i36HFGyAq80(PGLBmThdheR@;Q-j8H z{lM=-p^dZA#b#sw5b*yrD14mxL1%R(XkrpM4E?Op5pDQdcm)TpUzxo^hla+=;-~LA zTN9>M>=$8MFB;$37c0jRqY0gmunMWdu@#lErr4EudE6Ac66?gNrVCvQXBKxawtnbY z_)>DqfyjxgJ*mnn9EfRRi4GYR7BZxGCMwZW@bOh%+Kste3tM3g<1N^ z|BcV`uumFvid_XKrQX{Q3j%RF9*R5Q!3w$m+;}K7KDu!tI^>B6a1tWoR@{&D|CXj9 zZ4~}JPQ&twSaB>6H^wf<0zYymD%+CfBt8l*zPxz+!xMiznb`VtvZFh)7tpf`yBFdW z@y+pyA9W>aU|Za=Fu7Q~c>cqoKj9Ny-H`A_E@!5s|Hg~)y$gY*qUG8HOGEinl5I7R zOld7x?pawnw2;BeO$g+up+QY#f8|WxuC1C$;Uj|hyb!*)DSs~T0D3H2?<%?ITvb0r zrOI2yr;y~YsZ+~Xr(!kXJAbdt)ht`<;uT4&C!zDma9uNqFAk9mITRRyvHqp(n<2@8 zk5i;Yv$W+fXGgmV{J(=HWi2;5$T)dFvxyk7jAREuo_*wv5m2aK`g12%5S!Yt285*! zj3BB@4g@5TgUpy}IyWuIz71&$5~Mf|0AGh>rvu_^mDtUxXJ?;ee`Lj2AP%Q~u}_|0 zA9aDJ8lRC6k@yE078f5nVL4!=ltyRl(VRnQzsiy;4uz1r?*)tH74AC%9 zfDfbL9a%Na^p82ae=?Ff`Y;L0gSVsCfO-_W*BMJrD!Js1sjOkkS2bblRsIS}Dzt^d zXL%oEKu&23BiU_P*0j2?jfi0--HbiUZWgdH*{)Gf2cd==dVeR2=o*F_y=k#lb8!?Z0+n`{z4T5gSNpb4m2g_^G+CBu#W^ zi;3iy`n~tfrO^w~HxiAz6UIG6*6&T>$?upLOu+8h-O+8)?U4gtJwk`MfC`Jem9)a- zO|^w96lHcryJu^z9!gp5QAZMJ0G4-630+gd+H~Jo#2WIzjn8z%CvUbabS}I1(zkb9 z(cb6Am#GSYz2Vy&82k~V#+*Q}V=eh2QA9}ZAVb|iJTi6q{JK|g^ ztBh@qKcB4HzFf8=VtoK1u?}DNd2QLbH(}cQ$W}U>QEl{%=*75Z?#<{M-|xoe-K3)> zqD{GrqT6GJc+qw1b^G@x@m07altPBbw$f2a?a)SFi*H$UCmXjd9$Wen+TpZd2eg4J zs->rfm^J>|Vt2B6=i;T?4m>Tgz_1&wbVl~aT4J5)R@S}Odc8Hi?cMf;mc?ziw$Rs! zHnx&Jd$em||FUy;!nFH=r6d75=h=mWi9IJi-tuv6qV(lu%PR@pD-7R5FTDrz+{5>Y zX3S?$6a*gRxb67ZA_R_51Q3lP60su7%t;d@pON2y5EKSXcsVBTFoW4q`nGd$YK-j+ z!Iyt5nXr7OlEEJs8wVRsRD%RrYN&H~93l?(6^M~e(mCKjV)Z*g?ATZK0SU>9&vbX< ziw;O~QT`+Z0Y&z89LcPH@0gtzDl5Ke{|oy%glx_K3%ZSvw@6-;JO*?wQ0OLk|BSpY z@@|v&C*(=w{f4~1BJV$w$4ueKYK-XxWDx-|X<6`#pd5n~;t1s+>WhJp|6)3T?_J<4 z$4`p1Hy{68a1mX0fg739kMxxSIdKo%~B~|C)j4_pO_3 z{4qW%&WGp1>l}sG>)QBz(bD;fxr+Jvx%zbux9ey56238}f6sK?^q%v&bDhKOdgmdY z$Cm=0S?5?_{TWRW?~a_i7MKmJa|o`tweoHhjnkUeaUzv_^(>DMDL&Ne_&$96p=4D{ z0SGDASj@U(7vC5Wi9uZF5PT@|)qHVm>b>yw@G6JkVvt>K_pQ<0`hi{^rEBPrfpw0; z>*q99eoOT9{Ht@Xu5$>k*EI3&SR<|Gbq>MRlRVe8>r+n88x~5|=(6tW;OnC2=q!SD z4#9_C<|}wZy!3}PH*3~71lM=D_+DN~LE-h5L7sm?X(feMPixEg`V>rdVN<+H;rkZ* zsvcJ?I;&&VNFftvuzvmZufD!!rn_xbyeIPFwby4~U*!;t`QID4KEi^FP3*GN{89VQ z+Slme>OK>Hibvn$ig)B8r-*TY{8ui{jBq($p-%%M>Z!^i;V9Liw}=WwrD<*);)taZ6; HHlzO^i0-_e literal 0 HcmV?d00001 diff --git a/mochi_preview/vae/cp_conv.py b/mochi_preview/vae/cp_conv.py new file mode 100644 index 0000000..e5e96de --- /dev/null +++ b/mochi_preview/vae/cp_conv.py @@ -0,0 +1,152 @@ +from typing import Tuple, Union + +import torch +import torch.distributed as dist +import torch.nn.functional as F + +from ..dit.joint_model.context_parallel import get_cp_group, get_cp_rank_size + + +def cast_tuple(t, length=1): + return t if isinstance(t, tuple) else ((t,) * length) + + +def cp_pass_frames(x: torch.Tensor, frames_to_send: int) -> torch.Tensor: + """ + Forward pass that handles communication between ranks for inference. + Args: + x: Tensor of shape (B, C, T, H, W) + frames_to_send: int, number of frames to communicate between ranks + Returns: + output: Tensor of shape (B, C, T', H, W) + """ + cp_rank, cp_world_size = cp.get_cp_rank_size() + if frames_to_send == 0 or cp_world_size == 1: + return x + + group = get_cp_group() + global_rank = dist.get_rank() + + # Send to next rank + if cp_rank < cp_world_size - 1: + assert x.size(2) >= frames_to_send + tail = x[:, :, -frames_to_send:].contiguous() + dist.send(tail, global_rank + 1, group=group) + + # Receive from previous rank + if cp_rank > 0: + B, C, _, H, W = x.shape + recv_buffer = torch.empty( + (B, C, frames_to_send, H, W), + dtype=x.dtype, + device=x.device, + ) + dist.recv(recv_buffer, global_rank - 1, group=group) + x = torch.cat([recv_buffer, x], dim=2) + + return x + + +def _pad_to_max(x: torch.Tensor, max_T: int) -> torch.Tensor: + if max_T > x.size(2): + pad_T = max_T - x.size(2) + pad_dims = (0, 0, 0, 0, 0, pad_T) + return F.pad(x, pad_dims) + return x + + +def gather_all_frames(x: torch.Tensor) -> torch.Tensor: + """ + Gathers all frames from all processes for inference. + Args: + x: Tensor of shape (B, C, T, H, W) + Returns: + output: Tensor of shape (B, C, T_total, H, W) + """ + cp_rank, cp_size = get_cp_rank_size() + cp_group = get_cp_group() + + # Ensure the tensor is contiguous for collective operations + x = x.contiguous() + + # Get the local time dimension size + local_T = x.size(2) + local_T_tensor = torch.tensor([local_T], device=x.device, dtype=torch.int64) + + # Gather all T sizes from all processes + all_T = [torch.zeros(1, dtype=torch.int64, device=x.device) for _ in range(cp_size)] + dist.all_gather(all_T, local_T_tensor, group=cp_group) + all_T = [t.item() for t in all_T] + + # Pad the tensor at the end of the time dimension to match max_T + max_T = max(all_T) + x = _pad_to_max(x, max_T).contiguous() + + # Prepare a list to hold the gathered tensors + gathered_x = [torch.zeros_like(x).contiguous() for _ in range(cp_size)] + + # Perform the all_gather operation + dist.all_gather(gathered_x, x, group=cp_group) + + # Slice each gathered tensor back to its original T size + for idx, t_size in enumerate(all_T): + gathered_x[idx] = gathered_x[idx][:, :, :t_size] + + return torch.cat(gathered_x, dim=2) + + +def excessive_memory_usage(input: torch.Tensor, max_gb: float = 2.0) -> bool: + """Estimate memory usage based on input tensor size and data type.""" + element_size = input.element_size() # Size in bytes of each element + memory_bytes = input.numel() * element_size + memory_gb = memory_bytes / 1024**3 + return memory_gb > max_gb + + +class ContextParallelCausalConv3d(torch.nn.Conv3d): + def __init__( + self, + in_channels, + out_channels, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Union[int, Tuple[int, int, int]], + **kwargs, + ): + kernel_size = cast_tuple(kernel_size, 3) + stride = cast_tuple(stride, 3) + height_pad = (kernel_size[1] - 1) // 2 + width_pad = (kernel_size[2] - 1) // 2 + + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=(1, 1, 1), + padding=(0, height_pad, width_pad), + **kwargs, + ) + + def forward(self, x: torch.Tensor): + cp_rank, cp_world_size = get_cp_rank_size() + + context_size = self.kernel_size[0] - 1 + if cp_rank == 0: + mode = "constant" if self.padding_mode == "zeros" else self.padding_mode + x = F.pad(x, (0, 0, 0, 0, context_size, 0), mode=mode) + + if cp_world_size == 1: + return super().forward(x) + + if all(s == 1 for s in self.stride): + # Receive some frames from previous rank. + x = cp_pass_frames(x, context_size) + return super().forward(x) + + # Less efficient implementation for strided convs. + # All gather x, infer and chunk. + x = gather_all_frames(x) # [B, C, k - 1 + global_T, H, W] + x = super().forward(x) + x_chunks = x.tensor_split(cp_world_size, dim=2) + assert len(x_chunks) == cp_world_size + return x_chunks[cp_rank] diff --git a/mochi_preview/vae/model.py b/mochi_preview/vae/model.py new file mode 100644 index 0000000..1263271 --- /dev/null +++ b/mochi_preview/vae/model.py @@ -0,0 +1,815 @@ +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from ..dit.joint_model.context_parallel import get_cp_rank_size, local_shard +from ..vae.cp_conv import cp_pass_frames, gather_all_frames + + +def cast_tuple(t, length=1): + return t if isinstance(t, tuple) else ((t,) * length) + + +class GroupNormSpatial(nn.GroupNorm): + """ + GroupNorm applied per-frame. + """ + + def forward(self, x: torch.Tensor, *, chunk_size: int = 8): + B, C, T, H, W = x.shape + x = rearrange(x, "B C T H W -> (B T) C H W") + # Run group norm in chunks. + output = torch.empty_like(x) + for b in range(0, B * T, chunk_size): + output[b : b + chunk_size] = super().forward(x[b : b + chunk_size]) + return rearrange(output, "(B T) C H W -> B C T H W", B=B, T=T) + + +class SafeConv3d(torch.nn.Conv3d): + """ + NOTE: No support for padding along time dimension. + Input must already be padded along time. + """ + + def forward(self, input): + memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3 + if memory_count > 2: + part_num = int(memory_count / 2) + 1 + k = self.kernel_size[0] + input_idx = torch.arange(k - 1, input.size(2)) + input_chunks_idx = torch.chunk(input_idx, part_num, dim=0) + + # assert self.kernel_size == (3, 3, 3), f"kernel_size {self.kernel_size} != (3, 3, 3)" + assert self.stride[0] == 1, f"stride {self.stride}" + assert self.dilation[0] == 1, f"dilation {self.dilation}" + assert self.padding[0] == 0, f"padding {self.padding}" + + # Comptue output size + assert not input.requires_grad + B, _, T_in, H_in, W_in = input.shape + output_size = ( + B, + self.out_channels, + T_in - k + 1, + H_in // self.stride[1], + W_in // self.stride[2], + ) + output = torch.empty(output_size, dtype=input.dtype, device=input.device) + for input_chunk_idx in input_chunks_idx: + input_s = input_chunk_idx[0] - k + 1 + input_e = input_chunk_idx[-1] + 1 + input_chunk = input[:, :, input_s:input_e, :, :] + output_chunk = super(SafeConv3d, self).forward(input_chunk) + + output_s = input_s + output_e = output_s + output_chunk.size(2) + output[:, :, output_s:output_e, :, :] = output_chunk + + return output + else: + return super(SafeConv3d, self).forward(input) + + +class StridedSafeConv3d(torch.nn.Conv3d): + def forward(self, input, local_shard: bool = False): + assert self.stride[0] == self.kernel_size[0] + assert self.dilation[0] == 1 + assert self.padding[0] == 0 + + kernel_size = self.kernel_size[0] + stride = self.stride[0] + T_in = input.size(2) + T_out = T_in // kernel_size + + # Parallel implementation. + if local_shard: + idx = torch.arange(T_out) + idx = local_shard(idx, dim=0) + start = idx.min() * stride + end = idx.max() * stride + kernel_size + local_input = input[:, :, start:end, :, :] + return torch.nn.Conv3d.forward(self, local_input) + + raise NotImplementedError + + +class ContextParallelConv3d(SafeConv3d): + def __init__( + self, + in_channels, + out_channels, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Union[int, Tuple[int, int, int]], + causal: bool = True, + context_parallel: bool = True, + **kwargs, + ): + self.causal = causal + self.context_parallel = context_parallel + kernel_size = cast_tuple(kernel_size, 3) + stride = cast_tuple(stride, 3) + height_pad = (kernel_size[1] - 1) // 2 + width_pad = (kernel_size[2] - 1) // 2 + + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=(1, 1, 1), + padding=(0, height_pad, width_pad), + **kwargs, + ) + + def forward(self, x: torch.Tensor): + cp_rank, cp_world_size = get_cp_rank_size() + + # Compute padding amounts. + context_size = self.kernel_size[0] - 1 + if self.causal: + pad_front = context_size + pad_back = 0 + else: + pad_front = context_size // 2 + pad_back = context_size - pad_front + + # Apply padding. + assert self.padding_mode == "replicate" # DEBUG + mode = "constant" if self.padding_mode == "zeros" else self.padding_mode + if self.context_parallel and cp_world_size == 1: + x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode) + else: + if cp_rank == 0: + x = F.pad(x, (0, 0, 0, 0, pad_front, 0), mode=mode) + elif cp_rank == cp_world_size - 1 and pad_back: + x = F.pad(x, (0, 0, 0, 0, 0, pad_back), mode=mode) + + if self.context_parallel and cp_world_size == 1: + return super().forward(x) + + if self.stride[0] == 1: + # Receive some frames from previous rank. + x = cp_pass_frames(x, context_size) + return super().forward(x) + + # Less efficient implementation for strided convs. + # All gather x, infer and chunk. + assert ( + x.dtype == torch.bfloat16 + ), f"Expected x to be of type torch.bfloat16, got {x.dtype}" + + x = gather_all_frames(x) # [B, C, k - 1 + global_T, H, W] + return StridedSafeConv3d.forward(self, x, local_shard=True) + + +class Conv1x1(nn.Linear): + """*1x1 Conv implemented with a linear layer.""" + + def __init__(self, in_features: int, out_features: int, *args, **kwargs): + super().__init__(in_features, out_features, *args, **kwargs) + + def forward(self, x: torch.Tensor): + """Forward pass. + + Args: + x: Input tensor. Shape: [B, C, *] or [B, *, C]. + + Returns: + x: Output tensor. Shape: [B, C', *] or [B, *, C']. + """ + x = x.movedim(1, -1) + x = super().forward(x) + x = x.movedim(-1, 1) + return x + + +class DepthToSpaceTime(nn.Module): + def __init__( + self, + temporal_expansion: int, + spatial_expansion: int, + ): + super().__init__() + self.temporal_expansion = temporal_expansion + self.spatial_expansion = spatial_expansion + + # When printed, this module should show the temporal and spatial expansion factors. + def extra_repr(self): + return f"texp={self.temporal_expansion}, sexp={self.spatial_expansion}" + + def forward(self, x: torch.Tensor): + """Forward pass. + + Args: + x: Input tensor. Shape: [B, C, T, H, W]. + + Returns: + x: Rearranged tensor. Shape: [B, C/(st*s*s), T*st, H*s, W*s]. + """ + x = rearrange( + x, + "B (C st sh sw) T H W -> B C (T st) (H sh) (W sw)", + st=self.temporal_expansion, + sh=self.spatial_expansion, + sw=self.spatial_expansion, + ) + + cp_rank, _ = get_cp_rank_size() + if self.temporal_expansion > 1 and cp_rank == 0: + # Drop the first self.temporal_expansion - 1 frames. + # This is because we always want the 3x3x3 conv filter to only apply + # to the first frame, and the first frame doesn't need to be repeated. + assert all(x.shape) + x = x[:, :, self.temporal_expansion - 1 :] + assert all(x.shape) + + return x + + +def norm_fn( + in_channels: int, + affine: bool = True, +): + return GroupNormSpatial(affine=affine, num_groups=32, num_channels=in_channels) + + +class ResBlock(nn.Module): + """Residual block that preserves the spatial dimensions.""" + + def __init__( + self, + channels: int, + *, + affine: bool = True, + attn_block: Optional[nn.Module] = None, + padding_mode: str = "replicate", + causal: bool = True, + ): + super().__init__() + self.channels = channels + + assert causal + self.stack = nn.Sequential( + norm_fn(channels, affine=affine), + nn.SiLU(inplace=True), + ContextParallelConv3d( + in_channels=channels, + out_channels=channels, + kernel_size=(3, 3, 3), + stride=(1, 1, 1), + padding_mode=padding_mode, + bias=True, + causal=causal, + ), + norm_fn(channels, affine=affine), + nn.SiLU(inplace=True), + ContextParallelConv3d( + in_channels=channels, + out_channels=channels, + kernel_size=(3, 3, 3), + stride=(1, 1, 1), + padding_mode=padding_mode, + bias=True, + causal=causal, + ), + ) + + self.attn_block = attn_block if attn_block else nn.Identity() + + def forward(self, x: torch.Tensor): + """Forward pass. + + Args: + x: Input tensor. Shape: [B, C, T, H, W]. + """ + residual = x + x = self.stack(x) + x = x + residual + del residual + + return self.attn_block(x) + + +def prepare_for_attention(qkv: torch.Tensor, head_dim: int, qk_norm: bool = True): + """Prepare qkv tensor for attention and normalize qk. + + Args: + qkv: Input tensor. Shape: [B, L, 3 * num_heads * head_dim]. + + Returns: + q, k, v: qkv tensor split into q, k, v. Shape: [B, num_heads, L, head_dim]. + """ + assert qkv.ndim == 3 # [B, L, 3 * num_heads * head_dim] + assert qkv.size(2) % (3 * head_dim) == 0 + num_heads = qkv.size(2) // (3 * head_dim) + qkv = qkv.unflatten(2, (3, num_heads, head_dim)) + + q, k, v = qkv.unbind(2) # [B, L, num_heads, head_dim] + q = q.transpose(1, 2) # [B, num_heads, L, head_dim] + k = k.transpose(1, 2) # [B, num_heads, L, head_dim] + v = v.transpose(1, 2) # [B, num_heads, L, head_dim] + + if qk_norm: + q = F.normalize(q, p=2, dim=-1) + k = F.normalize(k, p=2, dim=-1) + + # Mixed precision can change the dtype of normed q/k to float32. + q = q.to(dtype=qkv.dtype) + k = k.to(dtype=qkv.dtype) + + return q, k, v + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + head_dim: int = 32, + qkv_bias: bool = False, + out_bias: bool = True, + qk_norm: bool = True, + ) -> None: + super().__init__() + self.head_dim = head_dim + self.num_heads = dim // head_dim + self.qk_norm = qk_norm + + self.qkv = nn.Linear(dim, 3 * dim, bias=qkv_bias) + self.out = nn.Linear(dim, dim, bias=out_bias) + + def forward( + self, + x: torch.Tensor, + *, + chunk_size=2**15, + ) -> torch.Tensor: + """Compute temporal self-attention. + + Args: + x: Input tensor. Shape: [B, C, T, H, W]. + chunk_size: Chunk size for large tensors. + + Returns: + x: Output tensor. Shape: [B, C, T, H, W]. + """ + B, _, T, H, W = x.shape + + if T == 1: + # No attention for single frame. + x = x.movedim(1, -1) # [B, C, T, H, W] -> [B, T, H, W, C] + qkv = self.qkv(x) + _, _, x = qkv.chunk(3, dim=-1) # Throw away queries and keys. + x = self.out(x) + return x.movedim(-1, 1) # [B, T, H, W, C] -> [B, C, T, H, W] + + # 1D temporal attention. + x = rearrange(x, "B C t h w -> (B h w) t C") + qkv = self.qkv(x) + + # Input: qkv with shape [B, t, 3 * num_heads * head_dim] + # Output: x with shape [B, num_heads, t, head_dim] + q, k, v = prepare_for_attention(qkv, self.head_dim, qk_norm=self.qk_norm) + + attn_kwargs = dict( + attn_mask=None, + dropout_p=0.0, + is_causal=True, + scale=self.head_dim**-0.5, + ) + + if q.size(0) <= chunk_size: + x = F.scaled_dot_product_attention( + q, k, v, **attn_kwargs + ) # [B, num_heads, t, head_dim] + else: + # Evaluate in chunks to avoid `RuntimeError: CUDA error: invalid configuration argument.` + # Chunks of 2**16 and up cause an error. + x = torch.empty_like(q) + for i in range(0, q.size(0), chunk_size): + qc = q[i : i + chunk_size] + kc = k[i : i + chunk_size] + vc = v[i : i + chunk_size] + chunk = F.scaled_dot_product_attention(qc, kc, vc, **attn_kwargs) + x[i : i + chunk_size].copy_(chunk) + + assert x.size(0) == q.size(0) + x = x.transpose(1, 2) # [B, t, num_heads, head_dim] + x = x.flatten(2) # [B, t, num_heads * head_dim] + + x = self.out(x) + x = rearrange(x, "(B h w) t C -> B C t h w", B=B, h=H, w=W) + return x + + +class AttentionBlock(nn.Module): + def __init__( + self, + dim: int, + **attn_kwargs, + ) -> None: + super().__init__() + self.norm = norm_fn(dim) + self.attn = Attention(dim, **attn_kwargs) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + self.attn(self.norm(x)) + + +class CausalUpsampleBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + num_res_blocks: int, + *, + temporal_expansion: int = 2, + spatial_expansion: int = 2, + **block_kwargs, + ): + super().__init__() + + blocks = [] + for _ in range(num_res_blocks): + blocks.append(block_fn(in_channels, **block_kwargs)) + self.blocks = nn.Sequential(*blocks) + + self.temporal_expansion = temporal_expansion + self.spatial_expansion = spatial_expansion + + # Change channels in the final convolution layer. + self.proj = Conv1x1( + in_channels, + out_channels * temporal_expansion * (spatial_expansion**2), + ) + + self.d2st = DepthToSpaceTime( + temporal_expansion=temporal_expansion, spatial_expansion=spatial_expansion + ) + + def forward(self, x): + x = self.blocks(x) + x = self.proj(x) + x = self.d2st(x) + return x + + +def block_fn(channels, *, has_attention: bool = False, **block_kwargs): + attn_block = AttentionBlock(channels) if has_attention else None + + return ResBlock( + channels, affine=True, attn_block=attn_block, **block_kwargs + ) + + +class DownsampleBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + num_res_blocks, + *, + temporal_reduction=2, + spatial_reduction=2, + **block_kwargs, + ): + """ + Downsample block for the VAE encoder. + + Args: + in_channels: Number of input channels. + out_channels: Number of output channels. + num_res_blocks: Number of residual blocks. + temporal_reduction: Temporal reduction factor. + spatial_reduction: Spatial reduction factor. + """ + super().__init__() + layers = [] + + # Change the channel count in the strided convolution. + # This lets the ResBlock have uniform channel count, + # as in ConvNeXt. + assert in_channels != out_channels + layers.append( + ContextParallelConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(temporal_reduction, spatial_reduction, spatial_reduction), + stride=(temporal_reduction, spatial_reduction, spatial_reduction), + padding_mode="replicate", + bias=True, + ) + ) + + for _ in range(num_res_blocks): + layers.append(block_fn(out_channels, **block_kwargs)) + + self.layers = nn.Sequential(*layers) + + def forward(self, x): + return self.layers(x) + + +def add_fourier_features(inputs: torch.Tensor, start=6, stop=8, step=1): + num_freqs = (stop - start) // step + assert inputs.ndim == 5 + C = inputs.size(1) + + # Create Base 2 Fourier features. + freqs = torch.arange(start, stop, step, dtype=inputs.dtype, device=inputs.device) + assert num_freqs == len(freqs) + w = torch.pow(2.0, freqs) * (2 * torch.pi) # [num_freqs] + C = inputs.shape[1] + w = w.repeat(C)[None, :, None, None, None] # [1, C * num_freqs, 1, 1, 1] + + # Interleaved repeat of input channels to match w. + h = inputs.repeat_interleave(num_freqs, dim=1) # [B, C * num_freqs, T, H, W] + # Scale channels by frequency. + h = w * h + + return torch.cat( + [ + inputs, + torch.sin(h), + torch.cos(h), + ], + dim=1, + ) + + +class FourierFeatures(nn.Module): + def __init__(self, start: int = 6, stop: int = 8, step: int = 1): + super().__init__() + self.start = start + self.stop = stop + self.step = step + + def forward(self, inputs): + """Add Fourier features to inputs. + + Args: + inputs: Input tensor. Shape: [B, C, T, H, W] + + Returns: + h: Output tensor. Shape: [B, (1 + 2 * num_freqs) * C, T, H, W] + """ + return add_fourier_features(inputs, self.start, self.stop, self.step) + + +class Decoder(nn.Module): + def __init__( + self, + *, + out_channels: int = 3, + latent_dim: int, + base_channels: int, + channel_multipliers: List[int], + num_res_blocks: List[int], + temporal_expansions: Optional[List[int]] = None, + spatial_expansions: Optional[List[int]] = None, + has_attention: List[bool], + output_norm: bool = True, + nonlinearity: str = "silu", + output_nonlinearity: str = "silu", + causal: bool = True, + **block_kwargs, + ): + super().__init__() + self.input_channels = latent_dim + self.base_channels = base_channels + self.channel_multipliers = channel_multipliers + self.num_res_blocks = num_res_blocks + self.output_nonlinearity = output_nonlinearity + assert nonlinearity == "silu" + assert causal + + ch = [mult * base_channels for mult in channel_multipliers] + self.num_up_blocks = len(ch) - 1 + assert len(num_res_blocks) == self.num_up_blocks + 2 + + blocks = [] + + first_block = [ + nn.Conv3d(latent_dim, ch[-1], kernel_size=(1, 1, 1)) + ] # Input layer. + # First set of blocks preserve channel count. + for _ in range(num_res_blocks[-1]): + first_block.append( + block_fn( + ch[-1], + has_attention=has_attention[-1], + causal=causal, + **block_kwargs, + ) + ) + blocks.append(nn.Sequential(*first_block)) + + assert len(temporal_expansions) == len(spatial_expansions) == self.num_up_blocks + assert len(num_res_blocks) == len(has_attention) == self.num_up_blocks + 2 + + upsample_block_fn = CausalUpsampleBlock + + for i in range(self.num_up_blocks): + block = upsample_block_fn( + ch[-i - 1], + ch[-i - 2], + num_res_blocks=num_res_blocks[-i - 2], + has_attention=has_attention[-i - 2], + temporal_expansion=temporal_expansions[-i - 1], + spatial_expansion=spatial_expansions[-i - 1], + causal=causal, + **block_kwargs, + ) + blocks.append(block) + + assert not output_norm + + # Last block. Preserve channel count. + last_block = [] + for _ in range(num_res_blocks[0]): + last_block.append( + block_fn( + ch[0], has_attention=has_attention[0], causal=causal, **block_kwargs + ) + ) + blocks.append(nn.Sequential(*last_block)) + + self.blocks = nn.ModuleList(blocks) + self.output_proj = Conv1x1(ch[0], out_channels) + + def forward(self, x): + """Forward pass. + + Args: + x: Latent tensor. Shape: [B, input_channels, t, h, w]. Scaled [-1, 1]. + + Returns: + x: Reconstructed video tensor. Shape: [B, C, T, H, W]. Scaled to [-1, 1]. + T + 1 = (t - 1) * 4. + H = h * 16, W = w * 16. + """ + for block in self.blocks: + x = block(x) + + if self.output_nonlinearity == "silu": + x = F.silu(x, inplace=not self.training) + else: + assert ( + not self.output_nonlinearity + ) # StyleGAN3 omits the to-RGB nonlinearity. + + return self.output_proj(x).contiguous() + + +def make_broadcastable( + tensor: torch.Tensor, + axis: int, + ndim: int, +) -> torch.Tensor: + """ + Reshapes the input tensor to have singleton dimensions in all axes except the specified axis. + + Args: + tensor (torch.Tensor): The tensor to reshape. Typically 1D. + axis (int): The axis along which the tensor should retain its original size. + ndim (int): The total number of dimensions the reshaped tensor should have. + + Returns: + torch.Tensor: The reshaped tensor with shape suitable for broadcasting. + """ + if tensor.dim() != 1: + raise ValueError(f"Expected tensor to be 1D, but got {tensor.dim()}D tensor.") + + axis = (axis + ndim) % ndim # Ensure the axis is within the tensor dimensions + shape = [1] * ndim # Start with all dimensions as 1 + shape[axis] = tensor.size(0) # Set the specified axis to the size of the tensor + return tensor.view(*shape) + + +def blend(a: torch.Tensor, b: torch.Tensor, axis: int) -> torch.Tensor: + """ + Blends two tensors `a` and `b` along the specified axis using linear interpolation. + + Args: + a (torch.Tensor): The first tensor. + b (torch.Tensor): The second tensor. Must have the same shape as `a`. + axis (int): The axis along which to perform the blending. + + Returns: + torch.Tensor: The blended tensor. + """ + assert ( + a.shape == b.shape + ), f"Tensors must have the same shape, got {a.shape} and {b.shape}" + steps = a.size(axis) + + # Create a weight tensor that linearly interpolates from 0 to 1 + start = 1 / (steps + 1) + end = steps / (steps + 1) + weight = torch.linspace(start, end, steps=steps, device=a.device, dtype=a.dtype) + + # Make the weight tensor broadcastable across all dimensions + weight = make_broadcastable(weight, axis, a.dim()) + + # Perform the blending + return a * (1 - weight) + b * weight + + +def blend_horizontal(a: torch.Tensor, b: torch.Tensor, overlap: int) -> torch.Tensor: + if overlap == 0: + return torch.cat([a, b], dim=-1) + + assert a.size(-1) >= overlap + assert b.size(-1) >= overlap + a_left, a_overlap = a[..., :-overlap], a[..., -overlap:] + b_overlap, b_right = b[..., :overlap], b[..., overlap:] + return torch.cat([a_left, blend(a_overlap, b_overlap, -1), b_right], dim=-1) + + +def blend_vertical(a: torch.Tensor, b: torch.Tensor, overlap: int) -> torch.Tensor: + if overlap == 0: + return torch.cat([a, b], dim=-2) + + assert a.size(-2) >= overlap + assert b.size(-2) >= overlap + a_top, a_overlap = a[..., :-overlap, :], a[..., -overlap:, :] + b_overlap, b_bottom = b[..., :overlap, :], b[..., overlap:, :] + return torch.cat([a_top, blend(a_overlap, b_overlap, -2), b_bottom], dim=-2) + + +def nearest_multiple(x: int, multiple: int) -> int: + return round(x / multiple) * multiple + + +def apply_tiled( + fn: Callable[[torch.Tensor], torch.Tensor], + x: torch.Tensor, + num_tiles_w: int, + num_tiles_h: int, + overlap: int = 0, # Number of pixel of overlap between adjacent tiles. + # Use a factor of 2 times the latent downsample factor. + min_block_size: int = 1, # Minimum number of pixels in each dimension when subdividing. +): + if num_tiles_w == 1 and num_tiles_h == 1: + return fn(x) + + assert ( + num_tiles_w & (num_tiles_w - 1) == 0 + ), f"num_tiles_w={num_tiles_w} must be a power of 2" + assert ( + num_tiles_h & (num_tiles_h - 1) == 0 + ), f"num_tiles_h={num_tiles_h} must be a power of 2" + + H, W = x.shape[-2:] + assert H % min_block_size == 0 + assert W % min_block_size == 0 + ov = overlap // 2 + assert ov % min_block_size == 0 + + if num_tiles_w >= 2: + # Subdivide horizontally. + half_W = nearest_multiple(W // 2, min_block_size) + left = x[..., :, : half_W + ov] + right = x[..., :, half_W - ov :] + + assert num_tiles_w % 2 == 0, f"num_tiles_w={num_tiles_w} must be even" + left = apply_tiled( + fn, left, num_tiles_w // 2, num_tiles_h, overlap, min_block_size + ) + right = apply_tiled( + fn, right, num_tiles_w // 2, num_tiles_h, overlap, min_block_size + ) + if left is None or right is None: + return None + + # If `fn` changed the resolution, adjust the overlap. + resample_factor = left.size(-1) / (half_W + ov) + out_overlap = int(overlap * resample_factor) + + return blend_horizontal(left, right, out_overlap) + + if num_tiles_h >= 2: + # Subdivide vertically. + half_H = nearest_multiple(H // 2, min_block_size) + top = x[..., : half_H + ov, :] + bottom = x[..., half_H - ov :, :] + + assert num_tiles_h % 2 == 0, f"num_tiles_h={num_tiles_h} must be even" + top = apply_tiled( + fn, top, num_tiles_w, num_tiles_h // 2, overlap, min_block_size + ) + bottom = apply_tiled( + fn, bottom, num_tiles_w, num_tiles_h // 2, overlap, min_block_size + ) + if top is None or bottom is None: + return None + + # If `fn` changed the resolution, adjust the overlap. + resample_factor = top.size(-2) / (half_H + ov) + out_overlap = int(overlap * resample_factor) + + return blend_vertical(top, bottom, out_overlap) + + raise ValueError(f"Invalid num_tiles_w={num_tiles_w} and num_tiles_h={num_tiles_h}") diff --git a/nodes.py b/nodes.py new file mode 100644 index 0000000..e6c5459 --- /dev/null +++ b/nodes.py @@ -0,0 +1,356 @@ +import os +import torch +import torch.nn as nn +import folder_paths +import comfy.model_management as mm +from comfy.utils import ProgressBar, load_torch_file +from einops import rearrange + +from contextlib import nullcontext + +from PIL import Image +import numpy as np +import json + +import logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +log = logging.getLogger(__name__) + +from .mochi_preview.t2v_synth_mochi import T2VSynthMochiModel +from .mochi_preview.vae.model import Decoder + +script_directory = os.path.dirname(os.path.abspath(__file__)) + +def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None): + if linear_steps is None: + linear_steps = num_steps // 2 + linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)] + threshold_noise_step_diff = linear_steps - threshold_noise * num_steps + quadratic_steps = num_steps - linear_steps + quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps ** 2) + linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps ** 2) + const = quadratic_coef * (linear_steps ** 2) + quadratic_sigma_schedule = [ + quadratic_coef * (i ** 2) + linear_coef * i + const + for i in range(linear_steps, num_steps) + ] + sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0] + sigma_schedule = [1.0 - x for x in sigma_schedule] + return sigma_schedule + +class DownloadAndLoadMochiModel: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ( + [ + "mochi_preview_dit_fp8_e4m3fn.safetensors", + ], + ), + "vae": ( + [ + "mochi_preview_vae_bf16.safetensors", + ], + ), + "precision": (["fp8_e4m3fn","fp16", "fp32", "bf16"], + {"default": "bf16", "tooltip": "official recommendation is that 2b model should be fp16, 5b model should be bf16"} + ), + }, + } + + RETURN_TYPES = ("MOCHIMODEL", "MOCHIVAE",) + RETURN_NAMES = ("mochi_model", "mochi_vae" ) + FUNCTION = "loadmodel" + CATEGORY = "MochiWrapper" + DESCRIPTION = "Downloads and loads the selected Mochi model from Huggingface" + + def loadmodel(self, model, vae, precision): + + device = mm.get_torch_device() + offload_device = mm.unet_offload_device() + mm.soft_empty_cache() + + dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] + + # Transformer model + model_download_path = os.path.join(folder_paths.models_dir, 'diffusion_models', 'mochi') + model_path = os.path.join(model_download_path, model) + + repo_id = "kijai/Mochi_preview_comfy" + + if not os.path.exists(model_path): + log.info(f"Downloading mochi model to: {model_path}") + from huggingface_hub import snapshot_download + snapshot_download( + repo_id=repo_id, + allow_patterns=[f"*{model}*"], + local_dir=model_download_path, + local_dir_use_symlinks=False, + ) + # VAE + vae_download_path = os.path.join(folder_paths.models_dir, 'vae', 'mochi') + vae_path = os.path.join(vae_download_path, vae) + + if not os.path.exists(vae_path): + log.info(f"Downloading mochi VAE to: {vae_path}") + from huggingface_hub import snapshot_download + snapshot_download( + repo_id=repo_id, + allow_patterns=[f"*{vae}*"], + local_dir=model_download_path, + local_dir_use_symlinks=False, + ) + + model = T2VSynthMochiModel( + device_id=0, + vae_stats_path=os.path.join(script_directory, "configs", "vae_stats.json"), + dit_checkpoint_path=model_path, + ) + vae = Decoder( + out_channels=3, + base_channels=128, + channel_multipliers=[1, 2, 4, 6], + temporal_expansions=[1, 2, 3], + spatial_expansions=[2, 2, 2], + num_res_blocks=[3, 3, 4, 6, 3], + latent_dim=12, + has_attention=[False, False, False, False, False], + padding_mode="replicate", + output_norm=False, + nonlinearity="silu", + output_nonlinearity="silu", + causal=True, + ) + decoder_sd = load_torch_file(vae_path) + vae.load_state_dict(decoder_sd, strict=True) + vae.eval().to(torch.bfloat16).to("cpu") + del decoder_sd + + return (model, vae,) + +class MochiTextEncode: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "clip": ("CLIP",), + "prompt": ("STRING", {"default": "", "multiline": True} ), + }, + "optional": { + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "force_offload": ("BOOLEAN", {"default": True}), + } + } + + RETURN_TYPES = ("CONDITIONING",) + RETURN_NAMES = ("conditioning",) + FUNCTION = "process" + CATEGORY = "MochiWrapper" + + def process(self, clip, prompt, strength=1.0, force_offload=True): + max_tokens = 256 + load_device = mm.text_encoder_device() + offload_device = mm.text_encoder_offload_device() + #print(clip.tokenizer.t5xxl) + clip.tokenizer.t5xxl.pad_to_max_length = True + clip.tokenizer.t5xxl.max_length = max_tokens + clip.cond_stage_model.t5xxl.return_attention_masks = True + clip.cond_stage_model.t5xxl.enable_attention_masks = True + clip.cond_stage_model.t5_attention_mask = True + clip.cond_stage_model.to(load_device) + tokens = clip.tokenizer.t5xxl.tokenize_with_weights(prompt, return_word_ids=True) + + embeds, _, attention_mask = clip.cond_stage_model.t5xxl.encode_token_weights(tokens) + + + if embeds.shape[1] > 256: + raise ValueError(f"Prompt is too long, max tokens supported is {max_tokens} or less, got {embeds.shape[1]}") + embeds *= strength + if force_offload: + clip.cond_stage_model.to(offload_device) + + t5_embeds = { + "embeds": embeds, + "attention_mask": attention_mask["attention_mask"].bool(), + } + return (t5_embeds, ) + + +class MochiSampler: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MOCHIMODEL",), + "positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "width": ("INT", {"default": 848, "min": 128, "max": 2048, "step": 8}), + "height": ("INT", {"default": 480, "min": 128, "max": 2048, "step": 8}), + "num_frames": ("INT", {"default": 49, "min": 7, "max": 1024, "step": 6}), + "steps": ("INT", {"default": 50, "min": 2}), + "cfg": ("FLOAT", {"default": 4.5, "min": 0.0, "max": 30.0, "step": 0.01}), + "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), + }, + } + + RETURN_TYPES = ("LATENT",) + RETURN_NAMES = ("model", "samples",) + FUNCTION = "process" + CATEGORY = "MochiWrapper" + + def process(self, model, positive, negative, steps, cfg, seed, height, width, num_frames): + mm.soft_empty_cache() + + device = mm.get_torch_device() + offload_device = mm.unet_offload_device() + + args = { + "height": height, + "width": width, + "num_frames": num_frames, + "mochi_args": { + "sigma_schedule": linear_quadratic_schedule(steps, 0.025), + "cfg_schedule": [cfg] * steps, + "num_inference_steps": steps, + "batch_cfg": False, + }, + "positive_embeds": positive, + "negative_embeds": negative, + "seed": seed, + } + latents = model.run(args, stream_results=False) + + mm.soft_empty_cache() + + return ({"samples": latents},) + +class MochiDecode: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "vae": ("MOCHIVAE",), + "samples": ("LATENT", ), + "enable_vae_tiling": ("BOOLEAN", {"default": False, "tooltip": "Drastically reduces memory use but may introduce seams"}), + }, + "optional": { + "tile_sample_min_height": ("INT", {"default": 240, "min": 16, "max": 2048, "step": 8, "tooltip": "Minimum tile height, default is half the height"}), + "tile_sample_min_width": ("INT", {"default": 424, "min": 16, "max": 2048, "step": 8, "tooltip": "Minimum tile width, default is half the width"}), + "tile_overlap_factor_height": ("FLOAT", {"default": 0.1666, "min": 0.0, "max": 1.0, "step": 0.001}), + "tile_overlap_factor_width": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}), + "auto_tile_size": ("BOOLEAN", {"default": True, "tooltip": "Auto size based on height and width, default is half the size"}), + } + } + + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ("images",) + FUNCTION = "decode" + CATEGORY = "MochiWrapper" + + def decode(self, vae, samples, enable_vae_tiling, tile_sample_min_height, tile_sample_min_width, tile_overlap_factor_height, tile_overlap_factor_width, auto_tile_size=True): + device = mm.get_torch_device() + offload_device = mm.unet_offload_device() + samples = samples["samples"] + + def blend_v(a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[4], b.shape[4], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + self.tile_overlap_factor_height = tile_overlap_factor_height if not auto_tile_size else 1 / 6 + self.tile_overlap_factor_width = tile_overlap_factor_width if not auto_tile_size else 1 / 5 + #7, 13, 19, 25, 31, 37, 43, 49, 55, 61, 67, 73, 79, 85, 91, 97, 103, 109, 115, 121, 127, 133, 139, 145, 151, 157, 163, 169, 175, 181, 187, 193, 199 + self.num_latent_frames_batch_size = 6 + + self.tile_sample_min_height = tile_sample_min_height if not auto_tile_size else samples.shape[3] // 2 * 8 + self.tile_sample_min_width = tile_sample_min_width if not auto_tile_size else samples.shape[4] // 2 * 8 + + self.tile_latent_min_height = int(self.tile_sample_min_height / 8) + self.tile_latent_min_width = int(self.tile_sample_min_width / 8) + + vae.to(device) + with torch.amp.autocast("cuda", dtype=torch.bfloat16): + if not enable_vae_tiling: + samples = vae(samples) + else: + batch_size, num_channels, num_frames, height, width = samples.shape + overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height)) + overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width)) + blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height) + blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width) + row_limit_height = self.tile_sample_min_height - blend_extent_height + row_limit_width = self.tile_sample_min_width - blend_extent_width + frame_batch_size = self.num_latent_frames_batch_size + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, overlap_height): + row = [] + for j in range(0, width, overlap_width): + time = [] + for k in range(num_frames // frame_batch_size): + remaining_frames = num_frames % frame_batch_size + start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames) + end_frame = frame_batch_size * (k + 1) + remaining_frames + tile = samples[ + :, + :, + start_frame:end_frame, + i : i + self.tile_latent_min_height, + j : j + self.tile_latent_min_width, + ] + tile = vae(tile) + time.append(tile) + row.append(torch.cat(time, dim=2)) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = blend_v(rows[i - 1][j], tile, blend_extent_height) + if j > 0: + tile = blend_h(row[j - 1], tile, blend_extent_width) + result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + samples = torch.cat(result_rows, dim=3) + vae.to(offload_device) + #print("samples", samples.shape, samples.dtype, samples.device) + + samples = samples.float() + samples = (samples + 1.0) / 2.0 + samples.clamp_(0.0, 1.0) + + frames = rearrange(samples, "b c t h w -> (t b) h w c").cpu().float() + #print(frames.shape) + + return (frames,) + + +NODE_CLASS_MAPPINGS = { + "DownloadAndLoadMochiModel": DownloadAndLoadMochiModel, + "MochiSampler": MochiSampler, + "MochiDecode": MochiDecode, + "MochiTextEncode": MochiTextEncode, +} +NODE_DISPLAY_NAME_MAPPINGS = { + "DownloadAndLoadMochiModel": "(Down)load Mochi Model", + "MochiSampler": "Mochi Sampler", + "MochiDecode": "Mochi Decode", + "MochiTextEncode": "Mochi TextEncode", + } diff --git a/readme.md b/readme.md new file mode 100644 index 0000000..549eee6 --- /dev/null +++ b/readme.md @@ -0,0 +1,18 @@ +# ComfyUI wrapper nodes for Mochi video gen: https://github.com/genmoai/models + + +# WORK IN PROGRESS + +## Requires flash_attn ! + +Depending on frame count can fit under 20GB, VAE decoding is heavy and there is experimental tiled decoder (taken from CogVideoX -diffusers code) which allows higher frame counts, so far highest I've done is 97 with the default tile size 2x2 grid. + +Models: + +https://huggingface.co/Kijai/Mochi_preview_comfy/tree/main + +model to: `ComfyUI/models/diffusion_models/mochi` + +vae to: `ComfyUI/models/vae/mochi` + +There is autodownload node (also will be normal loader node) \ No newline at end of file