mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-13 23:04:40 +08:00
Fix up sageattention loader for hunyuan etc.
This commit is contained in:
parent
973ceb6ca8
commit
f5ce7d017b
@ -1,11 +1,16 @@
|
|||||||
from comfy.ldm.modules import attention as comfy_attention
|
from comfy.ldm.modules import attention as comfy_attention
|
||||||
|
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.sd
|
import comfy.sd
|
||||||
import torch
|
import torch
|
||||||
import folder_paths
|
import folder_paths
|
||||||
orig_attention = comfy_attention.optimized_attention
|
|
||||||
import comfy.model_management as mm
|
import comfy.model_management as mm
|
||||||
|
from comfy.cli_args import args
|
||||||
|
|
||||||
|
orig_attention = comfy_attention.optimized_attention
|
||||||
|
original_patch_model = comfy.model_patcher.ModelPatcher.patch_model
|
||||||
|
original_load_lora_for_models = comfy.sd.load_lora_for_models
|
||||||
|
|
||||||
class BaseLoaderKJ:
|
class BaseLoaderKJ:
|
||||||
original_linear = None
|
original_linear = None
|
||||||
@ -14,35 +19,28 @@ class BaseLoaderKJ:
|
|||||||
def _patch_modules(self, patch_cublaslinear, sage_attention):
|
def _patch_modules(self, patch_cublaslinear, sage_attention):
|
||||||
from comfy.ops import disable_weight_init, CastWeightBiasOp, cast_bias_weight
|
from comfy.ops import disable_weight_init, CastWeightBiasOp, cast_bias_weight
|
||||||
|
|
||||||
global orig_attention
|
|
||||||
if 'orig_attention' not in globals():
|
|
||||||
orig_attention = comfy_attention.optimized_attention
|
|
||||||
|
|
||||||
if sage_attention != "disabled":
|
if sage_attention != "disabled":
|
||||||
|
print("Patching comfy attention to use sageattn")
|
||||||
from sageattention import sageattn
|
from sageattention import sageattn
|
||||||
def set_sage_func(sage_attention):
|
def set_sage_func(sage_attention):
|
||||||
if sage_attention == "auto":
|
if sage_attention == "auto":
|
||||||
def func(q, k, v, is_causal=False, attn_mask=None):
|
def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
|
||||||
return sageattn(q, k, v, is_causal=is_causal, attn_mask=attn_mask)
|
return sageattn(q, k, v, is_causal=is_causal, attn_mask=attn_mask, tensor_layout=tensor_layout)
|
||||||
return func
|
return func
|
||||||
elif sage_attention == "sageattn_qk_int8_pv_fp16_cuda":
|
elif sage_attention == "sageattn_qk_int8_pv_fp16_cuda":
|
||||||
from sageattention import sageattn_qk_int8_pv_fp16_cuda
|
from sageattention import sageattn_qk_int8_pv_fp16_cuda
|
||||||
def func(q, k, v, is_causal=False, attn_mask=None):
|
def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
|
||||||
return sageattn_qk_int8_pv_fp16_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32")
|
return sageattn_qk_int8_pv_fp16_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32", tensor_layout=tensor_layout)
|
||||||
return func
|
return func
|
||||||
elif sage_attention == "sageattn_qk_int8_pv_fp16_triton":
|
elif sage_attention == "sageattn_qk_int8_pv_fp16_triton":
|
||||||
from sageattention import sageattn_qk_int8_pv_fp16_triton
|
from sageattention import sageattn_qk_int8_pv_fp16_triton
|
||||||
def func(q, k, v, is_causal=False, attn_mask=None):
|
def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
|
||||||
return sageattn_qk_int8_pv_fp16_triton(q, k, v, is_causal=is_causal, attn_mask=attn_mask)
|
return sageattn_qk_int8_pv_fp16_triton(q, k, v, is_causal=is_causal, attn_mask=attn_mask, tensor_layout=tensor_layout)
|
||||||
return func
|
return func
|
||||||
elif sage_attention == "sageattn_qk_int8_pv_fp8_cuda":
|
elif sage_attention == "sageattn_qk_int8_pv_fp8_cuda":
|
||||||
from sageattention import sageattn_qk_int8_pv_fp8_cuda
|
from sageattention import sageattn_qk_int8_pv_fp8_cuda
|
||||||
def func(q, k, v, is_causal=False, attn_mask=None):
|
def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
|
||||||
return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp32")
|
return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp32", tensor_layout=tensor_layout)
|
||||||
return func
|
|
||||||
else:
|
|
||||||
def func(q, k, v, is_causal=False, attn_mask=None):
|
|
||||||
return sageattn(q, k, v, is_causal=is_causal, attn_mask=attn_mask)
|
|
||||||
return func
|
return func
|
||||||
|
|
||||||
sage_func = set_sage_func(sage_attention)
|
sage_func = set_sage_func(sage_attention)
|
||||||
@ -51,25 +49,41 @@ class BaseLoaderKJ:
|
|||||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
b, _, _, dim_head = q.shape
|
b, _, _, dim_head = q.shape
|
||||||
|
tensor_layout="HND"
|
||||||
else:
|
else:
|
||||||
b, _, dim_head = q.shape
|
b, _, dim_head = q.shape
|
||||||
dim_head //= heads
|
dim_head //= heads
|
||||||
if dim_head not in (64, 96, 128) or not (k.shape == q.shape and v.shape == q.shape):
|
|
||||||
return orig_attention(q, k, v, heads, mask, attn_precision, skip_reshape)
|
|
||||||
if not skip_reshape:
|
|
||||||
q, k, v = map(
|
q, k, v = map(
|
||||||
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
|
lambda t: t.view(b, -1, heads, dim_head),
|
||||||
(q, k, v),
|
(q, k, v),
|
||||||
)
|
)
|
||||||
return (
|
tensor_layout="NHD"
|
||||||
sage_func(q, k, v, is_causal=False, attn_mask=mask)
|
if mask is not None:
|
||||||
.transpose(1, 2)
|
# add a batch dimension if there isn't already one
|
||||||
.reshape(b, -1, heads * dim_head)
|
if mask.ndim == 2:
|
||||||
)
|
mask = mask.unsqueeze(0)
|
||||||
|
# add a heads dimension if there isn't already one
|
||||||
|
if mask.ndim == 3:
|
||||||
|
mask = mask.unsqueeze(1)
|
||||||
|
out = sage_func(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
||||||
|
if tensor_layout == "HND":
|
||||||
|
out = (
|
||||||
|
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
out = out.reshape(b, -1, heads * dim_head)
|
||||||
|
return out
|
||||||
|
|
||||||
comfy_attention.optimized_attention = attention_sage
|
comfy_attention.optimized_attention = attention_sage
|
||||||
|
comfy.ldm.hunyuan_video.model.optimized_attention = attention_sage
|
||||||
|
comfy.ldm.flux.math.optimized_attention = attention_sage
|
||||||
|
comfy.ldm.genmo.joint_model.asymm_models_joint.optimized_attention = attention_sage
|
||||||
|
|
||||||
else:
|
else:
|
||||||
comfy_attention.optimized_attention = orig_attention
|
comfy_attention.optimized_attention = orig_attention
|
||||||
|
comfy.ldm.hunyuan_video.model.optimized_attention = orig_attention
|
||||||
|
comfy.ldm.flux.math.optimized_attention = orig_attention
|
||||||
|
comfy.ldm.genmo.joint_model.asymm_models_joint.optimized_attention = orig_attention
|
||||||
|
|
||||||
if patch_cublaslinear:
|
if patch_cublaslinear:
|
||||||
if not BaseLoaderKJ.cublas_patched:
|
if not BaseLoaderKJ.cublas_patched:
|
||||||
@ -105,8 +119,8 @@ class CheckpointLoaderKJ(BaseLoaderKJ):
|
|||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": {
|
return {"required": {
|
||||||
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}),
|
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}),
|
||||||
"patch_cublaslinear": ("BOOLEAN", {"default": True, "tooltip": "Enable or disable the patching, won't take effect on already loaded models!"}),
|
"patch_cublaslinear": ("BOOLEAN", {"default": False, "tooltip": "Enable or disable the patching, won't take effect on already loaded models!"}),
|
||||||
"sage_attention": ("BOOLEAN", {"default": False, "tooltip": "Patch comfy attention to use sageattn."}),
|
"sage_attention": (["disabled", "auto", "sageattn_qk_int8_pv_fp16_cuda", "sageattn_qk_int8_pv_fp16_triton", "sageattn_qk_int8_pv_fp8_cuda"], {"default": False, "tooltip": "Patch comfy attention to use sageattn."}),
|
||||||
}}
|
}}
|
||||||
|
|
||||||
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
||||||
@ -128,7 +142,7 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ):
|
|||||||
return {"required": {
|
return {"required": {
|
||||||
"ckpt_name": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "The name of the checkpoint (model) to load."}),
|
"ckpt_name": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "The name of the checkpoint (model) to load."}),
|
||||||
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],),
|
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],),
|
||||||
"patch_cublaslinear": ("BOOLEAN", {"default": True, "tooltip": "Enable or disable the patching, won't take effect on already loaded models!"}),
|
"patch_cublaslinear": ("BOOLEAN", {"default": False, "tooltip": "Enable or disable the patching, won't take effect on already loaded models!"}),
|
||||||
"sage_attention": (["disabled", "auto", "sageattn_qk_int8_pv_fp16_cuda", "sageattn_qk_int8_pv_fp16_triton", "sageattn_qk_int8_pv_fp8_cuda"], {"default": False, "tooltip": "Patch comfy attention to use sageattn."}),
|
"sage_attention": (["disabled", "auto", "sageattn_qk_int8_pv_fp16_cuda", "sageattn_qk_int8_pv_fp16_triton", "sageattn_qk_int8_pv_fp8_cuda"], {"default": False, "tooltip": "Patch comfy attention to use sageattn."}),
|
||||||
}}
|
}}
|
||||||
|
|
||||||
@ -139,17 +153,12 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ):
|
|||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
CATEGORY = "KJNodes/experimental"
|
CATEGORY = "KJNodes/experimental"
|
||||||
|
|
||||||
def patch_and_load(self, ckpt_name, weight_dtype, patch_cublaslinear, sage_attention):
|
def patch_and_load(self, ckpt_name, weight_dtype, patch_cublaslinear, sage_attention):
|
||||||
self._patch_modules(patch_cublaslinear, sage_attention)
|
|
||||||
from nodes import UNETLoader
|
from nodes import UNETLoader
|
||||||
model, = UNETLoader.load_unet(self, ckpt_name, weight_dtype)
|
model, = UNETLoader.load_unet(self, ckpt_name, weight_dtype)
|
||||||
|
self._patch_modules(patch_cublaslinear, sage_attention)
|
||||||
return (model,)
|
return (model,)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
original_patch_model = comfy.model_patcher.ModelPatcher.patch_model
|
|
||||||
original_load_lora_for_models = comfy.sd.load_lora_for_models
|
|
||||||
|
|
||||||
def patched_patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
|
def patched_patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
|
||||||
|
|
||||||
if lowvram_model_memory == 0:
|
if lowvram_model_memory == 0:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user