fix progress bars for vid2vid

This commit is contained in:
kijai 2024-11-11 01:31:13 +02:00
parent ca63f5dade
commit 43bc7fb4fc

View File

@ -536,11 +536,10 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
logger.info(f"latents: {latents.shape}")
logger.info(f"mask: {mask.shape}")
# 7. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
comfy_pbar = ProgressBar(num_inference_steps)
# 8. context schedule and temporal tiling
# 7. context schedule and temporal tiling
if context_schedule is not None and context_schedule == "temporal_tiling":
t_tile_length = context_frames
t_tile_overlap = context_overlap
@ -560,7 +559,7 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
use_temporal_tiling = False
use_context_schedule = False
logger.info("Temporal tiling and context schedule disabled")
# 8.5. Create rotary embeds if required
# 7.5. Create rotary embeds if required
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
if self.transformer.config.use_rotary_positional_embeddings
@ -569,7 +568,7 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
if tora is not None and do_classifier_free_guidance:
video_flow_features = tora["video_flow_features"].repeat(1, 2, 1, 1, 1).contiguous()
# 9. Controlnet
#8. Controlnet
if controlnet is not None:
self.controlnet = controlnet["control_model"].to(device)
if self.transformer.dtype == torch.float8_e4m3fn:
@ -604,8 +603,9 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
for param in module.parameters():
param.data = param.data.to(device)
# 10. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
# 9. Denoising loop
comfy_pbar = ProgressBar(len(timesteps))
with self.progress_bar(total=len(timesteps)) as progress_bar:
old_pred_original_sample = None # for DPM-solver++
for i, t in enumerate(timesteps):
if self.interrupt: