diff --git a/pipeline_cogvideox.py b/pipeline_cogvideox.py index 3fdf043..9ce8529 100644 --- a/pipeline_cogvideox.py +++ b/pipeline_cogvideox.py @@ -658,10 +658,9 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) counter = torch.zeros_like(latent_model_input) noise_pred = torch.zeros_like(latent_model_input) - + if image_cond_latents is not None: latent_image_input = torch.cat([image_cond_latents] * 2) if do_classifier_free_guidance else image_cond_latents - latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) @@ -724,7 +723,14 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): noise_pred = noise_pred.float() else: for c in context_queue: + print("c:", c) + partial_latent_model_input = latent_model_input[:, c, :, :, :] + if image_cond_latents is not None: + partial_latent_image_input = latent_image_input[:, :len(c), :, :, :] + partial_latent_model_input = torch.cat([partial_latent_model_input,partial_latent_image_input], dim=2) + + print(partial_latent_model_input.shape) if (tora is not None and tora["start_percent"] <= current_step_percentage <= tora["end_percent"]): if do_classifier_free_guidance: partial_video_flow_features = tora["video_flow_features"][:, c, :, :, :].repeat(1, 2, 1, 1, 1).contiguous() @@ -768,7 +774,13 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() - comfy_pbar.update(1) + if callback is not None: + alpha_prod_t = self.scheduler.alphas_cumprod[t] + beta_prod_t = 1 - alpha_prod_t + callback_tensor = (alpha_prod_t**0.5) * latent_model_input[0][:, :16, :, :] - (beta_prod_t**0.5) * noise_pred.detach()[0] + callback(i, callback_tensor * 5, None, num_inference_steps) + else: + comfy_pbar.update(1) # region sampling else: