Only apply RifleX to the temporal dimension

This commit is contained in:
blepping 2025-02-28 03:44:03 -07:00 committed by GitHub
parent 97d20e27e5
commit dc482957d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2577,7 +2577,8 @@ def rope_riflex(pos, dim, theta, L_test, k):
omega = 1.0 / (theta**scale)
# RIFLEX modification - adjust last frequency component if L_test and k are provided
omega[k-1] = 0.9 * 2 * torch.pi / L_test
if k and L_test:
omega[k-1] = 0.9 * 2 * torch.pi / L_test
out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
@ -2596,24 +2597,7 @@ class EmbedND_RifleX(nn.Module):
def forward(self, ids):
n_axes = ids.shape[-1]
emb = torch.cat(
[rope_riflex(ids[..., i], self.axes_dim[i], self.theta, self.num_frames, self.k) for i in range(n_axes)],
[rope_riflex(ids[..., i], self.axes_dim[i], self.theta, self.num_frames, self.k if i == 0 else 0) for i in range(n_axes)],
dim=-3,
)
return emb.unsqueeze(1)
class EmbedND_RifleX(nn.Module):
def __init__(self, dim, theta, axes_dim, num_frames, k):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
self.num_frames = num_frames
self.k = k
def forward(self, ids):
n_axes = ids.shape[-1]
emb = torch.cat(
[rope_riflex(ids[..., i], self.axes_dim[i], self.theta, self.num_frames, self.k) for i in range(n_axes)],
dim=-3,
)
return emb.unsqueeze(1)