diff --git a/nodes.py b/nodes.py index 6314475..eefa861 100644 --- a/nodes.py +++ b/nodes.py @@ -71,6 +71,11 @@ class DownloadAndLoadCogVideoModel: pipe = CogVideoXPipeline(vae, transformer, scheduler) if torch_compile: + torch._inductor.config.conv_1x1_as_mm = True + torch._inductor.config.coordinate_descent_tuning = True + torch._inductor.config.epilogue_fusion = False + torch._inductor.config.coordinate_descent_check_all_directions = True + torch._inductor.config.num_stages = 1 pipe.transformer.to(memory_format=torch.channels_last) pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True) @@ -281,6 +286,12 @@ class CogVideoDecode: "pipeline": ("COGVIDEOPIPE",), "samples": ("LATENT", ), "enable_vae_tiling": ("BOOLEAN", {"default": False}), + }, + "optional": { + "tile_sample_min_height": ("INT", {"default": 96, "min": 16, "max": 2048, "step": 8}), + "tile_sample_min_width": ("INT", {"default": 96, "min": 16, "max": 2048, "step": 8}), + "tile_overlap_factor_height": ("FLOAT", {"default": 0.083, "min": 0.0, "max": 1.0, "step": 0.001}), + "tile_overlap_factor_width": ("FLOAT", {"default": 0.083, "min": 0.0, "max": 1.0, "step": 0.001}), } } @@ -289,7 +300,7 @@ class CogVideoDecode: FUNCTION = "decode" CATEGORY = "CogVideoWrapper" - def decode(self, pipeline, samples, enable_vae_tiling): + def decode(self, pipeline, samples, enable_vae_tiling, tile_sample_min_height, tile_sample_min_width, tile_overlap_factor_height, tile_overlap_factor_width): device = mm.get_torch_device() offload_device = mm.unet_offload_device() latents = samples["samples"] @@ -297,10 +308,10 @@ class CogVideoDecode: vae.to(device) if enable_vae_tiling: vae.enable_tiling( - tile_sample_min_height=96, - tile_sample_min_width=96, - tile_overlap_factor_height=1 / 12, - tile_overlap_factor_width=1 / 12, + tile_sample_min_height=tile_sample_min_height, + tile_sample_min_width=tile_sample_min_width, + tile_overlap_factor_height=tile_overlap_factor_height, + tile_overlap_factor_width=tile_overlap_factor_width, ) latents = latents.to(vae.dtype) latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]