mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-01-16 20:04:22 +08:00
Only apply RifleX to the temporal dimension
This commit is contained in:
parent
97d20e27e5
commit
dc482957d8
@ -2577,7 +2577,8 @@ def rope_riflex(pos, dim, theta, L_test, k):
|
||||
omega = 1.0 / (theta**scale)
|
||||
|
||||
# RIFLEX modification - adjust last frequency component if L_test and k are provided
|
||||
omega[k-1] = 0.9 * 2 * torch.pi / L_test
|
||||
if k and L_test:
|
||||
omega[k-1] = 0.9 * 2 * torch.pi / L_test
|
||||
|
||||
out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega)
|
||||
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
||||
@ -2596,24 +2597,7 @@ class EmbedND_RifleX(nn.Module):
|
||||
def forward(self, ids):
|
||||
n_axes = ids.shape[-1]
|
||||
emb = torch.cat(
|
||||
[rope_riflex(ids[..., i], self.axes_dim[i], self.theta, self.num_frames, self.k) for i in range(n_axes)],
|
||||
[rope_riflex(ids[..., i], self.axes_dim[i], self.theta, self.num_frames, self.k if i == 0 else 0) for i in range(n_axes)],
|
||||
dim=-3,
|
||||
)
|
||||
return emb.unsqueeze(1)
|
||||
|
||||
class EmbedND_RifleX(nn.Module):
|
||||
def __init__(self, dim, theta, axes_dim, num_frames, k):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
self.num_frames = num_frames
|
||||
self.k = k
|
||||
|
||||
def forward(self, ids):
|
||||
n_axes = ids.shape[-1]
|
||||
emb = torch.cat(
|
||||
[rope_riflex(ids[..., i], self.axes_dim[i], self.theta, self.num_frames, self.k) for i in range(n_axes)],
|
||||
dim=-3,
|
||||
)
|
||||
return emb.unsqueeze(1)
|
||||
Loading…
x
Reference in New Issue
Block a user