mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-09 21:04:23 +08:00
Update nodes.py
This commit is contained in:
parent
d3a753e8e5
commit
45c3f06d0a
27
nodes.py
27
nodes.py
@ -1107,11 +1107,28 @@ class ToraEncodeTrajectory:
|
|||||||
vae._clear_fake_context_parallel_cache()
|
vae._clear_fake_context_parallel_cache()
|
||||||
|
|
||||||
#get coordinates from string and convert to compatible range/format (has to be 256x256 for the model)
|
#get coordinates from string and convert to compatible range/format (has to be 256x256 for the model)
|
||||||
coordinates = json.loads(coordinates.replace("'", '"'))
|
# coordinates = json.loads(coordinates.replace("'", '"'))
|
||||||
coordinates = [(coord['x'], coord['y']) for coord in coordinates]
|
# coordinates = [(coord['x'], coord['y']) for coord in coordinates]
|
||||||
traj_list_range_256 = scale_traj_list_to_256(coordinates, width, height)
|
# traj_list_range_256 = scale_traj_list_to_256(coordinates, width, height)
|
||||||
|
print(f"Type of coordinates: {type(coordinates)}")
|
||||||
|
print(f"Structure of coordinates: {coordinates}")
|
||||||
|
print(len(coordinates))
|
||||||
|
|
||||||
video_flow, points = process_traj(traj_list_range_256, num_frames, (height,width), device=device)
|
|
||||||
|
if len(coordinates) < 10:
|
||||||
|
coords_list = []
|
||||||
|
for coords in coordinates:
|
||||||
|
coords = json.loads(coords.replace("'", '"'))
|
||||||
|
coords = [(coord['x'], coord['y']) for coord in coords]
|
||||||
|
traj_list_range_256 = scale_traj_list_to_256(coords, width, height)
|
||||||
|
coords_list.append(traj_list_range_256)
|
||||||
|
else:
|
||||||
|
coords = json.loads(coordinates.replace("'", '"'))
|
||||||
|
coords = [(coord['x'], coord['y']) for coord in coords]
|
||||||
|
coords_list = scale_traj_list_to_256(coords, width, height)
|
||||||
|
|
||||||
|
|
||||||
|
video_flow, points = process_traj(coords_list, num_frames, (height,width), device=device)
|
||||||
video_flow = rearrange(video_flow, "T H W C -> T C H W")
|
video_flow = rearrange(video_flow, "T H W C -> T C H W")
|
||||||
video_flow = flow_to_image(video_flow).unsqueeze_(0).to(device) # [1 T C H W]
|
video_flow = flow_to_image(video_flow).unsqueeze_(0).to(device) # [1 T C H W]
|
||||||
|
|
||||||
@ -1128,6 +1145,8 @@ class ToraEncodeTrajectory:
|
|||||||
vae.to(device)
|
vae.to(device)
|
||||||
|
|
||||||
video_flow = vae.encode(video_flow).latent_dist.sample(generator) * vae.config.scaling_factor
|
video_flow = vae.encode(video_flow).latent_dist.sample(generator) * vae.config.scaling_factor
|
||||||
|
|
||||||
|
if not pipeline["cpu_offloading"]:
|
||||||
vae.to(offload_device)
|
vae.to(offload_device)
|
||||||
|
|
||||||
video_flow_features = tora_model["traj_extractor"](video_flow.to(torch.float32))
|
video_flow_features = tora_model["traj_extractor"](video_flow.to(torch.float32))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user