mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-05-27 06:57:52 +08:00
Merge pull request #201 from blepping/fix_riflex
Only apply RifleX to the temporal dimension
This commit is contained in:
commit
82272ef448
@ -2577,7 +2577,8 @@ def rope_riflex(pos, dim, theta, L_test, k):
|
|||||||
omega = 1.0 / (theta**scale)
|
omega = 1.0 / (theta**scale)
|
||||||
|
|
||||||
# RIFLEX modification - adjust last frequency component if L_test and k are provided
|
# RIFLEX modification - adjust last frequency component if L_test and k are provided
|
||||||
omega[k-1] = 0.9 * 2 * torch.pi / L_test
|
if k and L_test:
|
||||||
|
omega[k-1] = 0.9 * 2 * torch.pi / L_test
|
||||||
|
|
||||||
out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega)
|
out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega)
|
||||||
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
||||||
@ -2596,24 +2597,7 @@ class EmbedND_RifleX(nn.Module):
|
|||||||
def forward(self, ids):
|
def forward(self, ids):
|
||||||
n_axes = ids.shape[-1]
|
n_axes = ids.shape[-1]
|
||||||
emb = torch.cat(
|
emb = torch.cat(
|
||||||
[rope_riflex(ids[..., i], self.axes_dim[i], self.theta, self.num_frames, self.k) for i in range(n_axes)],
|
[rope_riflex(ids[..., i], self.axes_dim[i], self.theta, self.num_frames, self.k if i == 0 else 0) for i in range(n_axes)],
|
||||||
dim=-3,
|
dim=-3,
|
||||||
)
|
)
|
||||||
return emb.unsqueeze(1)
|
return emb.unsqueeze(1)
|
||||||
|
|
||||||
class EmbedND_RifleX(nn.Module):
|
|
||||||
def __init__(self, dim, theta, axes_dim, num_frames, k):
|
|
||||||
super().__init__()
|
|
||||||
self.dim = dim
|
|
||||||
self.theta = theta
|
|
||||||
self.axes_dim = axes_dim
|
|
||||||
self.num_frames = num_frames
|
|
||||||
self.k = k
|
|
||||||
|
|
||||||
def forward(self, ids):
|
|
||||||
n_axes = ids.shape[-1]
|
|
||||||
emb = torch.cat(
|
|
||||||
[rope_riflex(ids[..., i], self.axes_dim[i], self.theta, self.num_frames, self.k) for i in range(n_axes)],
|
|
||||||
dim=-3,
|
|
||||||
)
|
|
||||||
return emb.unsqueeze(1)
|
|
||||||
Loading…
x
Reference in New Issue
Block a user