mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-09 21:04:41 +08:00
ApplyRifleXRoPE_WanVideo
experimental
This commit is contained in:
parent
1a4259f052
commit
97d20e27e5
@ -182,6 +182,7 @@ NODE_CONFIG = {
|
||||
"VAELoaderKJ": {"class": VAELoaderKJ, "name": "VAELoader KJ"},
|
||||
"ScheduledCFGGuidance": {"class": ScheduledCFGGuidance, "name": "Scheduled CFG Guidance"},
|
||||
"ApplyRifleXRoPE_HunuyanVideo": {"class": ApplyRifleXRoPE_HunuyanVideo, "name": "Apply RifleXRoPE HunuyanVideo"},
|
||||
"ApplyRifleXRoPE_WanVideo": {"class": ApplyRifleXRoPE_WanVideo, "name": "Apply RifleXRoPE WanVideo"},
|
||||
|
||||
#instance diffusion
|
||||
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},
|
||||
|
||||
@ -2494,6 +2494,42 @@ cfg input can be a list of floats matching step count, or a single float for all
|
||||
return (guider, )
|
||||
|
||||
|
||||
class ApplyRifleXRoPE_WanVideo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"latent": ("LATENT", {"tooltip": "Only used to get the latent count"}),
|
||||
"k": ("INT", {"default": 4, "min": 1, "max": 100, "step": 1, "tooltip": "Index of intrinsic frequency"}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
CATEGORY = "KJNodes/experimental"
|
||||
EXPERIMENTAL = True
|
||||
DESCRIPTION = "Extends the potential frame count of HunyuanVideo using this method: https://github.com/thu-ml/RIFLEx"
|
||||
|
||||
def patch(self, model, latent, k):
|
||||
model_class = model.model.diffusion_model
|
||||
|
||||
model_clone = model.clone()
|
||||
num_frames = latent["samples"].shape[2]
|
||||
d = model_class.dim // model_class.num_heads
|
||||
|
||||
rope_embedder = EmbedND_RifleX(
|
||||
d,
|
||||
10000.0,
|
||||
[d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)],
|
||||
num_frames,
|
||||
k
|
||||
)
|
||||
|
||||
model_clone.add_object_patch(f"diffusion_model.rope_embedder", rope_embedder)
|
||||
|
||||
return (model_clone, )
|
||||
|
||||
class ApplyRifleXRoPE_HunuyanVideo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -2513,7 +2549,6 @@ class ApplyRifleXRoPE_HunuyanVideo:
|
||||
|
||||
def patch(self, model, latent, k):
|
||||
model_class = model.model.diffusion_model
|
||||
print(model_class.pe_embedder)
|
||||
|
||||
model_clone = model.clone()
|
||||
num_frames = latent["samples"].shape[2]
|
||||
@ -2549,6 +2584,23 @@ def rope_riflex(pos, dim, theta, L_test, k):
|
||||
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||
return out.to(dtype=torch.float32, device=pos.device)
|
||||
|
||||
class EmbedND_RifleX(nn.Module):
|
||||
def __init__(self, dim, theta, axes_dim, num_frames, k):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
self.num_frames = num_frames
|
||||
self.k = k
|
||||
|
||||
def forward(self, ids):
|
||||
n_axes = ids.shape[-1]
|
||||
emb = torch.cat(
|
||||
[rope_riflex(ids[..., i], self.axes_dim[i], self.theta, self.num_frames, self.k) for i in range(n_axes)],
|
||||
dim=-3,
|
||||
)
|
||||
return emb.unsqueeze(1)
|
||||
|
||||
class EmbedND_RifleX(nn.Module):
|
||||
def __init__(self, dim, theta, axes_dim, num_frames, k):
|
||||
super().__init__()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user