From 70bd02445619cc21daf7ddc02ec66cc7478e755a Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 5 Nov 2024 19:40:35 +0200 Subject: [PATCH] Update asymm_models_joint.py --- .../dit/joint_model/asymm_models_joint.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/mochi_preview/dit/joint_model/asymm_models_joint.py b/mochi_preview/dit/joint_model/asymm_models_joint.py index 65a4017..fde6041 100644 --- a/mochi_preview/dit/joint_model/asymm_models_joint.py +++ b/mochi_preview/dit/joint_model/asymm_models_joint.py @@ -169,16 +169,14 @@ class AsymmetricAttention(nn.Module): pass else: from .layers import RMSNorm - if rms_norm_func == "apex": - self.q_norm_x = RMSNorm(self.head_dim) - self.k_norm_x = RMSNorm(self.head_dim) - self.q_norm_y = RMSNorm(self.head_dim) - self.k_norm_y = RMSNorm(self.head_dim) - else: - self.q_norm_x = RMSNorm(self.head_dim, device=device) - self.k_norm_x = RMSNorm(self.head_dim, device=device) - self.q_norm_y = RMSNorm(self.head_dim, device=device) - self.k_norm_y = RMSNorm(self.head_dim, device=device) + norm_kwargs = {} + if rms_norm_func != "apex": + norm_kwargs['device'] = device + + self.q_norm_x = RMSNorm(self.head_dim, **norm_kwargs) + self.k_norm_x = RMSNorm(self.head_dim, **norm_kwargs) + self.q_norm_y = RMSNorm(self.head_dim, **norm_kwargs) + self.k_norm_y = RMSNorm(self.head_dim, **norm_kwargs) # Output layers. y features go back down from dim_x -> dim_y. self.proj_x = nn.Linear(dim_x, dim_x, bias=out_bias, device=device)