diff --git a/mochi_preview/dit/joint_model/asymm_models_joint.py b/mochi_preview/dit/joint_model/asymm_models_joint.py index 7c319bc..1de23ef 100644 --- a/mochi_preview/dit/joint_model/asymm_models_joint.py +++ b/mochi_preview/dit/joint_model/asymm_models_joint.py @@ -127,7 +127,6 @@ class AsymmetricAttention(nn.Module): ): # Pre-norm for visual features x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size - #print("x in attn", x.dtype, x.device) # Process visual features qkv_x = self.qkv_x(x) # (B, M, 3 * dim_x) @@ -139,13 +138,6 @@ class AsymmetricAttention(nn.Module): # Process text features y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y) q_y, k_y, v_y = self.run_qkv_y(y) # (B, L, local_heads, head_dim) - #print("y in attn", y.dtype, y.device) - #print(q_y.dtype, q_y.device) - #print(self.q_norm_y.weight.dtype, self.q_norm_y.weight.device) - # self.q_norm_y.weight = self.q_norm_y.weight.to(q_y.dtype) - # self.q_norm_y.bias = self.q_norm_y.bias.to(q_y.dtype) - # self.k_norm_y.weight = self.k_norm_y.weight.to(k_y.dtype) - # self.k_norm_y.bias = self.k_norm_y.bias.to(k_y.dtype) q_y = self.q_norm_y(q_y) k_y = self.k_norm_y(k_y) diff --git a/nodes.py b/nodes.py index f18221e..3558955 100644 --- a/nodes.py +++ b/nodes.py @@ -48,6 +48,8 @@ class DownloadAndLoadMochiModel: "model": ( [ "mochi_preview_dit_fp8_e4m3fn.safetensors", + "mochi_preview_dit_bf16.safetensors", + ], {"tooltip": "Downloads from 'https://huggingface.co/Kijai/Mochi_preview_comfy' to 'models/diffusion_models/mochi'", }, ),