ApplyRifleXRoPE_WanVideo

experimental
This commit is contained in:
kijai 2025-02-27 12:55:53 +02:00
parent 1a4259f052
commit 97d20e27e5
2 changed files with 54 additions and 1 deletions

View File

@ -182,6 +182,7 @@ NODE_CONFIG = {
"VAELoaderKJ": {"class": VAELoaderKJ, "name": "VAELoader KJ"},
"ScheduledCFGGuidance": {"class": ScheduledCFGGuidance, "name": "Scheduled CFG Guidance"},
"ApplyRifleXRoPE_HunuyanVideo": {"class": ApplyRifleXRoPE_HunuyanVideo, "name": "Apply RifleXRoPE HunuyanVideo"},
"ApplyRifleXRoPE_WanVideo": {"class": ApplyRifleXRoPE_WanVideo, "name": "Apply RifleXRoPE WanVideo"},
#instance diffusion
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},

View File

@ -2494,6 +2494,42 @@ cfg input can be a list of floats matching step count, or a single float for all
return (guider, )
class ApplyRifleXRoPE_WanVideo:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"latent": ("LATENT", {"tooltip": "Only used to get the latent count"}),
"k": ("INT", {"default": 4, "min": 1, "max": 100, "step": 1, "tooltip": "Index of intrinsic frequency"}),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "KJNodes/experimental"
EXPERIMENTAL = True
DESCRIPTION = "Extends the potential frame count of HunyuanVideo using this method: https://github.com/thu-ml/RIFLEx"
def patch(self, model, latent, k):
model_class = model.model.diffusion_model
model_clone = model.clone()
num_frames = latent["samples"].shape[2]
d = model_class.dim // model_class.num_heads
rope_embedder = EmbedND_RifleX(
d,
10000.0,
[d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)],
num_frames,
k
)
model_clone.add_object_patch(f"diffusion_model.rope_embedder", rope_embedder)
return (model_clone, )
class ApplyRifleXRoPE_HunuyanVideo:
@classmethod
def INPUT_TYPES(s):
@ -2513,7 +2549,6 @@ class ApplyRifleXRoPE_HunuyanVideo:
def patch(self, model, latent, k):
model_class = model.model.diffusion_model
print(model_class.pe_embedder)
model_clone = model.clone()
num_frames = latent["samples"].shape[2]
@ -2565,3 +2600,20 @@ class EmbedND_RifleX(nn.Module):
dim=-3,
)
return emb.unsqueeze(1)
class EmbedND_RifleX(nn.Module):
def __init__(self, dim, theta, axes_dim, num_frames, k):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
self.num_frames = num_frames
self.k = k
def forward(self, ids):
n_axes = ids.shape[-1]
emb = torch.cat(
[rope_riflex(ids[..., i], self.axes_dim[i], self.theta, self.num_frames, self.k) for i in range(n_axes)],
dim=-3,
)
return emb.unsqueeze(1)