From 97d20e27e589854451a9d1f091f6524e947d6229 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Thu, 27 Feb 2025 12:55:53 +0200 Subject: [PATCH] ApplyRifleXRoPE_WanVideo experimental --- __init__.py | 1 + nodes/nodes.py | 54 +++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/__init__.py b/__init__.py index f78343a..b9a875f 100644 --- a/__init__.py +++ b/__init__.py @@ -182,6 +182,7 @@ NODE_CONFIG = { "VAELoaderKJ": {"class": VAELoaderKJ, "name": "VAELoader KJ"}, "ScheduledCFGGuidance": {"class": ScheduledCFGGuidance, "name": "Scheduled CFG Guidance"}, "ApplyRifleXRoPE_HunuyanVideo": {"class": ApplyRifleXRoPE_HunuyanVideo, "name": "Apply RifleXRoPE HunuyanVideo"}, + "ApplyRifleXRoPE_WanVideo": {"class": ApplyRifleXRoPE_WanVideo, "name": "Apply RifleXRoPE WanVideo"}, #instance diffusion "CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking}, diff --git a/nodes/nodes.py b/nodes/nodes.py index a90be48..c206393 100644 --- a/nodes/nodes.py +++ b/nodes/nodes.py @@ -2494,6 +2494,42 @@ cfg input can be a list of floats matching step count, or a single float for all return (guider, ) +class ApplyRifleXRoPE_WanVideo: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL",), + "latent": ("LATENT", {"tooltip": "Only used to get the latent count"}), + "k": ("INT", {"default": 4, "min": 1, "max": 100, "step": 1, "tooltip": "Index of intrinsic frequency"}), + } + } + + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + CATEGORY = "KJNodes/experimental" + EXPERIMENTAL = True + DESCRIPTION = "Extends the potential frame count of HunyuanVideo using this method: https://github.com/thu-ml/RIFLEx" + + def patch(self, model, latent, k): + model_class = model.model.diffusion_model + + model_clone = model.clone() + num_frames = latent["samples"].shape[2] + d = model_class.dim // model_class.num_heads + + rope_embedder = EmbedND_RifleX( + d, + 10000.0, + [d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)], + num_frames, + k + ) + + model_clone.add_object_patch(f"diffusion_model.rope_embedder", rope_embedder) + + return (model_clone, ) + class ApplyRifleXRoPE_HunuyanVideo: @classmethod def INPUT_TYPES(s): @@ -2513,7 +2549,6 @@ class ApplyRifleXRoPE_HunuyanVideo: def patch(self, model, latent, k): model_class = model.model.diffusion_model - print(model_class.pe_embedder) model_clone = model.clone() num_frames = latent["samples"].shape[2] @@ -2549,6 +2584,23 @@ def rope_riflex(pos, dim, theta, L_test, k): out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) return out.to(dtype=torch.float32, device=pos.device) +class EmbedND_RifleX(nn.Module): + def __init__(self, dim, theta, axes_dim, num_frames, k): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + self.num_frames = num_frames + self.k = k + + def forward(self, ids): + n_axes = ids.shape[-1] + emb = torch.cat( + [rope_riflex(ids[..., i], self.axes_dim[i], self.theta, self.num_frames, self.k) for i in range(n_axes)], + dim=-3, + ) + return emb.unsqueeze(1) + class EmbedND_RifleX(nn.Module): def __init__(self, dim, theta, axes_dim, num_frames, k): super().__init__()