diff --git a/nodes.py b/nodes.py index 8fdc579..2d90617 100644 --- a/nodes.py +++ b/nodes.py @@ -1104,16 +1104,7 @@ class ToraEncodeTrajectory: vae = pipeline["pipe"].vae vae.enable_slicing() - vae._clear_fake_context_parallel_cache() - - #get coordinates from string and convert to compatible range/format (has to be 256x256 for the model) - # coordinates = json.loads(coordinates.replace("'", '"')) - # coordinates = [(coord['x'], coord['y']) for coord in coordinates] - # traj_list_range_256 = scale_traj_list_to_256(coordinates, width, height) - print(f"Type of coordinates: {type(coordinates)}") - print(f"Structure of coordinates: {coordinates}") - print(len(coordinates)) - + vae._clear_fake_context_parallel_cache() if len(coordinates) < 10: coords_list = [] @@ -1301,7 +1292,7 @@ class CogVideoSampler: padding = torch.zeros((negative.shape[0], target_length - negative.shape[1], negative.shape[2]), device=negative.device) negative = torch.cat((negative, padding), dim=1) - autocastcondition = not pipeline["onediff"] + autocastcondition = not pipeline["onediff"] or not dtype == torch.float32 autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext() with autocast_context: latents = pipeline["pipe"]( @@ -1471,7 +1462,7 @@ class CogVideoXFunSampler: generator = torch.Generator(device=torch.device("cpu")).manual_seed(seed) - autocastcondition = not pipeline["onediff"] + autocastcondition = not pipeline["onediff"] or not dtype == torch.float32 autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext() with autocast_context: video_length = int((video_length - 1) // pipe.vae.config.temporal_compression_ratio * pipe.vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 @@ -1566,7 +1557,7 @@ class CogVideoXFunVid2VidSampler: generator = torch.Generator(device=torch.device("cpu")).manual_seed(seed) - autocastcondition = not pipeline["onediff"] + autocastcondition = not pipeline["onediff"] or not dtype == torch.float32 autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext() with autocast_context: video_length = int((video_length - 1) // pipe.vae.config.temporal_compression_ratio * pipe.vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 @@ -1813,7 +1804,7 @@ class CogVideoXFunControlSampler: generator = torch.Generator(device=torch.device("cpu")).manual_seed(seed) - autocastcondition = not pipeline["onediff"] + autocastcondition = not pipeline["onediff"] or not dtype == torch.float32 autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext() with autocast_context: