mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-09 04:44:30 +08:00
Experimental ApplyRifleXRoPE_HunuyanVideo
This commit is contained in:
parent
8f3cc622a8
commit
9273d23916
@ -180,6 +180,7 @@ NODE_CONFIG = {
|
||||
"LeapfusionHunyuanI2VPatcher": {"class": LeapfusionHunyuanI2V, "name": "Leapfusion Hunyuan I2V Patcher"},
|
||||
"VAELoaderKJ": {"class": VAELoaderKJ, "name": "VAELoader KJ"},
|
||||
"ScheduledCFGGuidance": {"class": ScheduledCFGGuidance, "name": "Scheduled CFG Guidance"},
|
||||
"ApplyRifleXRoPE_HunuyanVideo": {"class": ApplyRifleXRoPE_HunuyanVideo, "name": "Apply RifleXRoPE HunuyanVideo"},
|
||||
|
||||
#instance diffusion
|
||||
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from typing import Union
|
||||
@ -2488,4 +2489,79 @@ cfg input can be a list of floats matching step count, or a single float for all
|
||||
guider = Guider_ScheduledCFG(model)
|
||||
guider.set_conds(positive, negative)
|
||||
guider.set_cfg(cfg, start_percent, end_percent)
|
||||
return (guider, )
|
||||
return (guider, )
|
||||
|
||||
|
||||
class ApplyRifleXRoPE_HunuyanVideo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"latent": ("LATENT", ),
|
||||
"k": ("INT", {"default": 4, "min": 1, "max": 100, "step": 1}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
CATEGORY = "KJNodes/experimental"
|
||||
EXPERIMENTAL = True
|
||||
DESCRIPTION = "Extends the potential frame count of HunyuanVideo using this method: https://github.com/thu-ml/RIFLEx"
|
||||
|
||||
def patch(self, model, latent, k):
|
||||
model_class = model.model.diffusion_model
|
||||
print(model_class.pe_embedder)
|
||||
|
||||
model_clone = model.clone()
|
||||
num_frames = latent["samples"].shape[2]
|
||||
|
||||
pe_embedder = EmbedND_RifleX(
|
||||
model_class.params.hidden_size // model_class.params.num_heads,
|
||||
model_class.params.theta,
|
||||
model_class.params.axes_dim,
|
||||
num_frames,
|
||||
k
|
||||
)
|
||||
|
||||
model_clone.add_object_patch(f"diffusion_model.pe_embedder", pe_embedder)
|
||||
|
||||
return (model_clone, )
|
||||
|
||||
def rope_riflex(pos, dim, theta, L_test, k):
|
||||
from einops import rearrange
|
||||
assert dim % 2 == 0
|
||||
if model_management.is_device_mps(pos.device) or model_management.is_intel_xpu() or model_management.is_directml_enabled():
|
||||
device = torch.device("cpu")
|
||||
else:
|
||||
device = pos.device
|
||||
|
||||
scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device)
|
||||
omega = 1.0 / (theta**scale)
|
||||
|
||||
# RIFLEX modification - adjust last frequency component if L_test and k are provided
|
||||
omega[k-1] = 0.9 * 2 * torch.pi / L_test
|
||||
|
||||
out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega)
|
||||
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
||||
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||
print("rope: ", out)
|
||||
return out.to(dtype=torch.float32, device=pos.device)
|
||||
|
||||
class EmbedND_RifleX(nn.Module):
|
||||
def __init__(self, dim, theta, axes_dim, num_frames, k):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
self.num_frames = num_frames
|
||||
self.k = k
|
||||
|
||||
def forward(self, ids):
|
||||
n_axes = ids.shape[-1]
|
||||
emb = torch.cat(
|
||||
[rope_riflex(ids[..., i], self.axes_dim[i], self.theta, self.num_frames, self.k) for i in range(n_axes)],
|
||||
dim=-3,
|
||||
)
|
||||
print("emb: ", emb)
|
||||
return emb.unsqueeze(1)
|
||||
Loading…
x
Reference in New Issue
Block a user