Add control strenght, start and end percent adjustments

This commit is contained in:
Jukka Seppänen 2024-09-30 17:50:48 +03:00
parent dbeea9bfb8
commit f3a1ff933e
2 changed files with 22 additions and 3 deletions

View File

@ -444,6 +444,9 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 226,
comfyui_progressbar: bool = False,
control_strength: float = 1.0,
control_start_percent: float = 0.0,
control_end_percent: float = 1.0,
) -> Union[CogVideoX_Fun_PipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
@ -610,6 +613,8 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
)
control_latents = rearrange(control_video_latents_input, "b c f h w -> b f c h w")
control_latents = control_latents * control_strength
if comfyui_progressbar:
pbar.update(1)
@ -636,6 +641,13 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# Calculate the current step percentage
current_step_percentage = i / num_inference_steps
# Determine if control_latents should be applied
apply_control = control_start_percent <= current_step_percentage <= control_end_percent
current_control_latents = control_latents if apply_control else torch.zeros_like(control_latents)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
@ -646,7 +658,7 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
timestep=timestep,
image_rotary_emb=image_rotary_emb,
return_dict=False,
control_latents=control_latents,
control_latents=current_control_latents,
)[0]
noise_pred = noise_pred.float()

View File

@ -1105,6 +1105,9 @@ class CogVideoXFunVid2VidSampler:
"optional":{
"validation_video": ("IMAGE",),
"control_video": ("IMAGE",),
"control_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"control_start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"control_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
},
}
@ -1113,7 +1116,8 @@ class CogVideoXFunVid2VidSampler:
FUNCTION = "process"
CATEGORY = "CogVideoWrapper"
def process(self, pipeline, positive, negative, video_length, base_resolution, seed, steps, cfg, denoise_strength, scheduler, validation_video=None, control_video=None):
def process(self, pipeline, positive, negative, video_length, base_resolution, seed, steps, cfg, denoise_strength, scheduler,
validation_video=None, control_video=None, control_strength=1.0, control_start_percent=0.0, control_end_percent=1.0):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
pipe = pipeline["pipe"]
@ -1177,7 +1181,10 @@ class CogVideoXFunVid2VidSampler:
if control_video is not None:
latents = pipe(
**common_params,
control_video=input_video
control_video=input_video,
control_strength=control_strength,
control_start_percent=control_start_percent,
control_end_percent=control_end_percent
)
else:
latents = pipe(