check that Tora trajectory length matches

This commit is contained in:
kijai 2024-10-30 23:11:49 +02:00
parent 133b42eb4f
commit d314ddbe05
3 changed files with 9 additions and 0 deletions

View File

@ -894,6 +894,10 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
video_flow_features = tora["video_flow_features"].repeat(1, 2, 1, 1, 1).contiguous()
if tora is not None:
trajectory_length = tora["video_flow_features"].shape[1]
logger.info(f"Tora trajectory length: {trajectory_length}")
if trajectory_length != latents.shape[1]:
raise ValueError(f"Tora trajectory length {trajectory_length} does not match inpaint_latents count {latents.shape[2]}")
for module in self.transformer.fuser_list:
for param in module.parameters():
param.data = param.data.to(device)

View File

@ -915,6 +915,7 @@ class CogVideoXFunSampler:
base_path = pipeline["base_path"]
assert "fun" in base_path.lower(), "'Unfun' models not supported in 'CogVideoXFunSampler', use the 'CogVideoSampler'"
assert "pose" not in base_path.lower(), "'Pose' models not supported in 'CogVideoXFunSampler', use the 'CogVideoXFunControlSampler'"
if not pipeline["cpu_offloading"]:
pipe.enable_model_cpu_offload(device=device)

View File

@ -613,6 +613,10 @@ class CogVideoXPipeline(VideoSysPipeline):
control_weights= None
if tora is not None:
trajectory_length = tora["video_flow_features"].shape[1]
logger.info(f"Tora trajectory length: {trajectory_length}")
if trajectory_length != latents.shape[1]:
raise ValueError(f"Tora trajectory length {trajectory_length} does not match inpaint_latents count {latents.shape[2]}")
for module in self.transformer.fuser_list:
for param in module.parameters():
param.data = param.data.to(device)