From 1801c65e978ae095f2050cb96e27d4dc21d572a4 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sat, 5 Oct 2024 19:29:09 +0300 Subject: [PATCH] FreeNoise noise shuffling for context windows --- cogvideox_fun/pipeline_cogvideox_control.py | 52 ++- cogvideox_fun/pipeline_cogvideox_inpaint.py | 366 ++++++++++++++++---- nodes.py | 45 ++- 3 files changed, 388 insertions(+), 75 deletions(-) diff --git a/cogvideox_fun/pipeline_cogvideox_control.py b/cogvideox_fun/pipeline_cogvideox_control.py index cad1fad..1ec87e1 100644 --- a/cogvideox_fun/pipeline_cogvideox_control.py +++ b/cogvideox_fun/pipeline_cogvideox_control.py @@ -214,7 +214,8 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline): set_pab_manager(pab_config) def prepare_latents( - self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, timesteps, denoise_strength, num_inference_steps, latents=None, + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, timesteps, denoise_strength, num_inference_steps, + latents=None, freenoise=True, context_size=None, context_overlap=None ): shape = ( batch_size, @@ -228,9 +229,43 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline): f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - noise = randn_tensor(shape, generator=generator, device=device, dtype=self.vae.dtype) + noise = randn_tensor(shape, generator=generator, device=torch.device("cpu"), dtype=self.vae.dtype) + if freenoise: + print("Applying FreeNoise") + # code and comments from AnimateDiff-Evolved by Kosinkadink (https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved) + video_length = num_frames // 4 + delta = context_size - context_overlap + for start_idx in range(0, video_length-context_size, delta): + # start_idx corresponds to the beginning of a context window + # goal: place shuffled in the delta region right after the end of the context window + # if space after context window is not enough to place the noise, adjust and finish + place_idx = start_idx + context_size + # if place_idx is outside the valid indexes, we are already finished + if place_idx >= video_length: + break + end_idx = place_idx - 1 + #print("video_length:", video_length, "start_idx:", start_idx, "end_idx:", end_idx, "place_idx:", place_idx, "delta:", delta) + + # if there is not enough room to copy delta amount of indexes, copy limited amount and finish + if end_idx + delta >= video_length: + final_delta = video_length - place_idx + # generate list of indexes in final delta region + list_idx = torch.tensor(list(range(start_idx,start_idx+final_delta)), device=torch.device("cpu"), dtype=torch.long) + # shuffle list + list_idx = list_idx[torch.randperm(final_delta, generator=generator)] + # apply shuffled indexes + noise[:, place_idx:place_idx + final_delta, :, :, :] = noise[:, list_idx, :, :, :] + break + # otherwise, do normal behavior + # generate list of indexes in delta region + list_idx = torch.tensor(list(range(start_idx,start_idx+delta)), device=torch.device("cpu"), dtype=torch.long) + # shuffle list + list_idx = list_idx[torch.randperm(delta, generator=generator)] + # apply shuffled indexes + #print("place_idx:", place_idx, "delta:", delta, "list_idx:", list_idx) + noise[:, place_idx:place_idx + delta, :, :, :] = noise[:, list_idx, :, :, :] if latents is None: - latents = noise + latents = noise.to(device) else: latents = latents.to(device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, denoise_strength, device) @@ -492,6 +527,7 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline): context_frames: Optional[int] = None, context_stride: Optional[int] = None, context_overlap: Optional[int] = None, + freenoise: Optional[bool] = True, ) -> Union[CogVideoX_Fun_PipelineOutput, Tuple]: """ Function invoked when calling the pipeline for generation. @@ -634,6 +670,9 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline): denoise_strength, num_inference_steps, latents, + context_size=context_frames, + context_overlap=context_overlap, + freenoise=freenoise, ) if comfyui_progressbar: pbar.update(1) @@ -659,8 +698,8 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline): # 8.5. Temporal tiling prep if context_schedule is not None and context_schedule == "temporal_tiling": - t_tile_length = context_frames // 4 - t_tile_overlap = context_overlap // 4 + t_tile_length = context_frames + t_tile_overlap = context_overlap t_tile_weights = self._gaussian_weights(t_tile_length=t_tile_length, t_batch_size=1).to(latents.device).to(self.vae.dtype) use_temporal_tiling = True print("Temporal tiling enabled") @@ -668,9 +707,6 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline): print(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap") use_temporal_tiling = False use_context_schedule = True - context_frames = context_frames // 4 - context_stride = context_stride // 4 - context_overlap = context_overlap // 4 from .context import get_context_scheduler context = get_context_scheduler(context_schedule) diff --git a/cogvideox_fun/pipeline_cogvideox_inpaint.py b/cogvideox_fun/pipeline_cogvideox_inpaint.py index 5e56432..440bd3e 100644 --- a/cogvideox_fun/pipeline_cogvideox_inpaint.py +++ b/cogvideox_fun/pipeline_cogvideox_inpaint.py @@ -277,6 +277,9 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline): is_strength_max=True, return_noise=False, return_video_latents=False, + context_size=None, + context_overlap=None, + freenoise=False, ): shape = ( batch_size, @@ -309,11 +312,47 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline): video_latents = rearrange(video_latents, "b c f h w -> b f c h w") if latents is None: - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + noise = randn_tensor(shape, generator=generator, device=torch.device("cpu"), dtype=dtype) + if freenoise: + print("Applying FreeNoise") + # code and comments from AnimateDiff-Evolved by Kosinkadink (https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved) + video_length = video_length // 4 + delta = context_size - context_overlap + for start_idx in range(0, video_length-context_size, delta): + # start_idx corresponds to the beginning of a context window + # goal: place shuffled in the delta region right after the end of the context window + # if space after context window is not enough to place the noise, adjust and finish + place_idx = start_idx + context_size + # if place_idx is outside the valid indexes, we are already finished + if place_idx >= video_length: + break + end_idx = place_idx - 1 + #print("video_length:", video_length, "start_idx:", start_idx, "end_idx:", end_idx, "place_idx:", place_idx, "delta:", delta) + + # if there is not enough room to copy delta amount of indexes, copy limited amount and finish + if end_idx + delta >= video_length: + final_delta = video_length - place_idx + # generate list of indexes in final delta region + list_idx = torch.tensor(list(range(start_idx,start_idx+final_delta)), device=torch.device("cpu"), dtype=torch.long) + # shuffle list + list_idx = list_idx[torch.randperm(final_delta, generator=generator)] + # apply shuffled indexes + noise[:, place_idx:place_idx + final_delta, :, :, :] = noise[:, list_idx, :, :, :] + break + # otherwise, do normal behavior + # generate list of indexes in delta region + list_idx = torch.tensor(list(range(start_idx,start_idx+delta)), device=torch.device("cpu"), dtype=torch.long) + # shuffle list + list_idx = list_idx[torch.randperm(delta, generator=generator)] + # apply shuffled indexes + #print("place_idx:", place_idx, "delta:", delta, "list_idx:", list_idx) + noise[:, place_idx:place_idx + delta, :, :, :] = noise[:, list_idx, :, :, :] + # if strength is 1. then initialise the latents to noise, else initial to image + noise latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep) # if pure noise then scale the initial latents by the Scheduler's init sigma latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + latents = latents.to(device) else: noise = latents.to(device) latents = noise * self.scheduler.init_noise_sigma @@ -465,7 +504,10 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline): width: int, num_frames: int, device: torch.device, - ) -> Tuple[torch.Tensor, torch.Tensor]: + start_frame: Optional[int] = None, + end_frame: Optional[int] = None, + context_frames: Optional[int] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) @@ -481,6 +523,19 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline): temporal_size=num_frames, use_real=True, ) + + if start_frame is not None or context_frames is not None: + freqs_cos = freqs_cos.view(num_frames, grid_height * grid_width, -1) + freqs_sin = freqs_sin.view(num_frames, grid_height * grid_width, -1) + if context_frames is not None: + freqs_cos = freqs_cos[context_frames] + freqs_sin = freqs_sin[context_frames] + else: + freqs_cos = freqs_cos[start_frame:end_frame] + freqs_sin = freqs_sin[start_frame:end_frame] + + freqs_cos = freqs_cos.view(-1, freqs_cos.shape[-1]) + freqs_sin = freqs_sin.view(-1, freqs_sin.shape[-1]) freqs_cos = freqs_cos.to(device=device) freqs_sin = freqs_sin.to(device=device) @@ -540,6 +595,11 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline): strength: float = 1, noise_aug_strength: float = 0.0563, comfyui_progressbar: bool = False, + context_schedule: Optional[str] = None, + context_frames: Optional[int] = None, + context_stride: Optional[int] = None, + context_overlap: Optional[int] = None, + freenoise: Optional[bool] = True, ) -> Union[CogVideoX_Fun_PipelineOutput, Tuple]: """ Function invoked when calling the pipeline for generation. @@ -617,10 +677,10 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline): `tuple`. When returning a tuple, the first element is a list with the generated images. """ - if num_frames > 49: - raise ValueError( - "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation." - ) + # if num_frames > 49: + # raise ValueError( + # "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation." + # ) if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs @@ -704,6 +764,9 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline): is_strength_max=is_strength_max, return_noise=True, return_video_latents=return_image_latents, + context_size=context_frames, + context_overlap=context_overlap, + freenoise=freenoise, ) if return_image_latents: latents, noise, image_latents = latents_outputs @@ -794,11 +857,29 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline): extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Create rotary embeds if required - image_rotary_emb = ( - self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) - if self.transformer.config.use_rotary_positional_embeddings - else None - ) + if context_schedule is not None and context_schedule == "temporal_tiling": + t_tile_length = context_frames + t_tile_overlap = context_overlap + t_tile_weights = self._gaussian_weights(t_tile_length=t_tile_length, t_batch_size=1).to(latents.device).to(self.vae.dtype) + use_temporal_tiling = True + print("Temporal tiling enabled") + elif context_schedule is not None: + print(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap") + use_temporal_tiling = False + use_context_schedule = True + from .context import get_context_scheduler + context = get_context_scheduler(context_schedule) + + else: + use_temporal_tiling = False + use_context_schedule = False + print("Temporal tiling and context schedule disabled") + # 7. Create rotary embeds if required + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) # 8. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) @@ -809,63 +890,232 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline): for i, t in enumerate(timesteps): if self.interrupt: continue + if use_temporal_tiling and isinstance(self.scheduler, CogVideoXDDIMScheduler): + #temporal tiling code based on https://github.com/mayuelala/FollowYourEmoji/blob/main/models/video_pipeline.py + # ===================================================== + grid_ts = 0 + cur_t = 0 + while cur_t < latents.shape[1]: + cur_t = max(grid_ts * t_tile_length - t_tile_overlap * grid_ts, 0) + t_tile_length + grid_ts += 1 - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + all_t = latents.shape[1] + latents_all_list = [] + # ===================================================== - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latent_model_input.shape[0]) + for t_i in range(grid_ts): + if t_i < grid_ts - 1: + ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0) + if t_i == grid_ts - 1: + ofs_t = all_t - t_tile_length - # predict noise model_output - noise_pred = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - image_rotary_emb=image_rotary_emb, - return_dict=False, - inpaint_latents=inpaint_latents, - )[0] - noise_pred = noise_pred.float() + input_start_t = ofs_t + input_end_t = ofs_t + t_tile_length - # perform guidance - if use_dynamic_cfg: - self._guidance_scale = 1 + guidance_scale * ( - (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 - ) - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device, input_start_t, input_end_t) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) - # compute the previous noisy sample x_t -> x_t-1 - if not isinstance(self.scheduler, CogVideoXDPMScheduler): - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + latents_tile = latents[:, input_start_t:input_end_t,:, :, :] + inpaint_latents_tile = inpaint_latents[:, input_start_t:input_end_t, :, :, :] + + latent_model_input_tile = torch.cat([latents_tile] * 2) if do_classifier_free_guidance else latents_tile + latent_model_input_tile = self.scheduler.scale_model_input(latent_model_input_tile, t) + + #t_input = t[None].to(device) + t_input = t.expand(latent_model_input_tile.shape[0]) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + + # predict noise model_output + noise_pred = self.transformer( + hidden_states=latent_model_input_tile, + encoder_hidden_states=prompt_embeds, + timestep=t_input, + image_rotary_emb=image_rotary_emb, + return_dict=False, + inpaint_latents=inpaint_latents_tile, + )[0] + noise_pred = noise_pred.float() + + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_tile = self.scheduler.step(noise_pred, t, latents_tile.to(self.vae.dtype), **extra_step_kwargs, return_dict=False)[0] + latents_all_list.append(latents_tile) + + # ========================================== + latents_all = torch.zeros(latents.shape, device=latents.device, dtype=self.vae.dtype) + contributors = torch.zeros(latents.shape, device=latents.device, dtype=self.vae.dtype) + # Add each tile contribution to overall latents + for t_i in range(grid_ts): + if t_i < grid_ts - 1: + ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0) + if t_i == grid_ts - 1: + ofs_t = all_t - t_tile_length + + input_start_t = ofs_t + input_end_t = ofs_t + t_tile_length + + latents_all[:, input_start_t:input_end_t,:, :, :] += latents_all_list[t_i] * t_tile_weights + contributors[:, input_start_t:input_end_t,:, :, :] += t_tile_weights + + latents_all /= contributors + + latents = latents_all + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + pbar.update(1) + # ========================================== + elif use_context_schedule: + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # Calculate the current step percentage + current_step_percentage = i / num_inference_steps + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + context_queue = list(context( + i, num_inference_steps, latents.shape[1], context_frames, context_stride, context_overlap, + )) + counter = torch.zeros_like(latent_model_input) + noise_pred = torch.zeros_like(latent_model_input) + if do_classifier_free_guidance: + noise_uncond = torch.zeros_like(latent_model_input) + + for c in context_queue: + partial_latent_model_input = latent_model_input[:, c, :, :, :] + partial_inpaint_latents = inpaint_latents[:, c, :, :, :] + partial_inpaint_latents[:, 0, :, :, :] = inpaint_latents[:, 0, :, :, :] + + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device, context_frames=c) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) + + # predict noise model_output + noise_pred[:, c, :, :, :] += self.transformer( + hidden_states=partial_latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + return_dict=False, + inpaint_latents=partial_inpaint_latents, + )[0] + + counter[:, c, :, :, :] += 1 + if do_classifier_free_guidance: + noise_uncond[:, c, :, :, :] += self.transformer( + hidden_states=partial_latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + return_dict=False, + inpaint_latents=partial_inpaint_latents, + )[0] + + noise_pred = noise_pred.float() + + noise_pred /= counter + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + else: + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) + latents = latents.to(prompt_embeds.dtype) + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if comfyui_progressbar: + pbar.update(1) + else: - latents, old_pred_original_sample = self.scheduler.step( - noise_pred, - old_pred_original_sample, - t, - timesteps[i - 1] if i > 0 else None, - latents, - **extra_step_kwargs, + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, return_dict=False, - ) - latents = latents.to(prompt_embeds.dtype) + inpaint_latents=inpaint_latents, + )[0] + noise_pred = noise_pred.float() - # call the callback, if provided - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + # perform guidance + if use_dynamic_cfg: + self._guidance_scale = 1 + guidance_scale * ( + (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 + ) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + # compute the previous noisy sample x_t -> x_t-1 + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + else: + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) + latents = latents.to(prompt_embeds.dtype) - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if comfyui_progressbar: - pbar.update(1) + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if comfyui_progressbar: + pbar.update(1) # if output_type == "numpy": # video = self.decode_latents(latents) diff --git a/nodes.py b/nodes.py index e60ae18..24b8ce0 100644 --- a/nodes.py +++ b/nodes.py @@ -974,7 +974,7 @@ class CogVideoXFunSampler: "pipeline": ("COGVIDEOPIPE",), "positive": ("CONDITIONING", ), "negative": ("CONDITIONING", ), - "video_length": ("INT", {"default": 49, "min": 5, "max": 49, "step": 4}), + "video_length": ("INT", {"default": 49, "min": 5, "max": 2048, "step": 4}), "base_resolution": ("INT", {"min": 256, "max": 1280, "step": 64, "default": 512, "tooltip": "Base resolution, closest training data bucket resolution is chosen based on the selection."}), "seed": ("INT", {"default": 43, "min": 0, "max": 0xffffffffffffffff}), "steps": ("INT", {"default": 50, "min": 1, "max": 200, "step": 1}), @@ -1003,6 +1003,7 @@ class CogVideoXFunSampler: "end_img": ("IMAGE",), "opt_empty_latent": ("LATENT",), "noise_aug_strength": ("FLOAT", {"default": 0.0563, "min": 0.0, "max": 1.0, "step": 0.001}), + "context_options": ("COGCONTEXT", ), }, } @@ -1012,7 +1013,7 @@ class CogVideoXFunSampler: CATEGORY = "CogVideoWrapper" def process(self, pipeline, positive, negative, video_length, base_resolution, seed, steps, cfg, scheduler, - start_img=None, end_img=None, opt_empty_latent=None, noise_aug_strength=0.0563): + start_img=None, end_img=None, opt_empty_latent=None, noise_aug_strength=0.0563, context_options=None): device = mm.get_torch_device() offload_device = mm.unet_offload_device() pipe = pipeline["pipe"] @@ -1041,6 +1042,9 @@ class CogVideoXFunSampler: log.info(f"Closest bucket size: {width}x{height}") # Load Sampler + if context_options is not None and context_options["context_schedule"] == "temporal_tiling": + logging.info("Temporal tiling enabled, changing scheduler to DDIM_tiled") + scheduler="DDIM_tiled" scheduler_config = pipeline["scheduler_config"] if scheduler in scheduler_mapping: noise_scheduler = scheduler_mapping[scheduler].from_config(scheduler_config) @@ -1050,7 +1054,15 @@ class CogVideoXFunSampler: #if not pipeline["cpu_offloading"]: # pipe.transformer.to(device) - generator= torch.Generator(device=device).manual_seed(seed) + + if context_options is not None: + context_frames = context_options["context_frames"] // 4 + context_stride = context_options["context_stride"] // 4 + context_overlap = context_options["context_overlap"] // 4 + else: + context_frames, context_stride, context_overlap = None, None, None + + generator= torch.Generator(device="cpu").manual_seed(seed) autocastcondition = not pipeline["onediff"] autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext() @@ -1072,6 +1084,11 @@ class CogVideoXFunSampler: mask_video = input_video_mask, comfyui_progressbar = True, noise_aug_strength = noise_aug_strength, + context_schedule=context_options["context_schedule"] if context_options is not None else None, + context_frames=context_frames, + context_stride= context_stride, + context_overlap= context_overlap, + freenoise=context_options["freenoise"] if context_options is not None else None ) #if not pipeline["cpu_offloading"]: # pipe.transformer.to(offload_device) @@ -1282,6 +1299,7 @@ class CogVideoContextOptions: "context_frames": ("INT", {"default": 12, "min": 2, "max": 100, "step": 1, "tooltip": "Number of pixel frames in the context, NOTE: the latent space has 4 frames in 1"} ), "context_stride": ("INT", {"default": 4, "min": 4, "max": 100, "step": 1, "tooltip": "Context stride as pixel frames, NOTE: the latent space has 4 frames in 1"} ), "context_overlap": ("INT", {"default": 4, "min": 4, "max": 100, "step": 1, "tooltip": "Context overlap as pixel frames, NOTE: the latent space has 4 frames in 1"} ), + "freenoise": ("BOOLEAN", {"default": True, "tooltip": "Shuffle the noise"}), } } @@ -1290,12 +1308,13 @@ class CogVideoContextOptions: FUNCTION = "process" CATEGORY = "CogVideoWrapper" - def process(self, context_schedule, context_frames, context_stride, context_overlap): + def process(self, context_schedule, context_frames, context_stride, context_overlap, freenoise): context_options = { "context_schedule":context_schedule, "context_frames":context_frames, "context_stride":context_stride, - "context_overlap":context_overlap + "context_overlap":context_overlap, + "freenoise":freenoise } return (context_options,) @@ -1362,6 +1381,13 @@ class CogVideoXFunControlSampler: mm.soft_empty_cache() + if context_options is not None: + context_frames = context_options["context_frames"] // 4 + context_stride = context_options["context_stride"] // 4 + context_overlap = context_options["context_overlap"] // 4 + else: + context_frames, context_stride, context_overlap = None, None, None + # Load Sampler scheduler_config = pipeline["scheduler_config"] if context_options is not None and context_options["context_schedule"] == "temporal_tiling": @@ -1373,7 +1399,7 @@ class CogVideoXFunControlSampler: else: raise ValueError(f"Unknown scheduler: {scheduler}") - generator= torch.Generator(device).manual_seed(seed) + generator=torch.Generator(torch.device("cpu")).manual_seed(seed) autocastcondition = not pipeline["onediff"] autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext() @@ -1401,9 +1427,10 @@ class CogVideoXFunControlSampler: latents=samples["samples"] if samples is not None else None, denoise_strength=denoise_strength, context_schedule=context_options["context_schedule"] if context_options is not None else None, - context_frames=context_options["context_frames"] if context_options is not None else None, - context_stride=context_options["context_stride"] if context_options is not None else None, - context_overlap=context_options["context_overlap"] if context_options is not None else None + context_frames=context_frames, + context_stride= context_stride, + context_overlap= context_overlap, + freenoise=context_options["freenoise"] if context_options is not None else None )