From e66735527c489391f1426f5b78e64a38ccd77a95 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Fri, 25 Oct 2024 20:04:01 +0300 Subject: [PATCH] tweak --- mochi_preview/dit/joint_model/asymm_models_joint.py | 2 +- mochi_preview/dit/joint_model/mod_rmsnorm.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/mochi_preview/dit/joint_model/asymm_models_joint.py b/mochi_preview/dit/joint_model/asymm_models_joint.py index 1e4c207..68b49b0 100644 --- a/mochi_preview/dit/joint_model/asymm_models_joint.py +++ b/mochi_preview/dit/joint_model/asymm_models_joint.py @@ -123,7 +123,7 @@ class AsymmetricAttention(nn.Module): qkv_y = qkv_y.view(qkv_y.size(0), qkv_y.size(1), 3, local_heads, self.head_dim) q_y, k_y, v_y = qkv_y.unbind(2) return q_y, k_y, v_y - + def prepare_qkv( self, x: torch.Tensor, # (B, N, dim_x) diff --git a/mochi_preview/dit/joint_model/mod_rmsnorm.py b/mochi_preview/dit/joint_model/mod_rmsnorm.py index b3e317c..d6ba726 100644 --- a/mochi_preview/dit/joint_model/mod_rmsnorm.py +++ b/mochi_preview/dit/joint_model/mod_rmsnorm.py @@ -18,6 +18,5 @@ class ModulatedRMSNorm(torch.autograd.Function): return x_modulated.type_as(x) -@torch.compiler.disable() def modulated_rmsnorm(x, scale, eps=1e-6): return ModulatedRMSNorm.apply(x, scale, eps)