Update pipeline_cogvideox.py

This commit is contained in:
kijai 2024-10-06 01:50:31 +03:00
parent 00d38f9a22
commit 668ba792db

View File

@ -542,6 +542,8 @@ class CogVideoXPipeline(VideoSysPipeline):
use_temporal_tiling = True
print("Temporal tiling enabled")
elif context_schedule is not None:
if image_cond_latents is not None:
raise NotImplementedError("Context schedule not currently supported with image conditioning")
print(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap")
use_temporal_tiling = False
use_context_schedule = True
@ -684,7 +686,6 @@ class CogVideoXPipeline(VideoSysPipeline):
for c in context_queue:
partial_latent_model_input = latent_model_input[:, c, :, :, :]
# predict noise model_output
noise_pred[:, c, :, :, :] += self.transformer(
hidden_states=partial_latent_model_input,
@ -729,6 +730,7 @@ class CogVideoXPipeline(VideoSysPipeline):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
comfy_pbar.update(1)
else:
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
@ -740,7 +742,6 @@ class CogVideoXPipeline(VideoSysPipeline):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output
noise_pred = self.transformer(
hidden_states=latent_model_input,