mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-04-17 06:46:58 +08:00
Swap out PatchCublasLinear in favor of checkpoint loader that patches it
Because the patch needs to happen before model loading
This commit is contained in:
parent
2263b8cb41
commit
7c1228a5a3
@ -150,7 +150,7 @@ NODE_CONFIG = {
|
||||
"FluxBlockLoraLoader": {"class": FluxBlockLoraLoader, "name": "Flux Block Lora Loader"},
|
||||
"FluxBlockLoraSelect": {"class": FluxBlockLoraSelect, "name": "Flux Block Lora Select"},
|
||||
"CustomControlNetWeightsFluxFromList": {"class": CustomControlNetWeightsFluxFromList, "name": "Custom ControlNet Weights Flux From List"},
|
||||
"PatchCublasLinear": {"class": PatchCublasLinear, "name": "Patch Cublas Linear"},
|
||||
"CheckpointLoaderKJ": {"class": CheckpointLoaderKJ, "name": "CheckpointLoaderKJ"},
|
||||
|
||||
#instance diffusion
|
||||
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},
|
||||
|
||||
@ -2124,29 +2124,58 @@ class ModelSaveKJ:
|
||||
os.makedirs(full_output_folder)
|
||||
save_torch_file(new_sd, os.path.join(full_output_folder, output_checkpoint))
|
||||
return {}
|
||||
|
||||
class PatchCublasLinear:
|
||||
|
||||
|
||||
from comfy.ldm.modules import attention as comfy_attention
|
||||
orig_attention = comfy_attention.optimized_attention
|
||||
|
||||
class CheckpointLoaderKJ:
|
||||
original_linear = None
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"enabled": ("BOOLEAN", {"default": True, "tooltip": "Enable or disable the patching, won't take effect on already loaded models!"}),
|
||||
return {"required": {
|
||||
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}),
|
||||
"patch_cublaslinear": ("BOOLEAN", {"default": True, "tooltip": "Enable or disable the patching, won't take effect on already loaded models!"}),
|
||||
"sage_attention": ("BOOLEAN", {"default": False, "tooltip": "Patch comfy attention to use sageattn."}),
|
||||
},
|
||||
}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
||||
FUNCTION = "patch"
|
||||
OUTPUT_NODE = True
|
||||
DESCRIPTION = "Highly experimental node that simply patches the Linear layer to use torch-cublas-hgemm, won't take effect on already loaded models!"
|
||||
DESCRIPTION = "Exemplar node for patching torch.nn.Linear with CublasLinear: https://github.com/aredden/torch-cublas-hgemm"
|
||||
|
||||
CATEGORY = "KJNodes/experimental"
|
||||
|
||||
def patch(self, model, enabled):
|
||||
def patch(self, ckpt_name, patch_cublaslinear, sage_attention):
|
||||
from comfy.ops import disable_weight_init, CastWeightBiasOp, cast_bias_weight
|
||||
from nodes import CheckpointLoaderSimple
|
||||
try:
|
||||
from cublas_ops import CublasLinear
|
||||
except ImportError:
|
||||
raise Exception("Can't import 'torch-cublas-hgemm', install it from here https://github.com/aredden/torch-cublas-hgemm")
|
||||
|
||||
if sage_attention:
|
||||
from sageattention import sageattn
|
||||
|
||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
else:
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
if dim_head not in (64, 96, 128) or not (k.shape == q.shape and v.shape == q.shape):
|
||||
return orig_attention(q, k, v, heads, mask=mask, attn_precision=attn_precision, skip_reshape=skip_reshape)
|
||||
if not skip_reshape:
|
||||
q, k, v = map(
|
||||
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
|
||||
(q, k, v),
|
||||
)
|
||||
return (
|
||||
sageattn(q, k, v, is_causal=False, attn_mask=mask, dropout_p=0.0, smooth_k=True)
|
||||
.transpose(1, 2)
|
||||
.reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
|
||||
class OriginalLinear(torch.nn.Linear, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
@ -2175,10 +2204,17 @@ class PatchCublasLinear:
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
if enabled:
|
||||
if patch_cublaslinear:
|
||||
disable_weight_init.Linear = PatchedLinear
|
||||
else:
|
||||
disable_weight_init.Linear = OriginalLinear
|
||||
if sage_attention:
|
||||
comfy_attention.optimized_attention = attention_sage
|
||||
else:
|
||||
comfy_attention.optimized_attention = orig_attention
|
||||
|
||||
model, clip, vae = CheckpointLoaderSimple.load_checkpoint(self, ckpt_name)
|
||||
|
||||
return model,
|
||||
|
||||
return model, clip, vae
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user