mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-09 12:54:40 +08:00
ApplyRifleXRoPE_WanVideo
experimental
This commit is contained in:
parent
1a4259f052
commit
97d20e27e5
@ -182,6 +182,7 @@ NODE_CONFIG = {
|
|||||||
"VAELoaderKJ": {"class": VAELoaderKJ, "name": "VAELoader KJ"},
|
"VAELoaderKJ": {"class": VAELoaderKJ, "name": "VAELoader KJ"},
|
||||||
"ScheduledCFGGuidance": {"class": ScheduledCFGGuidance, "name": "Scheduled CFG Guidance"},
|
"ScheduledCFGGuidance": {"class": ScheduledCFGGuidance, "name": "Scheduled CFG Guidance"},
|
||||||
"ApplyRifleXRoPE_HunuyanVideo": {"class": ApplyRifleXRoPE_HunuyanVideo, "name": "Apply RifleXRoPE HunuyanVideo"},
|
"ApplyRifleXRoPE_HunuyanVideo": {"class": ApplyRifleXRoPE_HunuyanVideo, "name": "Apply RifleXRoPE HunuyanVideo"},
|
||||||
|
"ApplyRifleXRoPE_WanVideo": {"class": ApplyRifleXRoPE_WanVideo, "name": "Apply RifleXRoPE WanVideo"},
|
||||||
|
|
||||||
#instance diffusion
|
#instance diffusion
|
||||||
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},
|
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},
|
||||||
|
|||||||
@ -2494,6 +2494,42 @@ cfg input can be a list of floats matching step count, or a single float for all
|
|||||||
return (guider, )
|
return (guider, )
|
||||||
|
|
||||||
|
|
||||||
|
class ApplyRifleXRoPE_WanVideo:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model": ("MODEL",),
|
||||||
|
"latent": ("LATENT", {"tooltip": "Only used to get the latent count"}),
|
||||||
|
"k": ("INT", {"default": 4, "min": 1, "max": 100, "step": 1, "tooltip": "Index of intrinsic frequency"}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "patch"
|
||||||
|
CATEGORY = "KJNodes/experimental"
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
DESCRIPTION = "Extends the potential frame count of HunyuanVideo using this method: https://github.com/thu-ml/RIFLEx"
|
||||||
|
|
||||||
|
def patch(self, model, latent, k):
|
||||||
|
model_class = model.model.diffusion_model
|
||||||
|
|
||||||
|
model_clone = model.clone()
|
||||||
|
num_frames = latent["samples"].shape[2]
|
||||||
|
d = model_class.dim // model_class.num_heads
|
||||||
|
|
||||||
|
rope_embedder = EmbedND_RifleX(
|
||||||
|
d,
|
||||||
|
10000.0,
|
||||||
|
[d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)],
|
||||||
|
num_frames,
|
||||||
|
k
|
||||||
|
)
|
||||||
|
|
||||||
|
model_clone.add_object_patch(f"diffusion_model.rope_embedder", rope_embedder)
|
||||||
|
|
||||||
|
return (model_clone, )
|
||||||
|
|
||||||
class ApplyRifleXRoPE_HunuyanVideo:
|
class ApplyRifleXRoPE_HunuyanVideo:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -2513,7 +2549,6 @@ class ApplyRifleXRoPE_HunuyanVideo:
|
|||||||
|
|
||||||
def patch(self, model, latent, k):
|
def patch(self, model, latent, k):
|
||||||
model_class = model.model.diffusion_model
|
model_class = model.model.diffusion_model
|
||||||
print(model_class.pe_embedder)
|
|
||||||
|
|
||||||
model_clone = model.clone()
|
model_clone = model.clone()
|
||||||
num_frames = latent["samples"].shape[2]
|
num_frames = latent["samples"].shape[2]
|
||||||
@ -2549,6 +2584,23 @@ def rope_riflex(pos, dim, theta, L_test, k):
|
|||||||
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||||
return out.to(dtype=torch.float32, device=pos.device)
|
return out.to(dtype=torch.float32, device=pos.device)
|
||||||
|
|
||||||
|
class EmbedND_RifleX(nn.Module):
|
||||||
|
def __init__(self, dim, theta, axes_dim, num_frames, k):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.theta = theta
|
||||||
|
self.axes_dim = axes_dim
|
||||||
|
self.num_frames = num_frames
|
||||||
|
self.k = k
|
||||||
|
|
||||||
|
def forward(self, ids):
|
||||||
|
n_axes = ids.shape[-1]
|
||||||
|
emb = torch.cat(
|
||||||
|
[rope_riflex(ids[..., i], self.axes_dim[i], self.theta, self.num_frames, self.k) for i in range(n_axes)],
|
||||||
|
dim=-3,
|
||||||
|
)
|
||||||
|
return emb.unsqueeze(1)
|
||||||
|
|
||||||
class EmbedND_RifleX(nn.Module):
|
class EmbedND_RifleX(nn.Module):
|
||||||
def __init__(self, dim, theta, axes_dim, num_frames, k):
|
def __init__(self, dim, theta, axes_dim, num_frames, k):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user