This commit is contained in:
kijai 2024-10-22 00:52:58 +03:00
parent c2299a6c79
commit 35cb4fde94
2 changed files with 4 additions and 5 deletions

View File

@ -1298,7 +1298,6 @@ class CogVideoSampler:
negative_prompt_embeds=negative.to(dtype).to(device),
generator=generator,
device=device,
scheduler_name=scheduler,
context_schedule=context_options["context_schedule"] if context_options is not None else None,
context_frames=context_frames,
context_stride= context_stride,

View File

@ -382,7 +382,6 @@ class CogVideoXPipeline(VideoSysPipeline):
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
device = torch.device("cuda"),
scheduler_name: str = "DPM",
context_schedule: Optional[str] = None,
context_frames: Optional[int] = None,
context_stride: Optional[int] = None,
@ -582,7 +581,7 @@ class CogVideoXPipeline(VideoSysPipeline):
else None
)
if tora is not None and do_classifier_free_guidance:
tora["video_flow_features"] = tora["video_flow_features"].repeat(1, 2, 1, 1, 1).contiguous()
video_flow_features = tora["video_flow_features"].repeat(1, 2, 1, 1, 1).contiguous()
# 9. Controlnet
if controlnet is not None:
@ -783,7 +782,7 @@ class CogVideoXPipeline(VideoSysPipeline):
else:
for c in context_queue:
partial_latent_model_input = latent_model_input[:, c, :, :, :]
if tora is not None:
if (tora is not None and tora["start_percent"] <= current_step_percentage <= tora["end_percent"]):
if do_classifier_free_guidance:
partial_video_flow_features = tora["video_flow_features"][:, c, :, :, :].repeat(1, 2, 1, 1, 1).contiguous()
else:
@ -868,7 +867,8 @@ class CogVideoXPipeline(VideoSysPipeline):
return_dict=False,
controlnet_states=controlnet_states,
controlnet_weights=control_weights,
video_flow_features=tora["video_flow_features"] if (tora["start_percent"] <= current_step_percentage <= tora["end_percent"]) else None,
video_flow_features=video_flow_features if (tora is not None and tora["start_percent"] <= current_step_percentage <= tora["end_percent"]) else None,
)[0]
noise_pred = noise_pred.float()