From f5ce7d017b56f11a601556801715b900d99ecdf5 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sun, 22 Dec 2024 14:59:38 +0200 Subject: [PATCH] Fix up sageattention loader for hunyuan etc. --- nodes/model_optimization_nodes.py | 81 +++++++++++++++++-------------- 1 file changed, 45 insertions(+), 36 deletions(-) diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index 277bea5..33b5d52 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -1,11 +1,16 @@ from comfy.ldm.modules import attention as comfy_attention + import comfy.model_patcher import comfy.utils import comfy.sd import torch import folder_paths -orig_attention = comfy_attention.optimized_attention import comfy.model_management as mm +from comfy.cli_args import args + +orig_attention = comfy_attention.optimized_attention +original_patch_model = comfy.model_patcher.ModelPatcher.patch_model +original_load_lora_for_models = comfy.sd.load_lora_for_models class BaseLoaderKJ: original_linear = None @@ -14,35 +19,28 @@ class BaseLoaderKJ: def _patch_modules(self, patch_cublaslinear, sage_attention): from comfy.ops import disable_weight_init, CastWeightBiasOp, cast_bias_weight - global orig_attention - if 'orig_attention' not in globals(): - orig_attention = comfy_attention.optimized_attention - if sage_attention != "disabled": + print("Patching comfy attention to use sageattn") from sageattention import sageattn def set_sage_func(sage_attention): if sage_attention == "auto": - def func(q, k, v, is_causal=False, attn_mask=None): - return sageattn(q, k, v, is_causal=is_causal, attn_mask=attn_mask) + def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"): + return sageattn(q, k, v, is_causal=is_causal, attn_mask=attn_mask, tensor_layout=tensor_layout) return func elif sage_attention == "sageattn_qk_int8_pv_fp16_cuda": from sageattention import sageattn_qk_int8_pv_fp16_cuda - def func(q, k, v, is_causal=False, attn_mask=None): - return sageattn_qk_int8_pv_fp16_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32") + def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"): + return sageattn_qk_int8_pv_fp16_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32", tensor_layout=tensor_layout) return func elif sage_attention == "sageattn_qk_int8_pv_fp16_triton": from sageattention import sageattn_qk_int8_pv_fp16_triton - def func(q, k, v, is_causal=False, attn_mask=None): - return sageattn_qk_int8_pv_fp16_triton(q, k, v, is_causal=is_causal, attn_mask=attn_mask) + def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"): + return sageattn_qk_int8_pv_fp16_triton(q, k, v, is_causal=is_causal, attn_mask=attn_mask, tensor_layout=tensor_layout) return func elif sage_attention == "sageattn_qk_int8_pv_fp8_cuda": from sageattention import sageattn_qk_int8_pv_fp8_cuda - def func(q, k, v, is_causal=False, attn_mask=None): - return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp32") - return func - else: - def func(q, k, v, is_causal=False, attn_mask=None): - return sageattn(q, k, v, is_causal=is_causal, attn_mask=attn_mask) + def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"): + return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp32", tensor_layout=tensor_layout) return func sage_func = set_sage_func(sage_attention) @@ -51,25 +49,41 @@ class BaseLoaderKJ: def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): if skip_reshape: b, _, _, dim_head = q.shape + tensor_layout="HND" else: b, _, dim_head = q.shape dim_head //= heads - if dim_head not in (64, 96, 128) or not (k.shape == q.shape and v.shape == q.shape): - return orig_attention(q, k, v, heads, mask, attn_precision, skip_reshape) - if not skip_reshape: q, k, v = map( - lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), + lambda t: t.view(b, -1, heads, dim_head), (q, k, v), ) - return ( - sage_func(q, k, v, is_causal=False, attn_mask=mask) - .transpose(1, 2) - .reshape(b, -1, heads * dim_head) - ) + tensor_layout="NHD" + if mask is not None: + # add a batch dimension if there isn't already one + if mask.ndim == 2: + mask = mask.unsqueeze(0) + # add a heads dimension if there isn't already one + if mask.ndim == 3: + mask = mask.unsqueeze(1) + out = sage_func(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout) + if tensor_layout == "HND": + out = ( + out.transpose(1, 2).reshape(b, -1, heads * dim_head) + ) + else: + out = out.reshape(b, -1, heads * dim_head) + return out comfy_attention.optimized_attention = attention_sage + comfy.ldm.hunyuan_video.model.optimized_attention = attention_sage + comfy.ldm.flux.math.optimized_attention = attention_sage + comfy.ldm.genmo.joint_model.asymm_models_joint.optimized_attention = attention_sage + else: comfy_attention.optimized_attention = orig_attention + comfy.ldm.hunyuan_video.model.optimized_attention = orig_attention + comfy.ldm.flux.math.optimized_attention = orig_attention + comfy.ldm.genmo.joint_model.asymm_models_joint.optimized_attention = orig_attention if patch_cublaslinear: if not BaseLoaderKJ.cublas_patched: @@ -105,8 +119,8 @@ class CheckpointLoaderKJ(BaseLoaderKJ): def INPUT_TYPES(s): return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}), - "patch_cublaslinear": ("BOOLEAN", {"default": True, "tooltip": "Enable or disable the patching, won't take effect on already loaded models!"}), - "sage_attention": ("BOOLEAN", {"default": False, "tooltip": "Patch comfy attention to use sageattn."}), + "patch_cublaslinear": ("BOOLEAN", {"default": False, "tooltip": "Enable or disable the patching, won't take effect on already loaded models!"}), + "sage_attention": (["disabled", "auto", "sageattn_qk_int8_pv_fp16_cuda", "sageattn_qk_int8_pv_fp16_triton", "sageattn_qk_int8_pv_fp8_cuda"], {"default": False, "tooltip": "Patch comfy attention to use sageattn."}), }} RETURN_TYPES = ("MODEL", "CLIP", "VAE") @@ -128,7 +142,7 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ): return {"required": { "ckpt_name": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "The name of the checkpoint (model) to load."}), "weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],), - "patch_cublaslinear": ("BOOLEAN", {"default": True, "tooltip": "Enable or disable the patching, won't take effect on already loaded models!"}), + "patch_cublaslinear": ("BOOLEAN", {"default": False, "tooltip": "Enable or disable the patching, won't take effect on already loaded models!"}), "sage_attention": (["disabled", "auto", "sageattn_qk_int8_pv_fp16_cuda", "sageattn_qk_int8_pv_fp16_triton", "sageattn_qk_int8_pv_fp8_cuda"], {"default": False, "tooltip": "Patch comfy attention to use sageattn."}), }} @@ -139,17 +153,12 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ): EXPERIMENTAL = True CATEGORY = "KJNodes/experimental" - def patch_and_load(self, ckpt_name, weight_dtype, patch_cublaslinear, sage_attention): - self._patch_modules(patch_cublaslinear, sage_attention) + def patch_and_load(self, ckpt_name, weight_dtype, patch_cublaslinear, sage_attention): from nodes import UNETLoader model, = UNETLoader.load_unet(self, ckpt_name, weight_dtype) + self._patch_modules(patch_cublaslinear, sage_attention) return (model,) - - -original_patch_model = comfy.model_patcher.ModelPatcher.patch_model -original_load_lora_for_models = comfy.sd.load_lora_for_models - def patched_patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False): if lowvram_model_memory == 0: