diff --git a/cogvideox_fun/pipeline_cogvideox_control.py b/cogvideox_fun/pipeline_cogvideox_control.py index 7130d50..5dcda4a 100644 --- a/cogvideox_fun/pipeline_cogvideox_control.py +++ b/cogvideox_fun/pipeline_cogvideox_control.py @@ -444,6 +444,9 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline): callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 226, comfyui_progressbar: bool = False, + control_strength: float = 1.0, + control_start_percent: float = 0.0, + control_end_percent: float = 1.0, ) -> Union[CogVideoX_Fun_PipelineOutput, Tuple]: """ Function invoked when calling the pipeline for generation. @@ -610,6 +613,8 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline): ) control_latents = rearrange(control_video_latents_input, "b c f h w -> b f c h w") + control_latents = control_latents * control_strength + if comfyui_progressbar: pbar.update(1) @@ -636,6 +641,13 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline): latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + # Calculate the current step percentage + current_step_percentage = i / num_inference_steps + + # Determine if control_latents should be applied + apply_control = control_start_percent <= current_step_percentage <= control_end_percent + current_control_latents = control_latents if apply_control else torch.zeros_like(control_latents) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) @@ -646,7 +658,7 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline): timestep=timestep, image_rotary_emb=image_rotary_emb, return_dict=False, - control_latents=control_latents, + control_latents=current_control_latents, )[0] noise_pred = noise_pred.float() diff --git a/nodes.py b/nodes.py index ed74095..62529ec 100644 --- a/nodes.py +++ b/nodes.py @@ -1105,6 +1105,9 @@ class CogVideoXFunVid2VidSampler: "optional":{ "validation_video": ("IMAGE",), "control_video": ("IMAGE",), + "control_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "control_start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "control_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), }, } @@ -1113,7 +1116,8 @@ class CogVideoXFunVid2VidSampler: FUNCTION = "process" CATEGORY = "CogVideoWrapper" - def process(self, pipeline, positive, negative, video_length, base_resolution, seed, steps, cfg, denoise_strength, scheduler, validation_video=None, control_video=None): + def process(self, pipeline, positive, negative, video_length, base_resolution, seed, steps, cfg, denoise_strength, scheduler, + validation_video=None, control_video=None, control_strength=1.0, control_start_percent=0.0, control_end_percent=1.0): device = mm.get_torch_device() offload_device = mm.unet_offload_device() pipe = pipeline["pipe"] @@ -1177,7 +1181,10 @@ class CogVideoXFunVid2VidSampler: if control_video is not None: latents = pipe( **common_params, - control_video=input_video + control_video=input_video, + control_strength=control_strength, + control_start_percent=control_start_percent, + control_end_percent=control_end_percent ) else: latents = pipe(