mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-09 04:44:22 +08:00
FreeNoise noise shuffling for context windows
This commit is contained in:
parent
6e82eb1618
commit
1801c65e97
@ -214,7 +214,8 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
|
||||
set_pab_manager(pab_config)
|
||||
|
||||
def prepare_latents(
|
||||
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, timesteps, denoise_strength, num_inference_steps, latents=None,
|
||||
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, timesteps, denoise_strength, num_inference_steps,
|
||||
latents=None, freenoise=True, context_size=None, context_overlap=None
|
||||
):
|
||||
shape = (
|
||||
batch_size,
|
||||
@ -228,9 +229,43 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=self.vae.dtype)
|
||||
noise = randn_tensor(shape, generator=generator, device=torch.device("cpu"), dtype=self.vae.dtype)
|
||||
if freenoise:
|
||||
print("Applying FreeNoise")
|
||||
# code and comments from AnimateDiff-Evolved by Kosinkadink (https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved)
|
||||
video_length = num_frames // 4
|
||||
delta = context_size - context_overlap
|
||||
for start_idx in range(0, video_length-context_size, delta):
|
||||
# start_idx corresponds to the beginning of a context window
|
||||
# goal: place shuffled in the delta region right after the end of the context window
|
||||
# if space after context window is not enough to place the noise, adjust and finish
|
||||
place_idx = start_idx + context_size
|
||||
# if place_idx is outside the valid indexes, we are already finished
|
||||
if place_idx >= video_length:
|
||||
break
|
||||
end_idx = place_idx - 1
|
||||
#print("video_length:", video_length, "start_idx:", start_idx, "end_idx:", end_idx, "place_idx:", place_idx, "delta:", delta)
|
||||
|
||||
# if there is not enough room to copy delta amount of indexes, copy limited amount and finish
|
||||
if end_idx + delta >= video_length:
|
||||
final_delta = video_length - place_idx
|
||||
# generate list of indexes in final delta region
|
||||
list_idx = torch.tensor(list(range(start_idx,start_idx+final_delta)), device=torch.device("cpu"), dtype=torch.long)
|
||||
# shuffle list
|
||||
list_idx = list_idx[torch.randperm(final_delta, generator=generator)]
|
||||
# apply shuffled indexes
|
||||
noise[:, place_idx:place_idx + final_delta, :, :, :] = noise[:, list_idx, :, :, :]
|
||||
break
|
||||
# otherwise, do normal behavior
|
||||
# generate list of indexes in delta region
|
||||
list_idx = torch.tensor(list(range(start_idx,start_idx+delta)), device=torch.device("cpu"), dtype=torch.long)
|
||||
# shuffle list
|
||||
list_idx = list_idx[torch.randperm(delta, generator=generator)]
|
||||
# apply shuffled indexes
|
||||
#print("place_idx:", place_idx, "delta:", delta, "list_idx:", list_idx)
|
||||
noise[:, place_idx:place_idx + delta, :, :, :] = noise[:, list_idx, :, :, :]
|
||||
if latents is None:
|
||||
latents = noise
|
||||
latents = noise.to(device)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, denoise_strength, device)
|
||||
@ -492,6 +527,7 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
|
||||
context_frames: Optional[int] = None,
|
||||
context_stride: Optional[int] = None,
|
||||
context_overlap: Optional[int] = None,
|
||||
freenoise: Optional[bool] = True,
|
||||
) -> Union[CogVideoX_Fun_PipelineOutput, Tuple]:
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@ -634,6 +670,9 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
|
||||
denoise_strength,
|
||||
num_inference_steps,
|
||||
latents,
|
||||
context_size=context_frames,
|
||||
context_overlap=context_overlap,
|
||||
freenoise=freenoise,
|
||||
)
|
||||
if comfyui_progressbar:
|
||||
pbar.update(1)
|
||||
@ -659,8 +698,8 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
|
||||
|
||||
# 8.5. Temporal tiling prep
|
||||
if context_schedule is not None and context_schedule == "temporal_tiling":
|
||||
t_tile_length = context_frames // 4
|
||||
t_tile_overlap = context_overlap // 4
|
||||
t_tile_length = context_frames
|
||||
t_tile_overlap = context_overlap
|
||||
t_tile_weights = self._gaussian_weights(t_tile_length=t_tile_length, t_batch_size=1).to(latents.device).to(self.vae.dtype)
|
||||
use_temporal_tiling = True
|
||||
print("Temporal tiling enabled")
|
||||
@ -668,9 +707,6 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
|
||||
print(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap")
|
||||
use_temporal_tiling = False
|
||||
use_context_schedule = True
|
||||
context_frames = context_frames // 4
|
||||
context_stride = context_stride // 4
|
||||
context_overlap = context_overlap // 4
|
||||
from .context import get_context_scheduler
|
||||
context = get_context_scheduler(context_schedule)
|
||||
|
||||
|
||||
@ -277,6 +277,9 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
|
||||
is_strength_max=True,
|
||||
return_noise=False,
|
||||
return_video_latents=False,
|
||||
context_size=None,
|
||||
context_overlap=None,
|
||||
freenoise=False,
|
||||
):
|
||||
shape = (
|
||||
batch_size,
|
||||
@ -309,11 +312,47 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
|
||||
video_latents = rearrange(video_latents, "b c f h w -> b f c h w")
|
||||
|
||||
if latents is None:
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
noise = randn_tensor(shape, generator=generator, device=torch.device("cpu"), dtype=dtype)
|
||||
if freenoise:
|
||||
print("Applying FreeNoise")
|
||||
# code and comments from AnimateDiff-Evolved by Kosinkadink (https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved)
|
||||
video_length = video_length // 4
|
||||
delta = context_size - context_overlap
|
||||
for start_idx in range(0, video_length-context_size, delta):
|
||||
# start_idx corresponds to the beginning of a context window
|
||||
# goal: place shuffled in the delta region right after the end of the context window
|
||||
# if space after context window is not enough to place the noise, adjust and finish
|
||||
place_idx = start_idx + context_size
|
||||
# if place_idx is outside the valid indexes, we are already finished
|
||||
if place_idx >= video_length:
|
||||
break
|
||||
end_idx = place_idx - 1
|
||||
#print("video_length:", video_length, "start_idx:", start_idx, "end_idx:", end_idx, "place_idx:", place_idx, "delta:", delta)
|
||||
|
||||
# if there is not enough room to copy delta amount of indexes, copy limited amount and finish
|
||||
if end_idx + delta >= video_length:
|
||||
final_delta = video_length - place_idx
|
||||
# generate list of indexes in final delta region
|
||||
list_idx = torch.tensor(list(range(start_idx,start_idx+final_delta)), device=torch.device("cpu"), dtype=torch.long)
|
||||
# shuffle list
|
||||
list_idx = list_idx[torch.randperm(final_delta, generator=generator)]
|
||||
# apply shuffled indexes
|
||||
noise[:, place_idx:place_idx + final_delta, :, :, :] = noise[:, list_idx, :, :, :]
|
||||
break
|
||||
# otherwise, do normal behavior
|
||||
# generate list of indexes in delta region
|
||||
list_idx = torch.tensor(list(range(start_idx,start_idx+delta)), device=torch.device("cpu"), dtype=torch.long)
|
||||
# shuffle list
|
||||
list_idx = list_idx[torch.randperm(delta, generator=generator)]
|
||||
# apply shuffled indexes
|
||||
#print("place_idx:", place_idx, "delta:", delta, "list_idx:", list_idx)
|
||||
noise[:, place_idx:place_idx + delta, :, :, :] = noise[:, list_idx, :, :, :]
|
||||
|
||||
# if strength is 1. then initialise the latents to noise, else initial to image + noise
|
||||
latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep)
|
||||
# if pure noise then scale the initial latents by the Scheduler's init sigma
|
||||
latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
|
||||
latents = latents.to(device)
|
||||
else:
|
||||
noise = latents.to(device)
|
||||
latents = noise * self.scheduler.init_noise_sigma
|
||||
@ -465,7 +504,10 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
|
||||
width: int,
|
||||
num_frames: int,
|
||||
device: torch.device,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
start_frame: Optional[int] = None,
|
||||
end_frame: Optional[int] = None,
|
||||
context_frames: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||
base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||
@ -481,6 +523,19 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
|
||||
temporal_size=num_frames,
|
||||
use_real=True,
|
||||
)
|
||||
|
||||
if start_frame is not None or context_frames is not None:
|
||||
freqs_cos = freqs_cos.view(num_frames, grid_height * grid_width, -1)
|
||||
freqs_sin = freqs_sin.view(num_frames, grid_height * grid_width, -1)
|
||||
if context_frames is not None:
|
||||
freqs_cos = freqs_cos[context_frames]
|
||||
freqs_sin = freqs_sin[context_frames]
|
||||
else:
|
||||
freqs_cos = freqs_cos[start_frame:end_frame]
|
||||
freqs_sin = freqs_sin[start_frame:end_frame]
|
||||
|
||||
freqs_cos = freqs_cos.view(-1, freqs_cos.shape[-1])
|
||||
freqs_sin = freqs_sin.view(-1, freqs_sin.shape[-1])
|
||||
|
||||
freqs_cos = freqs_cos.to(device=device)
|
||||
freqs_sin = freqs_sin.to(device=device)
|
||||
@ -540,6 +595,11 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
|
||||
strength: float = 1,
|
||||
noise_aug_strength: float = 0.0563,
|
||||
comfyui_progressbar: bool = False,
|
||||
context_schedule: Optional[str] = None,
|
||||
context_frames: Optional[int] = None,
|
||||
context_stride: Optional[int] = None,
|
||||
context_overlap: Optional[int] = None,
|
||||
freenoise: Optional[bool] = True,
|
||||
) -> Union[CogVideoX_Fun_PipelineOutput, Tuple]:
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@ -617,10 +677,10 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
|
||||
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
if num_frames > 49:
|
||||
raise ValueError(
|
||||
"The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
|
||||
)
|
||||
# if num_frames > 49:
|
||||
# raise ValueError(
|
||||
# "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
|
||||
# )
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
@ -704,6 +764,9 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
|
||||
is_strength_max=is_strength_max,
|
||||
return_noise=True,
|
||||
return_video_latents=return_image_latents,
|
||||
context_size=context_frames,
|
||||
context_overlap=context_overlap,
|
||||
freenoise=freenoise,
|
||||
)
|
||||
if return_image_latents:
|
||||
latents, noise, image_latents = latents_outputs
|
||||
@ -794,11 +857,29 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Create rotary embeds if required
|
||||
image_rotary_emb = (
|
||||
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
||||
if self.transformer.config.use_rotary_positional_embeddings
|
||||
else None
|
||||
)
|
||||
if context_schedule is not None and context_schedule == "temporal_tiling":
|
||||
t_tile_length = context_frames
|
||||
t_tile_overlap = context_overlap
|
||||
t_tile_weights = self._gaussian_weights(t_tile_length=t_tile_length, t_batch_size=1).to(latents.device).to(self.vae.dtype)
|
||||
use_temporal_tiling = True
|
||||
print("Temporal tiling enabled")
|
||||
elif context_schedule is not None:
|
||||
print(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap")
|
||||
use_temporal_tiling = False
|
||||
use_context_schedule = True
|
||||
from .context import get_context_scheduler
|
||||
context = get_context_scheduler(context_schedule)
|
||||
|
||||
else:
|
||||
use_temporal_tiling = False
|
||||
use_context_schedule = False
|
||||
print("Temporal tiling and context schedule disabled")
|
||||
# 7. Create rotary embeds if required
|
||||
image_rotary_emb = (
|
||||
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
||||
if self.transformer.config.use_rotary_positional_embeddings
|
||||
else None
|
||||
)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
@ -809,63 +890,232 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
if use_temporal_tiling and isinstance(self.scheduler, CogVideoXDDIMScheduler):
|
||||
#temporal tiling code based on https://github.com/mayuelala/FollowYourEmoji/blob/main/models/video_pipeline.py
|
||||
# =====================================================
|
||||
grid_ts = 0
|
||||
cur_t = 0
|
||||
while cur_t < latents.shape[1]:
|
||||
cur_t = max(grid_ts * t_tile_length - t_tile_overlap * grid_ts, 0) + t_tile_length
|
||||
grid_ts += 1
|
||||
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
all_t = latents.shape[1]
|
||||
latents_all_list = []
|
||||
# =====================================================
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
for t_i in range(grid_ts):
|
||||
if t_i < grid_ts - 1:
|
||||
ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0)
|
||||
if t_i == grid_ts - 1:
|
||||
ofs_t = all_t - t_tile_length
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
return_dict=False,
|
||||
inpaint_latents=inpaint_latents,
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
input_start_t = ofs_t
|
||||
input_end_t = ofs_t + t_tile_length
|
||||
|
||||
# perform guidance
|
||||
if use_dynamic_cfg:
|
||||
self._guidance_scale = 1 + guidance_scale * (
|
||||
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
|
||||
)
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
image_rotary_emb = (
|
||||
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device, input_start_t, input_end_t)
|
||||
if self.transformer.config.use_rotary_positional_embeddings
|
||||
else None
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
latents_tile = latents[:, input_start_t:input_end_t,:, :, :]
|
||||
inpaint_latents_tile = inpaint_latents[:, input_start_t:input_end_t, :, :, :]
|
||||
|
||||
latent_model_input_tile = torch.cat([latents_tile] * 2) if do_classifier_free_guidance else latents_tile
|
||||
latent_model_input_tile = self.scheduler.scale_model_input(latent_model_input_tile, t)
|
||||
|
||||
#t_input = t[None].to(device)
|
||||
t_input = t.expand(latent_model_input_tile.shape[0]) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input_tile,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=t_input,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
return_dict=False,
|
||||
inpaint_latents=inpaint_latents_tile,
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_tile = self.scheduler.step(noise_pred, t, latents_tile.to(self.vae.dtype), **extra_step_kwargs, return_dict=False)[0]
|
||||
latents_all_list.append(latents_tile)
|
||||
|
||||
# ==========================================
|
||||
latents_all = torch.zeros(latents.shape, device=latents.device, dtype=self.vae.dtype)
|
||||
contributors = torch.zeros(latents.shape, device=latents.device, dtype=self.vae.dtype)
|
||||
# Add each tile contribution to overall latents
|
||||
for t_i in range(grid_ts):
|
||||
if t_i < grid_ts - 1:
|
||||
ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0)
|
||||
if t_i == grid_ts - 1:
|
||||
ofs_t = all_t - t_tile_length
|
||||
|
||||
input_start_t = ofs_t
|
||||
input_end_t = ofs_t + t_tile_length
|
||||
|
||||
latents_all[:, input_start_t:input_end_t,:, :, :] += latents_all_list[t_i] * t_tile_weights
|
||||
contributors[:, input_start_t:input_end_t,:, :, :] += t_tile_weights
|
||||
|
||||
latents_all /= contributors
|
||||
|
||||
latents = latents_all
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
pbar.update(1)
|
||||
# ==========================================
|
||||
elif use_context_schedule:
|
||||
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# Calculate the current step percentage
|
||||
current_step_percentage = i / num_inference_steps
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
context_queue = list(context(
|
||||
i, num_inference_steps, latents.shape[1], context_frames, context_stride, context_overlap,
|
||||
))
|
||||
counter = torch.zeros_like(latent_model_input)
|
||||
noise_pred = torch.zeros_like(latent_model_input)
|
||||
if do_classifier_free_guidance:
|
||||
noise_uncond = torch.zeros_like(latent_model_input)
|
||||
|
||||
for c in context_queue:
|
||||
partial_latent_model_input = latent_model_input[:, c, :, :, :]
|
||||
partial_inpaint_latents = inpaint_latents[:, c, :, :, :]
|
||||
partial_inpaint_latents[:, 0, :, :, :] = inpaint_latents[:, 0, :, :, :]
|
||||
|
||||
image_rotary_emb = (
|
||||
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device, context_frames=c)
|
||||
if self.transformer.config.use_rotary_positional_embeddings
|
||||
else None
|
||||
)
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred[:, c, :, :, :] += self.transformer(
|
||||
hidden_states=partial_latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
return_dict=False,
|
||||
inpaint_latents=partial_inpaint_latents,
|
||||
)[0]
|
||||
|
||||
counter[:, c, :, :, :] += 1
|
||||
if do_classifier_free_guidance:
|
||||
noise_uncond[:, c, :, :, :] += self.transformer(
|
||||
hidden_states=partial_latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
return_dict=False,
|
||||
inpaint_latents=partial_inpaint_latents,
|
||||
)[0]
|
||||
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
noise_pred /= counter
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
else:
|
||||
latents, old_pred_original_sample = self.scheduler.step(
|
||||
noise_pred,
|
||||
old_pred_original_sample,
|
||||
t,
|
||||
timesteps[i - 1] if i > 0 else None,
|
||||
latents,
|
||||
**extra_step_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
latents = latents.to(prompt_embeds.dtype)
|
||||
|
||||
# call the callback, if provided
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if comfyui_progressbar:
|
||||
pbar.update(1)
|
||||
|
||||
else:
|
||||
latents, old_pred_original_sample = self.scheduler.step(
|
||||
noise_pred,
|
||||
old_pred_original_sample,
|
||||
t,
|
||||
timesteps[i - 1] if i > 0 else None,
|
||||
latents,
|
||||
**extra_step_kwargs,
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
return_dict=False,
|
||||
)
|
||||
latents = latents.to(prompt_embeds.dtype)
|
||||
inpaint_latents=inpaint_latents,
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
# call the callback, if provided
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
# perform guidance
|
||||
if use_dynamic_cfg:
|
||||
self._guidance_scale = 1 + guidance_scale * (
|
||||
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
|
||||
)
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
else:
|
||||
latents, old_pred_original_sample = self.scheduler.step(
|
||||
noise_pred,
|
||||
old_pred_original_sample,
|
||||
t,
|
||||
timesteps[i - 1] if i > 0 else None,
|
||||
latents,
|
||||
**extra_step_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
latents = latents.to(prompt_embeds.dtype)
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if comfyui_progressbar:
|
||||
pbar.update(1)
|
||||
# call the callback, if provided
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if comfyui_progressbar:
|
||||
pbar.update(1)
|
||||
|
||||
# if output_type == "numpy":
|
||||
# video = self.decode_latents(latents)
|
||||
|
||||
45
nodes.py
45
nodes.py
@ -974,7 +974,7 @@ class CogVideoXFunSampler:
|
||||
"pipeline": ("COGVIDEOPIPE",),
|
||||
"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"video_length": ("INT", {"default": 49, "min": 5, "max": 49, "step": 4}),
|
||||
"video_length": ("INT", {"default": 49, "min": 5, "max": 2048, "step": 4}),
|
||||
"base_resolution": ("INT", {"min": 256, "max": 1280, "step": 64, "default": 512, "tooltip": "Base resolution, closest training data bucket resolution is chosen based on the selection."}),
|
||||
"seed": ("INT", {"default": 43, "min": 0, "max": 0xffffffffffffffff}),
|
||||
"steps": ("INT", {"default": 50, "min": 1, "max": 200, "step": 1}),
|
||||
@ -1003,6 +1003,7 @@ class CogVideoXFunSampler:
|
||||
"end_img": ("IMAGE",),
|
||||
"opt_empty_latent": ("LATENT",),
|
||||
"noise_aug_strength": ("FLOAT", {"default": 0.0563, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"context_options": ("COGCONTEXT", ),
|
||||
},
|
||||
}
|
||||
|
||||
@ -1012,7 +1013,7 @@ class CogVideoXFunSampler:
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
|
||||
def process(self, pipeline, positive, negative, video_length, base_resolution, seed, steps, cfg, scheduler,
|
||||
start_img=None, end_img=None, opt_empty_latent=None, noise_aug_strength=0.0563):
|
||||
start_img=None, end_img=None, opt_empty_latent=None, noise_aug_strength=0.0563, context_options=None):
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
pipe = pipeline["pipe"]
|
||||
@ -1041,6 +1042,9 @@ class CogVideoXFunSampler:
|
||||
log.info(f"Closest bucket size: {width}x{height}")
|
||||
|
||||
# Load Sampler
|
||||
if context_options is not None and context_options["context_schedule"] == "temporal_tiling":
|
||||
logging.info("Temporal tiling enabled, changing scheduler to DDIM_tiled")
|
||||
scheduler="DDIM_tiled"
|
||||
scheduler_config = pipeline["scheduler_config"]
|
||||
if scheduler in scheduler_mapping:
|
||||
noise_scheduler = scheduler_mapping[scheduler].from_config(scheduler_config)
|
||||
@ -1050,7 +1054,15 @@ class CogVideoXFunSampler:
|
||||
|
||||
#if not pipeline["cpu_offloading"]:
|
||||
# pipe.transformer.to(device)
|
||||
generator= torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
if context_options is not None:
|
||||
context_frames = context_options["context_frames"] // 4
|
||||
context_stride = context_options["context_stride"] // 4
|
||||
context_overlap = context_options["context_overlap"] // 4
|
||||
else:
|
||||
context_frames, context_stride, context_overlap = None, None, None
|
||||
|
||||
generator= torch.Generator(device="cpu").manual_seed(seed)
|
||||
|
||||
autocastcondition = not pipeline["onediff"]
|
||||
autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext()
|
||||
@ -1072,6 +1084,11 @@ class CogVideoXFunSampler:
|
||||
mask_video = input_video_mask,
|
||||
comfyui_progressbar = True,
|
||||
noise_aug_strength = noise_aug_strength,
|
||||
context_schedule=context_options["context_schedule"] if context_options is not None else None,
|
||||
context_frames=context_frames,
|
||||
context_stride= context_stride,
|
||||
context_overlap= context_overlap,
|
||||
freenoise=context_options["freenoise"] if context_options is not None else None
|
||||
)
|
||||
#if not pipeline["cpu_offloading"]:
|
||||
# pipe.transformer.to(offload_device)
|
||||
@ -1282,6 +1299,7 @@ class CogVideoContextOptions:
|
||||
"context_frames": ("INT", {"default": 12, "min": 2, "max": 100, "step": 1, "tooltip": "Number of pixel frames in the context, NOTE: the latent space has 4 frames in 1"} ),
|
||||
"context_stride": ("INT", {"default": 4, "min": 4, "max": 100, "step": 1, "tooltip": "Context stride as pixel frames, NOTE: the latent space has 4 frames in 1"} ),
|
||||
"context_overlap": ("INT", {"default": 4, "min": 4, "max": 100, "step": 1, "tooltip": "Context overlap as pixel frames, NOTE: the latent space has 4 frames in 1"} ),
|
||||
"freenoise": ("BOOLEAN", {"default": True, "tooltip": "Shuffle the noise"}),
|
||||
}
|
||||
}
|
||||
|
||||
@ -1290,12 +1308,13 @@ class CogVideoContextOptions:
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
|
||||
def process(self, context_schedule, context_frames, context_stride, context_overlap):
|
||||
def process(self, context_schedule, context_frames, context_stride, context_overlap, freenoise):
|
||||
context_options = {
|
||||
"context_schedule":context_schedule,
|
||||
"context_frames":context_frames,
|
||||
"context_stride":context_stride,
|
||||
"context_overlap":context_overlap
|
||||
"context_overlap":context_overlap,
|
||||
"freenoise":freenoise
|
||||
}
|
||||
|
||||
return (context_options,)
|
||||
@ -1362,6 +1381,13 @@ class CogVideoXFunControlSampler:
|
||||
|
||||
mm.soft_empty_cache()
|
||||
|
||||
if context_options is not None:
|
||||
context_frames = context_options["context_frames"] // 4
|
||||
context_stride = context_options["context_stride"] // 4
|
||||
context_overlap = context_options["context_overlap"] // 4
|
||||
else:
|
||||
context_frames, context_stride, context_overlap = None, None, None
|
||||
|
||||
# Load Sampler
|
||||
scheduler_config = pipeline["scheduler_config"]
|
||||
if context_options is not None and context_options["context_schedule"] == "temporal_tiling":
|
||||
@ -1373,7 +1399,7 @@ class CogVideoXFunControlSampler:
|
||||
else:
|
||||
raise ValueError(f"Unknown scheduler: {scheduler}")
|
||||
|
||||
generator= torch.Generator(device).manual_seed(seed)
|
||||
generator=torch.Generator(torch.device("cpu")).manual_seed(seed)
|
||||
|
||||
autocastcondition = not pipeline["onediff"]
|
||||
autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext()
|
||||
@ -1401,9 +1427,10 @@ class CogVideoXFunControlSampler:
|
||||
latents=samples["samples"] if samples is not None else None,
|
||||
denoise_strength=denoise_strength,
|
||||
context_schedule=context_options["context_schedule"] if context_options is not None else None,
|
||||
context_frames=context_options["context_frames"] if context_options is not None else None,
|
||||
context_stride=context_options["context_stride"] if context_options is not None else None,
|
||||
context_overlap=context_options["context_overlap"] if context_options is not None else None
|
||||
context_frames=context_frames,
|
||||
context_stride= context_stride,
|
||||
context_overlap= context_overlap,
|
||||
freenoise=context_options["freenoise"] if context_options is not None else None
|
||||
|
||||
)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user