From 137da34a53b94f7d06c4be0441ec650e72e577a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Fri, 20 Sep 2024 20:23:46 +0300 Subject: [PATCH 1/3] Add load_device selection for GGUF node --- nodes.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/nodes.py b/nodes.py index 0ba47bf..bd91a13 100644 --- a/nodes.py +++ b/nodes.py @@ -178,7 +178,7 @@ class DownloadAndLoadCogVideoGGUFModel: ), "vae_precision": (["fp16", "fp32", "bf16"], {"default": "bf16", "tooltip": "VAE dtype"}), "fp8_fastmode": ("BOOLEAN", {"default": False, "tooltip": "only supported on 4090 and later GPUs"}), - "compile": (["disabled","onediff","torch"], {"tooltip": "UNTESTED WITH GGUF"}), + "load_device": (["main_device", "offload_device"], {"default": "main_device"}), }, } @@ -187,7 +187,7 @@ class DownloadAndLoadCogVideoGGUFModel: FUNCTION = "loadmodel" CATEGORY = "CogVideoWrapper" - def loadmodel(self, model, vae_precision, compile, fp8_fastmode): + def loadmodel(self, model, vae_precision, fp8_fastmode, load_device): device = mm.get_torch_device() offload_device = mm.unet_offload_device() mm.soft_empty_cache() @@ -227,7 +227,10 @@ class DownloadAndLoadCogVideoGGUFModel: transformer.to(torch.float8_e4m3fn) transformer = mz_gguf_loader.quantize_load_state_dict(transformer, sd, device="cpu") - transformer.to(device) + if load_device == "offload_device": + transformer.to(offload_device) + else: + transformer.to(device) # transformer # if fp8_transformer == "fastmode": From baabbf9a46a6e3722f30861809088b329258e2ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Fri, 20 Sep 2024 20:53:19 +0300 Subject: [PATCH 2/3] Update nodes.py --- nodes.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/nodes.py b/nodes.py index bd91a13..5969e24 100644 --- a/nodes.py +++ b/nodes.py @@ -232,19 +232,6 @@ class DownloadAndLoadCogVideoGGUFModel: else: transformer.to(device) - # transformer - # if fp8_transformer == "fastmode": - # if "2b" in model: - # for name, param in transformer.named_parameters(): - # if name != "pos_embedding": - # param.data = param.data.to(torch.float8_e4m3fn) - # elif "I2V" in model: - # for name, param in transformer.named_parameters(): - # if "patch_embed" not in name: - # param.data = param.data.to(torch.float8_e4m3fn) - # else: - # transformer.to(torch.float8_e4m3fn) - if fp8_fastmode: from .fp8_optimization import convert_fp8_linear convert_fp8_linear(transformer, vae_dtype) @@ -677,7 +664,7 @@ class CogVideoDecode: latents = latents.to(vae.dtype) latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] latents = 1 / vae.config.scaling_factor * latents - + vae._clear_fake_context_parallel_cache() frames = vae.decode(latents).sample vae.disable_tiling() if not pipeline["cpu_offloading"]: From d22c0b866bc9ab52a07d3d054a230c0efd6f545b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Fri, 20 Sep 2024 21:36:42 +0300 Subject: [PATCH 3/3] Update nodes.py --- nodes.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nodes.py b/nodes.py index 5969e24..4c06493 100644 --- a/nodes.py +++ b/nodes.py @@ -576,16 +576,16 @@ class CogVideoSampler: offload_device = mm.unet_offload_device() pipe = pipeline["pipe"] dtype = pipeline["dtype"] - + scheduler_config = pipeline["scheduler_config"] if not pipeline["cpu_offloading"]: pipe.transformer.to(device) generator = torch.Generator(device=device).manual_seed(seed) if scheduler == "DDIM" or scheduler == "DDIM_tiled": - pipe.scheduler = CogVideoXDDIMScheduler.from_pretrained(base_path, subfolder="scheduler") + pipe.scheduler = CogVideoXDDIMScheduler.from_config(scheduler_config) elif scheduler == "DPM": - pipe.scheduler = CogVideoXDPMScheduler.from_pretrained(base_path, subfolder="scheduler") + pipe.scheduler = CogVideoXDPMScheduler.from_config(scheduler_config) if negative.shape[1] < positive.shape[1]: target_length = positive.shape[1]