diff --git a/pipeline_cogvideox.py b/pipeline_cogvideox.py index fefd0bc..ae1aa6d 100644 --- a/pipeline_cogvideox.py +++ b/pipeline_cogvideox.py @@ -536,11 +536,10 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin): logger.info(f"latents: {latents.shape}") logger.info(f"mask: {mask.shape}") - # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - comfy_pbar = ProgressBar(num_inference_steps) - # 8. context schedule and temporal tiling + # 7. context schedule and temporal tiling if context_schedule is not None and context_schedule == "temporal_tiling": t_tile_length = context_frames t_tile_overlap = context_overlap @@ -560,7 +559,7 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin): use_temporal_tiling = False use_context_schedule = False logger.info("Temporal tiling and context schedule disabled") - # 8.5. Create rotary embeds if required + # 7.5. Create rotary embeds if required image_rotary_emb = ( self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) if self.transformer.config.use_rotary_positional_embeddings @@ -569,7 +568,7 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin): if tora is not None and do_classifier_free_guidance: video_flow_features = tora["video_flow_features"].repeat(1, 2, 1, 1, 1).contiguous() - # 9. Controlnet + #8. Controlnet if controlnet is not None: self.controlnet = controlnet["control_model"].to(device) if self.transformer.dtype == torch.float8_e4m3fn: @@ -604,8 +603,9 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin): for param in module.parameters(): param.data = param.data.to(device) - # 10. Denoising loop - with self.progress_bar(total=num_inference_steps) as progress_bar: + # 9. Denoising loop + comfy_pbar = ProgressBar(len(timesteps)) + with self.progress_bar(total=len(timesteps)) as progress_bar: old_pred_original_sample = None # for DPM-solver++ for i, t in enumerate(timesteps): if self.interrupt: