mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2026-05-10 05:02:23 +08:00
context schedule and all samplers for the non-Fun sampler as well
This commit is contained in:
parent
2be8f694b0
commit
00d38f9a22
@ -841,11 +841,6 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
|
|||||||
for c in context_queue:
|
for c in context_queue:
|
||||||
partial_latent_model_input = latent_model_input[:, c, :, :, :]
|
partial_latent_model_input = latent_model_input[:, c, :, :, :]
|
||||||
partial_control_latents = current_control_latents[:, c, :, :, :]
|
partial_control_latents = current_control_latents[:, c, :, :, :]
|
||||||
# image_rotary_emb = (
|
|
||||||
# self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device, context_frames=c)
|
|
||||||
# if self.transformer.config.use_rotary_positional_embeddings
|
|
||||||
# else None
|
|
||||||
# )
|
|
||||||
|
|
||||||
# predict noise model_output
|
# predict noise model_output
|
||||||
noise_pred[:, c, :, :, :] += self.transformer(
|
noise_pred[:, c, :, :, :] += self.transformer(
|
||||||
@ -857,7 +852,7 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
|
|||||||
control_latents=partial_control_latents,
|
control_latents=partial_control_latents,
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
counter[:, c, :, :, :] += 1
|
# uncond
|
||||||
if do_classifier_free_guidance:
|
if do_classifier_free_guidance:
|
||||||
noise_uncond[:, c, :, :, :] += self.transformer(
|
noise_uncond[:, c, :, :, :] += self.transformer(
|
||||||
hidden_states=partial_latent_model_input,
|
hidden_states=partial_latent_model_input,
|
||||||
@ -867,7 +862,8 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
|
|||||||
return_dict=False,
|
return_dict=False,
|
||||||
control_latents=partial_control_latents,
|
control_latents=partial_control_latents,
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
|
counter[:, c, :, :, :] += 1
|
||||||
noise_pred = noise_pred.float()
|
noise_pred = noise_pred.float()
|
||||||
|
|
||||||
noise_pred /= counter
|
noise_pred /= counter
|
||||||
|
|||||||
87
nodes.py
87
nodes.py
@ -27,6 +27,7 @@ from diffusers.schedulers import (
|
|||||||
HeunDiscreteScheduler,
|
HeunDiscreteScheduler,
|
||||||
SASolverScheduler,
|
SASolverScheduler,
|
||||||
DEISMultistepScheduler,
|
DEISMultistepScheduler,
|
||||||
|
LCMScheduler
|
||||||
)
|
)
|
||||||
|
|
||||||
scheduler_mapping = {
|
scheduler_mapping = {
|
||||||
@ -41,8 +42,11 @@ scheduler_mapping = {
|
|||||||
"SASolverScheduler": SASolverScheduler,
|
"SASolverScheduler": SASolverScheduler,
|
||||||
"UniPCMultistepScheduler": UniPCMultistepScheduler,
|
"UniPCMultistepScheduler": UniPCMultistepScheduler,
|
||||||
"HeunDiscreteScheduler": HeunDiscreteScheduler,
|
"HeunDiscreteScheduler": HeunDiscreteScheduler,
|
||||||
"DEISMultistepScheduler": DEISMultistepScheduler
|
"DEISMultistepScheduler": DEISMultistepScheduler,
|
||||||
|
"LCMScheduler": LCMScheduler
|
||||||
}
|
}
|
||||||
|
available_schedulers = list(scheduler_mapping.keys())
|
||||||
|
|
||||||
|
|
||||||
from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
|
from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
|
||||||
from .pipeline_cogvideox import CogVideoXPipeline
|
from .pipeline_cogvideox import CogVideoXPipeline
|
||||||
@ -833,14 +837,16 @@ class CogVideoSampler:
|
|||||||
"steps": ("INT", {"default": 50, "min": 1}),
|
"steps": ("INT", {"default": 50, "min": 1}),
|
||||||
"cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}),
|
"cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}),
|
||||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||||
"scheduler": (["DDIM", "DPM", "DDIM_tiled"], {"tooltip": "5B likes DPM, but it doesn't support temporal tiling"}),
|
"scheduler": (available_schedulers,
|
||||||
"t_tile_length": ("INT", {"default": 16, "min": 2, "max": 128, "step": 1, "tooltip": "Length of temporal tiling, use same alue as num_frames to disable, disabled automatically for DPM"}),
|
{
|
||||||
"t_tile_overlap": ("INT", {"default": 8, "min": 2, "max": 128, "step": 1, "tooltip": "Overlap of temporal tiling"}),
|
"default": 'CogVideoXDDIM'
|
||||||
|
}),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"samples": ("LATENT", ),
|
"samples": ("LATENT", ),
|
||||||
"denoise_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
"denoise_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
"image_cond_latents": ("LATENT", ),
|
"image_cond_latents": ("LATENT", ),
|
||||||
|
"context_options": ("COGCONTEXT", ),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -849,17 +855,13 @@ class CogVideoSampler:
|
|||||||
FUNCTION = "process"
|
FUNCTION = "process"
|
||||||
CATEGORY = "CogVideoWrapper"
|
CATEGORY = "CogVideoWrapper"
|
||||||
|
|
||||||
def process(self, pipeline, positive, negative, steps, cfg, seed, height, width, num_frames, scheduler, t_tile_length, t_tile_overlap, samples=None,
|
def process(self, pipeline, positive, negative, steps, cfg, seed, height, width, num_frames, scheduler, samples=None,
|
||||||
denoise_strength=1.0, image_cond_latents=None):
|
denoise_strength=1.0, image_cond_latents=None, context_options=None):
|
||||||
mm.soft_empty_cache()
|
mm.soft_empty_cache()
|
||||||
|
|
||||||
base_path = pipeline["base_path"]
|
base_path = pipeline["base_path"]
|
||||||
|
|
||||||
assert "fun" not in base_path.lower(), "'Fun' models not supported in 'CogVideoSampler', use the 'CogVideoXFunSampler'"
|
assert "fun" not in base_path.lower(), "'Fun' models not supported in 'CogVideoSampler', use the 'CogVideoXFunSampler'"
|
||||||
assert t_tile_length > t_tile_overlap, "t_tile_length must be greater than t_tile_overlap"
|
|
||||||
assert t_tile_length <= num_frames, "t_tile_length must be equal or less than num_frames"
|
|
||||||
t_tile_length = t_tile_length // 4
|
|
||||||
t_tile_overlap = t_tile_overlap // 4
|
|
||||||
|
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
@ -869,12 +871,20 @@ class CogVideoSampler:
|
|||||||
|
|
||||||
if not pipeline["cpu_offloading"]:
|
if not pipeline["cpu_offloading"]:
|
||||||
pipe.transformer.to(device)
|
pipe.transformer.to(device)
|
||||||
generator = torch.Generator(device=device).manual_seed(seed)
|
generator = torch.Generator(device=torch.device("cpu")).manual_seed(seed)
|
||||||
|
|
||||||
if scheduler == "DDIM" or scheduler == "DDIM_tiled":
|
if scheduler in scheduler_mapping:
|
||||||
pipe.scheduler = CogVideoXDDIMScheduler.from_config(scheduler_config)
|
noise_scheduler = scheduler_mapping[scheduler].from_config(scheduler_config)
|
||||||
elif scheduler == "DPM":
|
pipe.scheduler = noise_scheduler
|
||||||
pipe.scheduler = CogVideoXDPMScheduler.from_config(scheduler_config)
|
else:
|
||||||
|
raise ValueError(f"Unknown scheduler: {scheduler}")
|
||||||
|
|
||||||
|
if context_options is not None:
|
||||||
|
context_frames = context_options["context_frames"] // 4
|
||||||
|
context_stride = context_options["context_stride"] // 4
|
||||||
|
context_overlap = context_options["context_overlap"] // 4
|
||||||
|
else:
|
||||||
|
context_frames, context_stride, context_overlap = None, None, None
|
||||||
|
|
||||||
if negative.shape[1] < positive.shape[1]:
|
if negative.shape[1] < positive.shape[1]:
|
||||||
target_length = positive.shape[1]
|
target_length = positive.shape[1]
|
||||||
@ -889,8 +899,6 @@ class CogVideoSampler:
|
|||||||
height = height,
|
height = height,
|
||||||
width = width,
|
width = width,
|
||||||
num_frames = num_frames,
|
num_frames = num_frames,
|
||||||
t_tile_length = t_tile_length,
|
|
||||||
t_tile_overlap = t_tile_overlap,
|
|
||||||
guidance_scale=cfg,
|
guidance_scale=cfg,
|
||||||
latents=samples["samples"] if samples is not None else None,
|
latents=samples["samples"] if samples is not None else None,
|
||||||
image_cond_latents=image_cond_latents["samples"] if image_cond_latents is not None else None,
|
image_cond_latents=image_cond_latents["samples"] if image_cond_latents is not None else None,
|
||||||
@ -899,7 +907,12 @@ class CogVideoSampler:
|
|||||||
negative_prompt_embeds=negative.to(dtype).to(device),
|
negative_prompt_embeds=negative.to(dtype).to(device),
|
||||||
generator=generator,
|
generator=generator,
|
||||||
device=device,
|
device=device,
|
||||||
scheduler_name=scheduler
|
scheduler_name=scheduler,
|
||||||
|
context_schedule=context_options["context_schedule"] if context_options is not None else None,
|
||||||
|
context_frames=context_frames,
|
||||||
|
context_stride= context_stride,
|
||||||
|
context_overlap= context_overlap,
|
||||||
|
freenoise=context_options["freenoise"] if context_options is not None else None
|
||||||
)
|
)
|
||||||
if not pipeline["cpu_offloading"]:
|
if not pipeline["cpu_offloading"]:
|
||||||
pipe.transformer.to(offload_device)
|
pipe.transformer.to(offload_device)
|
||||||
@ -979,24 +992,7 @@ class CogVideoXFunSampler:
|
|||||||
"seed": ("INT", {"default": 43, "min": 0, "max": 0xffffffffffffffff}),
|
"seed": ("INT", {"default": 43, "min": 0, "max": 0xffffffffffffffff}),
|
||||||
"steps": ("INT", {"default": 50, "min": 1, "max": 200, "step": 1}),
|
"steps": ("INT", {"default": 50, "min": 1, "max": 200, "step": 1}),
|
||||||
"cfg": ("FLOAT", {"default": 6.0, "min": 1.0, "max": 20.0, "step": 0.01}),
|
"cfg": ("FLOAT", {"default": 6.0, "min": 1.0, "max": 20.0, "step": 0.01}),
|
||||||
"scheduler": (
|
"scheduler": (available_schedulers, {"default": 'DDIM'})
|
||||||
[
|
|
||||||
"Euler",
|
|
||||||
"Euler A",
|
|
||||||
"DPM++",
|
|
||||||
"PNDM",
|
|
||||||
"DDIM",
|
|
||||||
"SASolverScheduler",
|
|
||||||
"UniPCMultistepScheduler",
|
|
||||||
"HeunDiscreteScheduler",
|
|
||||||
"DEISMultistepScheduler",
|
|
||||||
"CogVideoXDDIM",
|
|
||||||
"CogVideoXDPMScheduler",
|
|
||||||
],
|
|
||||||
{
|
|
||||||
"default": 'DDIM'
|
|
||||||
}
|
|
||||||
)
|
|
||||||
},
|
},
|
||||||
"optional":{
|
"optional":{
|
||||||
"start_img": ("IMAGE",),
|
"start_img": ("IMAGE",),
|
||||||
@ -1331,24 +1327,7 @@ class CogVideoXFunControlSampler:
|
|||||||
"seed": ("INT", {"default": 42, "min": 0, "max": 0xffffffffffffffff}),
|
"seed": ("INT", {"default": 42, "min": 0, "max": 0xffffffffffffffff}),
|
||||||
"steps": ("INT", {"default": 25, "min": 1, "max": 200, "step": 1}),
|
"steps": ("INT", {"default": 25, "min": 1, "max": 200, "step": 1}),
|
||||||
"cfg": ("FLOAT", {"default": 6.0, "min": 1.0, "max": 20.0, "step": 0.01}),
|
"cfg": ("FLOAT", {"default": 6.0, "min": 1.0, "max": 20.0, "step": 0.01}),
|
||||||
"scheduler": (
|
"scheduler": (available_schedulers, {"default": 'DDIM'}),
|
||||||
[
|
|
||||||
"Euler",
|
|
||||||
"Euler A",
|
|
||||||
"DPM++",
|
|
||||||
"PNDM",
|
|
||||||
"DDIM",
|
|
||||||
"SASolverScheduler",
|
|
||||||
"UniPCMultistepScheduler",
|
|
||||||
"HeunDiscreteScheduler",
|
|
||||||
"DEISMultistepScheduler",
|
|
||||||
"CogVideoXDDIM",
|
|
||||||
"CogVideoXDPMScheduler",
|
|
||||||
],
|
|
||||||
{
|
|
||||||
"default": 'DDIM'
|
|
||||||
}
|
|
||||||
),
|
|
||||||
"control_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
"control_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
"control_start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
"control_start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
"control_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
"control_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
|||||||
@ -21,7 +21,6 @@ import torch.nn.functional as F
|
|||||||
import math
|
import math
|
||||||
|
|
||||||
from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
|
from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
|
||||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
|
||||||
from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
|
from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
|
||||||
from diffusers.utils import logging
|
from diffusers.utils import logging
|
||||||
from diffusers.utils.torch_utils import randn_tensor
|
from diffusers.utils.torch_utils import randn_tensor
|
||||||
@ -164,7 +163,8 @@ class CogVideoXPipeline(VideoSysPipeline):
|
|||||||
set_pab_manager(pab_config)
|
set_pab_manager(pab_config)
|
||||||
|
|
||||||
def prepare_latents(
|
def prepare_latents(
|
||||||
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, timesteps, denoise_strength, num_inference_steps, latents=None,
|
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, timesteps, denoise_strength,
|
||||||
|
num_inference_steps, latents=None, freenoise=True, context_size=None, context_overlap=None
|
||||||
):
|
):
|
||||||
shape = (
|
shape = (
|
||||||
batch_size,
|
batch_size,
|
||||||
@ -178,9 +178,43 @@ class CogVideoXPipeline(VideoSysPipeline):
|
|||||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||||
)
|
)
|
||||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=self.vae.dtype)
|
noise = randn_tensor(shape, generator=generator, device=torch.device("cpu"), dtype=self.vae.dtype)
|
||||||
|
if freenoise:
|
||||||
|
print("Applying FreeNoise")
|
||||||
|
# code and comments from AnimateDiff-Evolved by Kosinkadink (https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved)
|
||||||
|
video_length = num_frames // 4
|
||||||
|
delta = context_size - context_overlap
|
||||||
|
for start_idx in range(0, video_length-context_size, delta):
|
||||||
|
# start_idx corresponds to the beginning of a context window
|
||||||
|
# goal: place shuffled in the delta region right after the end of the context window
|
||||||
|
# if space after context window is not enough to place the noise, adjust and finish
|
||||||
|
place_idx = start_idx + context_size
|
||||||
|
# if place_idx is outside the valid indexes, we are already finished
|
||||||
|
if place_idx >= video_length:
|
||||||
|
break
|
||||||
|
end_idx = place_idx - 1
|
||||||
|
#print("video_length:", video_length, "start_idx:", start_idx, "end_idx:", end_idx, "place_idx:", place_idx, "delta:", delta)
|
||||||
|
|
||||||
|
# if there is not enough room to copy delta amount of indexes, copy limited amount and finish
|
||||||
|
if end_idx + delta >= video_length:
|
||||||
|
final_delta = video_length - place_idx
|
||||||
|
# generate list of indexes in final delta region
|
||||||
|
list_idx = torch.tensor(list(range(start_idx,start_idx+final_delta)), device=torch.device("cpu"), dtype=torch.long)
|
||||||
|
# shuffle list
|
||||||
|
list_idx = list_idx[torch.randperm(final_delta, generator=generator)]
|
||||||
|
# apply shuffled indexes
|
||||||
|
noise[:, place_idx:place_idx + final_delta, :, :, :] = noise[:, list_idx, :, :, :]
|
||||||
|
break
|
||||||
|
# otherwise, do normal behavior
|
||||||
|
# generate list of indexes in delta region
|
||||||
|
list_idx = torch.tensor(list(range(start_idx,start_idx+delta)), device=torch.device("cpu"), dtype=torch.long)
|
||||||
|
# shuffle list
|
||||||
|
list_idx = list_idx[torch.randperm(delta, generator=generator)]
|
||||||
|
# apply shuffled indexes
|
||||||
|
#print("place_idx:", place_idx, "delta:", delta, "list_idx:", list_idx)
|
||||||
|
noise[:, place_idx:place_idx + delta, :, :, :] = noise[:, list_idx, :, :, :]
|
||||||
if latents is None:
|
if latents is None:
|
||||||
latents = noise
|
latents = noise.to(device)
|
||||||
else:
|
else:
|
||||||
latents = latents.to(device)
|
latents = latents.to(device)
|
||||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, denoise_strength, device)
|
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, denoise_strength, device)
|
||||||
@ -346,6 +380,11 @@ class CogVideoXPipeline(VideoSysPipeline):
|
|||||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||||
device = torch.device("cuda"),
|
device = torch.device("cuda"),
|
||||||
scheduler_name: str = "DPM",
|
scheduler_name: str = "DPM",
|
||||||
|
context_schedule: Optional[str] = None,
|
||||||
|
context_frames: Optional[int] = None,
|
||||||
|
context_stride: Optional[int] = None,
|
||||||
|
context_overlap: Optional[int] = None,
|
||||||
|
freenoise: Optional[bool] = True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Function invoked when calling the pipeline for generation.
|
Function invoked when calling the pipeline for generation.
|
||||||
@ -448,7 +487,10 @@ class CogVideoXPipeline(VideoSysPipeline):
|
|||||||
timesteps,
|
timesteps,
|
||||||
denoise_strength,
|
denoise_strength,
|
||||||
num_inference_steps,
|
num_inference_steps,
|
||||||
latents
|
latents,
|
||||||
|
context_size=context_frames,
|
||||||
|
context_overlap=context_overlap,
|
||||||
|
freenoise=freenoise,
|
||||||
)
|
)
|
||||||
latents = latents.to(self.vae.dtype)
|
latents = latents.to(self.vae.dtype)
|
||||||
#print("latents", latents.shape)
|
#print("latents", latents.shape)
|
||||||
@ -492,22 +534,37 @@ class CogVideoXPipeline(VideoSysPipeline):
|
|||||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||||
comfy_pbar = ProgressBar(num_inference_steps)
|
comfy_pbar = ProgressBar(num_inference_steps)
|
||||||
|
|
||||||
# 8. Temporal tiling prep
|
# 8.5. Temporal tiling prep
|
||||||
if "tiled" in scheduler_name:
|
if context_schedule is not None and context_schedule == "temporal_tiling":
|
||||||
|
t_tile_length = context_frames
|
||||||
|
t_tile_overlap = context_overlap
|
||||||
t_tile_weights = self._gaussian_weights(t_tile_length=t_tile_length, t_batch_size=1).to(latents.device).to(self.vae.dtype)
|
t_tile_weights = self._gaussian_weights(t_tile_length=t_tile_length, t_batch_size=1).to(latents.device).to(self.vae.dtype)
|
||||||
temporal_tiling = True
|
use_temporal_tiling = True
|
||||||
print("Temporal tiling enabled")
|
print("Temporal tiling enabled")
|
||||||
|
elif context_schedule is not None:
|
||||||
|
print(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap")
|
||||||
|
use_temporal_tiling = False
|
||||||
|
use_context_schedule = True
|
||||||
|
from .cogvideox_fun.context import get_context_scheduler
|
||||||
|
context = get_context_scheduler(context_schedule)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
temporal_tiling = False
|
use_temporal_tiling = False
|
||||||
print("Temporal tiling disabled")
|
use_context_schedule = False
|
||||||
#print("latents.shape", latents.shape)
|
print("Temporal tiling and context schedule disabled")
|
||||||
|
# 7. Create rotary embeds if required
|
||||||
|
image_rotary_emb = (
|
||||||
|
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
||||||
|
if self.transformer.config.use_rotary_positional_embeddings
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||||
old_pred_original_sample = None # for DPM-solver++
|
old_pred_original_sample = None # for DPM-solver++
|
||||||
for i, t in enumerate(timesteps):
|
for i, t in enumerate(timesteps):
|
||||||
if self.interrupt:
|
if self.interrupt:
|
||||||
continue
|
continue
|
||||||
if temporal_tiling and isinstance(self.scheduler, CogVideoXDDIMScheduler):
|
if use_temporal_tiling and isinstance(self.scheduler, CogVideoXDDIMScheduler):
|
||||||
#temporal tiling code based on https://github.com/mayuelala/FollowYourEmoji/blob/main/models/video_pipeline.py
|
#temporal tiling code based on https://github.com/mayuelala/FollowYourEmoji/blob/main/models/video_pipeline.py
|
||||||
# =====================================================
|
# =====================================================
|
||||||
grid_ts = 0
|
grid_ts = 0
|
||||||
@ -533,7 +590,7 @@ class CogVideoXPipeline(VideoSysPipeline):
|
|||||||
#latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
#latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||||
|
|
||||||
image_rotary_emb = (
|
image_rotary_emb = (
|
||||||
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device, input_start_t, input_end_t)
|
self._prepare_rotary_positional_embeddings(height, width, t_tile_length, device)
|
||||||
if self.transformer.config.use_rotary_positional_embeddings
|
if self.transformer.config.use_rotary_positional_embeddings
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
@ -600,6 +657,79 @@ class CogVideoXPipeline(VideoSysPipeline):
|
|||||||
progress_bar.update()
|
progress_bar.update()
|
||||||
comfy_pbar.update(1)
|
comfy_pbar.update(1)
|
||||||
# ==========================================
|
# ==========================================
|
||||||
|
elif use_context_schedule:
|
||||||
|
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||||
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||||
|
counter = torch.zeros_like(latent_model_input)
|
||||||
|
noise_pred = torch.zeros_like(latent_model_input)
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
noise_uncond = torch.zeros_like(latent_model_input)
|
||||||
|
|
||||||
|
if image_cond_latents is not None:
|
||||||
|
latent_image_input = torch.cat([image_cond_latents] * 2) if do_classifier_free_guidance else image_cond_latents
|
||||||
|
latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2)
|
||||||
|
|
||||||
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||||
|
timestep = t.expand(latent_model_input.shape[0])
|
||||||
|
|
||||||
|
context_queue = list(context(
|
||||||
|
i, num_inference_steps, latents.shape[1], context_frames, context_stride, context_overlap,
|
||||||
|
))
|
||||||
|
|
||||||
|
image_rotary_emb = (
|
||||||
|
self._prepare_rotary_positional_embeddings(height, width, context_frames, device)
|
||||||
|
if self.transformer.config.use_rotary_positional_embeddings
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
for c in context_queue:
|
||||||
|
partial_latent_model_input = latent_model_input[:, c, :, :, :]
|
||||||
|
|
||||||
|
# predict noise model_output
|
||||||
|
noise_pred[:, c, :, :, :] += self.transformer(
|
||||||
|
hidden_states=partial_latent_model_input,
|
||||||
|
encoder_hidden_states=prompt_embeds,
|
||||||
|
timestep=timestep,
|
||||||
|
image_rotary_emb=image_rotary_emb,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
# uncond
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
noise_uncond[:, c, :, :, :] += self.transformer(
|
||||||
|
hidden_states=partial_latent_model_input,
|
||||||
|
encoder_hidden_states=prompt_embeds,
|
||||||
|
timestep=timestep,
|
||||||
|
image_rotary_emb=image_rotary_emb,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
counter[:, c, :, :, :] += 1
|
||||||
|
noise_pred = noise_pred.float()
|
||||||
|
|
||||||
|
noise_pred /= counter
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||||
|
noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||||
|
|
||||||
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
|
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
|
||||||
|
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||||
|
else:
|
||||||
|
latents, old_pred_original_sample = self.scheduler.step(
|
||||||
|
noise_pred,
|
||||||
|
old_pred_original_sample,
|
||||||
|
t,
|
||||||
|
timesteps[i - 1] if i > 0 else None,
|
||||||
|
latents,
|
||||||
|
**extra_step_kwargs,
|
||||||
|
return_dict=False,
|
||||||
|
)
|
||||||
|
latents = latents.to(prompt_embeds.dtype)
|
||||||
|
|
||||||
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||||
|
progress_bar.update()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user