tweak
This commit is contained in:
parent
2e22529c99
commit
e66735527c
@ -123,7 +123,7 @@ class AsymmetricAttention(nn.Module):
|
|||||||
qkv_y = qkv_y.view(qkv_y.size(0), qkv_y.size(1), 3, local_heads, self.head_dim)
|
qkv_y = qkv_y.view(qkv_y.size(0), qkv_y.size(1), 3, local_heads, self.head_dim)
|
||||||
q_y, k_y, v_y = qkv_y.unbind(2)
|
q_y, k_y, v_y = qkv_y.unbind(2)
|
||||||
return q_y, k_y, v_y
|
return q_y, k_y, v_y
|
||||||
|
|
||||||
def prepare_qkv(
|
def prepare_qkv(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor, # (B, N, dim_x)
|
x: torch.Tensor, # (B, N, dim_x)
|
||||||
|
|||||||
@ -18,6 +18,5 @@ class ModulatedRMSNorm(torch.autograd.Function):
|
|||||||
|
|
||||||
return x_modulated.type_as(x)
|
return x_modulated.type_as(x)
|
||||||
|
|
||||||
@torch.compiler.disable()
|
|
||||||
def modulated_rmsnorm(x, scale, eps=1e-6):
|
def modulated_rmsnorm(x, scale, eps=1e-6):
|
||||||
return ModulatedRMSNorm.apply(x, scale, eps)
|
return ModulatedRMSNorm.apply(x, scale, eps)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user