From 9273d2391679b25bc37abb29c8f148226eecc50b Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Mon, 24 Feb 2025 18:46:20 +0200 Subject: [PATCH] Experimental ApplyRifleXRoPE_HunuyanVideo --- __init__.py | 1 + nodes/nodes.py | 78 +++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 78 insertions(+), 1 deletion(-) diff --git a/__init__.py b/__init__.py index 98049ac..a9188a5 100644 --- a/__init__.py +++ b/__init__.py @@ -180,6 +180,7 @@ NODE_CONFIG = { "LeapfusionHunyuanI2VPatcher": {"class": LeapfusionHunyuanI2V, "name": "Leapfusion Hunyuan I2V Patcher"}, "VAELoaderKJ": {"class": VAELoaderKJ, "name": "VAELoader KJ"}, "ScheduledCFGGuidance": {"class": ScheduledCFGGuidance, "name": "Scheduled CFG Guidance"}, + "ApplyRifleXRoPE_HunuyanVideo": {"class": ApplyRifleXRoPE_HunuyanVideo, "name": "Apply RifleXRoPE HunuyanVideo"}, #instance diffusion "CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking}, diff --git a/nodes/nodes.py b/nodes/nodes.py index 70a9017..fa01fc8 100644 --- a/nodes/nodes.py +++ b/nodes/nodes.py @@ -1,4 +1,5 @@ import torch +import torch.nn as nn import numpy as np from PIL import Image from typing import Union @@ -2488,4 +2489,79 @@ cfg input can be a list of floats matching step count, or a single float for all guider = Guider_ScheduledCFG(model) guider.set_conds(positive, negative) guider.set_cfg(cfg, start_percent, end_percent) - return (guider, ) \ No newline at end of file + return (guider, ) + + +class ApplyRifleXRoPE_HunuyanVideo: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL",), + "latent": ("LATENT", ), + "k": ("INT", {"default": 4, "min": 1, "max": 100, "step": 1}), + } + } + + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + CATEGORY = "KJNodes/experimental" + EXPERIMENTAL = True + DESCRIPTION = "Extends the potential frame count of HunyuanVideo using this method: https://github.com/thu-ml/RIFLEx" + + def patch(self, model, latent, k): + model_class = model.model.diffusion_model + print(model_class.pe_embedder) + + model_clone = model.clone() + num_frames = latent["samples"].shape[2] + + pe_embedder = EmbedND_RifleX( + model_class.params.hidden_size // model_class.params.num_heads, + model_class.params.theta, + model_class.params.axes_dim, + num_frames, + k + ) + + model_clone.add_object_patch(f"diffusion_model.pe_embedder", pe_embedder) + + return (model_clone, ) + +def rope_riflex(pos, dim, theta, L_test, k): + from einops import rearrange + assert dim % 2 == 0 + if model_management.is_device_mps(pos.device) or model_management.is_intel_xpu() or model_management.is_directml_enabled(): + device = torch.device("cpu") + else: + device = pos.device + + scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device) + omega = 1.0 / (theta**scale) + + # RIFLEX modification - adjust last frequency component if L_test and k are provided + omega[k-1] = 0.9 * 2 * torch.pi / L_test + + out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega) + out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) + out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) + print("rope: ", out) + return out.to(dtype=torch.float32, device=pos.device) + +class EmbedND_RifleX(nn.Module): + def __init__(self, dim, theta, axes_dim, num_frames, k): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + self.num_frames = num_frames + self.k = k + + def forward(self, ids): + n_axes = ids.shape[-1] + emb = torch.cat( + [rope_riflex(ids[..., i], self.axes_dim[i], self.theta, self.num_frames, self.k) for i in range(n_axes)], + dim=-3, + ) + print("emb: ", emb) + return emb.unsqueeze(1) \ No newline at end of file