mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-24 04:04:29 +08:00
sageattn3
This commit is contained in:
parent
9d7af919b9
commit
7ce88200c7
@ -17,7 +17,7 @@ except ImportError:
|
|||||||
v3_available = False
|
v3_available = False
|
||||||
logging.warning("ComfyUI v3 node API not available, please update ComfyUI to access latest v3 nodes.")
|
logging.warning("ComfyUI v3 node API not available, please update ComfyUI to access latest v3 nodes.")
|
||||||
|
|
||||||
sageattn_modes = ["disabled", "auto", "sageattn_qk_int8_pv_fp16_cuda", "sageattn_qk_int8_pv_fp16_triton", "sageattn_qk_int8_pv_fp8_cuda", "sageattn_qk_int8_pv_fp8_cuda++"]
|
sageattn_modes = ["disabled", "auto", "sageattn_qk_int8_pv_fp16_cuda", "sageattn_qk_int8_pv_fp16_triton", "sageattn_qk_int8_pv_fp8_cuda", "sageattn_qk_int8_pv_fp8_cuda++", "sageattn3", "sageattn3_per_block_mean"]
|
||||||
|
|
||||||
_initialized = False
|
_initialized = False
|
||||||
_original_functions = {}
|
_original_functions = {}
|
||||||
@ -118,6 +118,15 @@ class BaseLoaderKJ:
|
|||||||
def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
|
def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
|
||||||
return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp16", tensor_layout=tensor_layout)
|
return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp16", tensor_layout=tensor_layout)
|
||||||
return func
|
return func
|
||||||
|
elif "sageattn3" in sage_attention:
|
||||||
|
from sageattn3 import sageattn3_blackwell
|
||||||
|
if sage_attention == "sageattn3_per_block_mean":
|
||||||
|
def func(q, k, v, is_causal=False, attn_mask=None):
|
||||||
|
return sageattn3_blackwell(q, k, v, is_causal=is_causal, attn_mask=attn_mask, per_block_mean=True)
|
||||||
|
else:
|
||||||
|
def func(q, k, v, is_causal=False, attn_mask=None):
|
||||||
|
return sageattn3_blackwell(q, k, v, is_causal=is_causal, attn_mask=attn_mask, per_block_mean=False)
|
||||||
|
return func
|
||||||
|
|
||||||
sage_func = set_sage_func(sage_attention)
|
sage_func = set_sage_func(sage_attention)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user