diff --git a/mochi_preview/dit/joint_model/asymm_models_joint.py b/mochi_preview/dit/joint_model/asymm_models_joint.py index 5eca377..65a4017 100644 --- a/mochi_preview/dit/joint_model/asymm_models_joint.py +++ b/mochi_preview/dit/joint_model/asymm_models_joint.py @@ -169,10 +169,16 @@ class AsymmetricAttention(nn.Module): pass else: from .layers import RMSNorm - self.q_norm_x = RMSNorm(self.head_dim, device=device) - self.k_norm_x = RMSNorm(self.head_dim, device=device) - self.q_norm_y = RMSNorm(self.head_dim, device=device) - self.k_norm_y = RMSNorm(self.head_dim, device=device) + if rms_norm_func == "apex": + self.q_norm_x = RMSNorm(self.head_dim) + self.k_norm_x = RMSNorm(self.head_dim) + self.q_norm_y = RMSNorm(self.head_dim) + self.k_norm_y = RMSNorm(self.head_dim) + else: + self.q_norm_x = RMSNorm(self.head_dim, device=device) + self.k_norm_x = RMSNorm(self.head_dim, device=device) + self.q_norm_y = RMSNorm(self.head_dim, device=device) + self.k_norm_y = RMSNorm(self.head_dim, device=device) # Output layers. y features go back down from dim_x -> dim_y. self.proj_x = nn.Linear(dim_x, dim_x, bias=out_bias, device=device)