mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-10 05:15:05 +08:00
Add sageattn++ as selectable mode for easier testing
This commit is contained in:
parent
6363376a17
commit
ff49e1b01f
@ -8,6 +8,7 @@ import folder_paths
|
||||
import comfy.model_management as mm
|
||||
from comfy.cli_args import args
|
||||
|
||||
sageattn_modes = ["disabled", "auto", "sageattn_qk_int8_pv_fp16_cuda", "sageattn_qk_int8_pv_fp16_triton", "sageattn_qk_int8_pv_fp8_cuda", "sageattn_qk_int8_pv_fp8_cuda++"]
|
||||
|
||||
_initialized = False
|
||||
_original_functions = {}
|
||||
@ -49,6 +50,11 @@ class BaseLoaderKJ:
|
||||
def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
|
||||
return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp32", tensor_layout=tensor_layout)
|
||||
return func
|
||||
elif sage_attention == "sageattn_qk_int8_pv_fp8_cuda++":
|
||||
from sageattention import sageattn_qk_int8_pv_fp8_cuda
|
||||
def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
|
||||
return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp16", tensor_layout=tensor_layout)
|
||||
return func
|
||||
|
||||
sage_func = set_sage_func(sage_attention)
|
||||
|
||||
@ -136,7 +142,7 @@ class PathchSageAttentionKJ(BaseLoaderKJ):
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"model": ("MODEL",),
|
||||
"sage_attention": (["disabled", "auto", "sageattn_qk_int8_pv_fp16_cuda", "sageattn_qk_int8_pv_fp16_triton", "sageattn_qk_int8_pv_fp8_cuda"], {"default": False, "tooltip": "Global patch comfy attention to use sageattn, once patched to revert back to normal you would need to run this node again with disabled option."}),
|
||||
"sage_attention": (sageattn_modes, {"default": False, "tooltip": "Global patch comfy attention to use sageattn, once patched to revert back to normal you would need to run this node again with disabled option."}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("MODEL", )
|
||||
@ -167,7 +173,7 @@ class CheckpointLoaderKJ(BaseLoaderKJ):
|
||||
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2", "fp16", "bf16", "fp32"],),
|
||||
"compute_dtype": (["default", "fp16", "bf16", "fp32"], {"default": "default", "tooltip": "The compute dtype to use for the model."}),
|
||||
"patch_cublaslinear": ("BOOLEAN", {"default": False, "tooltip": "Enable or disable the patching, won't take effect on already loaded models!"}),
|
||||
"sage_attention": (["disabled", "auto", "sageattn_qk_int8_pv_fp16_cuda", "sageattn_qk_int8_pv_fp16_triton", "sageattn_qk_int8_pv_fp8_cuda"], {"default": False, "tooltip": "Patch comfy attention to use sageattn."}),
|
||||
"sage_attention": (sageattn_modes, {"default": False, "tooltip": "Patch comfy attention to use sageattn."}),
|
||||
"enable_fp16_accumulation": ("BOOLEAN", {"default": False, "tooltip": "Enable torch.backends.cuda.matmul.allow_fp16_accumulation, requires pytorch 2.7.0 nightly."}),
|
||||
}}
|
||||
|
||||
@ -309,7 +315,7 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ):
|
||||
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2", "fp16", "bf16", "fp32"],),
|
||||
"compute_dtype": (["default", "fp16", "bf16", "fp32"], {"default": "default", "tooltip": "The compute dtype to use for the model."}),
|
||||
"patch_cublaslinear": ("BOOLEAN", {"default": False, "tooltip": "Enable or disable the patching, won't take effect on already loaded models!"}),
|
||||
"sage_attention": (["disabled", "auto", "sageattn_qk_int8_pv_fp16_cuda", "sageattn_qk_int8_pv_fp16_triton", "sageattn_qk_int8_pv_fp8_cuda"], {"default": False, "tooltip": "Patch comfy attention to use sageattn."}),
|
||||
"sage_attention": (sageattn_modes, {"default": False, "tooltip": "Patch comfy attention to use sageattn."}),
|
||||
"enable_fp16_accumulation": ("BOOLEAN", {"default": False, "tooltip": "Enable torch.backends.cuda.matmul.allow_fp16_accumulation, requires pytorch 2.7.0 nightly."}),
|
||||
}}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user