Experimental ApplyRifleXRoPE_HunuyanVideo

This commit is contained in:
kijai 2025-02-24 18:46:20 +02:00
parent 8f3cc622a8
commit 9273d23916
2 changed files with 78 additions and 1 deletions

View File

@ -180,6 +180,7 @@ NODE_CONFIG = {
"LeapfusionHunyuanI2VPatcher": {"class": LeapfusionHunyuanI2V, "name": "Leapfusion Hunyuan I2V Patcher"}, "LeapfusionHunyuanI2VPatcher": {"class": LeapfusionHunyuanI2V, "name": "Leapfusion Hunyuan I2V Patcher"},
"VAELoaderKJ": {"class": VAELoaderKJ, "name": "VAELoader KJ"}, "VAELoaderKJ": {"class": VAELoaderKJ, "name": "VAELoader KJ"},
"ScheduledCFGGuidance": {"class": ScheduledCFGGuidance, "name": "Scheduled CFG Guidance"}, "ScheduledCFGGuidance": {"class": ScheduledCFGGuidance, "name": "Scheduled CFG Guidance"},
"ApplyRifleXRoPE_HunuyanVideo": {"class": ApplyRifleXRoPE_HunuyanVideo, "name": "Apply RifleXRoPE HunuyanVideo"},
#instance diffusion #instance diffusion
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking}, "CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},

View File

@ -1,4 +1,5 @@
import torch import torch
import torch.nn as nn
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from typing import Union from typing import Union
@ -2489,3 +2490,78 @@ cfg input can be a list of floats matching step count, or a single float for all
guider.set_conds(positive, negative) guider.set_conds(positive, negative)
guider.set_cfg(cfg, start_percent, end_percent) guider.set_cfg(cfg, start_percent, end_percent)
return (guider, ) return (guider, )
class ApplyRifleXRoPE_HunuyanVideo:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"latent": ("LATENT", ),
"k": ("INT", {"default": 4, "min": 1, "max": 100, "step": 1}),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "KJNodes/experimental"
EXPERIMENTAL = True
DESCRIPTION = "Extends the potential frame count of HunyuanVideo using this method: https://github.com/thu-ml/RIFLEx"
def patch(self, model, latent, k):
model_class = model.model.diffusion_model
print(model_class.pe_embedder)
model_clone = model.clone()
num_frames = latent["samples"].shape[2]
pe_embedder = EmbedND_RifleX(
model_class.params.hidden_size // model_class.params.num_heads,
model_class.params.theta,
model_class.params.axes_dim,
num_frames,
k
)
model_clone.add_object_patch(f"diffusion_model.pe_embedder", pe_embedder)
return (model_clone, )
def rope_riflex(pos, dim, theta, L_test, k):
from einops import rearrange
assert dim % 2 == 0
if model_management.is_device_mps(pos.device) or model_management.is_intel_xpu() or model_management.is_directml_enabled():
device = torch.device("cpu")
else:
device = pos.device
scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device)
omega = 1.0 / (theta**scale)
# RIFLEX modification - adjust last frequency component if L_test and k are provided
omega[k-1] = 0.9 * 2 * torch.pi / L_test
out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
print("rope: ", out)
return out.to(dtype=torch.float32, device=pos.device)
class EmbedND_RifleX(nn.Module):
def __init__(self, dim, theta, axes_dim, num_frames, k):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
self.num_frames = num_frames
self.k = k
def forward(self, ids):
n_axes = ids.shape[-1]
emb = torch.cat(
[rope_riflex(ids[..., i], self.axes_dim[i], self.theta, self.num_frames, self.k) for i in range(n_axes)],
dim=-3,
)
print("emb: ", emb)
return emb.unsqueeze(1)