Add load_device selection for GGUF node

This commit is contained in:
Jukka Seppänen 2024-09-20 20:23:46 +03:00
parent 1e26971695
commit 137da34a53

View File

@ -178,7 +178,7 @@ class DownloadAndLoadCogVideoGGUFModel:
), ),
"vae_precision": (["fp16", "fp32", "bf16"], {"default": "bf16", "tooltip": "VAE dtype"}), "vae_precision": (["fp16", "fp32", "bf16"], {"default": "bf16", "tooltip": "VAE dtype"}),
"fp8_fastmode": ("BOOLEAN", {"default": False, "tooltip": "only supported on 4090 and later GPUs"}), "fp8_fastmode": ("BOOLEAN", {"default": False, "tooltip": "only supported on 4090 and later GPUs"}),
"compile": (["disabled","onediff","torch"], {"tooltip": "UNTESTED WITH GGUF"}), "load_device": (["main_device", "offload_device"], {"default": "main_device"}),
}, },
} }
@ -187,7 +187,7 @@ class DownloadAndLoadCogVideoGGUFModel:
FUNCTION = "loadmodel" FUNCTION = "loadmodel"
CATEGORY = "CogVideoWrapper" CATEGORY = "CogVideoWrapper"
def loadmodel(self, model, vae_precision, compile, fp8_fastmode): def loadmodel(self, model, vae_precision, fp8_fastmode, load_device):
device = mm.get_torch_device() device = mm.get_torch_device()
offload_device = mm.unet_offload_device() offload_device = mm.unet_offload_device()
mm.soft_empty_cache() mm.soft_empty_cache()
@ -227,7 +227,10 @@ class DownloadAndLoadCogVideoGGUFModel:
transformer.to(torch.float8_e4m3fn) transformer.to(torch.float8_e4m3fn)
transformer = mz_gguf_loader.quantize_load_state_dict(transformer, sd, device="cpu") transformer = mz_gguf_loader.quantize_load_state_dict(transformer, sd, device="cpu")
transformer.to(device) if load_device == "offload_device":
transformer.to(offload_device)
else:
transformer.to(device)
# transformer # transformer
# if fp8_transformer == "fastmode": # if fp8_transformer == "fastmode":