diff --git a/nodes.py b/nodes.py index 0ba47bf..bd91a13 100644 --- a/nodes.py +++ b/nodes.py @@ -178,7 +178,7 @@ class DownloadAndLoadCogVideoGGUFModel: ), "vae_precision": (["fp16", "fp32", "bf16"], {"default": "bf16", "tooltip": "VAE dtype"}), "fp8_fastmode": ("BOOLEAN", {"default": False, "tooltip": "only supported on 4090 and later GPUs"}), - "compile": (["disabled","onediff","torch"], {"tooltip": "UNTESTED WITH GGUF"}), + "load_device": (["main_device", "offload_device"], {"default": "main_device"}), }, } @@ -187,7 +187,7 @@ class DownloadAndLoadCogVideoGGUFModel: FUNCTION = "loadmodel" CATEGORY = "CogVideoWrapper" - def loadmodel(self, model, vae_precision, compile, fp8_fastmode): + def loadmodel(self, model, vae_precision, fp8_fastmode, load_device): device = mm.get_torch_device() offload_device = mm.unet_offload_device() mm.soft_empty_cache() @@ -227,7 +227,10 @@ class DownloadAndLoadCogVideoGGUFModel: transformer.to(torch.float8_e4m3fn) transformer = mz_gguf_loader.quantize_load_state_dict(transformer, sd, device="cpu") - transformer.to(device) + if load_device == "offload_device": + transformer.to(offload_device) + else: + transformer.to(device) # transformer # if fp8_transformer == "fastmode":