allow 2nd GPU as FasterCache device

experimental
This commit is contained in:
kijai 2024-11-07 19:42:59 +02:00
parent f7f999adbc
commit f52c45800c

View File

@ -755,7 +755,7 @@ class CogVideoXFasterCache:
"start_step": ("INT", {"default": 15, "min": 0, "max": 1024, "step": 1}), "start_step": ("INT", {"default": 15, "min": 0, "max": 1024, "step": 1}),
"hf_step": ("INT", {"default": 30, "min": 0, "max": 1024, "step": 1}), "hf_step": ("INT", {"default": 30, "min": 0, "max": 1024, "step": 1}),
"lf_step": ("INT", {"default": 40, "min": 0, "max": 1024, "step": 1}), "lf_step": ("INT", {"default": 40, "min": 0, "max": 1024, "step": 1}),
"cache_device": (["main_device", "offload_device"], {"default": "main_device", "tooltip": "The device to use for the cache, main_device is on GPU and uses a lot of VRAM"}), "cache_device": (["main_device", "offload_device", "cuda:1"], {"default": "main_device", "tooltip": "The device to use for the cache, main_device is on GPU and uses a lot of VRAM"}),
}, },
} }
@ -767,11 +767,13 @@ class CogVideoXFasterCache:
def args(self, start_step, hf_step, lf_step, cache_device): def args(self, start_step, hf_step, lf_step, cache_device):
device = mm.get_torch_device() device = mm.get_torch_device()
offload_device = mm.unet_offload_device() offload_device = mm.unet_offload_device()
if cache_device == "cuda:1":
device = torch.device("cuda:1")
fastercache = { fastercache = {
"start_step" : start_step, "start_step" : start_step,
"hf_step" : hf_step, "hf_step" : hf_step,
"lf_step" : lf_step, "lf_step" : lf_step,
"cache_device" : device if cache_device == "main_device" else offload_device "cache_device" : device if cache_device != "offload_device" else offload_device
} }
return (fastercache,) return (fastercache,)