some tweaks to test I2V with context windows, add context window preview

This commit is contained in:
kijai 2025-01-28 22:40:58 +02:00
parent fed499e971
commit dbc63f622d

View File

@ -658,10 +658,9 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
counter = torch.zeros_like(latent_model_input)
noise_pred = torch.zeros_like(latent_model_input)
if image_cond_latents is not None:
latent_image_input = torch.cat([image_cond_latents] * 2) if do_classifier_free_guidance else image_cond_latents
latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
@ -724,7 +723,14 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
noise_pred = noise_pred.float()
else:
for c in context_queue:
print("c:", c)
partial_latent_model_input = latent_model_input[:, c, :, :, :]
if image_cond_latents is not None:
partial_latent_image_input = latent_image_input[:, :len(c), :, :, :]
partial_latent_model_input = torch.cat([partial_latent_model_input,partial_latent_image_input], dim=2)
print(partial_latent_model_input.shape)
if (tora is not None and tora["start_percent"] <= current_step_percentage <= tora["end_percent"]):
if do_classifier_free_guidance:
partial_video_flow_features = tora["video_flow_features"][:, c, :, :, :].repeat(1, 2, 1, 1, 1).contiguous()
@ -768,7 +774,13 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
comfy_pbar.update(1)
if callback is not None:
alpha_prod_t = self.scheduler.alphas_cumprod[t]
beta_prod_t = 1 - alpha_prod_t
callback_tensor = (alpha_prod_t**0.5) * latent_model_input[0][:, :16, :, :] - (beta_prod_t**0.5) * noise_pred.detach()[0]
callback(i, callback_tensor * 5, None, num_inference_steps)
else:
comfy_pbar.update(1)
# region sampling
else: