mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-14 07:14:29 +08:00
Update model_optimization_nodes.py
This commit is contained in:
parent
4d9c73ed46
commit
8dac94d9d9
@ -128,6 +128,7 @@ class BaseLoaderKJ:
|
|||||||
disable_weight_init.Linear = BaseLoaderKJ.original_linear
|
disable_weight_init.Linear = BaseLoaderKJ.original_linear
|
||||||
BaseLoaderKJ.cublas_patched = False
|
BaseLoaderKJ.cublas_patched = False
|
||||||
|
|
||||||
|
from comfy.patcher_extension import CallbacksMP
|
||||||
class PathchSageAttentionKJ(BaseLoaderKJ):
|
class PathchSageAttentionKJ(BaseLoaderKJ):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -143,7 +144,6 @@ class PathchSageAttentionKJ(BaseLoaderKJ):
|
|||||||
CATEGORY = "KJNodes/experimental"
|
CATEGORY = "KJNodes/experimental"
|
||||||
|
|
||||||
def patch(self, model, sage_attention):
|
def patch(self, model, sage_attention):
|
||||||
from comfy.patcher_extension import CallbacksMP
|
|
||||||
model_clone = model.clone()
|
model_clone = model.clone()
|
||||||
def patch_attention(model):
|
def patch_attention(model):
|
||||||
self._patch_modules(False, sage_attention)
|
self._patch_modules(False, sage_attention)
|
||||||
@ -168,9 +168,11 @@ class CheckpointLoaderKJ(BaseLoaderKJ):
|
|||||||
CATEGORY = "KJNodes/experimental"
|
CATEGORY = "KJNodes/experimental"
|
||||||
|
|
||||||
def patch(self, ckpt_name, patch_cublaslinear, sage_attention):
|
def patch(self, ckpt_name, patch_cublaslinear, sage_attention):
|
||||||
self._patch_modules(patch_cublaslinear, sage_attention)
|
|
||||||
from nodes import CheckpointLoaderSimple
|
from nodes import CheckpointLoaderSimple
|
||||||
model, clip, vae = CheckpointLoaderSimple.load_checkpoint(self, ckpt_name)
|
model, clip, vae = CheckpointLoaderSimple.load_checkpoint(self, ckpt_name)
|
||||||
|
def patch_attention(model):
|
||||||
|
self._patch_modules(patch_cublaslinear, sage_attention)
|
||||||
|
model.add_callback(CallbacksMP.ON_PRE_RUN,patch_attention)
|
||||||
return model, clip, vae
|
return model, clip, vae
|
||||||
|
|
||||||
class DiffusionModelLoaderKJ(BaseLoaderKJ):
|
class DiffusionModelLoaderKJ(BaseLoaderKJ):
|
||||||
@ -224,7 +226,10 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ):
|
|||||||
model.set_model_compute_dtype(dtype)
|
model.set_model_compute_dtype(dtype)
|
||||||
model.force_cast_weights = False
|
model.force_cast_weights = False
|
||||||
print(f"Setting {model_name} compute dtype to {dtype}")
|
print(f"Setting {model_name} compute dtype to {dtype}")
|
||||||
self._patch_modules(patch_cublaslinear, sage_attention)
|
|
||||||
|
def patch_attention(model):
|
||||||
|
self._patch_modules(patch_cublaslinear, sage_attention)
|
||||||
|
model.add_callback(CallbacksMP.ON_PRE_RUN,patch_attention)
|
||||||
|
|
||||||
return (model,)
|
return (model,)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user