From dc482957d814a5a78000a3452b6c623a48fbd992 Mon Sep 17 00:00:00 2001 From: blepping <157360029+blepping@users.noreply.github.com> Date: Fri, 28 Feb 2025 03:44:03 -0700 Subject: [PATCH] Only apply RifleX to the temporal dimension --- nodes/nodes.py | 22 +++------------------- 1 file changed, 3 insertions(+), 19 deletions(-) diff --git a/nodes/nodes.py b/nodes/nodes.py index c206393..38cd54f 100644 --- a/nodes/nodes.py +++ b/nodes/nodes.py @@ -2577,7 +2577,8 @@ def rope_riflex(pos, dim, theta, L_test, k): omega = 1.0 / (theta**scale) # RIFLEX modification - adjust last frequency component if L_test and k are provided - omega[k-1] = 0.9 * 2 * torch.pi / L_test + if k and L_test: + omega[k-1] = 0.9 * 2 * torch.pi / L_test out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega) out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) @@ -2596,24 +2597,7 @@ class EmbedND_RifleX(nn.Module): def forward(self, ids): n_axes = ids.shape[-1] emb = torch.cat( - [rope_riflex(ids[..., i], self.axes_dim[i], self.theta, self.num_frames, self.k) for i in range(n_axes)], + [rope_riflex(ids[..., i], self.axes_dim[i], self.theta, self.num_frames, self.k if i == 0 else 0) for i in range(n_axes)], dim=-3, ) return emb.unsqueeze(1) - -class EmbedND_RifleX(nn.Module): - def __init__(self, dim, theta, axes_dim, num_frames, k): - super().__init__() - self.dim = dim - self.theta = theta - self.axes_dim = axes_dim - self.num_frames = num_frames - self.k = k - - def forward(self, ids): - n_axes = ids.shape[-1] - emb = torch.cat( - [rope_riflex(ids[..., i], self.axes_dim[i], self.theta, self.num_frames, self.k) for i in range(n_axes)], - dim=-3, - ) - return emb.unsqueeze(1) \ No newline at end of file