Add DiffusionModelLoaderKJ

This commit is contained in:
kijai 2024-11-25 16:47:21 +02:00
parent 5920419f44
commit 1bc5c99f5a
2 changed files with 68 additions and 46 deletions

View File

@ -154,6 +154,7 @@ NODE_CONFIG = {
"FluxBlockLoraSelect": {"class": FluxBlockLoraSelect, "name": "Flux Block Lora Select"}, "FluxBlockLoraSelect": {"class": FluxBlockLoraSelect, "name": "Flux Block Lora Select"},
"CustomControlNetWeightsFluxFromList": {"class": CustomControlNetWeightsFluxFromList, "name": "Custom ControlNet Weights Flux From List"}, "CustomControlNetWeightsFluxFromList": {"class": CustomControlNetWeightsFluxFromList, "name": "Custom ControlNet Weights Flux From List"},
"CheckpointLoaderKJ": {"class": CheckpointLoaderKJ, "name": "CheckpointLoaderKJ"}, "CheckpointLoaderKJ": {"class": CheckpointLoaderKJ, "name": "CheckpointLoaderKJ"},
"DiffusionModelLoaderKJ": {"class": DiffusionModelLoaderKJ, "name": "Diffusion Model Loader KJ"},
"TorchCompileModelFluxAdvanced": {"class": TorchCompileModelFluxAdvanced, "name": "TorchCompileModelFluxAdvanced"}, "TorchCompileModelFluxAdvanced": {"class": TorchCompileModelFluxAdvanced, "name": "TorchCompileModelFluxAdvanced"},
"TorchCompileVAE": {"class": TorchCompileVAE, "name": "TorchCompileVAE"}, "TorchCompileVAE": {"class": TorchCompileVAE, "name": "TorchCompileVAE"},
"TorchCompileControlNet": {"class": TorchCompileControlNet, "name": "TorchCompileControlNet"}, "TorchCompileControlNet": {"class": TorchCompileControlNet, "name": "TorchCompileControlNet"},

View File

@ -2145,26 +2145,17 @@ class ModelSaveKJ:
from comfy.ldm.modules import attention as comfy_attention from comfy.ldm.modules import attention as comfy_attention
orig_attention = comfy_attention.optimized_attention orig_attention = comfy_attention.optimized_attention
class CheckpointLoaderKJ: class BaseLoaderKJ:
original_linear = None original_linear = None
@classmethod cublas_patched = False
def INPUT_TYPES(s):
return {"required": {
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}),
"patch_cublaslinear": ("BOOLEAN", {"default": True, "tooltip": "Enable or disable the patching, won't take effect on already loaded models!"}),
"sage_attention": ("BOOLEAN", {"default": False, "tooltip": "Patch comfy attention to use sageattn."}),
},
}
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "patch"
OUTPUT_NODE = True
DESCRIPTION = "Exemplar node for patching torch.nn.Linear with CublasLinear: https://github.com/aredden/torch-cublas-hgemm"
CATEGORY = "KJNodes/experimental" def _patch_modules(self, patch_cublaslinear, sage_attention):
def patch(self, ckpt_name, patch_cublaslinear, sage_attention):
from comfy.ops import disable_weight_init, CastWeightBiasOp, cast_bias_weight from comfy.ops import disable_weight_init, CastWeightBiasOp, cast_bias_weight
from nodes import CheckpointLoaderSimple import torch
global orig_attention
if 'orig_attention' not in globals():
orig_attention = comfy_attention.optimized_attention
if sage_attention: if sage_attention:
from sageattention import sageattn from sageattention import sageattn
@ -2177,7 +2168,7 @@ class CheckpointLoaderKJ:
b, _, dim_head = q.shape b, _, dim_head = q.shape
dim_head //= heads dim_head //= heads
if dim_head not in (64, 96, 128) or not (k.shape == q.shape and v.shape == q.shape): if dim_head not in (64, 96, 128) or not (k.shape == q.shape and v.shape == q.shape):
return orig_attention(q, k, v, heads, mask=mask, attn_precision=attn_precision, skip_reshape=skip_reshape) return orig_attention(q, k, v, heads, mask, attn_precision, skip_reshape)
if not skip_reshape: if not skip_reshape:
q, k, v = map( q, k, v = map(
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
@ -2189,30 +2180,21 @@ class CheckpointLoaderKJ:
.reshape(b, -1, heads * dim_head) .reshape(b, -1, heads * dim_head)
) )
# class OriginalLinear(torch.nn.Linear, CastWeightBiasOp): comfy_attention.optimized_attention = attention_sage
# def reset_parameters(self): else:
# return None comfy_attention.optimized_attention = orig_attention
# def forward_comfy_cast_weights(self, input):
# weight, bias = cast_bias_weight(self, input)
# return torch.nn.functional.linear(input, weight, bias)
# def forward(self, *args, **kwargs):
# if self.comfy_cast_weights:
# return self.forward_comfy_cast_weights(*args, **kwargs)
# else:
# return super().forward(*args, **kwargs)
cublas_patched = False
if patch_cublaslinear: if patch_cublaslinear:
if not cublas_patched: if not BaseLoaderKJ.cublas_patched:
original_linear = disable_weight_init.Linear BaseLoaderKJ.original_linear = disable_weight_init.Linear
try: try:
from cublas_ops import CublasLinear from cublas_ops import CublasLinear
except ImportError: except ImportError:
raise Exception("Can't import 'torch-cublas-hgemm', install it from here https://github.com/aredden/torch-cublas-hgemm") raise Exception("Can't import 'torch-cublas-hgemm', install it from here https://github.com/aredden/torch-cublas-hgemm")
class PatchedLinear(CublasLinear, CastWeightBiasOp): class PatchedLinear(CublasLinear, CastWeightBiasOp):
def reset_parameters(self): def reset_parameters(self):
return None pass
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input) weight, bias = cast_bias_weight(self, input)
@ -2223,21 +2205,60 @@ class CheckpointLoaderKJ:
return self.forward_comfy_cast_weights(*args, **kwargs) return self.forward_comfy_cast_weights(*args, **kwargs)
else: else:
return super().forward(*args, **kwargs) return super().forward(*args, **kwargs)
disable_weight_init.Linear = PatchedLinear disable_weight_init.Linear = PatchedLinear
cublas_patched = True BaseLoaderKJ.cublas_patched = True
else: else:
disable_weight_init.Linear = original_linear if BaseLoaderKJ.cublas_patched:
cublas_patched = False disable_weight_init.Linear = BaseLoaderKJ.original_linear
BaseLoaderKJ.cublas_patched = False
if sage_attention: class CheckpointLoaderKJ(BaseLoaderKJ):
comfy_attention.optimized_attention = attention_sage @classmethod
else: def INPUT_TYPES(s):
comfy_attention.optimized_attention = orig_attention return {"required": {
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}),
"patch_cublaslinear": ("BOOLEAN", {"default": True, "tooltip": "Enable or disable the patching, won't take effect on already loaded models!"}),
"sage_attention": ("BOOLEAN", {"default": False, "tooltip": "Patch comfy attention to use sageattn."}),
}}
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "patch"
OUTPUT_NODE = True
DESCRIPTION = "Experimental node for patching torch.nn.Linear with CublasLinear."
EXPERIMENTAL = True
CATEGORY = "KJNodes/experimental"
def patch(self, ckpt_name, patch_cublaslinear, sage_attention):
self._patch_modules(patch_cublaslinear, sage_attention)
from nodes import CheckpointLoaderSimple
model, clip, vae = CheckpointLoaderSimple.load_checkpoint(self, ckpt_name) model, clip, vae = CheckpointLoaderSimple.load_checkpoint(self, ckpt_name)
return model, clip, vae return model, clip, vae
class DiffusionModelLoaderKJ(BaseLoaderKJ):
@classmethod
def INPUT_TYPES(s):
return {"required": {
"ckpt_name": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "The name of the checkpoint (model) to load."}),
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],),
"patch_cublaslinear": ("BOOLEAN", {"default": True, "tooltip": "Enable or disable the patching, won't take effect on already loaded models!"}),
"sage_attention": ("BOOLEAN", {"default": False, "tooltip": "Patch comfy attention to use sageattn."}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch_and_load"
OUTPUT_NODE = True
DESCRIPTION = "Node for patching torch.nn.Linear with CublasLinear."
EXPERIMENTAL = True
CATEGORY = "KJNodes/experimental"
def patch_and_load(self, ckpt_name, weight_dtype, patch_cublaslinear, sage_attention):
self._patch_modules(patch_cublaslinear, sage_attention)
from nodes import UNETLoader
model, = UNETLoader.load_unet(self, ckpt_name, weight_dtype)
return (model,)
import comfy.model_patcher import comfy.model_patcher
import comfy.utils import comfy.utils
import comfy.sd import comfy.sd