Update nodes.py

This commit is contained in:
kijai 2024-10-21 00:29:44 +03:00
parent e8bc2fd052
commit 47f028e0bd

View File

@ -1030,8 +1030,8 @@ class ToraEncodeTrajectory:
}, },
} }
RETURN_TYPES = ("TORAFEATURES",) RETURN_TYPES = ("TORAFEATURES", )
RETURN_NAMES = ("tora_trajectory",) RETURN_NAMES = ("tora_trajectory", )
FUNCTION = "encode" FUNCTION = "encode"
CATEGORY = "CogVideoWrapper" CATEGORY = "CogVideoWrapper"
@ -1044,7 +1044,7 @@ class ToraEncodeTrajectory:
vae = pipeline["pipe"].vae vae = pipeline["pipe"].vae
vae.enable_slicing() vae.enable_slicing()
canvas_width, canvas_height = 256, 256 canvas_width, canvas_height = width, height
coordinates = json.loads(coordinates.replace("'", '"')) coordinates = json.loads(coordinates.replace("'", '"'))
coordinates = [(coord['x'], coord['y']) for coord in coordinates] coordinates = [(coord['x'], coord['y']) for coord in coordinates]
@ -1081,7 +1081,7 @@ class ToraEncodeTrajectory:
video_flow_features = traj_extractor(video_flow.to(torch.float32)) video_flow_features = traj_extractor(video_flow.to(torch.float32))
video_flow_features = torch.stack(video_flow_features) video_flow_features = torch.stack(video_flow_features)
return (video_flow_features, ) return (video_flow_features,)