Experimental ApplyRifleXRoPE_HunuyanVideo

This commit is contained in:
kijai 2025-02-24 18:46:20 +02:00
parent 8f3cc622a8
commit 9273d23916
2 changed files with 78 additions and 1 deletions

View File

@ -180,6 +180,7 @@ NODE_CONFIG = {
"LeapfusionHunyuanI2VPatcher": {"class": LeapfusionHunyuanI2V, "name": "Leapfusion Hunyuan I2V Patcher"},
"VAELoaderKJ": {"class": VAELoaderKJ, "name": "VAELoader KJ"},
"ScheduledCFGGuidance": {"class": ScheduledCFGGuidance, "name": "Scheduled CFG Guidance"},
"ApplyRifleXRoPE_HunuyanVideo": {"class": ApplyRifleXRoPE_HunuyanVideo, "name": "Apply RifleXRoPE HunuyanVideo"},
#instance diffusion
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},

View File

@ -1,4 +1,5 @@
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
from typing import Union
@ -2488,4 +2489,79 @@ cfg input can be a list of floats matching step count, or a single float for all
guider = Guider_ScheduledCFG(model)
guider.set_conds(positive, negative)
guider.set_cfg(cfg, start_percent, end_percent)
return (guider, )
return (guider, )
class ApplyRifleXRoPE_HunuyanVideo:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"latent": ("LATENT", ),
"k": ("INT", {"default": 4, "min": 1, "max": 100, "step": 1}),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "KJNodes/experimental"
EXPERIMENTAL = True
DESCRIPTION = "Extends the potential frame count of HunyuanVideo using this method: https://github.com/thu-ml/RIFLEx"
def patch(self, model, latent, k):
model_class = model.model.diffusion_model
print(model_class.pe_embedder)
model_clone = model.clone()
num_frames = latent["samples"].shape[2]
pe_embedder = EmbedND_RifleX(
model_class.params.hidden_size // model_class.params.num_heads,
model_class.params.theta,
model_class.params.axes_dim,
num_frames,
k
)
model_clone.add_object_patch(f"diffusion_model.pe_embedder", pe_embedder)
return (model_clone, )
def rope_riflex(pos, dim, theta, L_test, k):
from einops import rearrange
assert dim % 2 == 0
if model_management.is_device_mps(pos.device) or model_management.is_intel_xpu() or model_management.is_directml_enabled():
device = torch.device("cpu")
else:
device = pos.device
scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device)
omega = 1.0 / (theta**scale)
# RIFLEX modification - adjust last frequency component if L_test and k are provided
omega[k-1] = 0.9 * 2 * torch.pi / L_test
out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
print("rope: ", out)
return out.to(dtype=torch.float32, device=pos.device)
class EmbedND_RifleX(nn.Module):
def __init__(self, dim, theta, axes_dim, num_frames, k):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
self.num_frames = num_frames
self.k = k
def forward(self, ids):
n_axes = ids.shape[-1]
emb = torch.cat(
[rope_riflex(ids[..., i], self.axes_dim[i], self.theta, self.num_frames, self.k) for i in range(n_axes)],
dim=-3,
)
print("emb: ", emb)
return emb.unsqueeze(1)