mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-09 21:04:41 +08:00
Add DiffusionModelLoaderKJ
This commit is contained in:
parent
5920419f44
commit
1bc5c99f5a
@ -154,6 +154,7 @@ NODE_CONFIG = {
|
|||||||
"FluxBlockLoraSelect": {"class": FluxBlockLoraSelect, "name": "Flux Block Lora Select"},
|
"FluxBlockLoraSelect": {"class": FluxBlockLoraSelect, "name": "Flux Block Lora Select"},
|
||||||
"CustomControlNetWeightsFluxFromList": {"class": CustomControlNetWeightsFluxFromList, "name": "Custom ControlNet Weights Flux From List"},
|
"CustomControlNetWeightsFluxFromList": {"class": CustomControlNetWeightsFluxFromList, "name": "Custom ControlNet Weights Flux From List"},
|
||||||
"CheckpointLoaderKJ": {"class": CheckpointLoaderKJ, "name": "CheckpointLoaderKJ"},
|
"CheckpointLoaderKJ": {"class": CheckpointLoaderKJ, "name": "CheckpointLoaderKJ"},
|
||||||
|
"DiffusionModelLoaderKJ": {"class": DiffusionModelLoaderKJ, "name": "Diffusion Model Loader KJ"},
|
||||||
"TorchCompileModelFluxAdvanced": {"class": TorchCompileModelFluxAdvanced, "name": "TorchCompileModelFluxAdvanced"},
|
"TorchCompileModelFluxAdvanced": {"class": TorchCompileModelFluxAdvanced, "name": "TorchCompileModelFluxAdvanced"},
|
||||||
"TorchCompileVAE": {"class": TorchCompileVAE, "name": "TorchCompileVAE"},
|
"TorchCompileVAE": {"class": TorchCompileVAE, "name": "TorchCompileVAE"},
|
||||||
"TorchCompileControlNet": {"class": TorchCompileControlNet, "name": "TorchCompileControlNet"},
|
"TorchCompileControlNet": {"class": TorchCompileControlNet, "name": "TorchCompileControlNet"},
|
||||||
|
|||||||
113
nodes/nodes.py
113
nodes/nodes.py
@ -2145,30 +2145,21 @@ class ModelSaveKJ:
|
|||||||
from comfy.ldm.modules import attention as comfy_attention
|
from comfy.ldm.modules import attention as comfy_attention
|
||||||
orig_attention = comfy_attention.optimized_attention
|
orig_attention = comfy_attention.optimized_attention
|
||||||
|
|
||||||
class CheckpointLoaderKJ:
|
class BaseLoaderKJ:
|
||||||
original_linear = None
|
original_linear = None
|
||||||
@classmethod
|
cublas_patched = False
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {"required": {
|
|
||||||
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}),
|
|
||||||
"patch_cublaslinear": ("BOOLEAN", {"default": True, "tooltip": "Enable or disable the patching, won't take effect on already loaded models!"}),
|
|
||||||
"sage_attention": ("BOOLEAN", {"default": False, "tooltip": "Patch comfy attention to use sageattn."}),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
|
||||||
FUNCTION = "patch"
|
|
||||||
OUTPUT_NODE = True
|
|
||||||
DESCRIPTION = "Exemplar node for patching torch.nn.Linear with CublasLinear: https://github.com/aredden/torch-cublas-hgemm"
|
|
||||||
|
|
||||||
CATEGORY = "KJNodes/experimental"
|
def _patch_modules(self, patch_cublaslinear, sage_attention):
|
||||||
|
|
||||||
def patch(self, ckpt_name, patch_cublaslinear, sage_attention):
|
|
||||||
from comfy.ops import disable_weight_init, CastWeightBiasOp, cast_bias_weight
|
from comfy.ops import disable_weight_init, CastWeightBiasOp, cast_bias_weight
|
||||||
from nodes import CheckpointLoaderSimple
|
import torch
|
||||||
|
|
||||||
|
global orig_attention
|
||||||
|
if 'orig_attention' not in globals():
|
||||||
|
orig_attention = comfy_attention.optimized_attention
|
||||||
|
|
||||||
if sage_attention:
|
if sage_attention:
|
||||||
from sageattention import sageattn
|
from sageattention import sageattn
|
||||||
|
|
||||||
@torch.compiler.disable()
|
@torch.compiler.disable()
|
||||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
@ -2177,7 +2168,7 @@ class CheckpointLoaderKJ:
|
|||||||
b, _, dim_head = q.shape
|
b, _, dim_head = q.shape
|
||||||
dim_head //= heads
|
dim_head //= heads
|
||||||
if dim_head not in (64, 96, 128) or not (k.shape == q.shape and v.shape == q.shape):
|
if dim_head not in (64, 96, 128) or not (k.shape == q.shape and v.shape == q.shape):
|
||||||
return orig_attention(q, k, v, heads, mask=mask, attn_precision=attn_precision, skip_reshape=skip_reshape)
|
return orig_attention(q, k, v, heads, mask, attn_precision, skip_reshape)
|
||||||
if not skip_reshape:
|
if not skip_reshape:
|
||||||
q, k, v = map(
|
q, k, v = map(
|
||||||
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
|
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
|
||||||
@ -2187,32 +2178,23 @@ class CheckpointLoaderKJ:
|
|||||||
sageattn(q, k, v, is_causal=False, attn_mask=mask, dropout_p=0.0, smooth_k=True)
|
sageattn(q, k, v, is_causal=False, attn_mask=mask, dropout_p=0.0, smooth_k=True)
|
||||||
.transpose(1, 2)
|
.transpose(1, 2)
|
||||||
.reshape(b, -1, heads * dim_head)
|
.reshape(b, -1, heads * dim_head)
|
||||||
)
|
)
|
||||||
|
|
||||||
# class OriginalLinear(torch.nn.Linear, CastWeightBiasOp):
|
|
||||||
# def reset_parameters(self):
|
|
||||||
# return None
|
|
||||||
|
|
||||||
# def forward_comfy_cast_weights(self, input):
|
comfy_attention.optimized_attention = attention_sage
|
||||||
# weight, bias = cast_bias_weight(self, input)
|
else:
|
||||||
# return torch.nn.functional.linear(input, weight, bias)
|
comfy_attention.optimized_attention = orig_attention
|
||||||
|
|
||||||
# def forward(self, *args, **kwargs):
|
|
||||||
# if self.comfy_cast_weights:
|
|
||||||
# return self.forward_comfy_cast_weights(*args, **kwargs)
|
|
||||||
# else:
|
|
||||||
# return super().forward(*args, **kwargs)
|
|
||||||
cublas_patched = False
|
|
||||||
if patch_cublaslinear:
|
if patch_cublaslinear:
|
||||||
if not cublas_patched:
|
if not BaseLoaderKJ.cublas_patched:
|
||||||
original_linear = disable_weight_init.Linear
|
BaseLoaderKJ.original_linear = disable_weight_init.Linear
|
||||||
try:
|
try:
|
||||||
from cublas_ops import CublasLinear
|
from cublas_ops import CublasLinear
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise Exception("Can't import 'torch-cublas-hgemm', install it from here https://github.com/aredden/torch-cublas-hgemm")
|
raise Exception("Can't import 'torch-cublas-hgemm', install it from here https://github.com/aredden/torch-cublas-hgemm")
|
||||||
|
|
||||||
class PatchedLinear(CublasLinear, CastWeightBiasOp):
|
class PatchedLinear(CublasLinear, CastWeightBiasOp):
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
pass
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input):
|
||||||
weight, bias = cast_bias_weight(self, input)
|
weight, bias = cast_bias_weight(self, input)
|
||||||
@ -2223,20 +2205,59 @@ class CheckpointLoaderKJ:
|
|||||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
disable_weight_init.Linear = PatchedLinear
|
disable_weight_init.Linear = PatchedLinear
|
||||||
cublas_patched = True
|
BaseLoaderKJ.cublas_patched = True
|
||||||
else:
|
|
||||||
disable_weight_init.Linear = original_linear
|
|
||||||
cublas_patched = False
|
|
||||||
|
|
||||||
if sage_attention:
|
|
||||||
comfy_attention.optimized_attention = attention_sage
|
|
||||||
else:
|
else:
|
||||||
comfy_attention.optimized_attention = orig_attention
|
if BaseLoaderKJ.cublas_patched:
|
||||||
|
disable_weight_init.Linear = BaseLoaderKJ.original_linear
|
||||||
|
BaseLoaderKJ.cublas_patched = False
|
||||||
|
|
||||||
|
class CheckpointLoaderKJ(BaseLoaderKJ):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {
|
||||||
|
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}),
|
||||||
|
"patch_cublaslinear": ("BOOLEAN", {"default": True, "tooltip": "Enable or disable the patching, won't take effect on already loaded models!"}),
|
||||||
|
"sage_attention": ("BOOLEAN", {"default": False, "tooltip": "Patch comfy attention to use sageattn."}),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
||||||
|
FUNCTION = "patch"
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
DESCRIPTION = "Experimental node for patching torch.nn.Linear with CublasLinear."
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
CATEGORY = "KJNodes/experimental"
|
||||||
|
|
||||||
|
def patch(self, ckpt_name, patch_cublaslinear, sage_attention):
|
||||||
|
self._patch_modules(patch_cublaslinear, sage_attention)
|
||||||
|
from nodes import CheckpointLoaderSimple
|
||||||
model, clip, vae = CheckpointLoaderSimple.load_checkpoint(self, ckpt_name)
|
model, clip, vae = CheckpointLoaderSimple.load_checkpoint(self, ckpt_name)
|
||||||
|
|
||||||
return model, clip, vae
|
return model, clip, vae
|
||||||
|
|
||||||
|
class DiffusionModelLoaderKJ(BaseLoaderKJ):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {
|
||||||
|
"ckpt_name": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "The name of the checkpoint (model) to load."}),
|
||||||
|
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],),
|
||||||
|
"patch_cublaslinear": ("BOOLEAN", {"default": True, "tooltip": "Enable or disable the patching, won't take effect on already loaded models!"}),
|
||||||
|
"sage_attention": ("BOOLEAN", {"default": False, "tooltip": "Patch comfy attention to use sageattn."}),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "patch_and_load"
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
DESCRIPTION = "Node for patching torch.nn.Linear with CublasLinear."
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
CATEGORY = "KJNodes/experimental"
|
||||||
|
|
||||||
|
def patch_and_load(self, ckpt_name, weight_dtype, patch_cublaslinear, sage_attention):
|
||||||
|
self._patch_modules(patch_cublaslinear, sage_attention)
|
||||||
|
from nodes import UNETLoader
|
||||||
|
model, = UNETLoader.load_unet(self, ckpt_name, weight_dtype)
|
||||||
|
return (model,)
|
||||||
|
|
||||||
|
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user