mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-09 04:44:22 +08:00
- unify all pipelines into one - unify transformer model into one - separate VAE - add single file model loading
185 lines
6.3 KiB
Python
185 lines
6.3 KiB
Python
import numpy as np
|
|
from typing import Callable, Optional, List
|
|
|
|
|
|
def ordered_halving(val):
|
|
bin_str = f"{val:064b}"
|
|
bin_flip = bin_str[::-1]
|
|
as_int = int(bin_flip, 2)
|
|
|
|
return as_int / (1 << 64)
|
|
|
|
def does_window_roll_over(window: list[int], num_frames: int) -> tuple[bool, int]:
|
|
prev_val = -1
|
|
for i, val in enumerate(window):
|
|
val = val % num_frames
|
|
if val < prev_val:
|
|
return True, i
|
|
prev_val = val
|
|
return False, -1
|
|
|
|
def shift_window_to_start(window: list[int], num_frames: int):
|
|
start_val = window[0]
|
|
for i in range(len(window)):
|
|
# 1) subtract each element by start_val to move vals relative to the start of all frames
|
|
# 2) add num_frames and take modulus to get adjusted vals
|
|
window[i] = ((window[i] - start_val) + num_frames) % num_frames
|
|
|
|
def shift_window_to_end(window: list[int], num_frames: int):
|
|
# 1) shift window to start
|
|
shift_window_to_start(window, num_frames)
|
|
end_val = window[-1]
|
|
end_delta = num_frames - end_val - 1
|
|
for i in range(len(window)):
|
|
# 2) add end_delta to each val to slide windows to end
|
|
window[i] = window[i] + end_delta
|
|
|
|
def get_missing_indexes(windows: list[list[int]], num_frames: int) -> list[int]:
|
|
all_indexes = list(range(num_frames))
|
|
for w in windows:
|
|
for val in w:
|
|
try:
|
|
all_indexes.remove(val)
|
|
except ValueError:
|
|
pass
|
|
return all_indexes
|
|
|
|
def uniform_looped(
|
|
step: int = ...,
|
|
num_steps: Optional[int] = None,
|
|
num_frames: int = ...,
|
|
context_size: Optional[int] = None,
|
|
context_stride: int = 3,
|
|
context_overlap: int = 4,
|
|
closed_loop: bool = True,
|
|
):
|
|
if num_frames <= context_size:
|
|
yield list(range(num_frames))
|
|
return
|
|
|
|
context_stride = min(context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1)
|
|
|
|
for context_step in 1 << np.arange(context_stride):
|
|
pad = int(round(num_frames * ordered_halving(step)))
|
|
for j in range(
|
|
int(ordered_halving(step) * context_step) + pad,
|
|
num_frames + pad + (0 if closed_loop else -context_overlap),
|
|
(context_size * context_step - context_overlap),
|
|
):
|
|
yield [e % num_frames for e in range(j, j + context_size * context_step, context_step)]
|
|
|
|
#from AnimateDiff-Evolved by Kosinkadink (https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved)
|
|
def uniform_standard(
|
|
step: int = ...,
|
|
num_steps: Optional[int] = None,
|
|
num_frames: int = ...,
|
|
context_size: Optional[int] = None,
|
|
context_stride: int = 3,
|
|
context_overlap: int = 4,
|
|
closed_loop: bool = True,
|
|
):
|
|
windows = []
|
|
if num_frames <= context_size:
|
|
windows.append(list(range(num_frames)))
|
|
return windows
|
|
|
|
context_stride = min(context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1)
|
|
|
|
for context_step in 1 << np.arange(context_stride):
|
|
pad = int(round(num_frames * ordered_halving(step)))
|
|
for j in range(
|
|
int(ordered_halving(step) * context_step) + pad,
|
|
num_frames + pad + (0 if closed_loop else -context_overlap),
|
|
(context_size * context_step - context_overlap),
|
|
):
|
|
windows.append([e % num_frames for e in range(j, j + context_size * context_step, context_step)])
|
|
|
|
# now that windows are created, shift any windows that loop, and delete duplicate windows
|
|
delete_idxs = []
|
|
win_i = 0
|
|
while win_i < len(windows):
|
|
# if window is rolls over itself, need to shift it
|
|
is_roll, roll_idx = does_window_roll_over(windows[win_i], num_frames)
|
|
if is_roll:
|
|
roll_val = windows[win_i][roll_idx] # roll_val might not be 0 for windows of higher strides
|
|
shift_window_to_end(windows[win_i], num_frames=num_frames)
|
|
# check if next window (cyclical) is missing roll_val
|
|
if roll_val not in windows[(win_i+1) % len(windows)]:
|
|
# need to insert new window here - just insert window starting at roll_val
|
|
windows.insert(win_i+1, list(range(roll_val, roll_val + context_size)))
|
|
# delete window if it's not unique
|
|
for pre_i in range(0, win_i):
|
|
if windows[win_i] == windows[pre_i]:
|
|
delete_idxs.append(win_i)
|
|
break
|
|
win_i += 1
|
|
|
|
# reverse delete_idxs so that they will be deleted in an order that doesn't break idx correlation
|
|
delete_idxs.reverse()
|
|
for i in delete_idxs:
|
|
windows.pop(i)
|
|
return windows
|
|
|
|
def static_standard(
|
|
step: int = ...,
|
|
num_steps: Optional[int] = None,
|
|
num_frames: int = ...,
|
|
context_size: Optional[int] = None,
|
|
context_stride: int = 3,
|
|
context_overlap: int = 4,
|
|
closed_loop: bool = True,
|
|
):
|
|
windows = []
|
|
if num_frames <= context_size:
|
|
windows.append(list(range(num_frames)))
|
|
return windows
|
|
# always return the same set of windows
|
|
delta = context_size - context_overlap
|
|
for start_idx in range(0, num_frames, delta):
|
|
# if past the end of frames, move start_idx back to allow same context_length
|
|
ending = start_idx + context_size
|
|
if ending >= num_frames:
|
|
final_delta = ending - num_frames
|
|
final_start_idx = start_idx - final_delta
|
|
windows.append(list(range(final_start_idx, final_start_idx + context_size)))
|
|
break
|
|
windows.append(list(range(start_idx, start_idx + context_size)))
|
|
return windows
|
|
|
|
def get_context_scheduler(name: str) -> Callable:
|
|
if name == "uniform_looped":
|
|
return uniform_looped
|
|
elif name == "uniform_standard":
|
|
return uniform_standard
|
|
elif name == "static_standard":
|
|
return static_standard
|
|
else:
|
|
raise ValueError(f"Unknown context_overlap policy {name}")
|
|
|
|
|
|
def get_total_steps(
|
|
scheduler,
|
|
timesteps: List[int],
|
|
num_steps: Optional[int] = None,
|
|
num_frames: int = ...,
|
|
context_size: Optional[int] = None,
|
|
context_stride: int = 3,
|
|
context_overlap: int = 4,
|
|
closed_loop: bool = True,
|
|
):
|
|
return sum(
|
|
len(
|
|
list(
|
|
scheduler(
|
|
i,
|
|
num_steps,
|
|
num_frames,
|
|
context_size,
|
|
context_stride,
|
|
context_overlap,
|
|
)
|
|
)
|
|
)
|
|
for i in range(len(timesteps))
|
|
)
|