mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-09 21:04:23 +08:00
Add context schedules for the control pipeline
This commit is contained in:
parent
033ec61d86
commit
453ee1606e
184
cogvideox_fun/context.py
Normal file
184
cogvideox_fun/context.py
Normal file
@ -0,0 +1,184 @@
|
|||||||
|
import numpy as np
|
||||||
|
from typing import Callable, Optional, List
|
||||||
|
|
||||||
|
|
||||||
|
def ordered_halving(val):
|
||||||
|
bin_str = f"{val:064b}"
|
||||||
|
bin_flip = bin_str[::-1]
|
||||||
|
as_int = int(bin_flip, 2)
|
||||||
|
|
||||||
|
return as_int / (1 << 64)
|
||||||
|
|
||||||
|
def does_window_roll_over(window: list[int], num_frames: int) -> tuple[bool, int]:
|
||||||
|
prev_val = -1
|
||||||
|
for i, val in enumerate(window):
|
||||||
|
val = val % num_frames
|
||||||
|
if val < prev_val:
|
||||||
|
return True, i
|
||||||
|
prev_val = val
|
||||||
|
return False, -1
|
||||||
|
|
||||||
|
def shift_window_to_start(window: list[int], num_frames: int):
|
||||||
|
start_val = window[0]
|
||||||
|
for i in range(len(window)):
|
||||||
|
# 1) subtract each element by start_val to move vals relative to the start of all frames
|
||||||
|
# 2) add num_frames and take modulus to get adjusted vals
|
||||||
|
window[i] = ((window[i] - start_val) + num_frames) % num_frames
|
||||||
|
|
||||||
|
def shift_window_to_end(window: list[int], num_frames: int):
|
||||||
|
# 1) shift window to start
|
||||||
|
shift_window_to_start(window, num_frames)
|
||||||
|
end_val = window[-1]
|
||||||
|
end_delta = num_frames - end_val - 1
|
||||||
|
for i in range(len(window)):
|
||||||
|
# 2) add end_delta to each val to slide windows to end
|
||||||
|
window[i] = window[i] + end_delta
|
||||||
|
|
||||||
|
def get_missing_indexes(windows: list[list[int]], num_frames: int) -> list[int]:
|
||||||
|
all_indexes = list(range(num_frames))
|
||||||
|
for w in windows:
|
||||||
|
for val in w:
|
||||||
|
try:
|
||||||
|
all_indexes.remove(val)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
return all_indexes
|
||||||
|
|
||||||
|
def uniform_looped(
|
||||||
|
step: int = ...,
|
||||||
|
num_steps: Optional[int] = None,
|
||||||
|
num_frames: int = ...,
|
||||||
|
context_size: Optional[int] = None,
|
||||||
|
context_stride: int = 3,
|
||||||
|
context_overlap: int = 4,
|
||||||
|
closed_loop: bool = True,
|
||||||
|
):
|
||||||
|
if num_frames <= context_size:
|
||||||
|
yield list(range(num_frames))
|
||||||
|
return
|
||||||
|
|
||||||
|
context_stride = min(context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1)
|
||||||
|
|
||||||
|
for context_step in 1 << np.arange(context_stride):
|
||||||
|
pad = int(round(num_frames * ordered_halving(step)))
|
||||||
|
for j in range(
|
||||||
|
int(ordered_halving(step) * context_step) + pad,
|
||||||
|
num_frames + pad + (0 if closed_loop else -context_overlap),
|
||||||
|
(context_size * context_step - context_overlap),
|
||||||
|
):
|
||||||
|
yield [e % num_frames for e in range(j, j + context_size * context_step, context_step)]
|
||||||
|
|
||||||
|
#from AnimateDiff-Evolved by Kosinkadink (https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved)
|
||||||
|
def uniform_standard(
|
||||||
|
step: int = ...,
|
||||||
|
num_steps: Optional[int] = None,
|
||||||
|
num_frames: int = ...,
|
||||||
|
context_size: Optional[int] = None,
|
||||||
|
context_stride: int = 3,
|
||||||
|
context_overlap: int = 4,
|
||||||
|
closed_loop: bool = True,
|
||||||
|
):
|
||||||
|
windows = []
|
||||||
|
if num_frames <= context_size:
|
||||||
|
windows.append(list(range(num_frames)))
|
||||||
|
return windows
|
||||||
|
|
||||||
|
context_stride = min(context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1)
|
||||||
|
|
||||||
|
for context_step in 1 << np.arange(context_stride):
|
||||||
|
pad = int(round(num_frames * ordered_halving(step)))
|
||||||
|
for j in range(
|
||||||
|
int(ordered_halving(step) * context_step) + pad,
|
||||||
|
num_frames + pad + (0 if closed_loop else -context_overlap),
|
||||||
|
(context_size * context_step - context_overlap),
|
||||||
|
):
|
||||||
|
windows.append([e % num_frames for e in range(j, j + context_size * context_step, context_step)])
|
||||||
|
|
||||||
|
# now that windows are created, shift any windows that loop, and delete duplicate windows
|
||||||
|
delete_idxs = []
|
||||||
|
win_i = 0
|
||||||
|
while win_i < len(windows):
|
||||||
|
# if window is rolls over itself, need to shift it
|
||||||
|
is_roll, roll_idx = does_window_roll_over(windows[win_i], num_frames)
|
||||||
|
if is_roll:
|
||||||
|
roll_val = windows[win_i][roll_idx] # roll_val might not be 0 for windows of higher strides
|
||||||
|
shift_window_to_end(windows[win_i], num_frames=num_frames)
|
||||||
|
# check if next window (cyclical) is missing roll_val
|
||||||
|
if roll_val not in windows[(win_i+1) % len(windows)]:
|
||||||
|
# need to insert new window here - just insert window starting at roll_val
|
||||||
|
windows.insert(win_i+1, list(range(roll_val, roll_val + context_size)))
|
||||||
|
# delete window if it's not unique
|
||||||
|
for pre_i in range(0, win_i):
|
||||||
|
if windows[win_i] == windows[pre_i]:
|
||||||
|
delete_idxs.append(win_i)
|
||||||
|
break
|
||||||
|
win_i += 1
|
||||||
|
|
||||||
|
# reverse delete_idxs so that they will be deleted in an order that doesn't break idx correlation
|
||||||
|
delete_idxs.reverse()
|
||||||
|
for i in delete_idxs:
|
||||||
|
windows.pop(i)
|
||||||
|
return windows
|
||||||
|
|
||||||
|
def static_standard(
|
||||||
|
step: int = ...,
|
||||||
|
num_steps: Optional[int] = None,
|
||||||
|
num_frames: int = ...,
|
||||||
|
context_size: Optional[int] = None,
|
||||||
|
context_stride: int = 3,
|
||||||
|
context_overlap: int = 4,
|
||||||
|
closed_loop: bool = True,
|
||||||
|
):
|
||||||
|
windows = []
|
||||||
|
if num_frames <= context_size:
|
||||||
|
windows.append(list(range(num_frames)))
|
||||||
|
return windows
|
||||||
|
# always return the same set of windows
|
||||||
|
delta = context_size - context_overlap
|
||||||
|
for start_idx in range(0, num_frames, delta):
|
||||||
|
# if past the end of frames, move start_idx back to allow same context_length
|
||||||
|
ending = start_idx + context_size
|
||||||
|
if ending >= num_frames:
|
||||||
|
final_delta = ending - num_frames
|
||||||
|
final_start_idx = start_idx - final_delta
|
||||||
|
windows.append(list(range(final_start_idx, final_start_idx + context_size)))
|
||||||
|
break
|
||||||
|
windows.append(list(range(start_idx, start_idx + context_size)))
|
||||||
|
return windows
|
||||||
|
|
||||||
|
def get_context_scheduler(name: str) -> Callable:
|
||||||
|
if name == "uniform_looped":
|
||||||
|
return uniform_looped
|
||||||
|
elif name == "uniform_standard":
|
||||||
|
return uniform_standard
|
||||||
|
elif name == "static_standard":
|
||||||
|
return static_standard
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown context_overlap policy {name}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_total_steps(
|
||||||
|
scheduler,
|
||||||
|
timesteps: List[int],
|
||||||
|
num_steps: Optional[int] = None,
|
||||||
|
num_frames: int = ...,
|
||||||
|
context_size: Optional[int] = None,
|
||||||
|
context_stride: int = 3,
|
||||||
|
context_overlap: int = 4,
|
||||||
|
closed_loop: bool = True,
|
||||||
|
):
|
||||||
|
return sum(
|
||||||
|
len(
|
||||||
|
list(
|
||||||
|
scheduler(
|
||||||
|
i,
|
||||||
|
num_steps,
|
||||||
|
num_frames,
|
||||||
|
context_size,
|
||||||
|
context_stride,
|
||||||
|
context_overlap,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for i in range(len(timesteps))
|
||||||
|
)
|
||||||
@ -395,8 +395,9 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
|
|||||||
width: int,
|
width: int,
|
||||||
num_frames: int,
|
num_frames: int,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
start_frame: int = None,
|
start_frame: Optional[int] = None,
|
||||||
end_frame: int = None,
|
end_frame: Optional[int] = None,
|
||||||
|
context_frames: Optional[int] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||||
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||||
@ -414,12 +415,15 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
|
|||||||
use_real=True,
|
use_real=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if start_frame is not None:
|
if start_frame is not None or context_frames is not None:
|
||||||
freqs_cos = freqs_cos.view(num_frames, grid_height * grid_width, -1)
|
freqs_cos = freqs_cos.view(num_frames, grid_height * grid_width, -1)
|
||||||
freqs_sin = freqs_sin.view(num_frames, grid_height * grid_width, -1)
|
freqs_sin = freqs_sin.view(num_frames, grid_height * grid_width, -1)
|
||||||
|
if context_frames is not None:
|
||||||
freqs_cos = freqs_cos[start_frame:end_frame]
|
freqs_cos = freqs_cos[context_frames]
|
||||||
freqs_sin = freqs_sin[start_frame:end_frame]
|
freqs_sin = freqs_sin[context_frames]
|
||||||
|
else:
|
||||||
|
freqs_cos = freqs_cos[start_frame:end_frame]
|
||||||
|
freqs_sin = freqs_sin[start_frame:end_frame]
|
||||||
|
|
||||||
freqs_cos = freqs_cos.view(-1, freqs_cos.shape[-1])
|
freqs_cos = freqs_cos.view(-1, freqs_cos.shape[-1])
|
||||||
freqs_sin = freqs_sin.view(-1, freqs_sin.shape[-1])
|
freqs_sin = freqs_sin.view(-1, freqs_sin.shape[-1])
|
||||||
@ -483,9 +487,11 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
|
|||||||
control_strength: float = 1.0,
|
control_strength: float = 1.0,
|
||||||
control_start_percent: float = 0.0,
|
control_start_percent: float = 0.0,
|
||||||
control_end_percent: float = 1.0,
|
control_end_percent: float = 1.0,
|
||||||
t_tile_length: int = 12,
|
|
||||||
t_tile_overlap: int = 4,
|
|
||||||
scheduler_name: str = "DPM",
|
scheduler_name: str = "DPM",
|
||||||
|
context_schedule: Optional[str] = None,
|
||||||
|
context_frames: Optional[int] = None,
|
||||||
|
context_stride: Optional[int] = None,
|
||||||
|
context_overlap: Optional[int] = None,
|
||||||
) -> Union[CogVideoX_Fun_PipelineOutput, Tuple]:
|
) -> Union[CogVideoX_Fun_PipelineOutput, Tuple]:
|
||||||
"""
|
"""
|
||||||
Function invoked when calling the pipeline for generation.
|
Function invoked when calling the pipeline for generation.
|
||||||
@ -652,23 +658,33 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
|
|||||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||||
|
|
||||||
# 8.5. Temporal tiling prep
|
# 8.5. Temporal tiling prep
|
||||||
if "tiled" in scheduler_name:
|
if context_schedule is not None and context_schedule == "temporal_tiling":
|
||||||
t_tile_length = t_tile_length // 4
|
t_tile_length = context_frames // 4
|
||||||
t_tile_overlap = t_tile_overlap // 4
|
t_tile_overlap = context_overlap // 4
|
||||||
t_tile_weights = self._gaussian_weights(t_tile_length=t_tile_length, t_batch_size=1).to(latents.device).to(self.vae.dtype)
|
t_tile_weights = self._gaussian_weights(t_tile_length=t_tile_length, t_batch_size=1).to(latents.device).to(self.vae.dtype)
|
||||||
temporal_tiling = True
|
use_temporal_tiling = True
|
||||||
|
|
||||||
print("Temporal tiling enabled")
|
print("Temporal tiling enabled")
|
||||||
|
elif context_schedule is not None:
|
||||||
|
print(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap")
|
||||||
|
use_temporal_tiling = False
|
||||||
|
use_context_schedule = True
|
||||||
|
context_frames = context_frames // 4
|
||||||
|
context_stride = context_stride // 4
|
||||||
|
context_overlap = context_overlap // 4
|
||||||
|
from .context import get_context_scheduler
|
||||||
|
context = get_context_scheduler(context_schedule)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
temporal_tiling = False
|
use_temporal_tiling = False
|
||||||
print("Temporal tiling disabled")
|
use_context_schedule = False
|
||||||
|
print("Temporal tiling and context schedule disabled")
|
||||||
# 7. Create rotary embeds if required
|
# 7. Create rotary embeds if required
|
||||||
image_rotary_emb = (
|
image_rotary_emb = (
|
||||||
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
||||||
if self.transformer.config.use_rotary_positional_embeddings
|
if self.transformer.config.use_rotary_positional_embeddings
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
#print("latents.shape", latents.shape)
|
|
||||||
|
|
||||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||||
# for DPM-solver++
|
# for DPM-solver++
|
||||||
@ -677,7 +693,7 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
|
|||||||
if self.interrupt:
|
if self.interrupt:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if temporal_tiling and isinstance(self.scheduler, CogVideoXDDIMScheduler):
|
if use_temporal_tiling and isinstance(self.scheduler, CogVideoXDDIMScheduler):
|
||||||
#temporal tiling code based on https://github.com/mayuelala/FollowYourEmoji/blob/main/models/video_pipeline.py
|
#temporal tiling code based on https://github.com/mayuelala/FollowYourEmoji/blob/main/models/video_pipeline.py
|
||||||
# =====================================================
|
# =====================================================
|
||||||
grid_ts = 0
|
grid_ts = 0
|
||||||
@ -757,6 +773,96 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
|
|||||||
progress_bar.update()
|
progress_bar.update()
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
# ==========================================
|
# ==========================================
|
||||||
|
elif use_context_schedule:
|
||||||
|
|
||||||
|
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||||
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||||
|
|
||||||
|
# Calculate the current step percentage
|
||||||
|
current_step_percentage = i / num_inference_steps
|
||||||
|
|
||||||
|
# Determine if control_latents should be applied
|
||||||
|
apply_control = control_start_percent <= current_step_percentage <= control_end_percent
|
||||||
|
current_control_latents = control_latents if apply_control else torch.zeros_like(control_latents)
|
||||||
|
|
||||||
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||||
|
timestep = t.expand(latent_model_input.shape[0])
|
||||||
|
|
||||||
|
context_queue = list(context(
|
||||||
|
i, num_inference_steps, latents.shape[1], context_frames, context_stride, context_overlap,
|
||||||
|
))
|
||||||
|
counter = torch.zeros_like(latent_model_input)
|
||||||
|
noise_pred = torch.zeros_like(latent_model_input)
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
noise_uncond = torch.zeros_like(latent_model_input)
|
||||||
|
|
||||||
|
for c in context_queue:
|
||||||
|
partial_latent_model_input = latent_model_input[:, c, :, :, :]
|
||||||
|
partial_control_latents = current_control_latents[:, c, :, :, :]
|
||||||
|
image_rotary_emb = (
|
||||||
|
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device, context_frames=c)
|
||||||
|
if self.transformer.config.use_rotary_positional_embeddings
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# predict noise model_output
|
||||||
|
noise_pred[:, c, :, :, :] += self.transformer(
|
||||||
|
hidden_states=partial_latent_model_input,
|
||||||
|
encoder_hidden_states=prompt_embeds,
|
||||||
|
timestep=timestep,
|
||||||
|
image_rotary_emb=image_rotary_emb,
|
||||||
|
return_dict=False,
|
||||||
|
control_latents=partial_control_latents,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
counter[:, c, :, :, :] += 1
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
noise_uncond[:, c, :, :, :] += self.transformer(
|
||||||
|
hidden_states=partial_latent_model_input,
|
||||||
|
encoder_hidden_states=prompt_embeds,
|
||||||
|
timestep=timestep,
|
||||||
|
image_rotary_emb=image_rotary_emb,
|
||||||
|
return_dict=False,
|
||||||
|
control_latents=partial_control_latents,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
noise_pred = noise_pred.float()
|
||||||
|
|
||||||
|
noise_pred /= counter
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||||
|
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||||
|
|
||||||
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
|
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
|
||||||
|
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||||
|
else:
|
||||||
|
latents, old_pred_original_sample = self.scheduler.step(
|
||||||
|
noise_pred,
|
||||||
|
old_pred_original_sample,
|
||||||
|
t,
|
||||||
|
timesteps[i - 1] if i > 0 else None,
|
||||||
|
latents,
|
||||||
|
**extra_step_kwargs,
|
||||||
|
return_dict=False,
|
||||||
|
)
|
||||||
|
latents = latents.to(prompt_embeds.dtype)
|
||||||
|
|
||||||
|
# call the callback, if provided
|
||||||
|
if callback_on_step_end is not None:
|
||||||
|
callback_kwargs = {}
|
||||||
|
for k in callback_on_step_end_tensor_inputs:
|
||||||
|
callback_kwargs[k] = locals()[k]
|
||||||
|
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||||
|
|
||||||
|
latents = callback_outputs.pop("latents", latents)
|
||||||
|
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||||
|
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||||
|
|
||||||
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||||
|
progress_bar.update()
|
||||||
|
if comfyui_progressbar:
|
||||||
|
pbar.update(1)
|
||||||
else:
|
else:
|
||||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||||
|
|||||||
@ -682,7 +682,7 @@
|
|||||||
},
|
},
|
||||||
"size": {
|
"size": {
|
||||||
"0": 367.79998779296875,
|
"0": 367.79998779296875,
|
||||||
"1": 102
|
"1": 146
|
||||||
},
|
},
|
||||||
"flags": {},
|
"flags": {},
|
||||||
"order": 10,
|
"order": 10,
|
||||||
@ -706,8 +706,20 @@
|
|||||||
"links": [
|
"links": [
|
||||||
182
|
182
|
||||||
],
|
],
|
||||||
"shape": 3,
|
"slot_index": 0,
|
||||||
"slot_index": 0
|
"shape": 3
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "width",
|
||||||
|
"type": "INT",
|
||||||
|
"links": null,
|
||||||
|
"shape": 3
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "height",
|
||||||
|
"type": "INT",
|
||||||
|
"links": null,
|
||||||
|
"shape": 3
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"properties": {
|
"properties": {
|
||||||
@ -715,7 +727,8 @@
|
|||||||
},
|
},
|
||||||
"widgets_values": [
|
"widgets_values": [
|
||||||
512,
|
512,
|
||||||
false
|
false,
|
||||||
|
0
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -725,10 +738,10 @@
|
|||||||
"0": 1085,
|
"0": 1085,
|
||||||
"1": 312
|
"1": 312
|
||||||
},
|
},
|
||||||
"size": [
|
"size": {
|
||||||
311.22059416191496,
|
"0": 311.2205810546875,
|
||||||
286
|
"1": 350
|
||||||
],
|
},
|
||||||
"flags": {},
|
"flags": {},
|
||||||
"order": 11,
|
"order": 11,
|
||||||
"mode": 0,
|
"mode": 0,
|
||||||
@ -752,6 +765,16 @@
|
|||||||
"name": "control_latents",
|
"name": "control_latents",
|
||||||
"type": "COGCONTROL_LATENTS",
|
"type": "COGCONTROL_LATENTS",
|
||||||
"link": 182
|
"link": 182
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "samples",
|
||||||
|
"type": "LATENT",
|
||||||
|
"link": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "context_options",
|
||||||
|
"type": "COGCONTEXT",
|
||||||
|
"link": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -783,6 +806,7 @@
|
|||||||
"CogVideoXDPMScheduler",
|
"CogVideoXDPMScheduler",
|
||||||
0.7000000000000001,
|
0.7000000000000001,
|
||||||
0,
|
0,
|
||||||
|
1,
|
||||||
1
|
1
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -1034,10 +1058,10 @@
|
|||||||
"config": {},
|
"config": {},
|
||||||
"extra": {
|
"extra": {
|
||||||
"ds": {
|
"ds": {
|
||||||
"scale": 0.6934334949441758,
|
"scale": 0.6934334949442492,
|
||||||
"offset": [
|
"offset": [
|
||||||
364.1021432696588,
|
39.55130702561554,
|
||||||
27.28260472943026
|
104.54407751572876
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|||||||
65
nodes.py
65
nodes.py
@ -55,7 +55,7 @@ from .cogvideox_fun.autoencoder_magvit import AutoencoderKLCogVideoX as Autoenco
|
|||||||
from .cogvideox_fun.utils import get_image_to_video_latent, get_video_to_video_latent, ASPECT_RATIO_512, get_closest_ratio, to_pil
|
from .cogvideox_fun.utils import get_image_to_video_latent, get_video_to_video_latent, ASPECT_RATIO_512, get_closest_ratio, to_pil
|
||||||
from .cogvideox_fun.pipeline_cogvideox_inpaint import CogVideoX_Fun_Pipeline_Inpaint
|
from .cogvideox_fun.pipeline_cogvideox_inpaint import CogVideoX_Fun_Pipeline_Inpaint
|
||||||
from .cogvideox_fun.pipeline_cogvideox_control import CogVideoX_Fun_Pipeline_Control
|
from .cogvideox_fun.pipeline_cogvideox_control import CogVideoX_Fun_Pipeline_Control
|
||||||
from .cogvideox_fun.lora_utils import merge_lora, unmerge_lora
|
from .cogvideox_fun.lora_utils import merge_lora
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import json
|
import json
|
||||||
@ -342,12 +342,12 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
transformer = transformer.to(dtype).to(offload_device)
|
transformer = transformer.to(dtype).to(offload_device)
|
||||||
|
|
||||||
if lora is not None:
|
if lora is not None:
|
||||||
if lora['strength'] > 0:
|
logging.info(f"Merging LoRA weights from {lora['path']} with strength {lora['strength']}")
|
||||||
logging.info(f"Merging LoRA weights from {lora['path']} with strength {lora['strength']}")
|
if "fun" in model.lower():
|
||||||
transformer = merge_lora(transformer, lora["path"], lora["strength"])
|
transformer = merge_lora(transformer, lora["path"], lora["strength"])
|
||||||
else:
|
else:
|
||||||
logging.info(f"Removing LoRA weights from {lora['path']} with strength {lora['strength']}")
|
raise NotImplementedError("LoRA merging is currently only supported for Fun models")
|
||||||
transformer = unmerge_lora(transformer, lora["path"], lora["strength"])
|
|
||||||
|
|
||||||
if block_edit is not None:
|
if block_edit is not None:
|
||||||
transformer = remove_specific_blocks(transformer, block_edit)
|
transformer = remove_specific_blocks(transformer, block_edit)
|
||||||
@ -381,9 +381,7 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
pipe = CogVideoX_Fun_Pipeline_Inpaint(vae, transformer, scheduler, pab_config=pab_config)
|
pipe = CogVideoX_Fun_Pipeline_Inpaint(vae, transformer, scheduler, pab_config=pab_config)
|
||||||
else:
|
else:
|
||||||
vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
|
vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
|
||||||
pipe = CogVideoXPipeline(vae, transformer, scheduler, pab_config=pab_config)
|
pipe = CogVideoXPipeline(vae, transformer, scheduler, pab_config=pab_config)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if enable_sequential_cpu_offload:
|
if enable_sequential_cpu_offload:
|
||||||
pipe.enable_sequential_cpu_offload()
|
pipe.enable_sequential_cpu_offload()
|
||||||
@ -1274,7 +1272,34 @@ class CogVideoControlImageEncode:
|
|||||||
}
|
}
|
||||||
|
|
||||||
return (control_latents, width, height)
|
return (control_latents, width, height)
|
||||||
|
|
||||||
|
|
||||||
|
class CogVideoContextOptions:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {
|
||||||
|
"context_schedule": (["uniform_standard", "uniform_looped", "static_standard", "temporal_tiling"],),
|
||||||
|
"context_frames": ("INT", {"default": 12, "min": 2, "max": 100, "step": 1, "tooltip": "Number of pixel frames in the context, NOTE: the latent space has 4 frames in 1"} ),
|
||||||
|
"context_stride": ("INT", {"default": 4, "min": 4, "max": 10, "step": 1, "tooltip": "Context stride as pixel frames, NOTE: the latent space has 4 frames in 1"} ),
|
||||||
|
"context_overlap": ("INT", {"default": 4, "min": 4, "max": 10, "step": 1, "tooltip": "Context overlap as pixel frames, NOTE: the latent space has 4 frames in 1"} ),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("COGCONTEXT", )
|
||||||
|
RETURN_NAMES = ("context_options",)
|
||||||
|
FUNCTION = "process"
|
||||||
|
CATEGORY = "CogVideoWrapper"
|
||||||
|
|
||||||
|
def process(self, context_schedule, context_frames, context_stride, context_overlap):
|
||||||
|
context_options = {
|
||||||
|
"context_schedule":context_schedule,
|
||||||
|
"context_frames":context_frames,
|
||||||
|
"context_stride":context_stride,
|
||||||
|
"context_overlap":context_overlap
|
||||||
|
}
|
||||||
|
|
||||||
|
return (context_options,)
|
||||||
|
|
||||||
class CogVideoXFunControlSampler:
|
class CogVideoXFunControlSampler:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -1300,7 +1325,6 @@ class CogVideoXFunControlSampler:
|
|||||||
"DEISMultistepScheduler",
|
"DEISMultistepScheduler",
|
||||||
"CogVideoXDDIM",
|
"CogVideoXDDIM",
|
||||||
"CogVideoXDPMScheduler",
|
"CogVideoXDPMScheduler",
|
||||||
"DDIM_tiled",
|
|
||||||
],
|
],
|
||||||
{
|
{
|
||||||
"default": 'DDIM'
|
"default": 'DDIM'
|
||||||
@ -1309,12 +1333,11 @@ class CogVideoXFunControlSampler:
|
|||||||
"control_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
"control_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
"control_start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
"control_start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
"control_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
"control_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
"t_tile_length": ("INT", {"default": 48, "min": 2, "max": 128, "step": 1, "tooltip": "Length of temporal tiles for extending generations, only in effect with the tiled samplers"}),
|
|
||||||
"t_tile_overlap": ("INT", {"default": 8, "min": 2, "max": 128, "step": 1, "tooltip": "Overlap of temporal tiling"}),
|
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"samples": ("LATENT", ),
|
"samples": ("LATENT", ),
|
||||||
"denoise_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
"denoise_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
"context_options": ("COGCONTEXT", ),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1325,7 +1348,7 @@ class CogVideoXFunControlSampler:
|
|||||||
|
|
||||||
def process(self, pipeline, positive, negative, seed, steps, cfg, scheduler, control_latents,
|
def process(self, pipeline, positive, negative, seed, steps, cfg, scheduler, control_latents,
|
||||||
control_strength=1.0, control_start_percent=0.0, control_end_percent=1.0, t_tile_length=16, t_tile_overlap=8,
|
control_strength=1.0, control_start_percent=0.0, control_end_percent=1.0, t_tile_length=16, t_tile_overlap=8,
|
||||||
samples=None, denoise_strength=1.0):
|
samples=None, denoise_strength=1.0, context_options=None):
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
pipe = pipeline["pipe"]
|
pipe = pipeline["pipe"]
|
||||||
@ -1341,6 +1364,9 @@ class CogVideoXFunControlSampler:
|
|||||||
|
|
||||||
# Load Sampler
|
# Load Sampler
|
||||||
scheduler_config = pipeline["scheduler_config"]
|
scheduler_config = pipeline["scheduler_config"]
|
||||||
|
if context_options is not None and context_options["context_schedule"] == "temporal_tiling":
|
||||||
|
logging.info("Temporal tiling enabled, changing scheduler to DDIM_tiled")
|
||||||
|
scheduler="DDIM_tiled"
|
||||||
if scheduler in scheduler_mapping:
|
if scheduler in scheduler_mapping:
|
||||||
noise_scheduler = scheduler_mapping[scheduler].from_config(scheduler_config)
|
noise_scheduler = scheduler_mapping[scheduler].from_config(scheduler_config)
|
||||||
pipe.scheduler = noise_scheduler
|
pipe.scheduler = noise_scheduler
|
||||||
@ -1371,11 +1397,14 @@ class CogVideoXFunControlSampler:
|
|||||||
control_strength=control_strength,
|
control_strength=control_strength,
|
||||||
control_start_percent=control_start_percent,
|
control_start_percent=control_start_percent,
|
||||||
control_end_percent=control_end_percent,
|
control_end_percent=control_end_percent,
|
||||||
t_tile_length=t_tile_length,
|
|
||||||
t_tile_overlap=t_tile_overlap,
|
|
||||||
scheduler_name=scheduler,
|
scheduler_name=scheduler,
|
||||||
latents=samples["samples"] if samples is not None else None,
|
latents=samples["samples"] if samples is not None else None,
|
||||||
denoise_strength=denoise_strength,
|
denoise_strength=denoise_strength,
|
||||||
|
context_schedule=context_options["context_schedule"] if context_options is not None else None,
|
||||||
|
context_frames=context_options["context_frames"] if context_options is not None else None,
|
||||||
|
context_stride=context_options["context_stride"] if context_options is not None else None,
|
||||||
|
context_overlap=context_options["context_overlap"] if context_options is not None else None
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return (pipeline, {"samples": latents})
|
return (pipeline, {"samples": latents})
|
||||||
@ -1395,7 +1424,8 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"CogVideoPABConfig": CogVideoPABConfig,
|
"CogVideoPABConfig": CogVideoPABConfig,
|
||||||
"CogVideoTransformerEdit": CogVideoTransformerEdit,
|
"CogVideoTransformerEdit": CogVideoTransformerEdit,
|
||||||
"CogVideoControlImageEncode": CogVideoControlImageEncode,
|
"CogVideoControlImageEncode": CogVideoControlImageEncode,
|
||||||
"CogVideoLoraSelect": CogVideoLoraSelect
|
"CogVideoLoraSelect": CogVideoLoraSelect,
|
||||||
|
"CogVideoContextOptions": CogVideoContextOptions
|
||||||
}
|
}
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model",
|
"DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model",
|
||||||
@ -1412,5 +1442,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"CogVideoPABConfig": "CogVideo PABConfig",
|
"CogVideoPABConfig": "CogVideo PABConfig",
|
||||||
"CogVideoTransformerEdit": "CogVideo TransformerEdit",
|
"CogVideoTransformerEdit": "CogVideo TransformerEdit",
|
||||||
"CogVideoControlImageEncode": "CogVideo Control ImageEncode",
|
"CogVideoControlImageEncode": "CogVideo Control ImageEncode",
|
||||||
"CogVideoLoraSelect": "CogVideo LoraSelect"
|
"CogVideoLoraSelect": "CogVideo LoraSelect",
|
||||||
|
"CogVideoContextOptions": "CogVideo Context Options"
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user