From 4efb7c85df9141f1b620f8c53fa8b8f52b2c2b65 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Wed, 23 Oct 2024 15:45:20 +0300 Subject: [PATCH] cleanup --- .gitignore | 2 + .../t2v_synth_mochi.cpython-312.pyc | Bin 18397 -> 15531 bytes .../asymm_models_joint.cpython-312.pyc | Bin 24501 -> 24249 bytes mochi_preview/t2v_synth_mochi.py | 101 ++---------------- nodes.py | 19 ++-- 5 files changed, 17 insertions(+), 105 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a295864 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +*.pyc +__pycache__ diff --git a/mochi_preview/__pycache__/t2v_synth_mochi.cpython-312.pyc b/mochi_preview/__pycache__/t2v_synth_mochi.cpython-312.pyc index 4b28b7e858f447c64926bb0f6758ff4d8d667955..b7e593da6e72aadaa72c2bf773deb7a8095e2f15 100644 GIT binary patch delta 2880 zcmZ`*du$ZP8K1e^y{GTa#&_qpeKxpZA-obB45r5;I5;XoqCz5^N>|J6-g)uv-gR~l zFfANusbr873=2V`^Y$G22zaPEnL+FVQ zp(h`u*O!%*B6{kaROrx~jZVGAkRBCVmg<11o-~3+*Z=320w_@m=_^22x87<*j%{1s z{a~i2joy-^XU^@G3Vi{xm+0AZL~na7&=4r(rl`_>4;TBDfROf>MNdfkOxIJWh*}j* zS4*a4z#`x{{L{g$!Fzm84lB1ipkX-YWAp%}S39}uXKT%ID$Ld_5LPYCGaS!lZNN|STk*?*Kau&N)CQS`xt`t#cqM&gmyo>*D^3X6 zbi5~@5Mig{oHAHNgiEj=6tFaG7olR!c5Ai>+S*AWnVw-QVTL|poqSW~lQhd`Gu`qK z5LVA$&rnL+=jXFZU=tcLN0a~?M%m2(E}E6gu`CQC?oOB`Pj$_*t+A~@Il!Ac*3fo7 z+A$XwQQbIh77f+Z`Mr)&nx7x+EK~lauFq#$zd}z(ygu0AKhv=i)zf^W`zFuz9HAM0 zre~6$?%B=P_U4DDPX-T@6LclnL|l3TQu;8V@SCMG#~K`^ZNEqqD~{{2TG3N=(_>-& zdhb5kG#~5>P@c-2loh(4jqn$8{b1>La`}BP!gpOX;r?&Q{j=mocI>+I>Xz9PFRyyB ze%<+-D*@|s$ridCCZ$WrZmnuC5m8msF+EjfQIKM*`EdWn&@P;Y_+ag{y^YN3I_E=J2hx`M1~R-|5*gD_@D-eU$U< z{2b7?A$>7Uy0?7LOfnrWJ#*<9{=1dcZnquI^@Cl3aXI*-RM+^5;2W|$o(jLQLLTp) zpILQ4+9_g?aH?fEWp6?Rs@_3G5}_)q*i}!2tD3HxPFWU)AG#n@; z##LW3}7P^rN$C>uVq8aG3|R5Yw0*{fBds$=Y=Pi`F6Wp%!we7y4o z+;tJa#SWf(_^>SByJj1GdH%C&{sqOupL_h3v8PZL3kXUc&&RO5M9QulH{fLIQgJ8| z|3eHli-2_y-b3(?vvJ(g!OQC~XV0(82W}Alvvr3+dH?zk4qpJS5*9JjH6iger)aQC zs6m=Erd);=F%O|(CJF}|28CSB2Kz)11skoWQOxYCAa!a0^RDH8&Ul;f%Nq_uPx$?Y z3z3k4dc*uPBUAJ_{+E%-`qz-~b%bvqi~+#cP)tP2hGQ@dS`@=-7pZE+H9ZqNVAw|u z-DT&I8d_RB?~6D_8DhUI^3dn(Rp1xl^&dQY2;SEDyIXC#Q%J>{t|`?24m7u8z*@H6 z_lSsSHLp_Cpixwjdj$a8@3-S&J~KKv_${D7Q}aEfIyKAkedb1h?1%d7*|QMoo12zJY?T>9Zq0%V>ia17 zMF2rh2s**0aQ!;2o0eSZdPdb{-$nZOcxV3EzWq4$9sg65ixn$&X-cgg)tI0r8|Iht zo8mu58H@<~34bfEKo$Hop9}n&%y;ZKBZaW_C<*@W$Y4E)RUd>?`K^IEq zIcYj2gBegJD)wA<4>%&Bszpn4-J)SxstRE{15_9OijkfE^~OeVKSCK9M*;4iBa4B} za{FS4q}vw5IE$o7NB3eBr!mso|K~W)s(k03(NEwR{ryTm`xMUK1P~zz7qmTw3Ctv< zw76QySn7W3{p*~`NSB?OU7cdTKtU{N5%HPp9S5$Oma)O?AQGD!#F}SX?s|x@#uHDk ztYefJp1W~N`RNe%-MrqHY-9;X0l!gT^w}Q}-be6zRoKxr#Ky|4nD0}MxOlxmgqnCExha&gYZXUfy5U`;$7179(inmtXLpTcSz=4(szUOEs(Z5 pB>f)A-y!WlyhB#r&34=%t&0-hx_5-1-g_YRltkP9ND%$+{RhH}_VfS% delta 5321 zcmaJlYiu0nao_IkJ$QF~KkkmNTfQV6Z6Ea@C7F^bQx+vkB56UQV|ihFy}aEc&)VHP ze7k3f46UOPp;cNtR=fxtr?nk71rk+i8-OwLLej)&0Rd?Nu~MYNN`wUzMr@=7NK1xf z|7bh&-I0|U5%ex0lOU1epEK)dyEbo%Ddx2vLs{X8LW5|yZ& zPNdS4lQ`DrB_5uM!nR)6-rY!mQ|XhCQ$Y25hE%^jhGZ;Dn7m0XpveN?$^WXLk09! zqfKfv+NE|QF2xzY0%)St!SJBoX>>_lMz_?>`XRl?*dT2%HcA^=Kdf&$ZX~1xbSl*< zz1Mh1dWhlG`oqQ}(j!c$Mo$_^Dar6!z0cSzZ3etf+M?D=TX}MTs10X`+IRu%h)vq2 zHfiA(XPi~O-WyrRqH9rpkXgh zx&E9&?Tn%)cvsk$N!hM2ma%MLROnenmD3qr16Bpq6iO9yT629=%V`Rec+$|Ol_#=h z8pL_iOnBUyy-Mmt;iN*P({}@TvbF@j44%-mAsvmPf{f(=dyYxHS4y5zG zY~EDG(fsr@>|Wfjq%_ekuO~o7*oZhe;0#wa95QgZvZ)x_Cd;OYL2&|5F`Hf&r+|rg zKFOd6Ed}dcZvtQ*_}$(I!0Fx~7HlN3rg9hgH9wdqGK#S_OA?-m{a_QH4krr66KEYU z-0C#V8gh{HTH=x{3Y=Qx+4?*BVT78N7R}k;s_FO1+obU6o!aO(6 z-&dOV%nPdLqHv)+(aXMh-!YQ#Dqn-WF;ki4V&i8lw!j*5;sj{vpTrOeqEV{+0+cu zVXTugnmq&FhO?HW!OLJl?6|Bi4@b3vou{U3Yo?W@*4jG$pCQ2$Kmqo}6uvx~p`4O+ z&78Jp#7!B~hP9y~fGU=$=)nJHB?Yv@T>bccAtP_u;*=)N18h5GT~<1)J7lZQtpzpXVkLKIiWq?K$k1Wi^|U zW!eQ~nnWI9B5cL91*r&dpB7&HGh~$)yn#={oOhhN<7@JYCGi%r*ugjOlMSeWK@mc} z0KR8H1c0dO@^&`ip##X>j$j7>3oq%qb?Wn`2@zY-L0s0s#ZB9yLm)BcToc~tW7DM8 z!Yw{S@MG8fQ$XG%_IjwBCpPc=o42*f#v@=2u5!Q`_qUcFZI#>w)Py$gvBS=P)wDLL z{5+@fDu2;)p`yfknSc&gyV?1l+He8xx-bEI3PDm-3|X~jbDA5bnwn2(YrUy_It>Yj z@g@@ix7I4>GkNXKy441zJdB%bIqy%BSM3xeO3afeRZ5h#dLd&#h(za^R`3#AZiTA7 zl!3W3s#|BxnkgAlLtfWlDlnbpR_85E&eF`Z%oc1Vg08Ut*uEpQ5;o@s_dhi@CLbFe zKDv*Fk>Itn6f&|u4XMN4ww0o7-SFPwCk{UiGV+na(Ggcje(bs97h=JcAZhMpayyoqcNCl3i*+rT>yAN42nY1linP~$GC1AYcj8yk>-A=Chyq6iy>8{zu%`LuAGtE`=1oBM%g|aUX zIR9LGwz1&{+{*zw4yvr34W>D@(NXHI)mC5;4S#Eyn{&3;KmSI_0?aW@;pOMa*61KTbat5;FkH6Evp%uHaPD$jB&9=zLDoD+yFZ(HRE`i zX8p7erltMPD@`3>h95Ky9ryzfSW3S4F}b%uE;k?idGbdmzq{l6>R<2qdGd31{NJdm zjbCkk{qrvZH*Q0u48-$r}At2Lym+F6R8EvyCrJ3IH|xzxre`-u3ROw@zJ;?_7%SycQq)cIcdNPIEh+KUZ}#(tHKC z2H{$bmDZ5Q*``}QvN?Y+|R#JQ309J=)&=O0g>KYgWR2}3C4SV_y-=f1FdV4c4j=0_u*52{;6yFDKa?tOH0i|10rJKEvD)a@Nj z`Y&zqjt=_Wdgqn)`o+2CU4mFjGEC*pWaz1O{+I<}9rvaSjUEeIr zZTL(>VNRK!)==3w9B+@&hd{^`uvmbznNi?y9z)%?oj2lrop?Lw1WI_}?n0#>ETuqy zDFs^eediDHqs{+>V!r`k;oBu%`BI~u&M|S2yS?}=@xS4&IInfSSNC<4#rpz@Kt|kj z9`A~@VrHc{j?3$l@Ov75qjUkOlaAGeKKGrjA-;=qe&6*p4E0F&oylJUH{o}K@ZJDU zte7c{et;tgC$w3M-E4Xh2Pka_g+1jAZYc)r#`U1!gXxD}aq>MI`G+{?ot|gla^LIu zi-7NK9HYTGo1Eo->hOIh3RjWvdj!`IybJ(FMlrWyS~E4M%(9eX3U{k>Su0~_UeaXE zn9@{>{sO7Ka_84f#DtRiL6LQ9((Axa!H+duCzhS{i~reYaC_(xz!Mys27=VzK=1xW z+2-89I(qLAi9o_>KQ_ULm7See$S6WlZP|Uod%a#Uj z3l#^}t;MY8nT8h$<`nE4y*Pr&{Euy|mHrmT5-{ts-OwWRX%s=kIkzpgy%G80QuQfY zIUPb8GYl3Gu9u}X`u8aQ4+!vt5+dzzKHS!v{3BBS6@bf~bh#P&W9)x~{czbpmaXM1 zicyb#>}>2m9{)2$2M~OQa={e7(#W|fg}PjZW}SEXw}kE?8|_T*IBtIe+*lY8`7PYy zhJn*U0n4yNCEbktEeN&(a6M2RyB;{UQkYyq2HWWs#PCkKRT&eC3=Lc!GrmS6kSg3t zs4R8GvQiqBmvAmW0Z|K+$m69#&PII%j4?%e3AuNMtnj*as1OUaT8mgT9%_=NT7R(C6igbIF9&)m&wL{bnH zwLFz?&DHWrCyhU+G)>Z|I1MA7N@vF?@#D}53$rP;z`M5Fl!C2d0Gx4^dqvxj6WRWj z+OS}?VH>QdHP&)$Un%N@g~FInX!iS7ybk4ALEYAEV`m{oUUT$HIXGxV_T?4sx+?bI zBy9iA=ka#hPpKH)mRLFH9$d}rV!kHY$b^~Bso0f3LE}EFZSfToheEF;9HiXs5Viw6 z(Aq<9dk+<@Ilfp1@_w$}gS;;OY1Rt|I_r6%MH8zNQPL+XiB5PLC1x!+>B6*Ur2f-_Q?j{#L;_vZ0S`77vOu50d?< zVY-$`9!~>(849O$`pOVaoURYE)R@u$me!Pp&e9R35ohT=%3`ekajM}VCO(|bkvRUH sJ~tSXJi4X1DS70inU$u=&6dcNSxcI5-M;kRMdJ0TT$z%ixPIl^Kh8?SDgXcg delta 1026 zcmZ`$Ur19?7(ZwC-fef=+3vPG{|$Avq!o&;vIS+O4>s;xmPFGuXf#u1L2Uy??sXm9A=E+)OjD|T0Ah*U^aFj zjI^RF!6eu^G|V8RaEehe$;JnK6!W?0v7@z%%3GNYj+C?Zajcc+907> z3X*ifLW>(3!@N_mBv|#!8#E3jSj)7M2+5pm zmU)UG>L-8y8CK6B&Wg5GK2P7TF2x41exF~qty**%!<3Ff6#E|Js3xM*nlGzX5(Xvf z)d%^^vSEs9tk)FBf7*}O^ebw6g#jeH!(i9BU1oG+gk1%P%<2Urb%-e|OLaL(J!sJ0 zR~bt-?2}MJ*t`A+C7c?VNK@BiE?@4ZzM*6jHPsFGB%!Q7bx@QorI|H!rdzg@`URC! zJ59yPejfcJ2g*9K!jx)&lg{>+W;aG!@K2bIRIxcc*QX~wnpE@ANoEjD;Da7`5@|Ky z0Vqa>{HakoV2WfqeAh?f8DDxVeru$(y%|rBYg=gU?Y6p76ZB$gk_H&-ZG!pUq{l){ zOj*^S4oOX`fdGl#s2!mG&!8;<@%H=Ia>2#WLTKIW-5{>Q-DoF1v}HuZmGkFr&qX(E z&RpBlftA+f)^%IEwhOIvE_bfmg5QN|SUdNr=L$GHN-UU38ms#829g+e}5p$B&RJ3zajj8D6q-Ikwz)3+ FSDP: - model = FSDP( - model, - sharding_strategy=ShardingStrategy.FULL_SHARD, - mixed_precision=MixedPrecision( - param_dtype=param_dtype, - reduce_dtype=torch.float32, - buffer_dtype=torch.float32, - ), - auto_wrap_policy=auto_wrap_policy, - backward_prefetch=BackwardPrefetch.BACKWARD_PRE, - limit_all_gathers=True, - device_id=device_id, - sync_module_states=True, - use_orig_params=True, - ) - torch.cuda.synchronize() - return model - - def compute_packed_indices( N: int, text_mask: List[torch.Tensor], @@ -189,12 +112,6 @@ class T2VSynthMochiModel: t = Timer() self.device = torch.device(device_id) - #self.t5_tokenizer = T5_Tokenizer() - - # with t("load_text_encs"): - # t5_enc = T5EncoderModel.from_pretrained(T5_MODEL) - # self.t5_enc = t5_enc.eval().to(torch.bfloat16).to("cpu") - with t("construct_dit"): from .dit.joint_model.asymm_models_joint import ( AsymmDiTJoint, @@ -223,15 +140,15 @@ class T2VSynthMochiModel: model.load_state_dict(load_file(dit_checkpoint_path)) - with t("fsdp_dit"): - self.dit = model - self.dit.eval() - for name, param in self.dit.named_parameters(): - params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"} - if not any(keyword in name for keyword in params_to_keep): - param.data = param.data.to(torch.float8_e4m3fn) - else: - param.data = param.data.to(torch.bfloat16) + #with t("fsdp_dit"): + self.dit = model + self.dit.eval() + for name, param in self.dit.named_parameters(): + params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"} + if not any(keyword in name for keyword in params_to_keep): + param.data = param.data.to(torch.float8_e4m3fn) + else: + param.data = param.data.to(torch.bfloat16) vae_stats = json.load(open(vae_stats_path)) diff --git a/nodes.py b/nodes.py index e6c5459..ac77a53 100644 --- a/nodes.py +++ b/nodes.py @@ -1,17 +1,10 @@ import os import torch -import torch.nn as nn import folder_paths import comfy.model_management as mm from comfy.utils import ProgressBar, load_torch_file from einops import rearrange - -from contextlib import nullcontext - -from PIL import Image -import numpy as np -import json - +from tqdm import tqdm import logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') log = logging.getLogger(__name__) @@ -295,11 +288,11 @@ class MochiDecode: # Split z into overlapping tiles and decode them separately. # The tiles have an overlap to avoid seams between tiles. rows = [] - for i in range(0, height, overlap_height): + for i in tqdm(range(0, height, overlap_height), desc="Processing rows"): row = [] - for j in range(0, width, overlap_width): + for j in tqdm(range(0, width, overlap_width), desc="Processing columns", leave=False): time = [] - for k in range(num_frames // frame_batch_size): + for k in tqdm(range(num_frames // frame_batch_size), desc="Processing frames", leave=False): remaining_frames = num_frames % frame_batch_size start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames) end_frame = frame_batch_size * (k + 1) + remaining_frames @@ -316,9 +309,9 @@ class MochiDecode: rows.append(row) result_rows = [] - for i, row in enumerate(rows): + for i, row in enumerate(tqdm(rows, desc="Blending rows")): result_row = [] - for j, tile in enumerate(row): + for j, tile in enumerate(tqdm(row, desc="Blending tiles", leave=False)): # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: