From 34e029bacc1a045d9db0217a86c746f3c916359a Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Wed, 23 Oct 2024 15:49:37 +0300 Subject: [PATCH] fix dtype selection --- .../t2v_synth_mochi.cpython-312.pyc | Bin 15531 -> 15604 bytes mochi_preview/t2v_synth_mochi.py | 3 ++- nodes.py | 3 ++- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/mochi_preview/__pycache__/t2v_synth_mochi.cpython-312.pyc b/mochi_preview/__pycache__/t2v_synth_mochi.cpython-312.pyc index b7e593da6e72aadaa72c2bf773deb7a8095e2f15..1be05b30ed4b7e25cc183eddc6e4491325fa05cc 100644 GIT binary patch delta 1039 zcmZ8fO>7fK6rQ&~-nE@|?AT$QpOU1BX_69B;vfSFP-sewL=*v44;EmB%q~vsI@Nkn zo8YvGYKat8mC6%}T3V^42tk98z$Xq|xBx;3MMXmF1qq1*7jS_>X)BeP!1nJ+^UeF_ zz3-dZomrTCGO1nDG#_EHP>Jq8d`>G=H+VT^@%R*jLu8&9Bn6bD1VhBlN@OR6Z-EMi zOhv;(B{pP391ko>{sX2uM@*oeQXr*L#gI*bswFAqrQQ-r`KXVfUm#2y#b1Q*p#MzO z0=zV!T8S*Y@=9rQJpLynN$0wIU}(`ts$@Q*OrKOTa>WqqqPW;&9%C_{{If};a}l*w5VGz%Hwftsj0`Yhz?UY8>xfc z_I;dr z)9KONNTHBU7c|D3|B1I}tYute;3v_T_ZYj&=QHF<5YI%P$*1~%er+wDtvjv{Cd5nC z4M|D7oRE{CT#@8tKyku2yF6T;sy{EmAa*rTF!4y!AMiu@%jWL|7{Kn#tp+TX zZ?;YWs}IM9;CA`D*gchz&sIHGx3bMUM9VCfy6=HM%a6j}kY%F1`XuyYLyyCpp7x|z z|BiSTvUo9m35Ic|_XON2$JUf2sK%exwe?!Knsh?{vw9Lw8|Sgwkv~?&<6HJmA(3@krl#_!F=9-KZXrKz~FyTKuYC05Z|` A3jhEB delta 987 zcmZ8fU1$?o6h3EuGx;%TGB$}(Tj_SStF_u#tF1;$RU*=L!DWTGY^c2xYfM|2jABcQ z6#`cAK{$#pg7pVsH${*<3CpthAY~U;7a!csg7~6uN_KVe@5Ou7QSl7-JKvde&Ufye zxzENc1D|1%)`+mfo0@P`d5l{8?)XDzEuV0D?{ZSg^X_+j$p>W3)=j( z@VA}c?HV6-h6l!mhW8JQF{?m(LaSB`y>A4Wcg^QZ_@^la;4B{Zo&)E-7{Cc1{TymR zg&u^u;|fD}WEUVXiOB_ZM4MD;Rl{~vX`-RMRbwEYiu6tbL4uOP!k_jX)##XR(u0N$ z)F>1_v~dF9dtQIiJE;!?i(JM@kJIFCVcN?-idpuQJX<~-(p4*Vlyp)eJA-;Q6AiU(z>bV)v(+k^)YfTbKXJWgorM?(&l4&T##t;plEQzmk^d5j5zRNI#93hX3=F zY_LZ9HPRP|yDcRp-$hf21$4%Y;xx^gam-h5ncE~hLxE@~CTMdsjeT??`op|Ht}r&} ztjk(!Mv}QvdKBH7@Yj^0Rmj;JxSqXHK#D^R9lMWq#M+k*ipa=jhYEwmVm4d!^SJqg zZJx3lMGvEySgUc7|NWO!@N6Fa5qs{N?Yg%-_2=@`-9?+|xA-M9t;#opiF8<=Q`K}( zp9`z$h1IDAa|&*w?M?i`(zT}7I8*(pc}Bv1+R>82G@Wd*@J_YfGJ!mQpmhNMs@`aQ z=;6VIl`lP8`Rq2?wyHB757CzwI}w-8w&?WgtJq83oi(l$PG)%hNb&&Y>F?xC{F=_M zxq|nreO+Y*19WTs7tJ+!|EN=wi;itGA#h2+j{>d;SVoI`(ku3I`l1YVX{{QoYhXs} zx?moNFRUAad13LA|9yfTq)N{#+(%b?JG93_;R*fI+k>yEal`Fkzk<>KK%gDtxi>zY B;{E^t diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index 7ceb8d9..15fb287 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -107,6 +107,7 @@ class T2VSynthMochiModel: device_id: int, vae_stats_path: str, dit_checkpoint_path: str, + weight_dtype: torch.dtype = torch.float8_e4m3fn, ): super().__init__() t = Timer() @@ -146,7 +147,7 @@ class T2VSynthMochiModel: for name, param in self.dit.named_parameters(): params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"} if not any(keyword in name for keyword in params_to_keep): - param.data = param.data.to(torch.float8_e4m3fn) + param.data = param.data.to(weight_dtype) else: param.data = param.data.to(torch.bfloat16) diff --git a/nodes.py b/nodes.py index ac77a53..04ae75f 100644 --- a/nodes.py +++ b/nodes.py @@ -47,7 +47,7 @@ class DownloadAndLoadMochiModel: ], ), "precision": (["fp8_e4m3fn","fp16", "fp32", "bf16"], - {"default": "bf16", "tooltip": "official recommendation is that 2b model should be fp16, 5b model should be bf16"} + {"default": "fp8_e4m3fn", } ), }, } @@ -99,6 +99,7 @@ class DownloadAndLoadMochiModel: device_id=0, vae_stats_path=os.path.join(script_directory, "configs", "vae_stats.json"), dit_checkpoint_path=model_path, + weight_dtype=dtype, ) vae = Decoder( out_channels=3,