Swap out PatchCublasLinear in favor of checkpoint loader that patches it

Because the patch needs to happen before model loading
This commit is contained in:
kijai 2024-10-18 14:29:43 +03:00
parent 2263b8cb41
commit 7c1228a5a3
2 changed files with 46 additions and 10 deletions

View File

@ -150,7 +150,7 @@ NODE_CONFIG = {
"FluxBlockLoraLoader": {"class": FluxBlockLoraLoader, "name": "Flux Block Lora Loader"},
"FluxBlockLoraSelect": {"class": FluxBlockLoraSelect, "name": "Flux Block Lora Select"},
"CustomControlNetWeightsFluxFromList": {"class": CustomControlNetWeightsFluxFromList, "name": "Custom ControlNet Weights Flux From List"},
"PatchCublasLinear": {"class": PatchCublasLinear, "name": "Patch Cublas Linear"},
"CheckpointLoaderKJ": {"class": CheckpointLoaderKJ, "name": "CheckpointLoaderKJ"},
#instance diffusion
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},

View File

@ -2124,29 +2124,58 @@ class ModelSaveKJ:
os.makedirs(full_output_folder)
save_torch_file(new_sd, os.path.join(full_output_folder, output_checkpoint))
return {}
class PatchCublasLinear:
from comfy.ldm.modules import attention as comfy_attention
orig_attention = comfy_attention.optimized_attention
class CheckpointLoaderKJ:
original_linear = None
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"enabled": ("BOOLEAN", {"default": True, "tooltip": "Enable or disable the patching, won't take effect on already loaded models!"}),
return {"required": {
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}),
"patch_cublaslinear": ("BOOLEAN", {"default": True, "tooltip": "Enable or disable the patching, won't take effect on already loaded models!"}),
"sage_attention": ("BOOLEAN", {"default": False, "tooltip": "Patch comfy attention to use sageattn."}),
},
}
RETURN_TYPES = ("MODEL",)
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "patch"
OUTPUT_NODE = True
DESCRIPTION = "Highly experimental node that simply patches the Linear layer to use torch-cublas-hgemm, won't take effect on already loaded models!"
DESCRIPTION = "Exemplar node for patching torch.nn.Linear with CublasLinear: https://github.com/aredden/torch-cublas-hgemm"
CATEGORY = "KJNodes/experimental"
def patch(self, model, enabled):
def patch(self, ckpt_name, patch_cublaslinear, sage_attention):
from comfy.ops import disable_weight_init, CastWeightBiasOp, cast_bias_weight
from nodes import CheckpointLoaderSimple
try:
from cublas_ops import CublasLinear
except ImportError:
raise Exception("Can't import 'torch-cublas-hgemm', install it from here https://github.com/aredden/torch-cublas-hgemm")
if sage_attention:
from sageattention import sageattn
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
if skip_reshape:
b, _, _, dim_head = q.shape
else:
b, _, dim_head = q.shape
dim_head //= heads
if dim_head not in (64, 96, 128) or not (k.shape == q.shape and v.shape == q.shape):
return orig_attention(q, k, v, heads, mask=mask, attn_precision=attn_precision, skip_reshape=skip_reshape)
if not skip_reshape:
q, k, v = map(
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
(q, k, v),
)
return (
sageattn(q, k, v, is_causal=False, attn_mask=mask, dropout_p=0.0, smooth_k=True)
.transpose(1, 2)
.reshape(b, -1, heads * dim_head)
)
class OriginalLinear(torch.nn.Linear, CastWeightBiasOp):
def reset_parameters(self):
return None
@ -2175,10 +2204,17 @@ class PatchCublasLinear:
else:
return super().forward(*args, **kwargs)
if enabled:
if patch_cublaslinear:
disable_weight_init.Linear = PatchedLinear
else:
disable_weight_init.Linear = OriginalLinear
if sage_attention:
comfy_attention.optimized_attention = attention_sage
else:
comfy_attention.optimized_attention = orig_attention
model, clip, vae = CheckpointLoaderSimple.load_checkpoint(self, ckpt_name)
return model,
return model, clip, vae