Tora for fun context scheduling

This commit is contained in:
kijai 2024-10-29 19:51:36 +02:00
parent 3f97f07275
commit dccc8bdcb7

View File

@ -988,7 +988,6 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
pbar.update(1)
# ==========================================
elif use_context_schedule:
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@ -1001,6 +1000,8 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
counter = torch.zeros_like(latent_model_input)
noise_pred = torch.zeros_like(latent_model_input)
current_step_percentage = i / num_inference_steps
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, context_frames, device)
if self.transformer.config.use_rotary_positional_embeddings
@ -1011,6 +1012,13 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
partial_latent_model_input = latent_model_input[:, c, :, :, :]
partial_inpaint_latents = inpaint_latents[:, c, :, :, :]
partial_inpaint_latents[:, 0, :, :, :] = inpaint_latents[:, 0, :, :, :]
if (tora is not None and tora["start_percent"] <= current_step_percentage <= tora["end_percent"]):
if do_classifier_free_guidance:
partial_video_flow_features = tora["video_flow_features"][:, c, :, :, :].repeat(1, 2, 1, 1, 1).contiguous()
else:
partial_video_flow_features = tora["video_flow_features"][:, c, :, :, :]
else:
partial_video_flow_features = None
# predict noise model_output
noise_pred[:, c, :, :, :] += self.transformer(
@ -1020,6 +1028,7 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
image_rotary_emb=image_rotary_emb,
return_dict=False,
inpaint_latents=partial_inpaint_latents,
video_flow_features=partial_video_flow_features
)[0]
counter[:, c, :, :, :] += 1