From 45c3f06d0aac5a5a2d41c847970accddf404e6e5 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 22 Oct 2024 13:52:27 +0300 Subject: [PATCH] Update nodes.py --- nodes.py | 29 ++++++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/nodes.py b/nodes.py index 644e6d6..8fdc579 100644 --- a/nodes.py +++ b/nodes.py @@ -1107,11 +1107,28 @@ class ToraEncodeTrajectory: vae._clear_fake_context_parallel_cache() #get coordinates from string and convert to compatible range/format (has to be 256x256 for the model) - coordinates = json.loads(coordinates.replace("'", '"')) - coordinates = [(coord['x'], coord['y']) for coord in coordinates] - traj_list_range_256 = scale_traj_list_to_256(coordinates, width, height) + # coordinates = json.loads(coordinates.replace("'", '"')) + # coordinates = [(coord['x'], coord['y']) for coord in coordinates] + # traj_list_range_256 = scale_traj_list_to_256(coordinates, width, height) + print(f"Type of coordinates: {type(coordinates)}") + print(f"Structure of coordinates: {coordinates}") + print(len(coordinates)) + - video_flow, points = process_traj(traj_list_range_256, num_frames, (height,width), device=device) + if len(coordinates) < 10: + coords_list = [] + for coords in coordinates: + coords = json.loads(coords.replace("'", '"')) + coords = [(coord['x'], coord['y']) for coord in coords] + traj_list_range_256 = scale_traj_list_to_256(coords, width, height) + coords_list.append(traj_list_range_256) + else: + coords = json.loads(coordinates.replace("'", '"')) + coords = [(coord['x'], coord['y']) for coord in coords] + coords_list = scale_traj_list_to_256(coords, width, height) + + + video_flow, points = process_traj(coords_list, num_frames, (height,width), device=device) video_flow = rearrange(video_flow, "T H W C -> T C H W") video_flow = flow_to_image(video_flow).unsqueeze_(0).to(device) # [1 T C H W] @@ -1128,7 +1145,9 @@ class ToraEncodeTrajectory: vae.to(device) video_flow = vae.encode(video_flow).latent_dist.sample(generator) * vae.config.scaling_factor - vae.to(offload_device) + + if not pipeline["cpu_offloading"]: + vae.to(offload_device) video_flow_features = tora_model["traj_extractor"](video_flow.to(torch.float32)) video_flow_features = torch.stack(video_flow_features)