Update nodes.py

This commit is contained in:
kijai 2024-08-28 00:52:57 +03:00
parent 84d950d852
commit 4a75855740

View File

@ -71,6 +71,11 @@ class DownloadAndLoadCogVideoModel:
pipe = CogVideoXPipeline(vae, transformer, scheduler)
if torch_compile:
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True
torch._inductor.config.num_stages = 1
pipe.transformer.to(memory_format=torch.channels_last)
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
@ -281,6 +286,12 @@ class CogVideoDecode:
"pipeline": ("COGVIDEOPIPE",),
"samples": ("LATENT", ),
"enable_vae_tiling": ("BOOLEAN", {"default": False}),
},
"optional": {
"tile_sample_min_height": ("INT", {"default": 96, "min": 16, "max": 2048, "step": 8}),
"tile_sample_min_width": ("INT", {"default": 96, "min": 16, "max": 2048, "step": 8}),
"tile_overlap_factor_height": ("FLOAT", {"default": 0.083, "min": 0.0, "max": 1.0, "step": 0.001}),
"tile_overlap_factor_width": ("FLOAT", {"default": 0.083, "min": 0.0, "max": 1.0, "step": 0.001}),
}
}
@ -289,7 +300,7 @@ class CogVideoDecode:
FUNCTION = "decode"
CATEGORY = "CogVideoWrapper"
def decode(self, pipeline, samples, enable_vae_tiling):
def decode(self, pipeline, samples, enable_vae_tiling, tile_sample_min_height, tile_sample_min_width, tile_overlap_factor_height, tile_overlap_factor_width):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
latents = samples["samples"]
@ -297,10 +308,10 @@ class CogVideoDecode:
vae.to(device)
if enable_vae_tiling:
vae.enable_tiling(
tile_sample_min_height=96,
tile_sample_min_width=96,
tile_overlap_factor_height=1 / 12,
tile_overlap_factor_width=1 / 12,
tile_sample_min_height=tile_sample_min_height,
tile_sample_min_width=tile_sample_min_width,
tile_overlap_factor_height=tile_overlap_factor_height,
tile_overlap_factor_width=tile_overlap_factor_width,
)
latents = latents.to(vae.dtype)
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]