mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-09 12:54:22 +08:00
Update nodes.py
This commit is contained in:
parent
d3a753e8e5
commit
45c3f06d0a
29
nodes.py
29
nodes.py
@ -1107,11 +1107,28 @@ class ToraEncodeTrajectory:
|
||||
vae._clear_fake_context_parallel_cache()
|
||||
|
||||
#get coordinates from string and convert to compatible range/format (has to be 256x256 for the model)
|
||||
coordinates = json.loads(coordinates.replace("'", '"'))
|
||||
coordinates = [(coord['x'], coord['y']) for coord in coordinates]
|
||||
traj_list_range_256 = scale_traj_list_to_256(coordinates, width, height)
|
||||
# coordinates = json.loads(coordinates.replace("'", '"'))
|
||||
# coordinates = [(coord['x'], coord['y']) for coord in coordinates]
|
||||
# traj_list_range_256 = scale_traj_list_to_256(coordinates, width, height)
|
||||
print(f"Type of coordinates: {type(coordinates)}")
|
||||
print(f"Structure of coordinates: {coordinates}")
|
||||
print(len(coordinates))
|
||||
|
||||
|
||||
video_flow, points = process_traj(traj_list_range_256, num_frames, (height,width), device=device)
|
||||
if len(coordinates) < 10:
|
||||
coords_list = []
|
||||
for coords in coordinates:
|
||||
coords = json.loads(coords.replace("'", '"'))
|
||||
coords = [(coord['x'], coord['y']) for coord in coords]
|
||||
traj_list_range_256 = scale_traj_list_to_256(coords, width, height)
|
||||
coords_list.append(traj_list_range_256)
|
||||
else:
|
||||
coords = json.loads(coordinates.replace("'", '"'))
|
||||
coords = [(coord['x'], coord['y']) for coord in coords]
|
||||
coords_list = scale_traj_list_to_256(coords, width, height)
|
||||
|
||||
|
||||
video_flow, points = process_traj(coords_list, num_frames, (height,width), device=device)
|
||||
video_flow = rearrange(video_flow, "T H W C -> T C H W")
|
||||
video_flow = flow_to_image(video_flow).unsqueeze_(0).to(device) # [1 T C H W]
|
||||
|
||||
@ -1128,7 +1145,9 @@ class ToraEncodeTrajectory:
|
||||
vae.to(device)
|
||||
|
||||
video_flow = vae.encode(video_flow).latent_dist.sample(generator) * vae.config.scaling_factor
|
||||
vae.to(offload_device)
|
||||
|
||||
if not pipeline["cpu_offloading"]:
|
||||
vae.to(offload_device)
|
||||
|
||||
video_flow_features = tora_model["traj_extractor"](video_flow.to(torch.float32))
|
||||
video_flow_features = torch.stack(video_flow_features)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user