Update nodes.py

This commit is contained in:
kijai 2024-10-22 13:52:27 +03:00
parent d3a753e8e5
commit 45c3f06d0a

View File

@ -1107,11 +1107,28 @@ class ToraEncodeTrajectory:
vae._clear_fake_context_parallel_cache() vae._clear_fake_context_parallel_cache()
#get coordinates from string and convert to compatible range/format (has to be 256x256 for the model) #get coordinates from string and convert to compatible range/format (has to be 256x256 for the model)
coordinates = json.loads(coordinates.replace("'", '"')) # coordinates = json.loads(coordinates.replace("'", '"'))
coordinates = [(coord['x'], coord['y']) for coord in coordinates] # coordinates = [(coord['x'], coord['y']) for coord in coordinates]
traj_list_range_256 = scale_traj_list_to_256(coordinates, width, height) # traj_list_range_256 = scale_traj_list_to_256(coordinates, width, height)
print(f"Type of coordinates: {type(coordinates)}")
print(f"Structure of coordinates: {coordinates}")
print(len(coordinates))
video_flow, points = process_traj(traj_list_range_256, num_frames, (height,width), device=device)
if len(coordinates) < 10:
coords_list = []
for coords in coordinates:
coords = json.loads(coords.replace("'", '"'))
coords = [(coord['x'], coord['y']) for coord in coords]
traj_list_range_256 = scale_traj_list_to_256(coords, width, height)
coords_list.append(traj_list_range_256)
else:
coords = json.loads(coordinates.replace("'", '"'))
coords = [(coord['x'], coord['y']) for coord in coords]
coords_list = scale_traj_list_to_256(coords, width, height)
video_flow, points = process_traj(coords_list, num_frames, (height,width), device=device)
video_flow = rearrange(video_flow, "T H W C -> T C H W") video_flow = rearrange(video_flow, "T H W C -> T C H W")
video_flow = flow_to_image(video_flow).unsqueeze_(0).to(device) # [1 T C H W] video_flow = flow_to_image(video_flow).unsqueeze_(0).to(device) # [1 T C H W]
@ -1128,7 +1145,9 @@ class ToraEncodeTrajectory:
vae.to(device) vae.to(device)
video_flow = vae.encode(video_flow).latent_dist.sample(generator) * vae.config.scaling_factor video_flow = vae.encode(video_flow).latent_dist.sample(generator) * vae.config.scaling_factor
vae.to(offload_device)
if not pipeline["cpu_offloading"]:
vae.to(offload_device)
video_flow_features = tora_model["traj_extractor"](video_flow.to(torch.float32)) video_flow_features = tora_model["traj_extractor"](video_flow.to(torch.float32))
video_flow_features = torch.stack(video_flow_features) video_flow_features = torch.stack(video_flow_features)