This commit is contained in:
kijai 2024-10-25 20:04:01 +03:00
parent 2e22529c99
commit e66735527c
2 changed files with 1 additions and 2 deletions

View File

@ -123,7 +123,7 @@ class AsymmetricAttention(nn.Module):
qkv_y = qkv_y.view(qkv_y.size(0), qkv_y.size(1), 3, local_heads, self.head_dim)
q_y, k_y, v_y = qkv_y.unbind(2)
return q_y, k_y, v_y
def prepare_qkv(
self,
x: torch.Tensor, # (B, N, dim_x)

View File

@ -18,6 +18,5 @@ class ModulatedRMSNorm(torch.autograd.Function):
return x_modulated.type_as(x)
@torch.compiler.disable()
def modulated_rmsnorm(x, scale, eps=1e-6):
return ModulatedRMSNorm.apply(x, scale, eps)