diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index 54e01ac..55e1665 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -128,6 +128,7 @@ class BaseLoaderKJ: disable_weight_init.Linear = BaseLoaderKJ.original_linear BaseLoaderKJ.cublas_patched = False +from comfy.patcher_extension import CallbacksMP class PathchSageAttentionKJ(BaseLoaderKJ): @classmethod def INPUT_TYPES(s): @@ -143,7 +144,6 @@ class PathchSageAttentionKJ(BaseLoaderKJ): CATEGORY = "KJNodes/experimental" def patch(self, model, sage_attention): - from comfy.patcher_extension import CallbacksMP model_clone = model.clone() def patch_attention(model): self._patch_modules(False, sage_attention) @@ -168,9 +168,11 @@ class CheckpointLoaderKJ(BaseLoaderKJ): CATEGORY = "KJNodes/experimental" def patch(self, ckpt_name, patch_cublaslinear, sage_attention): - self._patch_modules(patch_cublaslinear, sage_attention) from nodes import CheckpointLoaderSimple model, clip, vae = CheckpointLoaderSimple.load_checkpoint(self, ckpt_name) + def patch_attention(model): + self._patch_modules(patch_cublaslinear, sage_attention) + model.add_callback(CallbacksMP.ON_PRE_RUN,patch_attention) return model, clip, vae class DiffusionModelLoaderKJ(BaseLoaderKJ): @@ -224,7 +226,10 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ): model.set_model_compute_dtype(dtype) model.force_cast_weights = False print(f"Setting {model_name} compute dtype to {dtype}") - self._patch_modules(patch_cublaslinear, sage_attention) + + def patch_attention(model): + self._patch_modules(patch_cublaslinear, sage_attention) + model.add_callback(CallbacksMP.ON_PRE_RUN,patch_attention) return (model,)