diff --git a/__init__.py b/__init__.py index 64d8cb4..ff4e8c9 100644 --- a/__init__.py +++ b/__init__.py @@ -154,6 +154,7 @@ NODE_CONFIG = { "FluxBlockLoraSelect": {"class": FluxBlockLoraSelect, "name": "Flux Block Lora Select"}, "CustomControlNetWeightsFluxFromList": {"class": CustomControlNetWeightsFluxFromList, "name": "Custom ControlNet Weights Flux From List"}, "CheckpointLoaderKJ": {"class": CheckpointLoaderKJ, "name": "CheckpointLoaderKJ"}, + "DiffusionModelLoaderKJ": {"class": DiffusionModelLoaderKJ, "name": "Diffusion Model Loader KJ"}, "TorchCompileModelFluxAdvanced": {"class": TorchCompileModelFluxAdvanced, "name": "TorchCompileModelFluxAdvanced"}, "TorchCompileVAE": {"class": TorchCompileVAE, "name": "TorchCompileVAE"}, "TorchCompileControlNet": {"class": TorchCompileControlNet, "name": "TorchCompileControlNet"}, diff --git a/nodes/nodes.py b/nodes/nodes.py index df5e951..331b59d 100644 --- a/nodes/nodes.py +++ b/nodes/nodes.py @@ -2145,30 +2145,21 @@ class ModelSaveKJ: from comfy.ldm.modules import attention as comfy_attention orig_attention = comfy_attention.optimized_attention -class CheckpointLoaderKJ: +class BaseLoaderKJ: original_linear = None - @classmethod - def INPUT_TYPES(s): - return {"required": { - "ckpt_name": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}), - "patch_cublaslinear": ("BOOLEAN", {"default": True, "tooltip": "Enable or disable the patching, won't take effect on already loaded models!"}), - "sage_attention": ("BOOLEAN", {"default": False, "tooltip": "Patch comfy attention to use sageattn."}), - }, - } - RETURN_TYPES = ("MODEL", "CLIP", "VAE") - FUNCTION = "patch" - OUTPUT_NODE = True - DESCRIPTION = "Exemplar node for patching torch.nn.Linear with CublasLinear: https://github.com/aredden/torch-cublas-hgemm" + cublas_patched = False - CATEGORY = "KJNodes/experimental" - - def patch(self, ckpt_name, patch_cublaslinear, sage_attention): + def _patch_modules(self, patch_cublaslinear, sage_attention): from comfy.ops import disable_weight_init, CastWeightBiasOp, cast_bias_weight - from nodes import CheckpointLoaderSimple - + import torch + + global orig_attention + if 'orig_attention' not in globals(): + orig_attention = comfy_attention.optimized_attention + if sage_attention: from sageattention import sageattn - + @torch.compiler.disable() def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): if skip_reshape: @@ -2177,7 +2168,7 @@ class CheckpointLoaderKJ: b, _, dim_head = q.shape dim_head //= heads if dim_head not in (64, 96, 128) or not (k.shape == q.shape and v.shape == q.shape): - return orig_attention(q, k, v, heads, mask=mask, attn_precision=attn_precision, skip_reshape=skip_reshape) + return orig_attention(q, k, v, heads, mask, attn_precision, skip_reshape) if not skip_reshape: q, k, v = map( lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), @@ -2187,32 +2178,23 @@ class CheckpointLoaderKJ: sageattn(q, k, v, is_causal=False, attn_mask=mask, dropout_p=0.0, smooth_k=True) .transpose(1, 2) .reshape(b, -1, heads * dim_head) - ) - - # class OriginalLinear(torch.nn.Linear, CastWeightBiasOp): - # def reset_parameters(self): - # return None + ) - # def forward_comfy_cast_weights(self, input): - # weight, bias = cast_bias_weight(self, input) - # return torch.nn.functional.linear(input, weight, bias) + comfy_attention.optimized_attention = attention_sage + else: + comfy_attention.optimized_attention = orig_attention - # def forward(self, *args, **kwargs): - # if self.comfy_cast_weights: - # return self.forward_comfy_cast_weights(*args, **kwargs) - # else: - # return super().forward(*args, **kwargs) - cublas_patched = False if patch_cublaslinear: - if not cublas_patched: - original_linear = disable_weight_init.Linear + if not BaseLoaderKJ.cublas_patched: + BaseLoaderKJ.original_linear = disable_weight_init.Linear try: from cublas_ops import CublasLinear except ImportError: raise Exception("Can't import 'torch-cublas-hgemm', install it from here https://github.com/aredden/torch-cublas-hgemm") + class PatchedLinear(CublasLinear, CastWeightBiasOp): def reset_parameters(self): - return None + pass def forward_comfy_cast_weights(self, input): weight, bias = cast_bias_weight(self, input) @@ -2223,20 +2205,59 @@ class CheckpointLoaderKJ: return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) + disable_weight_init.Linear = PatchedLinear - cublas_patched = True - else: - disable_weight_init.Linear = original_linear - cublas_patched = False - - if sage_attention: - comfy_attention.optimized_attention = attention_sage + BaseLoaderKJ.cublas_patched = True else: - comfy_attention.optimized_attention = orig_attention + if BaseLoaderKJ.cublas_patched: + disable_weight_init.Linear = BaseLoaderKJ.original_linear + BaseLoaderKJ.cublas_patched = False +class CheckpointLoaderKJ(BaseLoaderKJ): + @classmethod + def INPUT_TYPES(s): + return {"required": { + "ckpt_name": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}), + "patch_cublaslinear": ("BOOLEAN", {"default": True, "tooltip": "Enable or disable the patching, won't take effect on already loaded models!"}), + "sage_attention": ("BOOLEAN", {"default": False, "tooltip": "Patch comfy attention to use sageattn."}), + }} + + RETURN_TYPES = ("MODEL", "CLIP", "VAE") + FUNCTION = "patch" + OUTPUT_NODE = True + DESCRIPTION = "Experimental node for patching torch.nn.Linear with CublasLinear." + EXPERIMENTAL = True + CATEGORY = "KJNodes/experimental" + + def patch(self, ckpt_name, patch_cublaslinear, sage_attention): + self._patch_modules(patch_cublaslinear, sage_attention) + from nodes import CheckpointLoaderSimple model, clip, vae = CheckpointLoaderSimple.load_checkpoint(self, ckpt_name) - return model, clip, vae + +class DiffusionModelLoaderKJ(BaseLoaderKJ): + @classmethod + def INPUT_TYPES(s): + return {"required": { + "ckpt_name": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "The name of the checkpoint (model) to load."}), + "weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],), + "patch_cublaslinear": ("BOOLEAN", {"default": True, "tooltip": "Enable or disable the patching, won't take effect on already loaded models!"}), + "sage_attention": ("BOOLEAN", {"default": False, "tooltip": "Patch comfy attention to use sageattn."}), + }} + + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch_and_load" + OUTPUT_NODE = True + DESCRIPTION = "Node for patching torch.nn.Linear with CublasLinear." + EXPERIMENTAL = True + CATEGORY = "KJNodes/experimental" + + def patch_and_load(self, ckpt_name, weight_dtype, patch_cublaslinear, sage_attention): + self._patch_modules(patch_cublaslinear, sage_attention) + from nodes import UNETLoader + model, = UNETLoader.load_unet(self, ckpt_name, weight_dtype) + return (model,) + import comfy.model_patcher import comfy.utils