diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index 15772d0..277bea5 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -5,6 +5,7 @@ import comfy.sd import torch import folder_paths orig_attention = comfy_attention.optimized_attention +import comfy.model_management as mm class BaseLoaderKJ: original_linear = None @@ -155,7 +156,9 @@ def patched_patch_model(self, device_to=None, lowvram_model_memory=0, load_weigh full_load = True else: full_load = False - + + device_to = mm.get_torch_device() + load_weights = True if load_weights: self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load) for k in self.object_patches: