Update model_optimization_nodes.py

This commit is contained in:
kijai 2025-04-23 20:08:10 +03:00
parent 4d9c73ed46
commit 8dac94d9d9

View File

@ -128,6 +128,7 @@ class BaseLoaderKJ:
disable_weight_init.Linear = BaseLoaderKJ.original_linear
BaseLoaderKJ.cublas_patched = False
from comfy.patcher_extension import CallbacksMP
class PathchSageAttentionKJ(BaseLoaderKJ):
@classmethod
def INPUT_TYPES(s):
@ -143,7 +144,6 @@ class PathchSageAttentionKJ(BaseLoaderKJ):
CATEGORY = "KJNodes/experimental"
def patch(self, model, sage_attention):
from comfy.patcher_extension import CallbacksMP
model_clone = model.clone()
def patch_attention(model):
self._patch_modules(False, sage_attention)
@ -168,9 +168,11 @@ class CheckpointLoaderKJ(BaseLoaderKJ):
CATEGORY = "KJNodes/experimental"
def patch(self, ckpt_name, patch_cublaslinear, sage_attention):
self._patch_modules(patch_cublaslinear, sage_attention)
from nodes import CheckpointLoaderSimple
model, clip, vae = CheckpointLoaderSimple.load_checkpoint(self, ckpt_name)
def patch_attention(model):
self._patch_modules(patch_cublaslinear, sage_attention)
model.add_callback(CallbacksMP.ON_PRE_RUN,patch_attention)
return model, clip, vae
class DiffusionModelLoaderKJ(BaseLoaderKJ):
@ -224,7 +226,10 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ):
model.set_model_compute_dtype(dtype)
model.force_cast_weights = False
print(f"Setting {model_name} compute dtype to {dtype}")
self._patch_modules(patch_cublaslinear, sage_attention)
def patch_attention(model):
self._patch_modules(patch_cublaslinear, sage_attention)
model.add_callback(CallbacksMP.ON_PRE_RUN,patch_attention)
return (model,)