mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-06-01 17:07:07 +08:00
Experimental ApplyRifleXRoPE_HunuyanVideo
This commit is contained in:
parent
8f3cc622a8
commit
9273d23916
@ -180,6 +180,7 @@ NODE_CONFIG = {
|
|||||||
"LeapfusionHunyuanI2VPatcher": {"class": LeapfusionHunyuanI2V, "name": "Leapfusion Hunyuan I2V Patcher"},
|
"LeapfusionHunyuanI2VPatcher": {"class": LeapfusionHunyuanI2V, "name": "Leapfusion Hunyuan I2V Patcher"},
|
||||||
"VAELoaderKJ": {"class": VAELoaderKJ, "name": "VAELoader KJ"},
|
"VAELoaderKJ": {"class": VAELoaderKJ, "name": "VAELoader KJ"},
|
||||||
"ScheduledCFGGuidance": {"class": ScheduledCFGGuidance, "name": "Scheduled CFG Guidance"},
|
"ScheduledCFGGuidance": {"class": ScheduledCFGGuidance, "name": "Scheduled CFG Guidance"},
|
||||||
|
"ApplyRifleXRoPE_HunuyanVideo": {"class": ApplyRifleXRoPE_HunuyanVideo, "name": "Apply RifleXRoPE HunuyanVideo"},
|
||||||
|
|
||||||
#instance diffusion
|
#instance diffusion
|
||||||
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},
|
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from typing import Union
|
from typing import Union
|
||||||
@ -2488,4 +2489,79 @@ cfg input can be a list of floats matching step count, or a single float for all
|
|||||||
guider = Guider_ScheduledCFG(model)
|
guider = Guider_ScheduledCFG(model)
|
||||||
guider.set_conds(positive, negative)
|
guider.set_conds(positive, negative)
|
||||||
guider.set_cfg(cfg, start_percent, end_percent)
|
guider.set_cfg(cfg, start_percent, end_percent)
|
||||||
return (guider, )
|
return (guider, )
|
||||||
|
|
||||||
|
|
||||||
|
class ApplyRifleXRoPE_HunuyanVideo:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model": ("MODEL",),
|
||||||
|
"latent": ("LATENT", ),
|
||||||
|
"k": ("INT", {"default": 4, "min": 1, "max": 100, "step": 1}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "patch"
|
||||||
|
CATEGORY = "KJNodes/experimental"
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
DESCRIPTION = "Extends the potential frame count of HunyuanVideo using this method: https://github.com/thu-ml/RIFLEx"
|
||||||
|
|
||||||
|
def patch(self, model, latent, k):
|
||||||
|
model_class = model.model.diffusion_model
|
||||||
|
print(model_class.pe_embedder)
|
||||||
|
|
||||||
|
model_clone = model.clone()
|
||||||
|
num_frames = latent["samples"].shape[2]
|
||||||
|
|
||||||
|
pe_embedder = EmbedND_RifleX(
|
||||||
|
model_class.params.hidden_size // model_class.params.num_heads,
|
||||||
|
model_class.params.theta,
|
||||||
|
model_class.params.axes_dim,
|
||||||
|
num_frames,
|
||||||
|
k
|
||||||
|
)
|
||||||
|
|
||||||
|
model_clone.add_object_patch(f"diffusion_model.pe_embedder", pe_embedder)
|
||||||
|
|
||||||
|
return (model_clone, )
|
||||||
|
|
||||||
|
def rope_riflex(pos, dim, theta, L_test, k):
|
||||||
|
from einops import rearrange
|
||||||
|
assert dim % 2 == 0
|
||||||
|
if model_management.is_device_mps(pos.device) or model_management.is_intel_xpu() or model_management.is_directml_enabled():
|
||||||
|
device = torch.device("cpu")
|
||||||
|
else:
|
||||||
|
device = pos.device
|
||||||
|
|
||||||
|
scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device)
|
||||||
|
omega = 1.0 / (theta**scale)
|
||||||
|
|
||||||
|
# RIFLEX modification - adjust last frequency component if L_test and k are provided
|
||||||
|
omega[k-1] = 0.9 * 2 * torch.pi / L_test
|
||||||
|
|
||||||
|
out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega)
|
||||||
|
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
||||||
|
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||||
|
print("rope: ", out)
|
||||||
|
return out.to(dtype=torch.float32, device=pos.device)
|
||||||
|
|
||||||
|
class EmbedND_RifleX(nn.Module):
|
||||||
|
def __init__(self, dim, theta, axes_dim, num_frames, k):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.theta = theta
|
||||||
|
self.axes_dim = axes_dim
|
||||||
|
self.num_frames = num_frames
|
||||||
|
self.k = k
|
||||||
|
|
||||||
|
def forward(self, ids):
|
||||||
|
n_axes = ids.shape[-1]
|
||||||
|
emb = torch.cat(
|
||||||
|
[rope_riflex(ids[..., i], self.axes_dim[i], self.theta, self.num_frames, self.k) for i in range(n_axes)],
|
||||||
|
dim=-3,
|
||||||
|
)
|
||||||
|
print("emb: ", emb)
|
||||||
|
return emb.unsqueeze(1)
|
||||||
Loading…
x
Reference in New Issue
Block a user