From 7c1228a5a3861e0ad4c5bf68aa8360935022b587 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Fri, 18 Oct 2024 14:29:43 +0300 Subject: [PATCH] Swap out PatchCublasLinear in favor of checkpoint loader that patches it Because the patch needs to happen before model loading --- __init__.py | 2 +- nodes/nodes.py | 54 +++++++++++++++++++++++++++++++++++++++++--------- 2 files changed, 46 insertions(+), 10 deletions(-) diff --git a/__init__.py b/__init__.py index 026cdb6..50b374e 100644 --- a/__init__.py +++ b/__init__.py @@ -150,7 +150,7 @@ NODE_CONFIG = { "FluxBlockLoraLoader": {"class": FluxBlockLoraLoader, "name": "Flux Block Lora Loader"}, "FluxBlockLoraSelect": {"class": FluxBlockLoraSelect, "name": "Flux Block Lora Select"}, "CustomControlNetWeightsFluxFromList": {"class": CustomControlNetWeightsFluxFromList, "name": "Custom ControlNet Weights Flux From List"}, - "PatchCublasLinear": {"class": PatchCublasLinear, "name": "Patch Cublas Linear"}, + "CheckpointLoaderKJ": {"class": CheckpointLoaderKJ, "name": "CheckpointLoaderKJ"}, #instance diffusion "CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking}, diff --git a/nodes/nodes.py b/nodes/nodes.py index 7a7ffde..fa0d5cf 100644 --- a/nodes/nodes.py +++ b/nodes/nodes.py @@ -2124,29 +2124,58 @@ class ModelSaveKJ: os.makedirs(full_output_folder) save_torch_file(new_sd, os.path.join(full_output_folder, output_checkpoint)) return {} - -class PatchCublasLinear: + + +from comfy.ldm.modules import attention as comfy_attention +orig_attention = comfy_attention.optimized_attention + +class CheckpointLoaderKJ: original_linear = None @classmethod def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "enabled": ("BOOLEAN", {"default": True, "tooltip": "Enable or disable the patching, won't take effect on already loaded models!"}), + return {"required": { + "ckpt_name": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}), + "patch_cublaslinear": ("BOOLEAN", {"default": True, "tooltip": "Enable or disable the patching, won't take effect on already loaded models!"}), + "sage_attention": ("BOOLEAN", {"default": False, "tooltip": "Patch comfy attention to use sageattn."}), }, } - RETURN_TYPES = ("MODEL",) + RETURN_TYPES = ("MODEL", "CLIP", "VAE") FUNCTION = "patch" OUTPUT_NODE = True - DESCRIPTION = "Highly experimental node that simply patches the Linear layer to use torch-cublas-hgemm, won't take effect on already loaded models!" + DESCRIPTION = "Exemplar node for patching torch.nn.Linear with CublasLinear: https://github.com/aredden/torch-cublas-hgemm" CATEGORY = "KJNodes/experimental" - def patch(self, model, enabled): + def patch(self, ckpt_name, patch_cublaslinear, sage_attention): from comfy.ops import disable_weight_init, CastWeightBiasOp, cast_bias_weight + from nodes import CheckpointLoaderSimple try: from cublas_ops import CublasLinear except ImportError: raise Exception("Can't import 'torch-cublas-hgemm', install it from here https://github.com/aredden/torch-cublas-hgemm") + if sage_attention: + from sageattention import sageattn + + def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): + if skip_reshape: + b, _, _, dim_head = q.shape + else: + b, _, dim_head = q.shape + dim_head //= heads + if dim_head not in (64, 96, 128) or not (k.shape == q.shape and v.shape == q.shape): + return orig_attention(q, k, v, heads, mask=mask, attn_precision=attn_precision, skip_reshape=skip_reshape) + if not skip_reshape: + q, k, v = map( + lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), + (q, k, v), + ) + return ( + sageattn(q, k, v, is_causal=False, attn_mask=mask, dropout_p=0.0, smooth_k=True) + .transpose(1, 2) + .reshape(b, -1, heads * dim_head) + ) + class OriginalLinear(torch.nn.Linear, CastWeightBiasOp): def reset_parameters(self): return None @@ -2175,10 +2204,17 @@ class PatchCublasLinear: else: return super().forward(*args, **kwargs) - if enabled: + if patch_cublaslinear: disable_weight_init.Linear = PatchedLinear else: disable_weight_init.Linear = OriginalLinear + if sage_attention: + comfy_attention.optimized_attention = attention_sage + else: + comfy_attention.optimized_attention = orig_attention + + model, clip, vae = CheckpointLoaderSimple.load_checkpoint(self, ckpt_name) - return model, + + return model, clip, vae \ No newline at end of file