Update nodes.py

This commit is contained in:
kijai 2025-02-24 18:58:04 +02:00
parent 9273d23916
commit 69ec71d2bd

View File

@ -2498,8 +2498,8 @@ class ApplyRifleXRoPE_HunuyanVideo:
return {
"required": {
"model": ("MODEL",),
"latent": ("LATENT", ),
"k": ("INT", {"default": 4, "min": 1, "max": 100, "step": 1}),
"latent": ("LATENT", {"tooltip": "Only used to get the latent count"}),
"k": ("INT", {"default": 4, "min": 1, "max": 100, "step": 1, "tooltip": "Index of intrinsic frequency"}),
}
}
@ -2545,7 +2545,6 @@ def rope_riflex(pos, dim, theta, L_test, k):
out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
print("rope: ", out)
return out.to(dtype=torch.float32, device=pos.device)
class EmbedND_RifleX(nn.Module):
@ -2563,5 +2562,4 @@ class EmbedND_RifleX(nn.Module):
[rope_riflex(ids[..., i], self.axes_dim[i], self.theta, self.num_frames, self.k) for i in range(n_axes)],
dim=-3,
)
print("emb: ", emb)
return emb.unsqueeze(1)