From dec1f8ba0fdf8d6494990323ebf28b14561475d5 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Mon, 28 Oct 2024 18:24:38 +0200 Subject: [PATCH] Update nodes.py --- nodes.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/nodes.py b/nodes.py index 4648470..5786678 100644 --- a/nodes.py +++ b/nodes.py @@ -1104,6 +1104,9 @@ class ToraEncodeTrajectory: "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}), "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), }, + "optional": { + "enable_tiling": ("BOOL", {"default": False}), + } } RETURN_TYPES = ("TORAFEATURES", "IMAGE", ) @@ -1111,7 +1114,7 @@ class ToraEncodeTrajectory: FUNCTION = "encode" CATEGORY = "CogVideoWrapper" - def encode(self, pipeline, width, height, num_frames, coordinates, strength, start_percent, end_percent, tora_model): + def encode(self, pipeline, width, height, num_frames, coordinates, strength, start_percent, end_percent, tora_model, enable_tiling=False): check_diffusers_version() device = mm.get_torch_device() offload_device = mm.unet_offload_device() @@ -1122,7 +1125,11 @@ class ToraEncodeTrajectory: try: vae._clear_fake_context_parallel_cache() except: - pass + pass + + if enable_tiling: + from .mz_enable_vae_encode_tiling import enable_vae_encode_tiling + enable_vae_encode_tiling(vae) if len(coordinates) < 10: coords_list = []