mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-09 21:04:41 +08:00
Update model_optimization_nodes.py
This commit is contained in:
parent
4d9c73ed46
commit
8dac94d9d9
@ -128,6 +128,7 @@ class BaseLoaderKJ:
|
||||
disable_weight_init.Linear = BaseLoaderKJ.original_linear
|
||||
BaseLoaderKJ.cublas_patched = False
|
||||
|
||||
from comfy.patcher_extension import CallbacksMP
|
||||
class PathchSageAttentionKJ(BaseLoaderKJ):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -143,7 +144,6 @@ class PathchSageAttentionKJ(BaseLoaderKJ):
|
||||
CATEGORY = "KJNodes/experimental"
|
||||
|
||||
def patch(self, model, sage_attention):
|
||||
from comfy.patcher_extension import CallbacksMP
|
||||
model_clone = model.clone()
|
||||
def patch_attention(model):
|
||||
self._patch_modules(False, sage_attention)
|
||||
@ -168,9 +168,11 @@ class CheckpointLoaderKJ(BaseLoaderKJ):
|
||||
CATEGORY = "KJNodes/experimental"
|
||||
|
||||
def patch(self, ckpt_name, patch_cublaslinear, sage_attention):
|
||||
self._patch_modules(patch_cublaslinear, sage_attention)
|
||||
from nodes import CheckpointLoaderSimple
|
||||
model, clip, vae = CheckpointLoaderSimple.load_checkpoint(self, ckpt_name)
|
||||
def patch_attention(model):
|
||||
self._patch_modules(patch_cublaslinear, sage_attention)
|
||||
model.add_callback(CallbacksMP.ON_PRE_RUN,patch_attention)
|
||||
return model, clip, vae
|
||||
|
||||
class DiffusionModelLoaderKJ(BaseLoaderKJ):
|
||||
@ -224,7 +226,10 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ):
|
||||
model.set_model_compute_dtype(dtype)
|
||||
model.force_cast_weights = False
|
||||
print(f"Setting {model_name} compute dtype to {dtype}")
|
||||
self._patch_modules(patch_cublaslinear, sage_attention)
|
||||
|
||||
def patch_attention(model):
|
||||
self._patch_modules(patch_cublaslinear, sage_attention)
|
||||
model.add_callback(CallbacksMP.ON_PRE_RUN,patch_attention)
|
||||
|
||||
return (model,)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user